commit b4244090aeed5b4e28aa0f05caafd33d563b9ab0 Author: 张仪 Date: Wed Aug 21 22:15:12 2024 +0800 first commit diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..775478a1 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,29 @@ +Dockerfile +**/publish.py +my +.git +.refresh +__pycache__ +.ipynb_checkpoints/ +.vscode/ +__res/ +perf.data +perf.data.old +*.swp +*.ipynb +*.pdf +*.zip +*.tgz +test.py +extern/mkl/mkldnn_lnx*/* +data/ +build/ +venv/ +*.md +!*.src.md +!README.md +!README.cn.md +python/jittor.egg-info +dist/ +!doc/source/* +__data__ \ No newline at end of file diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 00000000..8386a510 --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,57 @@ +# This is a basic workflow to help you get started with Actions + +name: CI + +# Controls when the action will run. Triggers the workflow on push or pull request +# events but only for the master branch +on: [ push ] +# push: +# branches: [ master ] +# pull_request: +# branches: [ master ] + +# A workflow run is made up of one or more jobs that can run sequentially or in parallel +jobs: + test_clang_8_cuda_10: + # The type of runner that the job will run on + runs-on: self-hosted + + # Steps represent a sequence of tasks that will be executed as part of the job + steps: + # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it + - uses: actions/checkout@v2 + + - name: test + run: | + export cache_name=github_${GITHUB_REF##*/} + export cc_path="clang++-8" + export cc_flags=" -g " + export log_sync=0 + export log_v=0 + export PYTHONIOENCODING=utf8 + export PYTHONPATH=`pwd`/python + export nvcc_path=/usr/local/cuda/bin/nvcc + python3.7 -c "import jittor" + python3.7 -m jittor.test -v + + test_gcc: + # The type of runner that the job will run on + runs-on: self-hosted + + # Steps represent a sequence of tasks that will be executed as part of the job + steps: + # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it + - uses: actions/checkout@v2 + + - name: test + run: | + export cache_name=github_${GITHUB_REF##*/} + export cc_path="g++" + export cc_flags=" -g " + export log_sync=0 + export log_v=0 + export PYTHONIOENCODING=utf8 + export PYTHONPATH=`pwd`/python + export nvcc_path= + python3.7 -c "import jittor" + python3.7 -m jittor.test -v \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..375ae76f --- /dev/null +++ b/.gitignore @@ -0,0 +1,30 @@ +my +.refresh +.DS_Store +__pycache__ +.ipynb_checkpoints/ +.vscode/ +__res/ +perf.data +perf.data.old +*.swp +*.ipynb +*.pdf +*.zip +*.tgz +*.obj +test.py +extern/mkl/mkldnn_lnx*/* +data/ +build/ +venv/ +*.md +!*.src.md +!README.md +!README.cn.md +!CHANGELOG.md +python/jittor.egg-info +dist/ +!doc/source/* +core +__data__ diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml new file mode 100644 index 00000000..05dd4ad8 --- /dev/null +++ b/.gitlab-ci.yml @@ -0,0 +1,46 @@ +test_clang_8_cuda_10: + tags: + - clang + - cuda + script: + - export cache_name=$CI_COMMIT_REF_NAME + - export cc_path="clang-8" + - export cc_flags=" -g " + - export log_sync=0 + - export log_v=0 + - export PYTHONIOENCODING=utf8 + - export PYTHONPATH=`pwd`/python + - export nvcc_path=/usr/local/cuda/bin/nvcc + - python3.7 -c "import jittor" + - python3.7 -m jittor.test -v + +# test_icc_19: +# tags: +# - icc +# script: +# - export cache_name=$CI_COMMIT_REF_NAME +# - export cc_path="/opt/intel/system_studio_2019/bin/icc" +# - export cc_flags=" -g " +# - export log_sync=0 +# - export log_v=0 +# - export PYTHONIOENCODING=utf8 +# - export PYTHONPATH=`pwd`/python +# - export LD_LIBRARY_PATH="/opt/intel/system_studio_2019/compilers_and_libraries/linux/lib/intel64" +# - python3.7 -c "import jittor" +# - python3.7 -m jittor.test -v + +test_g++: + tags: + - gcc + script: + - export cache_name=$CI_COMMIT_REF_NAME + - export cc_path="g++" + - export cc_flags=" -g " + - export log_sync=0 + - export log_v=0 + - export PYTHONIOENCODING=utf8 + - export PYTHONPATH=`pwd`/python + - export nvcc_path= + - python3.7 -c "import jittor" + - python3.7 -m jittor.test -v + diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..bdb86cc3 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,75 @@ +# CHANGELOG + +### 计图 1.1.5.5 + +* 新增numpy code算子,现在可以直接使用numpy来自定义算子了,使用用例: + +```python +import jittor as jt + +def forward_code(np, data): + a = data["inputs"][0] + b = data["outputs"][0] + np.add(a,a,out=b) + +def backward_code(np, data): + dout = data["dout"] + out = data["outputs"][0] + np.copyto(out, dout*2.0) + +a = jt.random((5,1)) +b = jt.numpy_code( + a.shape, + a.dtype, + [a], + forward_code, + [backward_code], +) +``` + +* 新增 Function 模块,用户可以自定义反向传播了,使用用例: + +```python +import jittor as jt +from jittor import Function + +class MyFunc(Function): + def execute(self, x, y): + self.x = x + self.y = y + return x*y, x/y + + def grad(self, grad0, grad1): + return grad0 * self.y, grad1 * self.x +a = jt.array(3.0) +b = jt.array(4.0) +func = MyFunc() +c,d = func(a, b) +da, db = jt.grad(c+d*3, [a, b]) +assert da.data == 4 +assert db.data == 9 +``` + +* 新增 no_grad scope, 在这个scope中创建的所有变量都会停止梯度: + +```python +import jittor as jt + +with jt.no_grad(): + ... +``` + +* 新增 bmm(batch matrix multiply) 支持: + +```python +import jittor as jt +from jittor import nn + +batch, n, m, k = 100, 5, 6, 7 + +a = jt.random((batch, n, m)) +b = jt.random((batch, m, k)) +c = nn.bmm(a, b) +``` + +* 修复 unsqueeze \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..2e067946 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,50 @@ +# docker build commands +ARG FROM_IMAGE=ubuntu:18.04 + +FROM ${FROM_IMAGE} + +RUN apt update && apt install ca-certificates -y + +# change tsinghua mirror +RUN echo \ +"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security main restricted universe multiverse" > /etc/apt/sources.list + +RUN apt update && apt install wget \ + python3.7 python3.7-dev \ + g++ build-essential openssh-server -y + +WORKDIR /usr/src/jittor + +RUN apt download python3-distutils && dpkg-deb -x ./python3-distutils* / \ + && wget -O - https://bootstrap.pypa.io/get-pip.py | python3.7 + +ENV PYTHONIOENCODING utf8 + +# change tsinghua mirror +RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple + +RUN pip3 install \ + numpy \ + tqdm \ + pillow \ + astunparse \ + notebook + +RUN pip3 install matplotlib + +RUN apt install openmpi-bin openmpi-common libopenmpi-dev -y + +RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example + +RUN pip3 uninstall jittor -y + +COPY . . + +RUN pip3 install . --timeout 100 + +RUN python3.7 -m jittor.test.test_example + +CMD python3.7 -m jittor.notebook --allow-root --ip=0.0.0.0 \ No newline at end of file diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 00000000..f6436a57 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,203 @@ +Copyright (c) 2023 Jittor. All Rights Reserved + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright (c) 2023 Jittor. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..a0d50c8e --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,5 @@ +exclude __data__ +exclude __pycache__ +prune **/__data__/ +prune **/__pycache__ +prune *.pyc \ No newline at end of file diff --git a/README.cn.md b/README.cn.md new file mode 100644 index 00000000..27947fcc --- /dev/null +++ b/README.cn.md @@ -0,0 +1,422 @@ +# Jittor: 即时编译深度学习框架 + +![Jittor Logo](https://cg.cs.tsinghua.edu.cn/jittor/favicon_package_v0/JittorLogo_Final1220.svg) + + +[快速开始](#快速开始) | [安装](#安装) | [教程](#教程) | [English](./README.md) + + +Jittor 是一个基于即时编译和元算子的高性能深度学习框架,整个框架在即时编译的同时,还集成了强大的Op编译器和调优器,为您的模型生成定制化的高性能代码。Jittor还包含了丰富的高性能模型库,涵盖范围包括:图像识别,检测,分割,生成,可微渲染,几何学习,强化学习等等。 + + +Jittor前端语言为Python。前端使用了模块化和动态图执行的设计,这是目前最主流的深度学习框架接口设计。后端则使用高性能语言编写,如CUDA,C++。 + + +相关链接: +* [Jittor官网](https://cg.cs.tsinghua.edu.cn/jittor/) +* [Jittor教程](https://cg.cs.tsinghua.edu.cn/jittor/tutorial/) +* [Jittor模型库](https://cg.cs.tsinghua.edu.cn/jittor/resources/) +* [Jittor文档](https://cg.cs.tsinghua.edu.cn/jittor/assets/docs/index.html) +* [Github](https://github.com/jittor/jittor), [GitLink](https://www.gitlink.org.cn/jittor/jittor), [Gitee](https://gitee.com/jittor/jittor) +* [Jittor 论坛](https://discuss.jittor.org/) +* [Jittor 精选仓库](https://github.com/Jittor/jittor/blob/master/AWESOME-JITTOR-LIST.md) +* 即时通信: QQ Group(761222083) + + + +下面的代码演示了如何一步一步使用Python代码,从头对一个双层神经网络建模。 + +```python +import jittor as jt +from jittor import Module +from jittor import nn +import numpy as np + +class Model(Module): + def __init__(self): + self.layer1 = nn.Linear(1, 10) + self.relu = nn.Relu() + self.layer2 = nn.Linear(10, 1) + def execute (self,x) : + x = self.layer1(x) + x = self.relu(x) + x = self.layer2(x) + return x + +def get_data(n): # generate random data for training test. + for i in range(n): + x = np.random.rand(batch_size, 1) + y = x*x + yield jt.float32(x), jt.float32(y) + + +learning_rate = 0.1 +batch_size = 50 +n = 1000 + +model = Model() +optim = nn.SGD(model.parameters(), learning_rate) + +for i,(x,y) in enumerate(get_data(n)): + pred_y = model(x) + dy = pred_y - y + loss = dy * dy + loss_mean = loss.mean() + optim.step(loss_mean) + print(f"step {i}, loss = {loss_mean.data.sum()}") +``` + + + +## 大纲 + +- [快速开始](#快速开始) +- [安装](#安装) +- [教程](#教程) +- [贡献](#贡献) +- [团队](#团队) +- [版权声明](#版权声明) + + +## 快速开始 + + +我们提供了一些jupyter notebooks来帮助您快速入门Jittor。 + +- [示例:模型定义与训练][1] +- [基础:Op, Var][2] +- [元算子:通过元算子实现自己的卷积层][3] + + +## 安装 + +Jittor框架对环境要求如下: + + +| OS | CPU | Python | Compiler | (Optional) GPU platform | +|--------------------------------------------------------|-------------------------------------|--------|--------------|---------------------------------------------| +| Linux
(Ubuntu, CentOS, Arch,
UOS, KylinOS, ...) | x86
x86_64
ARM
loongson | >= 3.7 | g++ >=5.4 | Nvidia CUDA >= 10.0, [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar)
or [AMD ROCm](https://docs.amd.com/) >= 4.0
or [Hygon DCU DTK](https://tycloud.hpccube.com/doc/1.0.6/11277/general-handbook/software-tutorial/jittor.html) >= 22.04 | +| macOS
(>= 10.14 Mojave) | intel
Apple Silicon | >= 3.7 | clang >= 8.0 | - | +| Windows 10 & 11 | x86_64 | [>= 3.8](https://www.python.org/downloads/windows/) | - | Nvidia CUDA >= 10.2 [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#install-windows) | + +Jittor 提供了三种安装方法:pip、docker和手动安装: + + +## Pip 安装 + + +下面将展示Ubuntu的安装命令,如果您在使用其他Linux操作系统(如CentOS), 请安装好依赖(Python>=3.7, g++>=5.4)或者使用**docker安装**, 如果您已经装好编译器和对应版本的Python,我们强烈推荐您使用这种方法 +(如果无法访问github, 可以通过Jittor主页下载): + +```bash +sudo apt install python3.7-dev libomp-dev +python3.7 -m pip install jittor +# or install from github(latest version) +# python3.7 -m pip install git+https://github.com/Jittor/jittor.git +python3.7 -m jittor.test.test_example +``` + +如果测试运行通过,恭喜你已经安装完成. +jittor会自动在路径中寻找合适的编译器, 如果您希望手动指定编译器, 请使用环境变量 `cc_path` 和 `nvcc_path`(可选). + +### macOS 安装 + + +macOS 请使用 [homebrew](https://brew.sh) 安装额外的依赖。 + + +```bash +brew install libomp +``` + +之后您可以通过 pip 安装 jittor,并测试是否可以成功运行。 + + +```bash +python3.7 -m pip install jittor +python3.7 -m jittor.test.test_example +``` + +目前在 macOS 中,jittor 只支持 CPU 计算。 + + +### Windows安装 + + +Windows 请准备好Python>=3.8,安装方法如下(conda安装需要额外命令): + +Windows user please prepare Python>=3.8, install instructions are list below(conda needs extra instructions): + +```bash +# check your python version(>=3.8) +python --version +python -m pip install jittor +# if conda is used +conda install pywin32 +``` + +Windows 下,jittor会自动检测显卡并安装对应的 CUDA, 请确保您的NVIDIA驱动支持CUDA 10.2 以上,您还可以使用如下命令手动为Jittor安装CUDA: + + +```bash +python -m jittor_utils.install_cuda +``` + + + +## Docker 安装 + +我们提供了Docker安装方式,免去您配置环境,Docker安装方法如下: + + +``` +# CPU only(Linux) +docker run -it --network host jittor/jittor +# CPU and CUDA(Linux) +docker run -it --network host --gpus all jittor/jittor-cuda +# CPU only(Mac and Windows) +docker run -it -p 8888:8888 jittor/jittor +``` + +关于Docker安装的详细教程,可以参考[Windows/Mac/Linux通过Docker安装计图](https://cg.cs.tsinghua.edu.cn/jittor/tutorial/2020-5-15-00-00-docker/) + +## 手动安装 + + +我们将逐步演示如何在Ubuntu 16.04中安装Jittor,其他Linux发行版可能可以使用类似的命令。 + + +### 步骤一:选择您的后端编译器 + +```bash +# g++ +sudo apt install g++ build-essential libomp-dev + +# OR clang++-8 +wget -O - https://raw.githubusercontent.com/Jittor/jittor/master/script/install_llvm.sh > /tmp/llvm.sh +bash /tmp/llvm.sh 8 +``` + +### 步骤二:安装Python和python-dev + + +Jittor需要python的版本>=3.7。 + +```bash +sudo apt install python3.7 python3.7-dev +``` + + +### 步骤三:运行Jittor + + +整个框架是即时编译的。 让我们通过pip安装jittor + +```bash +git clone https://github.com/Jittor/jittor.git +sudo pip3.7 install ./jittor +export cc_path="clang++-8" +# if other compiler is used, change cc_path +# export cc_path="g++" +# export cc_path="icc" + +# run a simple test +python3.7 -m jittor.test.test_example +``` + +如果通过了测试,那么您的Jittor已经准备就绪。 + + +### 可选步骤四:启用CUDA + + +在Jittor中使用CUDA非常简单,只需设置环境值`nvcc_path` + +```bash +# replace this var with your nvcc location +export nvcc_path="/usr/local/cuda/bin/nvcc" +# run a simple cuda test +python3.7 -m jittor.test.test_cuda +``` + +如果测试通过,则可以通过设置`use_cuda`标识符在Jittor中启用CUDA。 + +```python +import jittor as jt +jt.flags.use_cuda = 1 +``` + + +### 可选步骤五:测试训练Resnet18 + + +要检查Jittor的完整性,您可以运行Resnet18训练测试。需要注意的是,这个测试需要6G显存。 + +```bash +python3.7 -m jittor.test.test_resnet +``` + +如果这些测试失败,请为我们报告错误,我们十分欢迎您为Jittor做出贡献^ _ ^ + + +## 教程 + + +在教程部分,我们将简要解释Jittor的基本概念。 + + +要使用Jittor训练模型,您需要了解两个主要概念: + +* Var:Jittor的基本数据类型 +* Operations:Jittor的算子与numpy类似 + + +### 数据类型 + + +首先,让我们开始使用Var。Var是jittor的基本数据类型,为了运算更加高效Jittor中的计算过程是异步的。 如果要访问数据,可以使用`Var.data`进行同步数据访问。 + +```python +import jittor as jt +a = jt.float32([1,2,3]) +print (a) +print (a.data) +# Output: float32[3,] +# Output: [ 1. 2. 3.] +``` + + +此外我们可以给变量起一个名字。 + +```python +a.name('a') +print(a.name()) +# Output: a +``` + + +### 数据运算 + + + Jittor的算子与numpy类似。 让我们尝试一些运算, 我们通过Op`jt.float32`创建Var `a`和`b`,并将它们相加。 输出这些变量相关信息,可以看出它们具有相同的形状和类型。 + +```python +import jittor as jt +a = jt.float32([1,2,3]) +b = jt.float32([4,5,6]) +c = a*b +print(a,b,c) +print(type(a), type(b), type(c)) +# Output: float32[3,] float32[3,] float32[3,] +# Output: +``` + +除此之外,我们使用的所有算子`jt.xxx(Var,...)`都具有别名`Var.xxx(...)`。 例如: + +```python +c.max() # alias of jt.max(c) +c.add(a) # alias of jt.add(c, a) +c.min(keepdims=True) # alias of jt.min(c, keepdims=True) +``` + + +如果您想知道Jittor支持的所有运算,可以运行`help(jt.ops)`。 您在`jt.ops.xxx`中找到的所有运算都可以通过别名`jt.xxx`。 + +```python +help(jt.ops) +# Output: +# abs(x: core.Var) -> core.Var +# add(x: core.Var, y: core.Var) -> core.Var +# array(data: array) -> core.Var +# binary(x: core.Var, y: core.Var, op: str) -> core.Var +# ...... +``` + +### 更多教程 + + +如果您想进一步了解Jittor,请查看以下notebooks: + +* 快速开始 + * [示例:模型定义与训练][1] + * [基本概念:Op, Var][2] + * [元算子:通过元算子实现自己的卷积层][3] +* 进阶 + * [自定义算子:使用C ++和CUDA编写您的算子,并其进行即时编译][4] + * [性能分析器:分析您的模型][5] + * Jtune:性能调优工具 + + + +[1]: python/jittor/notebook/example.src.md "示例" +[2]: python/jittor/notebook/basics.src.md "基本概念" +[3]: python/jittor/notebook/meta_op.src.md "元算子" +[4]: python/jittor/notebook/custom_op.src.md "自定义算子" +[5]: python/jittor/notebook/profiler.src.md "性能分析器" + + +这些notebooks可以通过python3.7 -m jittor.notebook在您自己的计算机中运行。 + + +## 贡献 + + +Jittor还很年轻。 它可能存在错误和问题。 请在我们的错误跟踪系统中报告它们。 我们欢迎您为Jittor做出贡献。 此外,如果您对Jittor有任何想法,请告诉我们。 + +您可以用以下方式帮助Jittor: + +* 在论文中引用 Jittor +* 向身边的好朋友推荐 Jittor +* 贡献代码 +* 贡献教程和文档 +* 提出issue +* 回答 jittor 相关问题 +* 点亮小星星 +* 持续关注 jittor +* …… + + + + +## 联系我们 + +官方主页: http://cg.cs.tsinghua.edu.cn/jittor/ + +电子邮件:jittor@qq.com + +提出issue:https://github.com/Jittor/jittor/issues + + + + + +QQ 群:761222083 + + + +## 团队 + + +Jittor目前由[清华大学计算机图形学组](https://cg.cs.tsinghua.edu.cn/)维护。 如果您也对Jittor感兴趣并希望对其进行改进,请加入我们! + + +## 引用 + +``` +@article{hu2020jittor, + title={Jittor: a novel deep learning framework with meta-operators and unified graph execution}, + author={Hu, Shi-Min and Liang, Dun and Yang, Guo-Ye and Yang, Guo-Wei and Zhou, Wen-Yang}, + journal={Science China Information Sciences}, + volume={63}, + number={222103}, + pages={1--21}, + year={2020} +} +``` + + +## 版权声明 + + +如LICENSE.txt文件中所示,Jittor使用Apache 2.0版权协议。 + diff --git a/README.md b/README.md new file mode 100644 index 00000000..90601b43 --- /dev/null +++ b/README.md @@ -0,0 +1,416 @@ +# Jittor: a Just-in-time(JIT) deep learning framework + +![Jittor Logo](https://cg.cs.tsinghua.edu.cn/jittor/favicon_package_v0/JittorLogo_Final1220.svg) + +[Quickstart](#quickstart) | [Install](#install) | [Tutorial](#tutorial) | [简体中文](./README.cn.md) + + +Jittor is a high-performance deep learning framework based on JIT compiling and meta-operators. The whole framework and meta-operators are compiled just-in-time. A powerful op compiler and tuner are integrated into Jittor. It allowed us to generate high-performance code with specialized for your model. Jittor also contains a wealth of high-performance model libraries, including: image recognition, detection, segmentation, generation, differentiable rendering, geometric learning, reinforcement learning, etc. . + + +The front-end language is Python. Module Design and Dynamic Graph Execution is used in the front-end, which is the most popular design for deeplearning framework interface. The back-end is implemented by high performance language, such as CUDA,C++. + + +Related Links: +* [Jittor Website](https://cg.cs.tsinghua.edu.cn/jittor/) +* [Jittor Tutorials](https://cg.cs.tsinghua.edu.cn/jittor/tutorial/) +* [Jittor Models](https://cg.cs.tsinghua.edu.cn/jittor/resources/) +* [Jittor Documents](https://cg.cs.tsinghua.edu.cn/jittor/assets/docs/index.html) +* [Github](https://github.com/jittor/jittor), [GitLink](https://www.gitlink.org.cn/jittor/jittor), [Gitee](https://gitee.com/jittor/jittor) +* [Jittor Forum](https://discuss.jittor.org/) +* [Awesome Jittor List](https://github.com/Jittor/jittor/blob/master/AWESOME-JITTOR-LIST.md) +* IM: QQ Group(761222083) + + + +The following example shows how to model a two-layer neural network step by step and train from scratch In a few lines of Python code. + + +```python +import jittor as jt +from jittor import Module +from jittor import nn +import numpy as np + +class Model(Module): + def __init__(self): + self.layer1 = nn.Linear(1, 10) + self.relu = nn.Relu() + self.layer2 = nn.Linear(10, 1) + def execute (self,x) : + x = self.layer1(x) + x = self.relu(x) + x = self.layer2(x) + return x + +def get_data(n): # generate random data for training test. + for i in range(n): + x = np.random.rand(batch_size, 1) + y = x*x + yield jt.float32(x), jt.float32(y) + + +learning_rate = 0.1 +batch_size = 50 +n = 1000 + +model = Model() +optim = nn.SGD(model.parameters(), learning_rate) + +for i,(x,y) in enumerate(get_data(n)): + pred_y = model(x) + dy = pred_y - y + loss = dy * dy + loss_mean = loss.mean() + optim.step(loss_mean) + print(f"step {i}, loss = {loss_mean.data.sum()}") +``` + +## Contents + +* [Quickstart](#quickstart) +* [Install](#install) +* [Tutorial](#tutorial) +* [Contributing](#contributing) +* [The Team](#theteam) +* [License](#license) + + + +## Quickstart + + +We provide some jupyter notebooks to help you quick start with Jittor. + + +- [Example: Model definition and training][1] +- [Basics: Op, Var][2] +- [Meta-operator: Implement your own convolution with Meta-operator][3] + +## Install + + + +Jittor environment requirements: + +| OS | CPU | Python | Compiler | (Optional) GPU platform | +|--------------------------------------------------------|-------------------------------------|--------|--------------|---------------------------------------------| +| Linux
(Ubuntu, CentOS, Arch,
UOS, KylinOS, ...) | x86
x86_64
ARM
loongson | >= 3.7 | g++ >=5.4 | Nvidia CUDA >= 10.0, [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar)
or [AMD ROCm](https://docs.amd.com/) >= 4.0
or [Hygon DCU DTK](https://tycloud.hpccube.com/doc/1.0.6/11277/general-handbook/software-tutorial/jittor.html) >= 22.04 | +| macOS
(>= 10.14 Mojave) | intel
Apple Silicon | >= 3.7 | clang >= 8.0 | - | +| Windows 10 & 11 | x86_64 | [>= 3.8](https://www.python.org/downloads/windows/) | - | Nvidia CUDA >= 10.2 [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#install-windows) | + + +Jittor offers three ways to install: pip, docker, or manual. + + +## Pip install + + +```bash +sudo apt install python3.7-dev libomp-dev +python3.7 -m pip install jittor +# or install from github(latest version) +# python3.7 -m pip install git+https://github.com/Jittor/jittor.git +python3.7 -m jittor.test.test_example +``` + + + +### macOS install + + +Please first install additional dependencies with [homebrew](https://brew.sh). + +```bash +brew install libomp +``` + + +Then you can install jittor through pip and run the example. + +```bash +python3.7 -m pip install jittor +python3.7 -m jittor.test.test_example +``` + + +Currently jittor only supports CPU on macOS. + + +### Windows install + + + +```bash +# check your python version(>=3.8) +python --version +python -m pip install jittor +# if conda is used +conda install pywin32 +``` + + +In Windows, jittor will automatically detect and install CUDA, please make sure your NVIDIA driver support CUDA 10.2 or above, or you can manually let jittor install CUDA for you: + +```bash +python -m jittor_utils.install_cuda +``` + + +## Docker Install + + + +We provide a Docker installation method to save you from configuring the environment. The Docker installation method is as follows: + +``` +# CPU only(Linux) +docker run -it --network host jittor/jittor +# CPU and CUDA(Linux) +docker run -it --network host --gpus all jittor/jittor-cuda +# CPU only(Mac and Windows) +docker run -it -p 8888:8888 jittor/jittor +``` + + +## manual install + +We will show how to install Jittor in Ubuntu 16.04 step by step, Other Linux distributions may have similar commands. + + +### Step 1: Choose your back-end compiler + + +```bash +# g++ +sudo apt install g++ build-essential libomp-dev + +# OR clang++-8 +wget -O - https://raw.githubusercontent.com/Jittor/jittor/master/script/install_llvm.sh > /tmp/llvm.sh +bash /tmp/llvm.sh 8 +``` +### Step 2: Install Python and python-dev + + +Jittor need python version >= 3.7. + + +```bash +sudo apt install python3.7 python3.7-dev +``` + +### Step 3: Run Jittor + + +The whole framework is compiled Just-in-time. Let's install jittor via pip + + +```bash +git clone https://github.com/Jittor/jittor.git +sudo pip3.7 install ./jittor +export cc_path="clang++-8" +# if other compiler is used, change cc_path +# export cc_path="g++" +# export cc_path="icc" + +# run a simple test +python3.7 -m jittor.test.test_example +``` +if the test is passed, your Jittor is ready. + + +### Optional Step 4: Enable CUDA + + +Using CUDA in Jittor is very simple, Just setup environment value `nvcc_path` + + +```bash +# replace this var with your nvcc location +export nvcc_path="/usr/local/cuda/bin/nvcc" +# run a simple cuda test +python3.7 -m jittor.test.test_cuda +``` +if the test is passed, your can use Jittor with CUDA by setting `use_cuda` flag. + + +```python +import jittor as jt +jt.flags.use_cuda = 1 +``` + +### Optional Step 5: Test Resnet18 training + + +To check the integrity of Jittor, you can run Resnet18 training test. Note: 6G GPU RAM is requires in this test. + + +```bash +python3.7 -m jittor.test.test_resnet +``` +if those tests are failed, please report bugs for us, and feel free to contribute ^_^ + + +## Tutorial + + +In the tutorial section, we will briefly explain the basic concept of Jittor. + + +To train your model with Jittor, there are only three main concepts you need to know: + + +* Var: basic data type of jittor +* Operations: Jittor'op is simular with numpy + +### Var + + +First, let's get started with Var. Var is the basic data type of jittor. Computation process in Jittor is asynchronous for optimization. If you want to access the data, `Var.data` can be used for synchronous data accessing. + + +```python +import jittor as jt +a = jt.float32([1,2,3]) +print (a) +print (a.data) +# Output: float32[3,] +# Output: [ 1. 2. 3.] +``` + +And we can give the variable a name. + + +```python +a.name('a') +print(a.name()) +# Output: a +``` + +### Operations + + +Jittor'op is simular with numpy. Let's try some operations. We create Var `a` and `b` via operation `jt.float32`, and add them. Printing those variables shows they have the same shape and dtype. + + +```python +import jittor as jt +a = jt.float32([1,2,3]) +b = jt.float32([4,5,6]) +c = a*b +print(a,b,c) +print(type(a), type(b), type(c)) +# Output: float32[3,] float32[3,] float32[3,] +# Output: +``` +Beside that, All the operators we used `jt.xxx(Var, ...)` have alias `Var.xxx(...)`. For example: + + +```python +c.max() # alias of jt.max(c) +c.add(a) # alias of jt.add(c, a) +c.min(keepdims=True) # alias of jt.min(c, keepdims=True) +``` + +if you want to know all the operation which Jittor supports. try `help(jt.ops)`. All the operation you found in `jt.ops.xxx`, can be used via alias `jt.xxx`. + + +```python +help(jt.ops) +# Output: +# abs(x: core.Var) -> core.Var +# add(x: core.Var, y: core.Var) -> core.Var +# array(data: array) -> core.Var +# binary(x: core.Var, y: core.Var, op: str) -> core.Var +# ...... +``` +### More + + +If you want to know more about Jittor, please check out the notebooks below: + + +* Quickstart + - [Example: Model definition and training][1] + - [Basics: Op, Var][2] + - [Meta-operator: Implement your own convolution with Meta-operator][3] +* Advanced + - [Custom Op: write your operator with C++ and CUDA and JIT compile it][4] + - [Profiler: Profiling your model][5] + - Jtune: Tool for performance tuning + + + +[1]: python/jittor/notebook/example.src.md "example" +[2]: python/jittor/notebook/basics.src.md "basics" +[3]: python/jittor/notebook/meta_op.src.md "meta_op" +[4]: python/jittor/notebook/custom_op.src.md "custom_op" +[5]: python/jittor/notebook/profiler.src.md "profiler" + +Those notebooks can be started in your own computer by `python3.7 -m jittor.notebook` + + +## Contributing + + +Jittor is still young. It may contain bugs and issues. Please report them in our bug track system. Contributions are welcome. Besides, if you have any ideas about Jittor, please let us know. + + + + +You can help Jittor in the following ways: + +* Citing Jittor in your paper +* recommend Jittor to your friends +* Contributing code +* Contributed tutorials and documentation +* File an issue +* Answer jittor related questions +* Light up the stars +* Keep an eye on jittor +* ...... + +## Contact Us + + + + + +Website: http://cg.cs.tsinghua.edu.cn/jittor/ + +Email: jittor@qq.com + +File an issue: https://github.com/Jittor/jittor/issues + +QQ Group: 836860279 + + + + +## The Team + + +Jittor is currently maintained by the [Tsinghua CSCG Group](https://cg.cs.tsinghua.edu.cn/). If you are also interested in Jittor and want to improve it, Please join us! + + +## Citation + + +``` +@article{hu2020jittor, + title={Jittor: a novel deep learning framework with meta-operators and unified graph execution}, + author={Hu, Shi-Min and Liang, Dun and Yang, Guo-Ye and Yang, Guo-Wei and Zhou, Wen-Yang}, + journal={Science China Information Sciences}, + volume={63}, + number={222103}, + pages={1--21}, + year={2020} +} +``` + +## License + + +Jittor is Apache 2.0 licensed, as found in the LICENSE.txt file. + + diff --git a/README.src.md b/README.src.md new file mode 100644 index 00000000..12b81077 --- /dev/null +++ b/README.src.md @@ -0,0 +1,524 @@ +# Jittor: a Just-in-time(JIT) deep learning framework +# Jittor: 即时编译深度学习框架 + +![Jittor Logo](https://cg.cs.tsinghua.edu.cn/jittor/favicon_package_v0/JittorLogo_Final1220.svg) + +[Quickstart](#quickstart) | [Install](#install) | [Tutorial](#tutorial) | [Chinese](./README.cn.md) + +[快速开始](#快速开始) | [安装](#安装) | [教程](#教程) | [English](./README.md) + +Jittor is a high-performance deep learning framework based on JIT compiling and meta-operators. The whole framework and meta-operators are compiled just-in-time. A powerful op compiler and tuner are integrated into Jittor. It allowed us to generate high-performance code with specialized for your model. Jittor also contains a wealth of high-performance model libraries, including: image recognition, detection, segmentation, generation, differentiable rendering, geometric learning, reinforcement learning, etc. . + +Jittor 是一个基于即时编译和元算子的高性能深度学习框架,整个框架在即时编译的同时,还集成了强大的Op编译器和调优器,为您的模型生成定制化的高性能代码。Jittor还包含了丰富的高性能模型库,涵盖范围包括:图像识别,检测,分割,生成,可微渲染,几何学习,强化学习等等。 + +The front-end language is Python. Module Design and Dynamic Graph Execution is used in the front-end, which is the most popular design for deeplearning framework interface. The back-end is implemented by high performance language, such as CUDA,C++. + +Jittor前端语言为Python。前端使用了模块化和动态图执行的设计,这是目前最主流的深度学习框架接口设计。后端则使用高性能语言编写,如CUDA,C++。 + +Related Links: +* [Jittor Website](https://cg.cs.tsinghua.edu.cn/jittor/) +* [Jittor Tutorials](https://cg.cs.tsinghua.edu.cn/jittor/tutorial/) +* [Jittor Models](https://cg.cs.tsinghua.edu.cn/jittor/resources/) +* [Jittor Documents](https://cg.cs.tsinghua.edu.cn/jittor/assets/docs/index.html) +* [Github](https://github.com/jittor/jittor), [GitLink](https://www.gitlink.org.cn/jittor/jittor), [Gitee](https://gitee.com/jittor/jittor) +* [Jittor Forum](https://discuss.jittor.org/) +* [Awesome Jittor List](https://github.com/Jittor/jittor/blob/master/AWESOME-JITTOR-LIST.md) +* IM: QQ Group(761222083) + +相关链接: +* [Jittor官网](https://cg.cs.tsinghua.edu.cn/jittor/) +* [Jittor教程](https://cg.cs.tsinghua.edu.cn/jittor/tutorial/) +* [Jittor模型库](https://cg.cs.tsinghua.edu.cn/jittor/resources/) +* [Jittor文档](https://cg.cs.tsinghua.edu.cn/jittor/assets/docs/index.html) +* [Github](https://github.com/jittor/jittor), [GitLink](https://www.gitlink.org.cn/jittor/jittor), [Gitee](https://gitee.com/jittor/jittor) +* [Jittor 论坛](https://discuss.jittor.org/) +* [Jittor 精选仓库](https://github.com/Jittor/jittor/blob/master/AWESOME-JITTOR-LIST.md) +* 即时通信: QQ Group(761222083) + + +The following example shows how to model a two-layer neural network step by step and train from scratch In a few lines of Python code. + +下面的代码演示了如何一步一步使用Python代码,从头对一个双层神经网络建模。 + +```python +import jittor as jt +from jittor import Module +from jittor import nn +import numpy as np + +class Model(Module): + def __init__(self): + self.layer1 = nn.Linear(1, 10) + self.relu = nn.Relu() + self.layer2 = nn.Linear(10, 1) + def execute (self,x) : + x = self.layer1(x) + x = self.relu(x) + x = self.layer2(x) + return x + +def get_data(n): # generate random data for training test. + for i in range(n): + x = np.random.rand(batch_size, 1) + y = x*x + yield jt.float32(x), jt.float32(y) + + +learning_rate = 0.1 +batch_size = 50 +n = 1000 + +model = Model() +optim = nn.SGD(model.parameters(), learning_rate) + +for i,(x,y) in enumerate(get_data(n)): + pred_y = model(x) + dy = pred_y - y + loss = dy * dy + loss_mean = loss.mean() + optim.step(loss_mean) + print(f"step {i}, loss = {loss_mean.data.sum()}") +``` + +## Contents + +* [Quickstart](#quickstart) +* [Install](#install) +* [Tutorial](#tutorial) +* [Contributing](#contributing) +* [The Team](#theteam) +* [License](#license) + +## 大纲 + +- [快速开始](#快速开始) +- [安装](#安装) +- [教程](#教程) +- [贡献](#贡献) +- [团队](#团队) +- [版权声明](#版权声明) + +## Quickstart + +## 快速开始 + +We provide some jupyter notebooks to help you quick start with Jittor. + +我们提供了一些jupyter notebooks来帮助您快速入门Jittor。 + +- [Example: Model definition and training][1] +- [示例:模型定义与训练][1] +- [Basics: Op, Var][2] +- [基础:Op, Var][2] +- [Meta-operator: Implement your own convolution with Meta-operator][3] +- [元算子:通过元算子实现自己的卷积层][3] + +## Install + +## 安装 + +Jittor框架对环境要求如下: + +Jittor environment requirements: + +| OS | CPU | Python | Compiler | (Optional) GPU platform | +|--------------------------------------------------------|-------------------------------------|--------|--------------|---------------------------------------------| +| Linux
(Ubuntu, CentOS, Arch,
UOS, KylinOS, ...) | x86
x86_64
ARM
loongson | >= 3.7 | g++ >=5.4 | Nvidia CUDA >= 10.0, [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar)
or [AMD ROCm](https://docs.amd.com/) >= 4.0
or [Hygon DCU DTK](https://tycloud.hpccube.com/doc/1.0.6/11277/general-handbook/software-tutorial/jittor.html) >= 22.04 | +| macOS
(>= 10.14 Mojave) | intel
Apple Silicon | >= 3.7 | clang >= 8.0 | - | +| Windows 10 & 11 | x86_64 | [>= 3.8](https://www.python.org/downloads/windows/) | - | Nvidia CUDA >= 10.2 [cuDNN](https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#install-windows) | + +Jittor 提供了三种安装方法:pip、docker和手动安装: + +Jittor offers three ways to install: pip, docker, or manual. + +## Pip 安装 + +## Pip install + +下面将展示Ubuntu的安装命令,如果您在使用其他Linux操作系统(如CentOS), 请安装好依赖(Python>=3.7, g++>=5.4)或者使用**docker安装**, 如果您已经装好编译器和对应版本的Python,我们强烈推荐您使用这种方法 +(如果无法访问github, 可以通过Jittor主页下载): + +```bash +sudo apt install python3.7-dev libomp-dev +python3.7 -m pip install jittor +# or install from github(latest version) +# python3.7 -m pip install git+https://github.com/Jittor/jittor.git +python3.7 -m jittor.test.test_example +``` + +如果测试运行通过,恭喜你已经安装完成. +jittor会自动在路径中寻找合适的编译器, 如果您希望手动指定编译器, 请使用环境变量 `cc_path` 和 `nvcc_path`(可选). + +### macOS 安装 + +### macOS install + +macOS 请使用 [homebrew](https://brew.sh) 安装额外的依赖。 + +Please first install additional dependencies with [homebrew](https://brew.sh). + +```bash +brew install libomp +``` + +之后您可以通过 pip 安装 jittor,并测试是否可以成功运行。 + +Then you can install jittor through pip and run the example. + +```bash +python3.7 -m pip install jittor +python3.7 -m jittor.test.test_example +``` + +目前在 macOS 中,jittor 只支持 CPU 计算。 + +Currently jittor only supports CPU on macOS. + +### Windows安装 + +### Windows install + +Windows 请准备好Python>=3.8,安装方法如下(conda安装需要额外命令): + +Windows user please prepare Python>=3.8, install instructions are list below(conda needs extra instructions): + +```bash +# check your python version(>=3.8) +python --version +python -m pip install jittor +# if conda is used +conda install pywin32 +``` + +Windows 下,jittor会自动检测显卡并安装对应的 CUDA, 请确保您的NVIDIA驱动支持CUDA 10.2 以上,您还可以使用如下命令手动为Jittor安装CUDA: + +In Windows, jittor will automatically detect and install CUDA, please make sure your NVIDIA driver support CUDA 10.2 or above, or you can manually let jittor install CUDA for you: + +```bash +python -m jittor_utils.install_cuda +``` + + +## Docker Install + +## Docker 安装 + +我们提供了Docker安装方式,免去您配置环境,Docker安装方法如下: + +We provide a Docker installation method to save you from configuring the environment. The Docker installation method is as follows: + +``` +# CPU only(Linux) +docker run -it --network host jittor/jittor +# CPU and CUDA(Linux) +docker run -it --network host --gpus all jittor/jittor-cuda +# CPU only(Mac and Windows) +docker run -it -p 8888:8888 jittor/jittor +``` + +关于Docker安装的详细教程,可以参考[Windows/Mac/Linux通过Docker安装计图](https://cg.cs.tsinghua.edu.cn/jittor/tutorial/2020-5-15-00-00-docker/) + +## 手动安装 +## manual install + +We will show how to install Jittor in Ubuntu 16.04 step by step, Other Linux distributions may have similar commands. + +我们将逐步演示如何在Ubuntu 16.04中安装Jittor,其他Linux发行版可能可以使用类似的命令。 + +### Step 1: Choose your back-end compiler + +### 步骤一:选择您的后端编译器 + +```bash +# g++ +sudo apt install g++ build-essential libomp-dev + +# OR clang++-8 +wget -O - https://raw.githubusercontent.com/Jittor/jittor/master/script/install_llvm.sh > /tmp/llvm.sh +bash /tmp/llvm.sh 8 +``` +### Step 2: Install Python and python-dev + +### 步骤二:安装Python和python-dev + +Jittor need python version >= 3.7. + +Jittor需要python的版本>=3.7。 + +```bash +sudo apt install python3.7 python3.7-dev +``` + +### Step 3: Run Jittor + +### 步骤三:运行Jittor + +The whole framework is compiled Just-in-time. Let's install jittor via pip + +整个框架是即时编译的。 让我们通过pip安装jittor + +```bash +git clone https://github.com/Jittor/jittor.git +sudo pip3.7 install ./jittor +export cc_path="clang++-8" +# if other compiler is used, change cc_path +# export cc_path="g++" +# export cc_path="icc" + +# run a simple test +python3.7 -m jittor.test.test_example +``` +if the test is passed, your Jittor is ready. + +如果通过了测试,那么您的Jittor已经准备就绪。 + +### Optional Step 4: Enable CUDA + +### 可选步骤四:启用CUDA + +Using CUDA in Jittor is very simple, Just setup environment value `nvcc_path` + +在Jittor中使用CUDA非常简单,只需设置环境值`nvcc_path` + +```bash +# replace this var with your nvcc location +export nvcc_path="/usr/local/cuda/bin/nvcc" +# run a simple cuda test +python3.7 -m jittor.test.test_cuda +``` +if the test is passed, your can use Jittor with CUDA by setting `use_cuda` flag. + +如果测试通过,则可以通过设置`use_cuda`标识符在Jittor中启用CUDA。 + +```python +import jittor as jt +jt.flags.use_cuda = 1 +``` + +### Optional Step 5: Test Resnet18 training + +### 可选步骤五:测试训练Resnet18 + +To check the integrity of Jittor, you can run Resnet18 training test. Note: 6G GPU RAM is requires in this test. + +要检查Jittor的完整性,您可以运行Resnet18训练测试。需要注意的是,这个测试需要6G显存。 + +```bash +python3.7 -m jittor.test.test_resnet +``` +if those tests are failed, please report bugs for us, and feel free to contribute ^_^ + +如果这些测试失败,请为我们报告错误,我们十分欢迎您为Jittor做出贡献^ _ ^ + +## Tutorial + +## 教程 + +In the tutorial section, we will briefly explain the basic concept of Jittor. + +在教程部分,我们将简要解释Jittor的基本概念。 + +To train your model with Jittor, there are only three main concepts you need to know: + +要使用Jittor训练模型,您需要了解两个主要概念: + +* Var: basic data type of jittor +* Var:Jittor的基本数据类型 +* Operations: Jittor'op is simular with numpy +* Operations:Jittor的算子与numpy类似 + +### Var + +### 数据类型 + +First, let's get started with Var. Var is the basic data type of jittor. Computation process in Jittor is asynchronous for optimization. If you want to access the data, `Var.data` can be used for synchronous data accessing. + +首先,让我们开始使用Var。Var是jittor的基本数据类型,为了运算更加高效Jittor中的计算过程是异步的。 如果要访问数据,可以使用`Var.data`进行同步数据访问。 + +```python +import jittor as jt +a = jt.float32([1,2,3]) +print (a) +print (a.data) +# Output: float32[3,] +# Output: [ 1. 2. 3.] +``` + +And we can give the variable a name. + +此外我们可以给变量起一个名字。 + +```python +a.name('a') +print(a.name()) +# Output: a +``` + +### Operations + +### 数据运算 + +Jittor'op is simular with numpy. Let's try some operations. We create Var `a` and `b` via operation `jt.float32`, and add them. Printing those variables shows they have the same shape and dtype. + + Jittor的算子与numpy类似。 让我们尝试一些运算, 我们通过Op`jt.float32`创建Var `a`和`b`,并将它们相加。 输出这些变量相关信息,可以看出它们具有相同的形状和类型。 + +```python +import jittor as jt +a = jt.float32([1,2,3]) +b = jt.float32([4,5,6]) +c = a*b +print(a,b,c) +print(type(a), type(b), type(c)) +# Output: float32[3,] float32[3,] float32[3,] +# Output: +``` +Beside that, All the operators we used `jt.xxx(Var, ...)` have alias `Var.xxx(...)`. For example: + +除此之外,我们使用的所有算子`jt.xxx(Var,...)`都具有别名`Var.xxx(...)`。 例如: + +```python +c.max() # alias of jt.max(c) +c.add(a) # alias of jt.add(c, a) +c.min(keepdims=True) # alias of jt.min(c, keepdims=True) +``` + +if you want to know all the operation which Jittor supports. try `help(jt.ops)`. All the operation you found in `jt.ops.xxx`, can be used via alias `jt.xxx`. + +如果您想知道Jittor支持的所有运算,可以运行`help(jt.ops)`。 您在`jt.ops.xxx`中找到的所有运算都可以通过别名`jt.xxx`。 + +```python +help(jt.ops) +# Output: +# abs(x: core.Var) -> core.Var +# add(x: core.Var, y: core.Var) -> core.Var +# array(data: array) -> core.Var +# binary(x: core.Var, y: core.Var, op: str) -> core.Var +# ...... +``` +### More + +### 更多教程 + +If you want to know more about Jittor, please check out the notebooks below: + +如果您想进一步了解Jittor,请查看以下notebooks: + +* Quickstart + - [Example: Model definition and training][1] + - [Basics: Op, Var][2] + - [Meta-operator: Implement your own convolution with Meta-operator][3] +* 快速开始 + * [示例:模型定义与训练][1] + * [基本概念:Op, Var][2] + * [元算子:通过元算子实现自己的卷积层][3] +* Advanced + - [Custom Op: write your operator with C++ and CUDA and JIT compile it][4] + - [Profiler: Profiling your model][5] + - Jtune: Tool for performance tuning +* 进阶 + * [自定义算子:使用C ++和CUDA编写您的算子,并其进行即时编译][4] + * [性能分析器:分析您的模型][5] + * Jtune:性能调优工具 + + + +[1]: python/jittor/notebook/example.src.md "example" +[2]: python/jittor/notebook/basics.src.md "basics" +[3]: python/jittor/notebook/meta_op.src.md "meta_op" +[4]: python/jittor/notebook/custom_op.src.md "custom_op" +[5]: python/jittor/notebook/profiler.src.md "profiler" +[1]: python/jittor/notebook/example.src.md "示例" +[2]: python/jittor/notebook/basics.src.md "基本概念" +[3]: python/jittor/notebook/meta_op.src.md "元算子" +[4]: python/jittor/notebook/custom_op.src.md "自定义算子" +[5]: python/jittor/notebook/profiler.src.md "性能分析器" + +Those notebooks can be started in your own computer by `python3.7 -m jittor.notebook` + +这些notebooks可以通过python3.7 -m jittor.notebook在您自己的计算机中运行。 + +## Contributing + +## 贡献 + +Jittor is still young. It may contain bugs and issues. Please report them in our bug track system. Contributions are welcome. Besides, if you have any ideas about Jittor, please let us know. + +Jittor还很年轻。 它可能存在错误和问题。 请在我们的错误跟踪系统中报告它们。 我们欢迎您为Jittor做出贡献。 此外,如果您对Jittor有任何想法,请告诉我们。 + +您可以用以下方式帮助Jittor: + +* 在论文中引用 Jittor +* 向身边的好朋友推荐 Jittor +* 贡献代码 +* 贡献教程和文档 +* 提出issue +* 回答 jittor 相关问题 +* 点亮小星星 +* 持续关注 jittor +* …… + +You can help Jittor in the following ways: + +* Citing Jittor in your paper +* recommend Jittor to your friends +* Contributing code +* Contributed tutorials and documentation +* File an issue +* Answer jittor related questions +* Light up the stars +* Keep an eye on jittor +* ...... + +## Contact Us + +## 联系我们 + +官方主页: http://cg.cs.tsinghua.edu.cn/jittor/ + +电子邮件:jittor@qq.com + +提出issue:https://github.com/Jittor/jittor/issues + +Website: http://cg.cs.tsinghua.edu.cn/jittor/ + +Email: jittor@qq.com + +File an issue: https://github.com/Jittor/jittor/issues + +QQ Group: 761222083 + +QQ 群:761222083 + + + +## The Team + +## 团队 + +Jittor is currently maintained by the [Tsinghua CSCG Group](https://cg.cs.tsinghua.edu.cn/). If you are also interested in Jittor and want to improve it, Please join us! + +Jittor目前由[清华大学计算机图形学组](https://cg.cs.tsinghua.edu.cn/)维护。 如果您也对Jittor感兴趣并希望对其进行改进,请加入我们! + +## Citation + +## 引用 + +``` +@article{hu2020jittor, + title={Jittor: a novel deep learning framework with meta-operators and unified graph execution}, + author={Hu, Shi-Min and Liang, Dun and Yang, Guo-Ye and Yang, Guo-Wei and Zhou, Wen-Yang}, + journal={Science China Information Sciences}, + volume={63}, + number={222103}, + pages={1--21}, + year={2020} +} +``` + +## License + +## 版权声明 + +Jittor is Apache 2.0 licensed, as found in the LICENSE.txt file. + +如LICENSE.txt文件中所示,Jittor使用Apache 2.0版权协议。 diff --git a/doc/Makefile b/doc/Makefile new file mode 100644 index 00000000..d0c3cbf1 --- /dev/null +++ b/doc/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/doc/build_doc.sh b/doc/build_doc.sh new file mode 100755 index 00000000..d8c53288 --- /dev/null +++ b/doc/build_doc.sh @@ -0,0 +1,18 @@ +# sudo python3.7 -m pip install \ +# recommonmark \ +# sphinx sphinx-autobuild sphinx_rtd_theme \ +# sphinx-autobuild \ +# --timeout 100 + + +bpath=$(readlink -f "${BASH_SOURCE[0]}") +bpath=$(dirname "${bpath}") + +jittor_path=$(readlink -f "${bpath}/..") + +echo "[doc path] $bpath" +echo "[jittor path] $jittor_path" + +export PYTHONPATH=$jittor_path/python +cd $bpath +sphinx-autobuild -b html source build -H 0.0.0.0 -p 8890 diff --git a/doc/logo.png b/doc/logo.png new file mode 100644 index 00000000..7bbc7488 Binary files /dev/null and b/doc/logo.png differ diff --git a/doc/source/Jittor性能测试与对比方法.md b/doc/source/Jittor性能测试与对比方法.md new file mode 100644 index 00000000..aa05d000 --- /dev/null +++ b/doc/source/Jittor性能测试与对比方法.md @@ -0,0 +1,176 @@ +Jittor性能测试与对比方法 +===================== + +下面代码以AlexNet为例,用于演示 Jittor 性能测试的正确方法: + +```python +import time +import jittor as jt +from jittor.models import resnet50 +jt.flags.use_cuda = jt.has_cuda + +warmup = 10 +rerun = 100 +batch_size = 8 +data = jt.random((batch_size, 3, 224, 224)) +model = resnet50() +model.eval() + +# 此段代码对jittor进行热身,确保时间测试准确 +jt.sync_all(True) +for i in range(warmup): + pred = model(data) + # sync是把计算图发送到计算设备上 + pred.sync() +# sync_all(true)是把计算图发射到计算设备上,并且同步。 +# 只有运行了jt.sync_all(True)才会真正地运行,时间才是有效的,因此执行forward前后都要执行这句话 +jt.sync_all(True) + +# 开始测试运行时间 +start = time.time() +for i in range(rerun): + pred = model(data) + pred.sync() +jt.sync_all(True) +end = time.time() + +print("Jittor FPS:", (rerun*batch_size)/(end-start)) + +``` + +在这段代码中,我们定义了几个参数`batch_size`, `warmup`, `rerun`, batch_size代表批大小,warmup是用于热身的循环次数,而rerun是用于测速的循环次数,最终输出FPS,对Jittor进行正确测速的关键是 热身部分和同步部分,热身部分确保测试时间稳定,没有包含编译用的时间,而同步部分确保计算完成,因为jittor是一个异步框架,只有同步操作能保证计算完成。 + +以上代码的运行结果如下(RTX Titan,batch 8): + +``` +Compiling Operators(8/8) used: 7.35s eta: 0s +Compiling Operators(13/13) used: 8.36s eta: 0s +Jittor FPS: 908.9853866375396 +``` + +我们还可以使用类似的代码测试 PyTorch的性能: + +```python +import time +import torch +from torchvision.models import resnet50 + +warmup = 10 +rerun = 100 +batch_size = 8 +data = torch.randn((batch_size, 3, 224, 224)).cuda() +model = resnet50() +model.cuda() +model.eval() + +# 此段代码对pytorch进行热身,确保时间测试准确 +torch.cuda.synchronize() +for i in range(warmup): + pred = model(data) +# synchronize用于确保PyTorch计算完成 +torch.cuda.synchronize() + +# 开始测试运行时间 +start = time.time() +for i in range(rerun): + pred = model(data) +torch.cuda.synchronize() +end = time.time() + +print("PyTorch FPS:", (rerun*batch_size)/(end-start)) +``` + + +以上代码的运行结果如下(RTX Titan,batch 8): + +``` +PyTorch FPS: 807.4806873965665 +``` + +我们还可以对这两段代码合并,并对比结果的一致性: + +```python +import time +import jittor as jt +from jittor.models import resnet50 +jt.flags.use_cuda = jt.has_cuda + +warmup = 100 +rerun = 1000 +batch_size = 8 +data = jt.random((batch_size, 3, 224, 224)) +model = resnet50() +model.eval() + +# 此段代码对jittor进行热身,确保时间测试准确 +jt.sync_all(True) +for i in range(warmup): + pred = model(data) + # sync是把计算图发送到计算设备上 + pred.sync() +# sync_all(true)是把计算图发射到计算设备上,并且同步。 +# 只有运行了jt.sync_all(True)才会真正地运行,时间才是有效的,因此执行forward前后都要执行这句话 +jt.sync_all(True) + +# 开始测试运行时间 +start = time.time() +for i in range(rerun): + pred = model(data) + pred.sync() +jt.sync_all(True) +end = time.time() + +print("Jittor FPS:", (rerun*batch_size)/(end-start)) +# 将 jittor 数据和参数导出为 numpy 和 torch 格式 +jittor_data = pred.numpy() +jittor_param = model.state_dict(to="torch") + +import numpy as np +import torch +from torchvision.models import resnet50 +data = torch.Tensor(data.numpy()).cuda() +model = resnet50() +# 加载 jittor 参数 +model.load_state_dict(jittor_param) +model.cuda() +model.eval() + +# 此段代码对pytorch进行热身,确保时间测试准确 +torch.cuda.synchronize() +for i in range(warmup): + pred = model(data) +# synchronize用于确保PyTorch计算完成 +torch.cuda.synchronize() + +# 开始测试运行时间 +start = time.time() +for i in range(rerun): + pred = model(data) +torch.cuda.synchronize() +end = time.time() + +print("PyTorch FPS:", (rerun*batch_size)/(end-start)) +pytorch_data = pred.detach().cpu().numpy() +err = np.mean(np.abs(pytorch_data - jittor_data)) +print("mean error:", err) + +``` + + +以上代码运行结果如下: + +``` +Jittor FPS: 908.9853866375396 +PyTorch FPS: 807.4806873965665 +mean error: 1e-5 +``` + +误差输出为1e-5, 在可接受范围内。正确测速与对比的几大关键点为: + +1. 充分热身,除去框架的准备时间。 +2. 多次运行,确保测试时间稳定。 +3. 加上同步语句,确保测试时间准确。 +4. 保证显存充足,在显存不足时,jittor会调用统一内存来弥补,会产生性能损失,请密切关注`nvidia-smi`的输出结果。 +5. 保证对比模型的一致性,检查输出结果的一致。 + +如果您对测试结果有疑问,或者有优化需求,欢迎随时联系Jittor开发团队。 diff --git a/doc/source/Jittor显存以及内存优化方法.md b/doc/source/Jittor显存以及内存优化方法.md new file mode 100644 index 00000000..12e820c5 --- /dev/null +++ b/doc/source/Jittor显存以及内存优化方法.md @@ -0,0 +1,75 @@ +Jittor显存以及内存优化方法 +===================== + +您可以主要通过两种方法,来改进内存消耗: + +1. 优化消耗内存比较大的变量 +2. 使用Jittor自动交换技术,将变量在显存-内存-硬盘之间交换,降低运行部署门槛。 + +## 优化消耗内存比较大的变量 + +您可以使用jittor的memory profiler,来分析显存消耗较大的代码,并且针对特定代码进行优化。使用方法如下: + +``` +net = jt.models.resnet18() +with jt.flag_scope(trace_py_var=3, profile_memory_enable=1): + imgs = jt.randn((1,3,224,224)) + net(imgs).sync() + jt.get_max_memory_treemap() +``` + +输出如下: +``` + | + ├─./python/jittor/test/test_memory_profiler.py:100(test_sample) + | [19.03 MB; 29.67%] + | ./python/jittor/test/test_memory_profiler.py:100 + | | + | └─./python/jittor/__init__.py:730(__call__) + | [19.03 MB; 29.67%] + | ./python/jittor/__init__.py:730 + | | + | └─./python/jittor/models/resnet.py:152(execute) + | [19.03 MB; 29.67%] + | ./python/jittor/models/resnet.py:152 + | | + | ├─./python/jittor/models/resnet.py:142(_forward_impl) + | | [6.13 MB; 9.55%] + | | ./python/jittor/models/resnet.py:142 + | | | +``` + + +## 使用自动交换技术 + +该技术确保Jittor在显存或者内存不足的情况下,都能以一定速度运行。 + +节省内存方法,请安装Jittor版本大于1.3.7.5,并添加如下环境变量: + +```bash +export JT_SAVE_MEM=1 +# 限制cpu最多使用16G +export cpu_mem_limit=16000000000 +# 限制device内存(如gpu、tpu等)最多使用8G +export device_mem_limit=8000000000 +# windows 用户,请使用powershell +# $env:JT_SAVE_MEM="1" +# $env:cpu_mem_limit="16000000000" +# $env:device_mem_limit="8000000000" +``` +用户可以自由设定cpu和设备内存的使用量,如果不希望对内存进行限制,可以设置为`-1`。 +```bash +# 限制cpu最多使用16G +export cpu_mem_limit=-1 +# 限制device内存(如gpu、tpu等)最多使用8G +export device_mem_limit=-1 +# windows 用户,请使用powershell +# $env:JT_SAVE_MEM="1" +# $env:cpu_mem_limit="-1" +# $env:device_mem_limit="-1" +``` + +如果想要清理磁盘交换文件,可以运行如下命令 +```bash +python3 -m jittor_utils.clean_cache swap +``` \ No newline at end of file diff --git a/doc/source/Jittor调试技巧.md b/doc/source/Jittor调试技巧.md new file mode 100644 index 00000000..1740a204 --- /dev/null +++ b/doc/source/Jittor调试技巧.md @@ -0,0 +1,90 @@ +Jittor调试技巧 +===================== + +该文档包含了几种异常情况的调试方法和技巧。 + +## 爆Nan、Inf + +在模型训练的过程中,可能因为数值不稳定而出现Nan或者Inf,为了帮助您定位出现nan的代码,您可以设置如下环境变量: + +```bash +export JT_CHECK_NAN=1 +export trace_py_var=3 +``` + +其中,环境变量`JT_CHECK_NAN=1`的用途是:当算子的输出出现异常浮点数时,自动报错并停止程序,环境变量`trace_py_var=3`的用途是:输出算子对应的Python代码行数,3代表输出的详细等级,为最高等级。 + +需要注意的是,开启这两个特性之后,jittor速度会大幅下降,并且触发重编译,请不要在训练环境或者生产环境开启该模式,也不要长时间开启该模式。 + +## 错误信息定位不准确 + +Jittor框架默认采用延迟执行(Lazy execution)的方式进行加速,算子的执行和创建是不同步的,这可能导致报错信息定位不准确,您可以手动关闭延迟执行,采取立刻执行(eager execution)的模式,使用如下环境变量即可: + +```bash +export lazy_execution=0 +``` + +或者在python代码中通过flag关闭 +```python +jt.flags.lazy_execution=0 +``` + +## 内存不足 + +当您发现Jittor由于内存相关问题,无法运行时,Jittor会向您报告内存使用情况,内存不足可能有两种情况: + +1. 训练模型过大,一个迭代就崩溃报错。 +2. 多次迭代的过程中,内存占用不断增长,直到最后内存耗尽报错。 + +**对于第一种情况** ,您可能需要调整模型或者数据大小,或者使用[多卡训练](jittor.mpi),此外,您还可以在每个迭代内部,让Jittor强制回收内存: + +```python +for ...: + ... + jt.sync_all() + jt.gc() +``` + +如果您使用到了CUDA和卷积,还有可能是卷积消耗的临时空间过大,在这种情况下,可以关闭cudnn的临时内存申请,请将如下代码插入到最开始: + +```python +jt.cudnn.set_max_workspace_ratio(0.0) +``` + +**对于第二种情况**,可能是存在内存内存泄漏,请检查您是否存在全局变量没有释放,或者全局变量没有停止梯度,导致计算图不断增加,检查方法如下,您可以在每个迭代内部,插入如下调试代码: + +```python +for ...: + ... + jt.sync_all() + jt.display_memory_info() +``` + +Jittor会输出内存消耗,以及计算图的大小`lived_var,lived_op`,以及用户持有的变量数`hold_var`, 如果计算图规模不断增大,请检查代码,或者提交github issue联系我们,并且附上错误日志和代码复现脚本。 + + +## 段错误 + +如果Jittor出现了段错误,建议您将错误提交github issue联系我们,并且附上错误日志和代码复现脚本。您也可以使用如下环境变量对程序以及框架进行诊断: + +```bash +export debug=1 +export gdb_attach=1 +``` + +其中,环境变量`debug=1`代表开启jittor的debug模式,性能会大幅下降,但会保留调试信息,`gdb_attach=1`将会自动将gdb贴在jittor的主进程上,方便您进行单步调试。关于gdb的使用,您可以参考[GDB Cheat Sheet](https://darkdust.net/files/GDB%20Cheat%20Sheet.pdf) + + +## 管理Jittor cache + +Jittor会在`~/.cache/jittor`目录下创建cache, cache里面可能包括 core(内核)、cuda编译器、cuda库、数据集(dataset)、预训练参数等等,在某些情况下cache可能失效,如系统更新、驱动更新等等,这种情况可能需要用户手动清除cache, 清除的方法如下: + +``` +python3 -m jittor_utils.clean_cache all +``` + +以上命令会清除jittor的所有cache,如果您不想全部清除,可以参考命令行帮助: + +``` +python3 -m jittor_utils.clean_cache help +``` \ No newline at end of file diff --git a/doc/source/README.cn.md b/doc/source/README.cn.md new file mode 100644 index 00000000..9c599289 --- /dev/null +++ b/doc/source/README.cn.md @@ -0,0 +1 @@ +../../README.cn.md \ No newline at end of file diff --git a/doc/source/conf.py b/doc/source/conf.py new file mode 100644 index 00000000..b21ecca2 --- /dev/null +++ b/doc/source/conf.py @@ -0,0 +1,101 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys +jittor_path = os.path.abspath('../../python') +print(f"[jittor_path] {jittor_path}") +sys.path.insert(0, jittor_path) + +import jittor + + +# -- Project information ----------------------------------------------------- + +project = 'Jittor' +copyright = '2020, Jittor' +author = 'Jittor' + +# The full version, including alpha/beta/rc tags +release = jittor.__version__ +# fix AttributeError for "typing.get_type_hints(jt.Var)" +jittor.Var.__module__ = "jittor_core" + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = 'zh_CN' + + + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + # 'recommonmark', + 'myst_parser', + 'sphinx.ext.autodoc', + # Auto-generate section labels. + 'sphinx.ext.autosectionlabel', + 'sphinx.ext.viewcode', +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = [] + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'alabaster' + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +import sphinx_rtd_theme +html_theme = "sphinx_rtd_theme" +html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] + +source_suffix = { + '.rst': 'restructuredtext', + '.txt': 'markdown', + '.md': 'markdown', +} + +import recommonmark +from recommonmark.transform import AutoStructify + +# At the bottom of conf.py +def setup(app): + app.add_config_value('recommonmark_config', { + # 'url_resolver': lambda url: github_doc_root + url, + 'auto_toc_tree_section': 'Contents', + }, True) + app.add_transform(AutoStructify) + + +# Prefix document path to section labels, otherwise autogenerated labels would look like 'heading' +# rather than 'path/to/file:heading' +autosectionlabel_prefix_document = True diff --git a/doc/source/index.rst b/doc/source/index.rst new file mode 100644 index 00000000..072e00ab --- /dev/null +++ b/doc/source/index.rst @@ -0,0 +1,60 @@ +.. Jittor documentation master file, created by + sphinx-quickstart on Mon May 18 23:05:53 2020. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +欢迎查阅计图文档 +================================== + +.. toctree:: + :maxdepth: 2 + :caption: 内容一览: + + README.cn.md + +.. toctree:: + :maxdepth: 2 + :caption: 模块API: + + jittor + jittor.nn + jittor.models + jittor.optim + jittor.init + jittor.contrib + jittor.dataset + jittor.transform + jittor.mpi + jittor.linalg + jittor.console + jittor.distributions + jittor.attention + jittor.loss3d + + +.. toctree:: + :maxdepth: 2 + :caption: 计图模型库: + + JDet + segmentation-jittor + InstanceSegmentation-jittor + gan-jittor + PointCloudLib + jrender + +.. toctree:: + :maxdepth: 1 + :caption: 其他: + + Jittor调试技巧 + Jittor性能测试与对比方法 + Jittor显存以及内存优化方法 + 教程 + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/doc/source/jittor.attention.md b/doc/source/jittor.attention.md new file mode 100644 index 00000000..fcfc0fe0 --- /dev/null +++ b/doc/source/jittor.attention.md @@ -0,0 +1,10 @@ +jittor.attention +===================== + +这里是Jittor的 注意力 模块的API文档,您可以通过`from jittor import attention`来获取该模块。 + +```eval_rst +.. automodule:: jittor.attention + :members: + :undoc-members: +``` diff --git a/doc/source/jittor.console.md b/doc/source/jittor.console.md new file mode 100644 index 00000000..48abca9e --- /dev/null +++ b/doc/source/jittor.console.md @@ -0,0 +1,237 @@ +jittor.console +===================== + +这里是Jittor的console api文档,console功能主要面向c/c++, 方便c++用户通过console使用jittor,jittor console 优化了 +c++数组和jittor内核之间的数据传输,减少了python额外开销,是通过c++使用jittor的高性能接口。 + +该功能要求 jittor版本大于1.2.2.17, 编译器支持c++17。 + +## 简单教程 + +我们提供了一个完整的教程,用户可以通过如下几行命令编译运行: + +```bash +# 生成c++ example源代码文件 +python3.7 -m jittor_utils.config --cxx-example > example.cc +# 调用g++编译example, 需要g++支持std=c++17 +g++ example.cc $(python3.7 -m jittor_utils.config --include-flags --libs-flags --cxx-flags) -o example +# 运行example +./example +``` + +运行结果可能如下: +```bash +hello jt console +1 +hello +1 2 3 4 +jt.Var([[-1 5 4] + [ 3 2 1]], dtype=int32) +2 3 +1 25 16 +9 4 1 +pred.shape 2 1000 +``` + +用户可以打开 example.cc, 修改成所需的应用,接下来我们会为大家讲解 example.cc 中的细节。 + +打开example.cc, 我们可以看到如下代码: + +```cpp +#include +#include + +using namespace std; + +int main() { + ... +} +``` + +这里我们导入了使用 console 所需的头文件 `pyjt/pyjt_console.h` + +接下来是jittor console的实例化, 并且使用python的print输出hello jt console: + +```cpp + jittor::Console console; + // run python code in console + console.run("print('hello jt console', flush=True)"); +``` + +输出结果: + +``` +hello jt console +``` + +注意到这里我们在 python print的时候使用了flush keyword,这是为了让python的输出流和c++的输出流保持一致, +不会错乱。 + +接下来我们调用了 `console.set(name, data)` 和 `console.get(name)` 往 console 里面设置了一个int变量a,并且再从console里面取出来。 + +```cpp + // set a python value: a = 1 + console.set("a", 1); + // get a python value + cout << console.get("a") << endl; +``` + +输出结果: + +``` +1 +``` + +同样的方法,我们还设置了 `string` 和 `vector`, 如下所示 + +```cpp + // set a python string + console.set("b", "hello"); + cout << console.get("b") << endl; + + // set a python array + vector x{1,2,3,4}; + console.set("x", x); + auto x2 = console.get>("x"); + for (auto a : x2) cout << a << " "; cout << endl; +``` + +输出结果: + +``` +hello +1 2 3 4 +``` + +我们还可以往console里面设置jittor变量,这里我们使用了下面几个新的接口: + +1. `jittor::array(shape, data)`: 这个接口创建了一个jittor的array,类型是`T`, 维度大小为`NDIM`, 形状为 `shape`, 注意shape的长度需要和`NDIM`保持一致,最后是传入的数据,可以是一个vector,也可以是一个指针。 +2. `console.set_array(name, arr)`: 往console里面设置该jittor array, 名称为`name`。 +3. `console.get(name)`: 从console里取出一个jittor array,类型为`T`,维度大小为`NDIM`,需要注意的是类型和维度大小必须和console中的变量匹配,否则会抛出异常。 +4. `arr(i,j)`: 对jittor变量取值。 +5. `arr.shape[i]`: 获取jittor变量的维度大小。 + +在这段代码中,我们首先创建了一个2x3的矩阵, 然后修改了矩阵中的值,随即设置到了python console里面,并且取出输出: + +```cpp + // set and get a jittor array + jittor::array arr2({2,3}, {6,5,4,3,2,1}); + arr2(0,0) = -1; + console.set_array("arr2", arr2); + console.run("print(arr2, flush=True); arr3 = arr2**2;"); + auto arr3 = console.get_array("arr3"); + cout << arr3.shape[0] << ' ' << arr3.shape[1] << endl; + for (int i=0; i input({2, 3, 224, 224}); + memset(input.data.get(), 0, input.nbyte()); + console.set_array("input", input); + console.run(R"( +import jittor as jt +from jittor.models import resnet + +model = resnet.resnet18() +pred = model(input) + )"); + auto pred = console.get_array("pred"); + cout << "pred.shape " << pred.shape[0] << ' ' << pred.shape[1] << endl; +``` + +我们输出了取出的变量的形状,结果如下: + +``` +pred.shape 2 1000 +``` + +## jittor array 接口一览 + +`jittor::array` 是 c++和jittor console交互的 array类型,他的定义如下: + +```cpp + +// T: 类型, N: 维度数量 +template +struct array { + +// N维 形状大小 +int64 shape[N]; +// 数据指针 +unique_ptr data; + +// 是否为浮点数 +bool is_float(); +// 是否为无符号类型 +bool is_unsigned(); +// 数组总大小,为shape数组累乘的结果 +int64 size(); +// 数组总比特数 +int64 nbyte(); +// 数据类型的字符串表示 +string dtype(); +// 维度数量, 同 N +int ndim(); + +// array 构造函数,shape为形状,数据未被初始化 +array(const vector& shape); +// array 构造函数,shape为形状,数据从data指针拷贝初始化 +array(const vector& shape, const T* data); +// array 构造函数,shape为形状,数据从data vector拷贝初始化 +array(const vector& shape, const vector& data); + +T& operator()(...); + +}; +``` + +## Console 接口一览 + +console接口主要用于设置变量,取出变量,运行脚本, 三部分构成。 + +```cpp + +struct Console { + +// 运行代码接口 +void run(const string& src); + +// 设置变量名称为s, 值为data +template +void set(const string& s, const T& data); + +// 获取变量名称为s +template +T get(const string& s) + +// 设置 array 变量 +void set_array(const string& s, const array& data); + +// 获取一个jittor array,类型为`T`,维度大小为`NDIM`,需要注意的是类型和维度大小必须和console中的变量匹配,否则会抛出异常。 +void get_array(const string& s); + +}; +``` + +其中 `get`,`set` 支持常见的c++类型有: + +1. int, uint, int64, uint64, float, double +2. string +3. vector +4. map, unordered_map diff --git a/doc/source/jittor.contrib.md b/doc/source/jittor.contrib.md new file mode 100644 index 00000000..7b729737 --- /dev/null +++ b/doc/source/jittor.contrib.md @@ -0,0 +1,10 @@ +jittor.contrib +===================== + +这里是Jittor的贡献代码模块模块的API文档,此模块的代码可能还没有完全成熟,我们将在后续迭代开发中继续完善,您可以通过`from jittor import contrib`来获取该模块。 + +```eval_rst +.. automodule:: jittor.contrib + :members: + :undoc-members: +``` diff --git a/doc/source/jittor.dataset.md b/doc/source/jittor.dataset.md new file mode 100644 index 00000000..a244b101 --- /dev/null +++ b/doc/source/jittor.dataset.md @@ -0,0 +1,11 @@ +jittor.dataset +===================== + +这里是Jittor的数据集模块的API文档,您可以通过`from jittor import dataset`来获取该模块。 + +```eval_rst +.. automodule:: jittor.dataset + :imported-members: + :members: + :undoc-members: +``` diff --git a/doc/source/jittor.distributions.md b/doc/source/jittor.distributions.md new file mode 100644 index 00000000..961fe232 --- /dev/null +++ b/doc/source/jittor.distributions.md @@ -0,0 +1,10 @@ +jittor.distributions +===================== + +这里是Jittor的随机分布模块的API文档,您可以通过`from jittor import distributions`来获取该模块。 + +```eval_rst +.. automodule:: jittor.distributions + :members: + :undoc-members: +``` diff --git a/doc/source/jittor.init.md b/doc/source/jittor.init.md new file mode 100644 index 00000000..289fceea --- /dev/null +++ b/doc/source/jittor.init.md @@ -0,0 +1,10 @@ +jittor.init +===================== + +这里是Jittor的参数初始化模块的API文档,您可以通过`from jittor import init`来获取该模块。 + +```eval_rst +.. automodule:: jittor.init + :members: + :undoc-members: +``` diff --git a/doc/source/jittor.linalg.md b/doc/source/jittor.linalg.md new file mode 100644 index 00000000..cce65a78 --- /dev/null +++ b/doc/source/jittor.linalg.md @@ -0,0 +1,57 @@ +jittor.linalg +===================== + +这里是Jittor的线性代数函数的API文档,您可以通过`from jittor import linalg`来获取该模块。 + +## 基本函数简介 +#### 基本线性代数运算API +- linalg.inv(a) + + 对a进行求逆运算 + +- linalg.pinv(a) + + 对a进行广义求逆运算。该运算不要求原矩阵a可逆。 + +- linalg.slogdet(a) + + 对a求取slogdet。会返回值以及符号。 + +- linalg.det(a) + + 对a求行列式。 + +- linalg.solve(a,b) + + 求解线性方程Ax=b的解。 + +#### 分解API +- linalg.cholesky(a) + + 对a进行cholesky分解。 + +- linalg.qr(a) + + 对a进行qr分解。 + +- linalg.svd + + 对a进行奇异值分解。 +#### 特征值API +- linalg.eig(a) + + 求取a的特征值以及特征向量。 + +- linalg.eigh(a) + + 针对埃尔米特矩阵或者对称矩阵求特征值以及特征向量。 + + +目前的linalg库支持 + +```eval_rst +.. automodule:: jittor.linalg + :members: + :undoc-members: +``` + diff --git a/doc/source/jittor.loss3d.md b/doc/source/jittor.loss3d.md new file mode 100644 index 00000000..2851af13 --- /dev/null +++ b/doc/source/jittor.loss3d.md @@ -0,0 +1,10 @@ +jittor.loss3d +===================== + +这里是Jittor的 3d 损失函数 模块的API文档,您可以通过`from jittor import loss3d`来获取该模块。 + +```eval_rst +.. automodule:: jittor.loss3d + :members: chamfer_loss, ChamferLoss, earth_mover_distance, EarthMoverDistance + :undoc-members: +``` diff --git a/doc/source/jittor.md b/doc/source/jittor.md new file mode 100644 index 00000000..a7c72e38 --- /dev/null +++ b/doc/source/jittor.md @@ -0,0 +1,54 @@ +jittor +===================== + +## jittor + +这里是Jittor主模块的API文档,您可以通过`import jittor`来获取该模块。 + +```eval_rst +.. automodule:: jittor + :members: + :undoc-members: +``` + +## jittor.core + +以下为Jittor的内核API,内核API可以通过`jittor.core.XXX`或者`jittor.XXX`直接访问。 + + +```eval_rst +.. automodule:: jittor_core + :imported-members: + :members: + :undoc-members: +``` + +## jittor.ops + +这里是Jittor的基础算子模块的API文档,该API可以通过`jittor.ops.XXX`或者`jittor.XXX`直接访问。 + +```eval_rst +.. automodule:: jittor_core.ops + :members: + :undoc-members: +``` + +## jittor.Var + +这里是Jittor的基础变量类的API文档。该API可以通过`my_jittor_var.XXX`直接访问。 + +```eval_rst +.. automodule:: jittor_core.Var + :members: + :undoc-members: +``` + +## jittor.Misc + +这里是Jittor的基础算子模块的API文档,该API可以通过`jittor.misc.XXX`或者`jittor.XXX`直接访问。 + +```eval_rst +.. automodule:: jittor.misc + :members: + :undoc-members: +``` \ No newline at end of file diff --git a/doc/source/jittor.models.md b/doc/source/jittor.models.md new file mode 100644 index 00000000..0881dc08 --- /dev/null +++ b/doc/source/jittor.models.md @@ -0,0 +1,14 @@ +jittor.models +===================== + +这里是Jittor的骨干网络模块的API文档,您可以通过`from jittor import models`来获取该模块。 + +```eval_rst + +.. automodule:: jittor.models + :members: + :imported-members: + :undoc-members: + :exclude-members: ResNet,ShuffleNetV2,SqueezeNet,VGG +``` + diff --git a/doc/source/jittor.mpi.md b/doc/source/jittor.mpi.md new file mode 100644 index 00000000..d0892cf9 --- /dev/null +++ b/doc/source/jittor.mpi.md @@ -0,0 +1,215 @@ +jittor.mpi +===================== + +计图分布式基于MPI(Message Passing Interface),本文档主要阐述使用计图MPI,进行多卡和分布式训练的教程。 + + +## 计图MPI安装 + +计图依赖`OpenMPI`,用户可以使用如下命令安装`OpenMPI`: + +```bash +sudo apt install openmpi-bin openmpi-common libopenmpi-dev +``` + +也可以参考 [OpenMPI 文档](https://www.open-mpi.org/faq/?category=building#easy-build),自行编译安装。 + +计图会自动检测环境变量中是否包含`mpicc`,如果计图成功的检测到了`mpicc`,那么会输出如下信息: + +``` +[i 0502 14:09:55.758481 24 __init__.py:203] Found mpicc(1.10.2) at /usr/bin/mpicc +``` + +如果计图没有在环境变量中找到mpi,用户也可以手动指定mpicc的路径告诉计图,添加环境变量即可:`export mpicc_path=/you/mpicc/path` + +`OpenMPI`安装完成以后,用户无需修改代码,需要做的仅仅是修改启动命令行,计图就会用数据并行的方式自动完成并行操作。 + +```bash +# 单卡训练代码 +python3.7 -m jittor.test.test_resnet +# 分布式多卡训练代码 +mpirun -np 4 python3.7 -m jittor.test.test_resnet +# 指定特定显卡的多卡训练代码 +CUDA_VISIBLE_DEVICES="2,3" mpirun -np 2 python3.7 -m jittor.test.test_resnet +``` + +这种便捷性的背后是计图的分布式算子的支撑,计图支持的mpi算子后端会使用nccl进行进一步的加速。计图所有分布式算法的开发均在Python前端完成,这让分布式算法的灵活度增强,开发分布式算法的难度也大大降低。 + +## 如何从单卡代码适配多卡代码 + +使用`mpirun`时,以下几种模块会自动检测mpi环境并且自动切换成多卡版本: + +* jittor.optimizer: 自动同步梯度 +* jittor.nn.BatchNorm*: 同步batch norm +* jittor.dataset: 自动数据并行 + +用户在使用MPI进行分布式训练时,计图内部的Dataset类会自动并行分发数据,需要注意的是Dataset类中设置的Batch size是**所有节点的batch size之和**,也就是总batch size, 不是单个节点接收到的batch size。 + +大部分情况下,单卡训练的代码可以直接使用`mpirun`实现分布式多卡运行。 但仍然如下几种情况下,需要对代码进行调整: + +1. 对硬盘进行写操作(保存模型,保存曲线) +2. 需要统计全局信息(validation 上的全局准确率) + +### 对硬盘进行写操作 + +对于第一点,假设原来您的代码如下: + +```python +for i, (images, labels) in enumerate(dataset): + output = model(images) + loss = nn.cross_entropy_loss(output, labels) + acc1 = accuracy(output, labels) + SGD.step(loss) + loss_data = loss.data + writer.add_scalar("Train/loss") +``` + +更改后的代码如下: + +```python +for i, (images, labels) in enumerate(dataset): + output = model(images) + loss = nn.cross_entropy_loss(output, labels) + acc1 = accuracy(output, labels) + SGD.step(loss) + loss_data = loss.data + if jt.rank == 0: + writer.add_scalar("Train/loss") +``` + +这里我们使用了 jt.rank 来限制,只允许第一个进程可以写 loss,这个代码在单卡下也是有效的,因为单卡的 jt.rank 值为 0, 需要注意的是,在 `if jt.rank == 0` 代码块里面的代码,不允许调用任何jittor的api,因为这很有可能导致多卡之间的api调用不一致而产生**死锁**! + +### 需要统计全局信息 + +统计全局信息有两种方法,第一种是使用提供的 mpi op 来实现全局信息统计, 如下所示, 是一个validation的代码: + +```python +def val(epoch): + global min_error + model.eval() + correct_nums = 0 + for i, (images, labels) in enumerate(valdataset): + output = model(images) + correct_nums += top1error(output, labels) + correct_nums.sync() + top1_error = (valdataset.total_len - correct_nums.data[0]) / valdataset.total_len + if top1_error < min_error: + print("[*] Best model is updated ...") + model.save('model_best.pkl') +``` + +更改方案如下: + +```python +def val(epoch): + global min_error + model.eval() + correct_nums = 0 + for i, (images, labels) in enumerate(valdataset): + output = model(images) + correct_nums += top1error(output, labels) + correct_nums.sync() + if jt.in_mpi: + correct_nums = correct_nums.mpi_all_reduce() + top1_error = (valdataset.total_len - correct_nums.data[0]) / valdataset.total_len + if jt.rank == 0 and top1_error < min_error: + print("[*] Best model is updated ...") + model.save('model_best.pkl') +``` + +可以留意到我们首先使用了 `mpi_all_reduce`, 来统计多卡的正确数量(mpi_all_reduce会将多个mpi进程的结果累加起来), 然后在 `jt.rank == 0` 的情况下才更新模型。 + +第二种方法是使用`@jt.single_process_scope()`,被装饰的代码会直接以单进程的方式执行,无需处理多卡。 + +```python +@jt.single_process_scope() +def val(epoch): + ...... +``` + + +## MPI接口 + +下面是 jittor 的 mpi api reference. +目前MPI开放接口如下: + +* `jt.in_mpi`: 当计图不在MPI环境下时,`jt.mpi == False`, 用户可以用这个判断是否在mpi环境下。 +* `jt.world_size`: 获取当前进程总数量,如果没有用mpi,则为1。 +* `jt.rank`: 获取当前进程的编号,区间为`0 ~ jt.world_size-1`, 如果没有用mpi,则为0。 +* `jt.mpi`: 计图的MPI模块。 +* `jt.Module.mpi_param_broadcast(root=0)`: 将模块的参数从root节点广播给其他节点。 +* `jt.mpi.mpi_reduce(x, op='add', root=0)`: 将所有节点的变量x使用算子op,reduce到root节点。如果op是'add'或者'sum',该接口会把所有变量求和,如果op是'mean',该接口会取均值。 + + + +* `jt.mpi.mpi_broadcast(x, root=0)`: 将变量x从root节点广播到所有节点。 + + + +* `jt.mpi.mpi_all_reduce(x, op='add')`: 将所有节点的变量x使用一起reduce,并且吧reduce的结果再次广播到所有节点。如果op是'add'或者'sum',该接口会把所有变量求和,如果op是'mean',该接口会取均值。 + + + + + +```eval_rst +.. automodule:: jittor_mpi_core + :members: + :undoc-members: +.. automodule:: jittor_mpi_core.ops + :members: + :undoc-members: +``` + +## 实例:MPI实现分布式同步批归一化层 + + +下面的代码是使用计图实现分布式同步批归一化层的实例代码,在原来批归一化层的基础上,只需增加三行代码,就可以实现分布式的batch norm,添加的代码如下: + +```python +# 将均值和方差,通过all reduce同步到所有节点 +if self.sync and jt.mpi: + xmean = xmean.mpi_all_reduce("mean") + x2mean = x2mean.mpi_all_reduce("mean") +``` + +> 注:计图内部已经实现了同步的批归一化层,用户不需要自己实现 + +分布式同步批归一化层的完整代码: + +```python +class BatchNorm(Module): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True): + assert affine == None + + self.sync = sync + self.num_features = num_features + self.is_train = is_train + self.eps = eps + self.momentum = momentum + self.weight = init.constant((num_features,), "float32", 1.0) + self.bias = init.constant((num_features,), "float32", 0.0) + self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad() + self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad() + + def execute(self, x): + if self.is_train: + xmean = jt.mean(x, dims=[0,2,3], keepdims=1) + x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1) + # 将均值和方差,通过all reduce同步到所有节点 + if self.sync and jt.mpi: + xmean = xmean.mpi_all_reduce("mean") + x2mean = x2mean.mpi_all_reduce("mean") + + xvar = x2mean-xmean*xmean + norm_x = (x-xmean)/jt.sqrt(xvar+self.eps) + self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentum + self.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentum + else: + running_mean = self.running_mean.broadcast(x, [0,2,3]) + running_var = self.running_var.broadcast(x, [0,2,3]) + norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps) + w = self.weight.broadcast(x, [0,2,3]) + b = self.bias.broadcast(x, [0,2,3]) + return norm_x * w + b +``` diff --git a/doc/source/jittor.nn.md b/doc/source/jittor.nn.md new file mode 100644 index 00000000..0ef2b09f --- /dev/null +++ b/doc/source/jittor.nn.md @@ -0,0 +1,24 @@ +jittor.nn +===================== + +这里是Jittor的神经网络模块的API文档,您可以通过`from jittor import nn`来获取该模块。 + +```eval_rst +.. automodule:: jittor.nn + :members: + :undoc-members: + +.. automodule:: jittor.nn + :imported-members: + :members: Pool, pool, AdaptiveAvgPool2d, Pool3d, AdaptiveMaxPool2d, AdaptiveAvgPool3d, AdaptiveMaxPool2d, pool3d, AvgPool2d, AvgPool3d, avg_pool2d, MaxPool2d, MaxPool3d, max_pool2d, max_pool3d, MaxUnpool2d, MaxUnpool3d + :undoc-members: + +.. autoclass:: jittor.nn.ReLU + :members: +.. autoclass:: jittor.nn.ReLU6 + :members: +.. autoclass:: jittor.nn.LeakyReLU + :members: +.. autoclass:: jittor.nn.Softmax + :members: +``` diff --git a/doc/source/jittor.optim.md b/doc/source/jittor.optim.md new file mode 100644 index 00000000..a0ca96e4 --- /dev/null +++ b/doc/source/jittor.optim.md @@ -0,0 +1,18 @@ +jittor.optim +===================== + +这里是Jittor的优化器模块的API文档,您可以通过`from jittor import optim`来获取该模块。 + +```eval_rst +.. automodule:: jittor.optim + :members: + :undoc-members: +``` + +以下是Jittor的学习率调度模块的API文档,学习率调度模块需要配合优化器使用,您可以通过`from jittor import lr_scheduler`来获取该模块。 + +```eval_rst +.. automodule:: jittor.lr_scheduler + :members: + :undoc-members: +``` \ No newline at end of file diff --git a/doc/source/jittor.transform.md b/doc/source/jittor.transform.md new file mode 100644 index 00000000..23d3da64 --- /dev/null +++ b/doc/source/jittor.transform.md @@ -0,0 +1,10 @@ +jittor.transform +===================== + +这里是Jittor的 数据变换 模块的API文档,您可以通过`from jittor import transform`来获取该模块。 + +```eval_rst +.. automodule:: jittor.transform + :members: + :undoc-members: +``` diff --git a/doc/source/todo.md b/doc/source/todo.md new file mode 100644 index 00000000..ac7015c9 --- /dev/null +++ b/doc/source/todo.md @@ -0,0 +1,12 @@ +TODO +===================== + +## 文档相关 + +* 文档语法规范 +* 文档加上教程链接 +* MPI接口文档 +* 文档自动更新 +* 首页到文档的链接 +* 模型库的文档(GAN,segmentation,detection...) +* 文档补全,重要的类加上使用example diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py new file mode 100644 index 00000000..5a443cfc --- /dev/null +++ b/python/jittor/__init__.py @@ -0,0 +1,2179 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Dun Liang . +# Meng-Hao Guo +# +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** + +__version__ = '1.3.100.10' +from jittor_utils import lock +with lock.lock_scope(): + ori_int = int + ori_float = float + ori_bool = bool + from . import compiler + from .compiler import LOG, has_cuda + from .compiler import compile_custom_ops, compile_custom_op + import jittor_core + import jittor_core as core + from jittor_core import * + from jittor_core.ops import * + from . import compile_extern + from .compile_extern import mkl_ops, mpi, mpi_ops, in_mpi, rank, world_size + if core.get_device_count() == 0: + has_cuda = compile_extern.has_cuda = compiler.has_cuda = False + from .compile_extern import cudnn, curand, cublas, cufft + from .init_cupy import numpy2cupy + +from typing import List, Tuple +import contextlib +import numpy as np +from collections import OrderedDict +from collections.abc import Sequence, Mapping +import types +import pickle +import hashlib +import sys, os +import traceback + +if "SKEY" in os.environ: + import jittor_utils.student_queue + +def dfs_to_numpy(x): + if isinstance(x, list): + for i in range(len(x)): + x[i] = dfs_to_numpy(x[i]) + elif isinstance(x, dict): + for k in x: + x[k] = dfs_to_numpy(x[k]) + elif isinstance(x, Var): + return x.numpy() + return x + +def safepickle(obj, path): + if path.endswith(".pth") or path.endswith(".pt") or path.endswith(".bin"): + from jittor_utils.save_pytorch import save_pytorch + save_pytorch(path, obj) + return + # Protocol version 4 was added in Python 3.4. It adds support for very large objects, pickling more kinds of objects, and some data format optimizations. + # ref: + # obj = dfs_to_numpy(obj) + s = pickle.dumps(obj, 4) + checksum = hashlib.sha1(s).digest() + s += bytes(checksum) + s += b"HCAJSLHD" + with open(path, 'wb') as f: + f.write(s) + +def _load_pkl(s, path): + try: + return pickle.loads(s) + except Exception as e: + msg = str(e) + msg += f"\nPath: \"{path}\"" + if "trunc" in msg: + msg += "\nThis file maybe corrupted, please consider remove it" \ + " and re-download." + raise RuntimeError(msg) + +def _upload(path, url, jk, tdir=""): + tdir = tdir + '/' if tdir != "" else "" + prefix = f"https://cg.cs.tsinghua.edu.cn/jittor/{tdir}assets" + if url.startswith("jittorhub://"): + url = url.replace("jittorhub://", prefix+"/build/checkpoints/") + assert url.startswith(prefix) + suffix = url[len(prefix):] + dir_suffix = "/".join(suffix.split("/")[:-1]) + jkey = flags.cache_path+"/_jkey" + with open(jkey, 'w') as f: + f.write(jk) + assert os.system(f"chmod 600 \"{jkey}\"") == 0 + print(dir_suffix) + assert os.system(f"s""s""h"f" -i \"{jkey}\" jittor" "@" "166" f".111.68.30 mkdir -p Documents/jittor-blog/{tdir}assets{dir_suffix}") == 0 + assert os.system(f"s""c""p"+f" -i \"{jkey}\" \"{path}\" jittor" "@" "166" f".111.68.30:Documents/jittor-blog/{tdir}assets{suffix}") == 0 + assert os.system(f"s""s""h"f" -i \"{jkey}\" jittor" "@" "166" ".111.68.30 Documents/jittor-blog.git/hooks/post-update") == 0 + + +def safeunpickle(path): + if path.startswith("jittorhub://"): + path = path.replace("jittorhub://", f"https://cg.cs.tsinghua.edu.cn/jittor/assets/build/checkpoints/") + if path.startswith("https:") or path.startswith("http:"): + base = path.split("/")[-1] + fname = os.path.join(compiler.ck_path, base) + from jittor_utils.misc import download_url_to_local + download_url_to_local(path, base, compiler.ck_path, None) + path = fname + if not (path.endswith(".pth") or path.endswith(".pkl") or path.endswith(".pt")): + return path + if path.endswith(".pth") or path.endswith(".pt") or path.endswith(".bin") : + from jittor_utils.load_pytorch import load_pytorch + model_dict = load_pytorch(path) + return model_dict + with open(path, "rb") as f: + s = f.read() + if not s.endswith(b"HCAJSLHD"): + return _load_pkl(s, path) + checksum = s[-28:-8] + s = s[:-28] + if hashlib.sha1(s).digest() != checksum: + raise ValueError("Pickle checksum does not match! path: "+path, + " This file maybe corrupted, please consider remove it" + " and re-download.") + return _load_pkl(s, path) + +class _call_no_record_scope: + def __enter__(self): pass + def __exit__(self, *exc): pass + def __call__(self, func): + def inner(*args, **kw): + with self: + ret = func(*args, **kw) + return ret + return inner + +class flag_scope(_call_no_record_scope): + def __init__(self, **jt_flags): + self.jt_flags = jt_flags + + def __enter__(self): + flags_bk = self.flags_bk = {} + try: + for k,v in self.jt_flags.items(): + origin = getattr(flags, k) + flags_bk[k] = origin + # merge dict attrs + if isinstance(origin, dict): + for ok, ov in origin.items(): + if ok not in v: + v[ok] = ov + setattr(flags, k, v) + except: + self.__exit__() + raise + + def __exit__(self, *exc): + for k,v in self.flags_bk.items(): + setattr(flags, k, v) + +class no_grad(flag_scope): + ''' no_grad scope, all variable created inside this +scope will stop grad. + +Example:: + + import jittor as jt + + with jt.no_grad(): + ... + + ''' + def __init__(self, **jt_flags): + self.jt_flags = jt_flags + jt_flags["no_grad"] = 1 + +class enable_grad(flag_scope): + ''' enable_grad scope, all variable created inside this +scope will start grad. + +Example:: + + import jittor as jt + + with jt.enable_grad(): + ... + + ''' + def __init__(self, **jt_flags): + self.jt_flags = jt_flags + jt_flags["no_grad"] = 0 + +single_log_capture = None + +class log_capture_scope(_call_no_record_scope): + """log capture scope + + Example:: + + with jt.log_capture_scope(log_v=0) as logs: + LOG.v("...") + print(logs) + """ + def __init__(self, **jt_flags): + jt_flags["use_parallel_op_compiler"] = 0 + self.fs = flag_scope(**jt_flags) + + def __enter__(self): + global single_log_capture + assert not single_log_capture + single_log_capture = True + self.logs = [] + LOG.log_capture_start() + try: + self.fs.__enter__() + if "log_v" in self.fs.jt_flags: + LOG.log_v = self.fs.jt_flags["log_v"] + return self.logs + except: + LOG.log_capture_stop() + single_log_capture = None + raise + + def __exit__(self, *exc): + global single_log_capture + self.fs.__exit__(*exc) + if "log_v" in self.fs.jt_flags: + LOG.log_v = flags.log_v + LOG.log_capture_stop() + self.logs.extend(LOG.log_capture_read()) + single_log_capture = None + + +class profile_scope(_call_no_record_scope): + """ profile scope + + example:: + + with jt.profile_scope() as report: + ...... + print(report) + """ + def __init__(self, warmup=0, rerun=0, **jt_flags): + self.fs = flag_scope(**jt_flags) + self.warmup = warmup + self.rerun = rerun + + def __enter__(self): + assert not flags.profiler_enable + self.report = [] + try: + self.fs.__enter__() + profiler.start(self.warmup, self.rerun) + return self.report + except: + profiler.stop() + raise + + def __exit__(self, *exc): + profiler.stop() + self.report.extend(profiler.report()) + self.fs.__exit__(*exc) + + +class profile_mark(_call_no_record_scope): + def __init__(self, mark_name: str): + ''' profiler mark is used for profiling part of code, + + Example:: + + a = jt.rand(1000,1000) + b = jt.rand(1000,1000) + jt.sync_all() + results = [] + with jt.profile_scope() as rep: + results.append(jt.matmul(a, b)) + with jt.profile_mark("mark1"): + results.append(jt.matmul(a, b)) + with jt.profile_mark("mark2"): + results.append(jt.matmul(a, b)) + with jt.profile_mark("mark3"): + results.append(jt.matmul(a, b)) + results.append(jt.matmul(a, b)) + + Output:: + + Total time: 46.8ms + Total Memory Access: 57.2MB + [Mark mark3] time: 9ms + [Mark mark2] time: 8.28ms + [Mark mark1] time: 17.7ms + + ''' + self.mark_name = mark_name + def __enter__(self): + self.options = flags.compile_options + new_options = flags.compile_options + prev_marks = "_marks:" + for x in self.options: + if x.startswith(prev_marks): + prev_marks = x + del new_options[x] + new_marks = prev_marks + self.mark_name + ',' + new_options[new_marks] = 1 + flags.compile_options = new_options + + def __exit__(self, *exc): + flags.compile_options = self.options + +class __single_process_scope: + def __init__(self, rank=0): + self.rank = rank + + def __enter__(self): + global in_mpi + self.bk_in_mpi = in_mpi + if mpi: + self.bk_mpi_state = mpi.get_state() + if not in_mpi: + return True + + ret = self.rank == mpi.world_rank() + in_mpi = compile_extern.in_mpi = False + mpi.set_state(False) + return ret + + def __exit__(self, *exc): + global in_mpi + in_mpi = compile_extern.in_mpi = self.bk_in_mpi + if mpi: + mpi.set_state(self.bk_mpi_state) + +def single_process_scope(rank=0): + """ single_process_scope + + Code in this scope will only be executed by single process. + + All the mpi code inside this scope will have not affect. + mpi.world_rank() and mpi.local_rank() will return 0, world_size() will return 1, + + example:: + + @jt.single_process_scope(rank=0) + def xxx(): + ... + """ + def outer(func): + def inner(*args, **kw): + ret = None + sync_all() + with __single_process_scope(rank) as flag: + if flag: + ret = func(*args, **kw) + return ret + return inner + return outer + +def clean(): + import gc + # make sure python do a full collection + gc.collect() + core.gc() + +cast = unary +Var.cast = Var.cast + +def array(data, dtype=None): + ''' Constructs a jittor Var from a number, List, numpy array or another jittor Var. + + :param data: The data to initialize the Var. + :type data: number, list, numpy.ndarray, or jittor.Var. + :param dtype: The data type of the Var. If None, the data type will be inferred from the data. + :type dtype: str, jittor type-cast function, or None. + + ---------------- + + Example:: + + >>> jt.array(1) + jt.Var([1], dtype=int32) + >>> jt.array([0, 2.71, 3.14]) + jt.Var([0. 2.71 3.14], dtype=float32) + >>> jt.array(np.arange(4, dtype=np.uint8)) + jt.Var([0 1 2 3], dtype=uint8) + ''' + if isinstance(data, core.Var): + if dtype is None: + ret = data.clone() + else: + ret = cast(data, dtype) + elif dtype is not None: + if isinstance(dtype, NanoString): + dtype = str(dtype) + elif callable(dtype): + dtype = dtype.__name__ + with jt.flag_scope(auto_convert_64_to_32=0): + ret = ops.array(np.array(data, dtype)) + else: + ret = ops.array(data) + # TODO: move those code to core + amp_reg = jt.flags.amp_reg + if amp_reg and ret.numel() != 1 and ret.dtype.is_float(): + if amp_reg & 16: + if amp_reg & 1: + if ret.dtype != "float32": + return ret.float32() + elif amp_reg & 2: + if ret.dtype != "float16": + return ret.float16() + return ret + +def random(shape, dtype="float32", type="uniform"): + ''' Constructs a random jittor Var. + + :param shape: The shape of the random Var. + :type shape: list or tuple. + :param dtype: The data type of the random Var. + :type dtype: str, jittor type-cast function, or None. + :param type: The random distribution, can be 'uniform' or 'normal'. + :type type: str + + ---------------- + + Example:: + + >>> jt.random((2, 3)) + jt.Var([[0.96788853 0.28334728 0.30482838] + [0.46107793 0.62798643 0.03457401]], dtype=float32) + ''' + for dim in shape: + if dim < 0: + raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}") + ret = ops.random(shape, "float32", type) + ## TODO: move those code to core + #if dtype in ["float16", "bfloat16"]: + # # TODO: make curand support fp16 + # ret = ops.random(shape, "float32", type).cast(dtype) + #else: + # ret = ops.random(shape, dtype, type) + amp_reg = jt.flags.amp_reg + if amp_reg: + if amp_reg & 16: + if amp_reg & 1: + if ret.dtype != "float32": + return ret.float32() + elif amp_reg & 2: + if ret.dtype != "float16": + return ret.float16() + return ret + +def float_auto(x): + if jt.flags.amp_reg & 2: + return x.float16() + return x.float32() +Var.float_auto = float_auto + +def array64(data, dtype=None): + with jt.flag_scope(auto_convert_64_to_32=0): + return array(data, dtype) + +def grad(loss, targets, retain_graph=True): + if type(targets) == core.Var: + return core.grad(loss, [targets], retain_graph)[0] + return core.grad(loss, targets, retain_graph) + +def liveness_info(): + return { + "hold_vars": core.number_of_hold_vars(), + "lived_vars": core.number_of_lived_vars(), + "lived_ops": core.number_of_lived_ops(), + } + +def ones(*shape, dtype="float32"): + ''' Constructs a jittor Var with all elements set to 1. + + :param shape: The shape of the output Var. + :type shape: list or tuple. + :param dtype: The data type of the output Var. + :type dtype: str, jittor type-cast function, or None. + :return: The output Var. + :rtype: jittor.Var + ''' + if isinstance(shape, tuple) and isinstance(shape[-1], (str, NanoString)): + dtype = shape[-1] + shape = shape[:-1] + if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)): + shape = shape[0] + for dim in shape: + if dim < 0: + raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}") + return unary(1, dtype).broadcast(shape) + +def new_ones(x, size): + return ones(size, x.dtype) +Var.new_ones = new_ones + +def ones_like(x): + ''' Constructs a jittor Var with all elements set to 1 and shape same with x. + + :param x: The reference jittor Var. + :type x: jt.Var + :return: The output Var. + :rtype: jittor.Var + ''' + return ones(x.shape,x.dtype) + +def zeros(*shape, dtype="float32"): + ''' Constructs a jittor Var with all elements set to 0. + + :param shape: The shape of the output Var. + :type shape: list or tuple. + :param dtype: The data type of the output Var. + :type dtype: str, jittor type-cast function, or None. + :return: The output Var. + :rtype: jittor.Var + ''' + if isinstance(shape, tuple) and isinstance(shape[-1], (str, NanoString)): + dtype = shape[-1] + shape = shape[:-1] + if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)): + shape = shape[0] + for dim in shape: + if dim < 0: + raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}") + return unary(0, dtype).broadcast(shape) + +def new_zeros(x, size): + return zeros(size, x.dtype) +Var.new_zeros = new_zeros + +def empty(*shape, dtype="float32"): + if isinstance(shape, tuple) and isinstance(shape[-1], (str, NanoString)): + dtype = shape[-1] + shape = shape[:-1] + if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)): + shape = shape[0] + return ops.empty(shape, dtype) + +def new_empty(x, size): + return empty(size, x.dtype) +Var.new_empty = new_empty + +def full(shape,val,dtype="float32"): + ''' Constructs a jittor Var with all elements set to val. + + :param shape: The shape of the output Var. + :type shape: list or tuple. + :param val: The value of the output Var. + :type val: number. + :param dtype: The data type of the output Var. Defaults to jt.float32. + :type dtype: str, jittor type-cast function, or None. + :return: The output Var. + :rtype: jittor.Var + ''' + if not isinstance(shape, (NanoVector, Sequence)): + shape = (shape,) + for dim in shape: + if dim < 0: + raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}") + return unary(val, dtype).broadcast(shape) + +def new_full(x, size, val): + return full(size, val, x.dtype) +Var.new_full = new_full + +def ne(x,y): + return x!=y +Var.ne = ne + +def full_like(x, val, dtype=None) -> Var: + ''' Constructs a jittor Var with all elements set to val and shape same with x. + + :param x: The reference jittor Var. + :type x: jt.Var. + :param val: The value of the output Var. + :type val: number. + :param dtype: if None, the dtype of the output is the same as x. + Otherwise, use the specified dtype. Defaults to None. + :type dtype: str, optional + :return: The output Var. + :rtype: jittor.Var + ''' + if dtype is None: dtype = x.dtype + return full(x.shape, val, dtype) + +def zeros_like(x, dtype=None) -> Var: + ''' Constructs a jittor Var with all elements set to 0 and shape same with x. + + :param x: The reference jittor Var. + :type x: jt.Var + :param dtype: if None, the dtype of the output is the same as x. + Otherwise, use the specified dtype. Defaults to None. + :type dtype: str, optional + :return: The output Var. + :rtype: jittor.Var + ''' + if dtype is None: dtype = x.dtype + return zeros(x.shape, dtype) + +flags = core.Flags() + +def var(x, dim=None, dims=None, unbiased=False, keepdims=False): + """ return the sample variance. If unbiased is True, Bessel's correction will be used. + + :param x: the input jittor Var. + :type x: jt.Var. + :param dim: the dimension to compute the variance. If both dim and dims are None, the variance of the whole tensor will be computed. + :type dim: int. + :param dims: the dimensions to compute the variance. If both dim and dims are None, the variance of the whole tensor will be computed. + :type dims: tuple of int. + :param unbiased: if True, Bessel's correction will be used. + :type unbiased: bool. + :param keepdim: if True, the output shape is same as input shape except for the dimension in dim. + :type keepdim: bool. + + Example:: + + >>> a = jt.rand(3) + >>> a + jt.Var([0.79613626 0.29322362 0.19785859], dtype=float32) + >>> a.var() + jt.Var([0.06888353], dtype=float32) + >>> a.var(unbiased=True) + jt.Var([0.10332529], dtype=float32) + """ + shape = x.shape + new_shape = list(x.shape) + + assert dim is None or dims is None, "dim and dims can not be both set" + if dim is None and dims is None: + dims = list(range(len(shape))) + elif dim is not None: + dims = [dim] + + mean = jt.mean(x, dims, keepdims=True) + mean = jt.broadcast(mean, shape) + + n = 1 + for d in dims: + n *= shape[d] + new_shape[d] = 1 + + sqr = (x - mean) ** 2 + sqr = jt.sum(sqr, dims=dims, keepdims=False) + if unbiased: + n -= 1 + sqr /= n + + if keepdims: + sqr = sqr.view(new_shape) + return sqr +Var.var = var + +def std(x): + matsize=1 + for i in x.shape: + matsize *= i + out=(x-x.mean()).sqr().sum() + out=out/(matsize-1) + out=out.maximum(1e-6).sqrt() + return out +Var.std = std + +def norm(x, p=2, dim=-1, keepdims=False, eps=1e-30, keepdim=False): + keepdim = keepdim or keepdims + assert p==1 or p==2 + if p==1: + return x.abs().sum(dim, keepdim) + if p==2: + return (x.sqr()).sum(dim, keepdim).maximum(eps).sqrt() +Var.norm = norm + +origin_reshape = reshape +def reshape(x, *shape): + if len(shape) == 1 and isinstance(shape[0], (Sequence, NanoVector)): + shape = shape[0] + return origin_reshape(x, shape) +reshape.__doc__ = origin_reshape.__doc__ +Var.view = Var.reshape = view = reshape + +origin_transpose = transpose +def transpose(x, *dim): + if len(dim) == 1 and isinstance(dim[0], (Sequence, NanoVector)): + dim = dim[0] + elif len(dim) == 2: + axes = list(range(x.ndim)) + a, b = dim + axes[a], axes[b] = axes[b], axes[a] + dim = axes + return origin_transpose(x, dim) +transpose.__doc__ = origin_transpose.__doc__ +Var.transpose = Var.permute = permute = transpose + +def flatten(input, start_dim=0, end_dim=-1): + '''flatten dimentions by reshape''' + in_shape = input.shape + start_dim = len(in_shape) + start_dim if start_dim < 0 else start_dim + end_dim = len(in_shape) + end_dim if end_dim < 0 else end_dim + assert end_dim >= start_dim, "end_dim should be larger than or equal to start_dim for flatten function" + if len(in_shape) <= end_dim: + raise IndexError(f"Dimension out of range (expected to be in range of [{-len(in_shape)}, {len(in_shape) - 1}], but got {end_dim})") + out_shape = [] + for i in range(0,start_dim,1): out_shape.append(in_shape[i]) + dims = 1 + for i in range(start_dim, end_dim+1, 1): dims *= in_shape[i] + out_shape.append(dims) + for i in range(end_dim+1,len(in_shape),1): out_shape.append(in_shape[i]) + return input.reshape(out_shape) +Var.flatten = flatten + +Var.detach_inplace = Var.start_grad + +def detach(x): + return x.detach() + +def unsqueeze(x, dim): + shape = list(x.shape) + if dim < 0: dim += len(shape) + 1 + assert dim <= len(shape) + return x.reshape(shape[:dim] + [1] + shape[dim:]) +Var.unsqueeze = unsqueeze + +def squeeze(x, dim=None): + shape = list(x.shape) + if dim is None: + new_shape = [s for s in shape if s > 1] + return x.reshape(new_shape) + else: + if dim < 0: dim += len(shape) + assert dim < len(shape) and dim >= 0 + assert shape[dim] == 1 + return x.reshape(shape[:dim] + shape[dim+1:]) +Var.squeeze = squeeze + +def clamp(x, min_v=None, max_v=None): + if x.shape[0]==0: + return x + if min_v is not None and max_v is not None: + assert min_v <= max_v + if min_v is not None: + x = x.maximum(min_v) + if max_v is not None: + x = x.minimum(max_v) + return x + +Var.clamp = clamp + +def clamp_(x, min_v=None, max_v=None): + ''' In-place version of clamp(). + + Args: + x (Jittor Var): + the input var + min_v ( Number or Var, optional) - lower-bound of clamp range + max_v ( Number or Var, optional) - upper-bound of clamp range + + Return: + x itself after clamp. + + ''' + return x.assign(x.clamp(min_v=min_v, max_v=max_v)) +Var.clamp_ = clamp_ + + +def outer(x, y): + ''' Returns the outer product of two 1-D vectors. + + :param x: the input Var. + :type x: jt.Var, numpy array, or python sequence. + :param y: the input Var. + :type y: jt.Var, numpy array, or python sequence. + + + Example:: + + >>> x = jt.arange(3) + >>> y = jt.arange(4) + >>> jt.outer(x, y) + jt.Var([[0 0 0 0] + [0 1 2 3] + [0 2 4 6]], dtype=int32) + >>> x.outer(y) + jt.Var([[0 0 0 0] + [0 1 2 3] + [0 2 4 6]], dtype=int32) + ''' + return jt.multiply(x.unsqueeze(1), y.unsqueeze(0)) +Var.outer = outer + +def erfinv_(x): + ''' In-place version of erfinv(). + ''' + return x.assign(x.erfinv()) +Var.erfinv_ = erfinv_ + +def erf_(x): + ''' In-place version of erf(). + ''' + return x.assign(x.erf()) +Var.erf_ = erf_ + +def abs_(x): + ''' In-place version of abs(). + ''' + return x.assign(x.abs()) +Var.abs_ = abs_ + +def sigmoid_(x): + ''' In-place version of sigmoid(). + ''' + return x.assign(x.sigmoid()) +Var.sigmoid_ = sigmoid_ + +def sqrt_(x): + ''' In-place version of sqrt(). + ''' + return x.assign(x.sqrt()) +Var.sqrt_ = sqrt_ + +def add_(x, y): + ''' In-place version of add(). + ''' + return x.assign(x.add(y)) +Var.add_ = add_ + +def multiply_(x, y): + ''' In-place version of multiply(). + ''' + return x.assign(x.multiply(y)) +Var.multiply_ = multiply_ + +def type_as(a, b): + return a.unary(op=b.dtype) +Var.type_as = type_as +Var.astype = Var.cast + +def masked_fill(x, mask, value): + return jt.ternary(mask, value, x) +Var.masked_fill = masked_fill + + +def sqr(x): return x*x +Var.sqr = sqr + +def pow(x, y): + ''' computes x^y, element-wise. + + This operation is equivalent to ``x ** y``. + + :param x: the first input. + :type x: a python number or jt.Var. + :param y: the second input. + :type y: a python number or jt.Var. + ''' + if isinstance(x,Var) and isinstance(y, (ori_int, ori_float)) and y == 2: + return x.sqr() + return core.ops.pow(x, y) +Var.pow = Var.__pow__ = pow + +def argmax(x: Var, dim: int, keepdims:bool=False): + ''' Returns the indices and values of the maximum elements along the specified dimension. + + :param x: the input Var. + :type x: jt.Var, numpy array, or python sequence. + :param dim: the dimension to reduce. + :type dim: int. + :param keepdims: whether the output Var has dim retained or not. Defaults to False + :type keepdims: bool, optional + + Example:: + + >>> a = jt.randn((2, 4)) + >>> a + jt.Var([[-0.33272865 -0.4951588 1.4128606 0.13734372] + [-1.633469 0.19593953 -0.7803732 -0.5260756 ]], dtype=float32) + >>> a.argmax(dim=0) + (jt.Var([0 1 0 0], dtype=int32), jt.Var([-0.33272865 0.19593953 1.4128606 0.13734372], dtype=float32)) + >>> a.argmax(dim=1) + (jt.Var([2 1], dtype=int32), jt.Var([1.4128606 0.19593953], dtype=float32)) + ''' + if dim is None: + dim = 0 + x = x.flatten() + return jt.arg_reduce(x, "max", dim, keepdims) +Var.argmax = argmax + +def argmin(x, dim: int, keepdims:bool=False): + ''' Returns the indices and values of the minimum elements along the specified dimension. + + :param x: the input Var. + :type x: jt.Var, numpy array, or python sequence. + :param dim: the dimension to reduce. + :type dim: int. + :param keepdims: whether the output Var has dim retained or not. Defaults to False + :type keepdims: bool, optional + + Example:: + + >>> a = jt.randn((2, 4)) + >>> a + jt.Var([[-0.33272865 -0.4951588 1.4128606 0.13734372] + [-1.633469 0.19593953 -0.7803732 -0.5260756 ]], dtype=float32) + >>> a.argmin(dim=0) + (jt.Var([1 0 1 1], dtype=int32), jt.Var([-1.633469 -0.4951588 -0.7803732 -0.5260756], dtype=float32)) + >>> a.argmin(dim=1) + (jt.Var([1 0], dtype=int32), jt.Var([-0.4951588 -1.633469 ], dtype=float32)) + ''' + return jt.arg_reduce(x, "min", dim, keepdims) +Var.argmin = argmin + +def randn(*size, dtype="float32", requires_grad=True) -> Var: + ''' samples random numbers from a standard normal distribution. + + :param size: shape of the output. + :type size: int or a sequence of int + + :param dtype: data type, defaults to "float32". + :type dtype: str, optional + + :param requires_grad: whether to enable gradient back-propgation, defaults to True. + :type requires_grad: bool, optional + + Example:: + + >>> jt.randn(3) + jt.Var([-1.019889 -0.30377278 -1.4948598 ], dtype=float32) + >>> jt.randn(2, 3) + jt.Var([[-0.15989183 -1.5010914 0.5476955 ] + [-0.612632 -1.1471151 -1.1879086 ]], dtype=float32) + ''' + if isinstance(size, tuple) and isinstance(size[0], (tuple, list, NanoVector)): size = size[0] + for dim in size: + if dim < 0: + raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {size}") + arr = jt.random(size, dtype, "normal") + if not requires_grad: return arr.stop_grad() + return arr + +def rand(*size, dtype="float32", requires_grad=True) -> Var: + ''' samples random numbers from a uniform distribution on the interval [0, 1). + + :param size: shape of the output. + :type size: int or a sequence of int + + :param dtype: data type, defaults to "float32". + :type dtype: str, optional + + :param requires_grad: whether to enable gradient back-propgation. defaults to True. + :type requires_grad: bool, optional + + Example:: + + >>> jt.rand(3) + jt.Var([0.31005102 0.02765604 0.8150749 ], dtype=float32) + >>> jt.rand(2, 3) + jt.Var([[0.96414304 0.3519264 0.8268017 ] + [0.05658621 0.04449705 0.86190987]], dtype=float32) + ''' + if isinstance(size, tuple) and isinstance(size[0], (tuple, list, NanoVector)): size = size[0] + arr = jt.random(size, dtype) + if not requires_grad: return arr.stop_grad() + return arr + +def rand_like(x, dtype=None) -> Var: + ''' samples random values from standard uniform distribution with the same shape as x. + + :param x: reference variable. + :type x: jt.Var + + :param dtype: if None, the dtype of the output is the same as x. + Otherwise, use the specified dtype. Defaults to None. + :type dtype: str, optional + + Example:: + + >>> x = jt.zeros((2, 3)) + >>> jt.rand_like(x) + jt.Var([[0.6164821 0.21476883 0.61959815] + [0.58626485 0.35345772 0.5638483 ]], dtype=float32) + ''' + if dtype is None: dtype = x.dtype + return jt.random(x.shape, dtype) + +def randn_like(x, dtype=None) -> Var: + ''' samples random values from standard normal distribution with the same shape as x. + + :param x: reference variable. + :type x: jt.Var + + :param dtype: if None, the dtype of the output is the same as x. + Otherwise, use the specified dtype. Defaults to None. + :type dtype: str, optional + + Example:: + + >>> x = jt.zeros((2, 3)) + >>> jt.randn_like(x) + jt.Var([[-1.1647032 0.34847224 -1.3061888 ] + [ 1.068085 -0.34366122 0.13172573]], dtype=float32) + ''' + if dtype is None: dtype = x.dtype + return jt.random(x.shape, x.dtype, "normal") + +def randint(low, high=None, shape=(1,), dtype="int32") -> Var: + ''' samples random integers from a uniform distribution on the interval [low, high). + + :param low: lowest intergers to be drawn from the distribution, defaults to 0. + :type low: int, optional + + :param high: One above the highest integer to be drawn from the distribution. + :type high: int + + :param shape: shape of the output size, defaults to (1,). + :type shape: tuple, optional + + :param dtype: data type of the output, defaults to "int32". + :type dtype: str, optional + + Example:: + + >>> jt.randint(3, shape=(3, 3)) + jt.Var([[2 0 2] + [2 1 2] + [2 0 1]], dtype=int32) + >>> jt.randint(1, 3, shape=(3, 3)) + jt.Var([[2 2 2] + [1 1 2] + [1 1 1]], dtype=int32) + ''' + if high is None: low, high = 0, low + for dim in shape: + if dim < 0: + raise RuntimeError(f"Trying to create tensor with negative dimension {dim}: {shape}") + v = (jt.random(shape) * (high - low) + low).clamp(low, high-0.5) + v = jt.floor_int(v) + return v.astype(dtype) + +def randint_like(x, low, high=None) -> Var: + ''' samples random values from standard normal distribution with the same shape as x. + + :param x: reference variable. + :type x: jt.Var + + :param low: lowest intergers to be drawn from the distribution, defaults to 0. + :type low: int, optional + + :param high: One above the highest integer to be drawn from the distribution. + :type high: int + + Example:: + + >>> x = jt.zeros((2, 3)) + >>> jt.randint_like(x, 10) + jt.Var([[9. 3. 4.] + [4. 8. 5.]], dtype=float32) + >>> jt.randint_like(x, 10, 20) + jt.Var([[17. 11. 18.] + [14. 17. 15.]], dtype=float32) + ''' + + return randint(low, high, x.shape, x.dtype) + +def normal(mean, std, size=None, dtype="float32") -> Var: + ''' samples random values from a normal distribution. + + :param mean: means of the normal distributions. + :type mean: int or jt.Var + + :param std: standard deviations of the normal distributions. + :type std: int or jt.Var + + :param size: shape of the output size. if not specified, the + shape of the output is determined by mean or std. Exception will be + raised if mean and std are all integers or have different shape in + this case. Defaults to None + :type size: tuple, optional + + :param dtype: data type of the output, defaults to "float32". + :type dtype: str, optional + + Example:: + + >>> jt.normal(5, 3, size=(2,3)) + jt.Var([[ 8.070848 7.654219 10.252696 ] + [ 6.383718 7.8817277 3.0786133]], dtype=float32) + >>> mean = jt.randint(low=0, high=10, shape=(10,)) + >>> jt.normal(mean, 0.1) + jt.Var([1.9524184 1.0749301 7.9864206 5.9407325 8.1596155 4.824019 7.955083 + 8.972998 6.0674286 8.88026 ], dtype=float32) + ''' + if size is None: + if isinstance(mean, Var) and isinstance(std, Var): + assert mean.shape == std.shape + size = mean.shape + else: + if isinstance(mean, Var): size = mean.shape + if isinstance(std, Var): size = std.shape + return jt.init.gauss(size, dtype, mean, std) + +def attrs(var): + return { + "is_stop_fuse": var.is_stop_fuse(), + "is_stop_grad": var.is_stop_grad(), + "shape": var.shape, + "dtype": var.dtype, + } +Var.attrs = attrs + +def fetch(*args): + ''' Async fetch vars with function closure. + +Example 1:: + + for img,label in enumerate(your_dataset): + pred = your_model(img) + loss = critic(pred, label) + acc = accuracy(pred, label) + jt.fetch(acc, loss, + lambda acc, loss: + print(f"loss:{loss} acc:{acc}" + ) + +Example 2:: + + for i,(img,label) in enumerate(your_dataset): + pred = your_model(img) + loss = critic(pred, label) + acc = accuracy(pred, label) + # variable i will be bind into function closure + jt.fetch(i, acc, loss, + lambda i, acc, loss: + print(f"#{i}, loss:{loss} acc:{acc}" + ) + ''' + assert len(args)>=1 + func = args[-1] + assert callable(func) + args = list(args[:-1]) + if len(args)>0 and isinstance(args[0], Sequence) \ + and len(args[0])>=1 and isinstance(args[0][0], Var): + raise TypeError("jt.Var should not inside a list or tuple.") + + var_map = [] + variables = [] + for i, v in enumerate(args): + if isinstance(v, Var): + variables.append(v) + var_map.append(i) + args[i] = None + def callback(*results): + for i,v in enumerate(results): + args[var_map[i]] = v + func(*args) + core.ops.fetch(variables, callback) + +Var.fetch = fetch + +def display_memory_info(): + import inspect, os + f = inspect.currentframe() + fileline = inspect.getframeinfo(f.f_back) + fileline = f"{os.path.basename(fileline.filename)}:{fileline.lineno}" + core.display_memory_info(fileline) + +def load(path: str): + ''' loads an object from a file. + ''' + model_dict = safeunpickle(path) + return model_dict + +def save(params_dict, path: str): + ''' saves the parameter dictionary to a file. + + :param params_dict: parameters to be saved + :type params_dict: list or dictionary + :param path: file path + :type path: str + ''' + safepickle(params_dict, path) + +def _uniq(x): + a = set() + b = [] + for i in x: + j = id(i) + if j not in a: + a.add(j) + b.append(i) + return b + +class Module: + def __init__(self, *args, **kw): + pass + def execute(self, *args, **kw): + ''' Executes the module computation. + + Raises NotImplementedError if the subclass does not override the method. + ''' + raise NotImplementedError("Please implement 'execute' method of "+str(type(self))) + + def __call__(self, *args, **kw): + return self.execute(*args, **kw) + def __repr__(self): + return self.__str__() + def _get_name(self): + return self.__class__.__name__ + def __name__(self): + pass + + def dfs(self, parents, k, callback, callback_leave=None, recurse=True): + ''' An utility function to traverse the module. ''' + n_children = 0 + for v in self.__dict__.values(): + if isinstance(v, Module): + n_children += 1 + ret = callback(parents, k, self, n_children) + if ret == False: return + if recurse: + for k,v in self.__dict__.items(): + if not isinstance(v, Module): + continue + parents.append(self) + v.dfs(parents, k, callback, callback_leave) + parents.pop() + if callback_leave: + callback_leave(parents, k, self, n_children) + + def __str__(self): + ss = [] + def callback(parents, k, v, n): + # indent key:class_name(extra_repr) + k = f"{k}: " if k is not None else "" + s = f"{' '*(len(parents)*4)}{k}{v.__class__.__name__}" + if n: + s += '(' + else: + s += f"({v.extra_repr()})" + ss.append(s) + def callback_leave(parents, k, v, n): + if n: + ss.append(' '*(len(parents)*4)+')') + self.dfs([], None, callback, callback_leave) + return "\n".join(ss) + + def parameters(self, recurse=True) -> List: + ''' Returns a list of module parameters. + + ---------------- + + Example:: + + >>> net = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 2)) + >>> for p in net.parameters(): + ... print(p.name) + ... + >>> for p in net.parameters(): + ... print(p.name()) + ... + 0.weight + 0.bias + 2.weight + 2.bias + ''' + ps = [] + stack = [] + def callback(parents, k, v, n): + stack.append(str(k)) + dc = v.__dict__ + if isinstance(v, nn.ParameterList): + dc = v.params + for k2, p in dc.items(): + if isinstance(k2, str) and k2.startswith("_"): continue + if isinstance(p, Var): + ps.append(p) + pname = ".".join(stack[1:]+[str(k2)]) + if len(pname) > len(p.name()): + p.name(pname) + def callback_leave(parents, k, v, n): + stack.pop() + self.dfs([], None, callback, callback_leave, recurse) + return _uniq(ps) + + def state_dict(self, to=None, recurse=True): + ''' Returns a dictionary containing + Jittor Var of the module and its descendants. + + Args: + to: target type of var, canbe None or 'numpy' or 'torch' + + Return: + dictionary of module's states. + + Example:: + + import jittor as jt + from jittor.models import resnet50 + jittor_model = resnet50() + dict = jittor_model.state_dict() + jittor_model.load_state_dict(dict) + + Example2(export Jittor params to PyTorch):: + + import jittor as jt + from jittor.models import resnet50 + jittor_model = resnet50() + import torch + from torchvision.models import resnet50 + torch_model = resnet50() + torch_model.load_state_dict(jittor_model.state_dict(to="torch")) + + ''' + uniq_set = set() + ps = {} + stack = [] + def callback(parents, k, v, n): + stack.append(str(k)) + dc = v.__dict__ + if isinstance(v, nn.ParameterList): + dc = v.params + for k2, p in dc.items(): + if isinstance(k2, str) and k2.startswith("_"): continue + if isinstance(p, Var): + if id(p) in uniq_set: continue + if not getattr(p, "persistent", True): + continue + uniq_set.add(id(p)) + pname = ".".join(stack[1:]+[str(k2)]) + ps[pname] = p + if len(pname) > len(p.name()): + p.name(pname) + def callback_leave(parents, k, v, n): + stack.pop() + self.dfs([], None, callback, callback_leave, recurse) + if to == "numpy": + for k,v in ps.items(): + if isinstance(v, Var): + ps[k] = v.numpy() + elif to == "torch": + import torch + for k,v in ps.items(): + if isinstance(v, Var): + ps[k] = torch.Tensor(v.numpy()) + return ps + + def named_parameters(self, recurse=True) -> List[Tuple[str, Var]]: + ''' Returns a list of module parameters and their names. + + ---------------- + + Example:: + + >>> net = nn.Linear(2, 5) + >>> net.named_parameters() + [('weight', jt.Var([[ 0.5964666 -0.3175258 ] + [ 0.41493994 -0.66982657] + [-0.32677156 0.49614117] + [-0.24102807 -0.08656466] + [ 0.15868133 -0.12468725]], dtype=float32)), + ('bias', jt.Var([-0.38282675 0.36271113 -0.7063226 0.02899247 0.52210844], dtype=float32))] + + ''' + state_dict = self.state_dict(recurse=recurse) + return list(state_dict.items()) + + def load_state_dict(self, params) -> None: + ''' + Loads the module's parameters from a dictionary. + ''' + self.load_parameters(params) + + def _load_from_state_dict(self, state, prefix="", *args, **kw): + if len(prefix): + new_state = {} + for k,v in state.items(): + if k.startswith(prefix): + new_state[k[len(prefix):]] = v + state = new_state + self.load_state_dict(state) + + def cuda(self, device=None): + flags.use_cuda = 1 + return self + + def npu(self, device=None): + flags.use_cuda = 1 + return self + + def modules(self) -> List: + ''' Returns a list of sub-modules in the module recursively. + + ---------------- + + Example:: + + >>> net = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 2)) + >>> net.modules() + [Sequential( + 0: Linear(2, 10, float32[10,], None) + 1: relu() + 2: Linear(10, 2, float32[2,], None) + ), Linear(2, 10, float32[10,], None), relu(), Linear(10, 2, float32[2,], None)] + ''' + ms = [] + def callback(parents, k, v, n): + if isinstance(v, Module): + ms.append(v) + self.dfs([], None, callback, None) + return _uniq(ms) + + def named_modules(self): + ''' Returns a list of sub-modules and their names recursively. + + ---------------- + + Example:: + + >>> net = nn.Sequential(nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 2)) + >>> net.named_modules() + [('', Sequential( + 0: Linear(2, 10, float32[10,], None) + 1: relu() + 2: Linear(10, 2, float32[2,], None) + )), ('0', Linear(2, 10, float32[10,], None)), ('1', relu()), ('2', Linear(10, 2, float32[2,], None))] + ''' + ms = [] + stack = [] + def callback(parents, k, v, n): + if isinstance(v, Module): + stack.append(str(k)) + name = ".".join(stack[1:]) + ms.append((name, v)) + def callback_leave(parents, k, v, n): + stack.pop() + self.dfs([], "", callback, callback_leave) + return ms + + def add_module(self, name, module): + setattr(self, name ,module) + return self + + @property + def _modules(self): + return { k:v for k,v in self.__dict__.items() if isinstance(v, Module) } + + @property + def _parameters(self): + return { k:v for k,v in self.__dict__.items() if isinstance(v, Var) } + + def requires_grad_(self, requires_grad=True): + ''' Sets requires_grad for all parameters and sub-modules. ''' + self._requires_grad = requires_grad + self._place_hooker() + return self + + def __hooked_call__(self, *args, **kw): + if hasattr(self, "__fhook2__"): + if len(kw): + self.__fhook2__(self, args, kw) + else: + self.__fhook2__(self, args) + if hasattr(self, "__bihook__"): + if len(kw): + LOG.w("backward hook not support kw") + args = grad_hooker(args, self.__bihook__) + if hasattr(self, "_requires_grad") and not self._requires_grad: + with jt.no_grad(): + ret = self.__hooked_call__(*args, **kw) + else: + ret = self.__hooked_call__(*args, **kw) + if hasattr(self, "__bohook__"): + if len(kw): + LOG.w("backward hook not support kw") + if isinstance(ret, Var): + ret = grad_hooker((ret,), self.__bohook__)[0] + else: + ret = grad_hooker(ret, self.__bohook__) + if hasattr(self, "__fhook__"): + if len(kw): + self.__fhook__(self, args, ret, kw) + else: + self.__fhook__(self, args, ret) + return ret + + def _place_hooker(self): + cls = self.__class__ + if hasattr(cls, "__hooked__"): + return + cls.__hooked__ = True + cls.__call__, cls.__hooked_call__ = \ + cls.__hooked_call__, cls.__call__ + + def register_forward_hook(self, func): + ''' Register a forward function hook that will be called after Module.execute. + + The hook function will be called with the following arguments:: + + hook(module, input_args, output) + or:: + hook(module, input_args, output, input_kwargs) + ''' + self.__fhook__ = func + self._place_hooker() + + def remove_forward_hook(self): + ''' Removes the current forward hook. ''' + if hasattr(self,"__fhook__"): + delattr(self,"__fhook__") + + def register_pre_forward_hook(self, func): + ''' Register a forward function hook that will be called before Module.execute. + + The hook function will be called with the following arguments:: + + hook(module, input_args) + or:: + hook(module, input_args, input_kwargs) + + ''' + self.__fhook2__ = func + self._place_hooker() + + def remove_pre_forward_hook(self): + ''' Removes the current pre-forward hook. ''' + if hasattr(self,"__fhook2__"): + delattr(self,"__fhook2__") + + def register_input_backward_hook(self, func): + self.__bihook__ = func + self._place_hooker() + + def remove_input_backward_hook(self): + if hasattr(self,"__bihook__"): + delattr(self,"__bihook__") + + def register_output_backward_hook(self, func): + self.__bohook__ = func + self._place_hooker() + + def remove_output_backward_hook(self): + if hasattr(self,"__bohook__"): + delattr(self,"__bohook__") + + def register_backward_hook(self, func): + ''' hook both input and output on backpropergation of this module. + +Arguments of hook are defined as:: + + hook(module, grad_input:tuple(jt.Var), grad_output:tuple(jt.Var)) -> tuple(jt.Var) or None + +`grad_input` is the origin gradients of input of this module, `grad_input` is the gradients of output of this module, return value is used to replace the gradient of input. + ''' + _grad_output = None + def bohook(grad_output): + nonlocal _grad_output + _grad_output = grad_output + def bihook(grad_input): + return func(self, grad_input, _grad_output) + self.register_input_backward_hook(bihook) + self.register_output_backward_hook(bohook) + + def remove_backward_hook(self): + ''' Removes the backward input and output hooks. + ''' + self.remove_input_backward_hook() + self.remove_output_backward_hook() + + def children(self) -> List: + ''' Returns an List of the children modules. ''' + cd = [] + def callback(parents, k, v, n): + if len(parents) == 1 and isinstance(v, Module): + cd.append(v) + return False + self.dfs([], None, callback, None) + return cd + + def extra_repr(self): + ss = [] + n = len(self.__init__.__code__.co_varnames) + if self.__init__.__defaults__ is not None: + n -= len(self.__init__.__defaults__) + for i, k in enumerate(self.__init__.__code__.co_varnames[1:]): + v = getattr(self, k) if hasattr(self, k) else None + if isinstance(v, Var): v = v.peek() + s = f"{k}={v}" if i >= n else str(v) + ss.append(s) + return ", ".join(ss) + + def apply(self, func): + ''' Applies a function to all sub-modules recursively. ''' + for m in self.modules(): + func(m) + + def load_parameters(self, params): + ''' loads parameters to the Module. + + :param params: dictionary of parameter names and parameters. + ''' + n_failed = 0 + for key in params.keys(): + v = self + key_ = key.split('.') + end = 0 + for k in key_: + if isinstance(v, nn.Sequential): + if (k in v.layers): + v = v[k] + elif k.isdigit() and (ori_int(k) in v.layers): + v = v[ori_int(k)] + else: + end=1 + break + else: + if hasattr(v, k): + v = getattr(v, k) + if v is None: + continue + assert isinstance(v, (Module, Var)), \ + f"expect a jittor Module or Var, but got <{v.__class__.__name__}>, key: {key}" + else: + end = 1 + break + if end == 1: + if not key.endswith("num_batches_tracked"): + n_failed += 1 + LOG.w(f'load parameter {key} failed ...') + else: + assert isinstance(v, Var), \ + f"expect a jittor Var, but got <{v.__class__.__name__}>, key: {key}" + if isinstance(params[key], np.ndarray) or isinstance(params[key], list): + param = array(params[key]) + elif isinstance(params[key], Var): + param = params[key] + else: + # assume is pytorch tensor + param = array(params[key].cpu().detach().numpy()) + if param.shape == v.shape: + LOG.v(f'load parameter {key} success ...') + v.update(param) + v.sync(False, False) + else: + n_failed += 1 + LOG.e(f'load parameter {key} failed: expect the shape of {key} to be {v.shape}, but got {param.shape}') + if n_failed: + LOG.w(f"load total {len(params)} params, {n_failed} failed") + + def save(self, path: str): + ''' saves parameters to a file. + + :param path: path to save. + :type path: str + + Example:: + + >>> class Net(nn.Module): + >>> ... + >>> net = Net() + >>> net.save('net.pkl') + >>> net.load('net.pkl') + ''' + params = self.state_dict() + safepickle(params, path) + + def load(self, path: str): + ''' loads parameters from a file. + + :param path: path to load. + :type path: str + + Example:: + + >>> class Net(nn.Module): + >>> ... + >>> net = Net() + >>> net.save('net.pkl') + >>> net.load('net.pkl') + + This method also supports loading a state dict from a pytorch .pth file. + + .. note:: + 当载入的参数与模型定义不一致时, jittor 会输出错误信息, 但是不会抛出异常. + 若载入参数出现模型定义中没有的参数名, 则会输出如下信息, 并忽略此参数: + + >>> [w 0205 21:49:39.962762 96 __init__.py:723] load parameter w failed ... + + 若载入参数的 shape 与模型定义不一致, 则会输出如下信息, 并忽略此参数: + + >>> [e 0205 21:49:39.962822 96 __init__.py:739] load parameter w failed: expect the shape of w to be [1000,100,], but got [3,100,100,] + + 如载入过程中出现错误, jittor 会输出概要信息, 您需要仔细核对错误信息 + + >>> [w 0205 21:49:39.962906 96 __init__.py:741] load total 100 params, 3 failed + ''' + self.load_parameters(load(path)) + + def eval(self): + ''' Sets the module in evaluation mode. ''' + def callback(parents, k, v, n): + if isinstance(v, Module): + v.is_train = False + self.dfs([], None, callback, None) + + # backup stop grad or not + if not hasattr(self, "backup_grad_state"): + self.backup_grad_state = {} + for p in self.parameters(): + if id(p) not in self.backup_grad_state: + self.backup_grad_state[id(p)] = not p.is_stop_grad() + p.stop_grad() + return self + + def train(self): + ''' Sets the module in training mode. ''' + def callback(parents, k, v, n): + if isinstance(v, Module): + v.is_train = True + self.dfs([], None, callback, None) + + # backup stop grad or not + if hasattr(self, "backup_grad_state"): + for p in self.parameters(): + if id(p) in self.backup_grad_state and self.backup_grad_state[id(p)]: + p.start_grad() + return self + + def is_training(self) -> bool: + ''' Returns whether the module is in training mode.''' + if not hasattr(self, "is_train"): + self.is_train = True + return self.is_train + + @property + def training(self): + if not hasattr(self, "is_train"): + self.is_train = True + return self.is_train + + @training.setter + def training(self, value): + self.is_train = value + + def mpi_param_broadcast(self, root=0): + if not in_mpi: return + for p in self.parameters(): + p.update(p.mpi_broadcast(root)) + + def __setattr__(self, key, value): + object.__setattr__(self, key, value) + + def __getattr__(self, key): + return object.__getattribute__(self, key) + + def register_buffer(self, key, value, persistent=True): + value.persistent = persistent + object.__setattr__(self, key, value) + return value + + @property + def _buffers(self): + buffers = {} + for k,v in self.__dict__.items(): + if isinstance(v, jt.Var): + buffers[k] = v + return buffers + + def named_buffers(self,recurse=False): + + buffers = [] + for k,v in self.__dict__.items(): + if isinstance(v, jt.Var): + buffers.append((k,v)) + return buffers + + def named_children(self,): + childs = [] + for k,v in self.__dict__.items(): + if isinstance(v,Module): + childs.append((k,v)) + return childs + + def float64(self): + '''convert all parameters to float16''' + self._amp_level = 0 + for p in self.parameters(): + if p.dtype.is_float(): + p.assign(p.float64()) + return self + + def float32(self): + '''convert all parameters to float32''' + self._amp_level = 0 + for p in self.parameters(): + if p.dtype.is_float(): + p.assign(p.float32()) + return self + + def float16(self): + '''convert all parameters to float16''' + # self._amp_level = 3 if flags.th_mode else 4 + # amp level better set globally + self._amp_level = -1 + if self._amp_level >= 0: + cls = self.__class__ + cls.__call__ = cls.__half_call__ + for p in self.parameters(): + if p.dtype.is_float(): + p.assign(p.float16()) + return self + + def bfloat16(self): + '''convert all parameters to bfloat16''' + # self._amp_level = 3 if flags.th_mode else 4 + # amp level better set globally + self._amp_level = -1 + if self._amp_level >= 0: + cls = self.__class__ + cls.__call__ = cls.__half_call__ + for p in self.parameters(): + if p.dtype.is_float(): + p.assign(p.bfloat16()) + return self + + def __half_call__(self, *args, **kw): + amp_level = getattr(self, "_amp_level", -1) + if amp_level >= 0: + with flag_scope(amp_level=amp_level): + return self.execute(*args, **kw) + else: + return self.execute(*args, **kw) + + def half(self): + '''convert all parameters to float16''' + return self.float16() + + def float_auto(self): + '''convert all parameters to float16 or float32 automatically + by jt.flags.auto_mixed_precision_level and jt.flags.amp_reg''' + self._amp_level = -1 + for p in self.parameters(): + if p.dtype.is_float(): + p.assign(p.float_auto()) + return self + + + +class Function(Module): + ''' Function Module for customized backward operations + +Example 1 (Function can have multiple input and multiple output, and user +can store value for backward computation):: + + import jittor as jt + from jittor import Function + + class MyFunc(Function): + def execute(self, x, y): + self.x = x + self.y = y + return x*y, x/y + + def grad(self, grad0, grad1): + return grad0 * self.y, grad1 * self.x + a = jt.array(3.0) + b = jt.array(4.0) + func = MyFunc.apply + c,d = func(a, b) + da, db = jt.grad(c+d*3, [a, b]) + assert da.data == 4 + assert db.data == 9 + +Example 2(Function can return None for no gradiant, and gradiant +can also be None):: + + import jittor as jt + from jittor import Function + + class MyFunc(Function): + def execute(self, x, y): + self.x = x + self.y = y + return x*y, x/y + + def grad(self, grad0, grad1): + assert grad1 is None + return grad0 * self.y, None + a = jt.array(3.0) + b = jt.array(4.0) + func = MyFunc.apply + c,d = func(a, b) + d.stop_grad() + da, db = jt.grad(c+d*3, [a, b]) + assert da.data == 4 + assert db.data == 0 + + ''' + def __call__(self, *args): + if flags.no_grad: + return self.execute(*args) + backup = args + args = list(args) + taped_inputs = [] + taped_outputs = [] + input_mask = [-1] * len(args) + for i,v in enumerate(args): + if isinstance(v, Var): + if v.is_stop_grad(): + # -2 in input_mask represents it is stop_grad + input_mask[i] = -2 + continue + v = v.tape() + input_mask[i] = len(taped_inputs) + args[i] = v + taped_inputs.append(v) + ori_res = self.execute(*args) + if not isinstance(ori_res, Sequence): + res = [ori_res] + else: + res = list(ori_res) + output_mask = [-1] * len(res) + for i,v in enumerate(res): + if isinstance(v, Var): + v = v.tape() + output_mask[i] = len(taped_outputs) + res[i] = v + taped_outputs.append(v) + self.input_mask = input_mask + self.output_mask = output_mask + # tape output and input together so + # backward treat them as one operator + tape_together(taped_inputs, taped_outputs, self._grad) + if isinstance(ori_res, Sequence): + return res + else: + return res[0] + + def _grad(self, *args): + new_args = ( (args[i] if i>=0 else None) for i in self.output_mask ) + ret = self.grad(*new_args) + if not isinstance(ret, Sequence): + ret = (ret,) + new_ret = [] + for i, r in enumerate(ret): + j = self.input_mask[i] + if j<0: + # -2 in input_mask represents it is stop_grad + assert r is None or j==-2, f"{type(self)}'s {i}-th returned grad should be None, "\ + "because the input value is not jittor variable." + else: + new_ret.append(r) + return new_ret + + def dfs(self, parents, k, callback, callback_leave=None, recurse=True): + pass + + @classmethod + def apply(cls, *args, **kw): + func = cls() + return func(*args, **kw) + +class GradHooker(Function): + def __init__(self, hook): + self.hook = hook + + def execute(self, *args): + return args + + def grad(self, *grad_input): + ret = self.hook(grad_input) + if ret: grad_input = ret + return grad_input + +def grad_hooker(args, hook): + hooker = GradHooker(hook) + return hooker(*args) + +def register_hook(v, hook): + """ register hook of any jittor Variables, if hook return not None, +the gradient of this variable will be alter, + + Example:: + + x = jt.array([0.0, 0.0]) + y = x * [1,2] + y.register_hook(lambda g: g*2) + dx = jt.grad(y, x) + print(dx) + # will be [2, 4] + + """ + def _hook(grads): + g = hook(grads[0]) + if g is not None: + return (g,) + return None + hooker = GradHooker(_hook) + v.swap(hooker(v)[0]) + return v + +Var.register_hook = register_hook + +def make_module(func, exec_n_args=1): + class MakeModule(Module): + def __init__(self, *args, **kw): + self.args = args + self.kw = kw + def execute(self, *args): + return func(*args, *self.args, **self.kw) + def __str__(self): + return f"{func.__name__}({self.extra_repr()})" + def extra_repr(self): + return ",".join(map(str, self.args)) + MakeModule.__name__ = func.__name__ + return MakeModule + + +def dirty_fix_pytorch_runtime_error(): + ''' This funtion should be called before pytorch. + + Example:: + + import jittor as jt + jt.dirty_fix_pytorch_runtime_error() + import torch + ''' + import os, platform + + if platform.system() == 'Linux': + os.RTLD_GLOBAL = os.RTLD_GLOBAL | os.RTLD_DEEPBIND + import jittor_utils + with jittor_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW): + import torch + + +import atexit + +class ExitHooks(object): + def __init__(self): + self.exit_code = None + self.exception = None + + def hook(self): + self._orig_exit = sys.exit + sys.exit = self.exit + sys.excepthook = self.exc_handler + + def exit(self, code=0): + self.exit_code = code + self._orig_exit(code) + + def exc_handler(self, exc_type, exc, *args): + self.exception = exc + traceback.print_exception(exc_type, exc, *args) + +hooks = ExitHooks() +hooks.hook() + +def jittor_exit(): + if hooks.exit_code is not None: + pass + elif hooks.exception is not None: + pass + else: + pass + # core.sync_all(True) + core.cleanup() +atexit.register(jittor_exit) + +def vtos(v): + data_str = f"jt.Var({v.numpy()}, dtype={v.dtype})" + data_str = data_str.replace("\n", "\n ") + return data_str + +Var.__str__ = vtos +Var.__repr__ = vtos +Var.peek = lambda x: f"{x.dtype}{x.shape}" + +def size(v, dim=None): + if dim is None: + return v.shape + return v.shape[dim] +Var.size = size + + +def to_int(v): + return ori_int(v.item()) + +def to_float(v): + return ori_float(v.item()) + +def to_bool(v): + assert v.dtype.is_int() or v.dtype.is_bool() + return ori_bool(v.item()) + +Var.__int__ = to_int +Var.__float__ = to_float +Var.__bool__ = to_bool + +def format(v, spec): + return v.item().__format__(spec) +Var.__format__ = format + +def get_len(var): + return var.shape[0] + +Var.__len__ = get_len +int = int32 +Var.int = Var.int32 +Var.long = Var.int32 +float = float32 +Var.float = Var.float32 +double = float64 +Var.double = Var.float64 +half = float16 +Var.half = Var.float16 + +def is_var(v): + return isinstance(v, Var) + +# __array__ interface is used for np.array(jt_var) +Var.__array__ = Var.numpy +Var.__array_priority__ = 2000 +# __reduce__, __module__ is used for pickle.dump and pickle.load +Var.__module__ = "jittor" +Var.__reduce__ = lambda self: (Var, (self.data,)) + +from . import nn +from . import attention +from . import lr_scheduler +from . import linalg +from .linalg import einsum +from .nn import matmul, \ + bmm, bmm_transpose, \ + baddbmm +from . import contrib +from . import numpy2cupy +from .contrib import concat, cat +from .misc import * +from . import sparse +from . import optim +from . import dataset +from . import init + +dtype = NanoString + +import jittor_utils + +for backend in jittor_utils.backends: + if hasattr(backend, "post_process"): + backend.post_process() + +# impl x.func(...) -> func_(...) +args = {"x", "input", "self"} +_white_list = {"mul", "add", "sub"} +for k,v in list(Var.__dict__.items()): + if k.startswith("_"): continue + if k.endswith("_"): continue + if not callable(v): continue + + if k not in _white_list: + if not hasattr(v, "__code__"): continue + conames = v.__code__.co_varnames + if len(conames) == 0: continue + arg_name = conames[0] + if arg_name not in args: continue + + new_k = k+"_" + if hasattr(Var, new_k): continue + def inplace_wrapper(new_k, prev_func): + setattr(Var, new_k, lambda x, *args, **kw: x.assign(prev_func(x, *args, **kw))) + inplace_wrapper(new_k, v) + +from . import math_util +from .math_util import * +from . import distributions + +if jt.compiler.has_acl: + from jittor.extern.acl.acl_compiler import change_function + change_function() \ No newline at end of file diff --git a/python/jittor/__init__.pyi b/python/jittor/__init__.pyi new file mode 100644 index 00000000..b849af4c --- /dev/null +++ b/python/jittor/__init__.pyi @@ -0,0 +1,7995 @@ +from jittor_core import * +from jittor_core.ops import * +from .misc import * +from . import attention as attention, contrib as contrib, dataset as dataset, init as init, linalg as linalg, lr_scheduler as lr_scheduler, numpy2cupy as numpy2cupy, optim as optim, sparse as sparse +from .compile_extern import cublas as cublas, cudnn as cudnn, cufft as cufft, curand as curand, mkl_ops as mkl_ops, mpi_ops as mpi_ops, world_size as world_size +from .compiler import compile_custom_op as compile_custom_op, compile_custom_ops as compile_custom_ops +from .contrib import concat as concat +from .nn import bmm as bmm, bmm_transpose as bmm_transpose, matmul as matmul +from collections import OrderedDict as OrderedDict +from collections.abc import Mapping as Mapping +from typing import Any, List, Tuple + + +def safepickle(obj, path) -> None: ... +def safeunpickle(path): ... + +class _call_no_record_scope: + def __enter__(self) -> None: ... + def __exit__(self, *exc) -> None: ... + def __call__(self, func): ... + +class flag_scope(_call_no_record_scope): + jt_flags: Any + def __init__(self, **jt_flags) -> None: ... + def __enter__(self) -> None: ... + def __exit__(self, *exc) -> None: ... + +class no_grad(flag_scope): + jt_flags: Any + def __init__(self, **jt_flags) -> None: ... + +class enable_grad(flag_scope): + jt_flags: Any + def __init__(self, **jt_flags) -> None: ... + +single_log_capture: Any + +class log_capture_scope(_call_no_record_scope): + fs: Any + def __init__(self, **jt_flags) -> None: ... + logs: Any + def __enter__(self): ... + def __exit__(self, *exc) -> None: ... + +class profile_scope(_call_no_record_scope): + fs: Any + warmup: Any + rerun: Any + def __init__(self, warmup: int = ..., rerun: int = ..., **jt_flags) -> None: ... + report: Any + def __enter__(self): ... + def __exit__(self, *exc) -> None: ... + +class __single_process_scope: + rank: Any + def __init__(self, rank: int = ...) -> None: ... + bk_in_mpi: Any + bk_mpi_state: Any + def __enter__(self): ... + def __exit__(self, *exc) -> None: ... + +def single_process_scope(rank: int = ...): ... +def clean() -> None: ... +cast = unary + +def array(data, dtype: Any | None = ...): ... +def random(shape, dtype: str = ..., type: str = ...): ... +def float_auto(x): ... +def array64(data, dtype: Any | None = ...): ... +def grad(loss, targets, retain_graph: bool = ...): ... +def liveness_info(): ... +def ones(shape, dtype: str = ...): ... +def ones_like(x): ... +def zeros(shape, dtype: str = ...): ... +def full(shape, val, dtype: str = ...): ... +def full_like(x, val, dtype: Any | None = ...) -> Var: ... +def zeros_like(x, dtype: Any | None = ...) -> Var: ... + +def var(x, dim: Any | None = ..., dims: Any | None = ..., unbiased: bool = ..., keepdims: bool = ...): ... +def std(x): ... +def norm(x, p: int = ..., dim: int = ..., keepdims: bool = ..., eps: float = ..., keepdim: bool = ...): ... +origin_reshape = reshape + +def reshape(x, *shape): ... +view = reshape +origin_transpose = transpose + +def transpose(x, *dim): ... +permute = transpose +def flatten(input, start_dim: int = ..., end_dim: int = ...): ... +def detach(x): ... +def unsqueeze(x, dim): ... +def squeeze(x, dim): ... +def clamp(x, min_v: Any | None = ..., max_v: Any | None = ...): ... +def type_as(a, b): ... +def masked_fill(x, mask, value): ... +def sqr(x): ... +def pow(x, y): ... +def argmax(x, dim, keepdims: bool = ...): ... +def argmin(x, dim, keepdims: bool = ...): ... +def randn(*size, dtype: str = ..., requires_grad: bool = ...) -> Var: ... +def rand(*size, dtype: str = ..., requires_grad: bool = ...) -> Var: ... +def rand_like(x, dtype: Any | None = ...) -> Var: ... +def randn_like(x, dtype: Any | None = ...) -> Var: ... +def randint(low, high: Any | None = ..., shape=..., dtype: str = ...) -> Var: ... +def randint_like(x, low, high: Any | None = ...) -> Var: ... +def normal(mean, std, size: Any | None = ..., dtype: str = ...) -> Var: ... +def attrs(var): ... +def fetch(*args) -> None: ... +def display_memory_info() -> None: ... +def load(path: str): ... +def save(params_dict, path: str): ... + +class Module: + def __init__(self, *args, **kw) -> None: ... + def execute(self, *args, **kw) -> None: ... + def __call__(self, *args, **kw): ... + def __name__(self) -> None: ... + def dfs(self, parents, k, callback, callback_leave: Any | None = ...) -> None: ... + def parameters(self) -> List: ... + def state_dict(self, to: Any | None = ...): ... + def named_parameters(self) -> List[Tuple[str, Var]]: ... + def load_state_dict(self, params) -> None: ... + def modules(self) -> List: ... + def named_modules(self): ... + def requires_grad_(self, requires_grad: bool = ...): ... + def __hooked_call__(self, *args, **kw): ... + __fhook__: Any + def register_forward_hook(self, func) -> None: ... + def remove_forward_hook(self) -> None: ... + __fhook2__: Any + def register_pre_forward_hook(self, func) -> None: ... + def remove_pre_forward_hook(self) -> None: ... + __bihook__: Any + def register_input_backward_hook(self, func) -> None: ... + def remove_input_backward_hook(self) -> None: ... + __bohook__: Any + def register_output_backward_hook(self, func) -> None: ... + def remove_output_backward_hook(self) -> None: ... + def register_backward_hook(self, func): ... + def remove_backward_hook(self) -> None: ... + def children(self) -> List: ... + def extra_repr(self): ... + def apply(self, func) -> None: ... + def load_parameters(self, params) -> None: ... + def save(self, path: str): ... + def load(self, path: str): ... + backup_grad_state: Any + def eval(self) -> None: ... + def train(self) -> None: ... + is_train: bool + def is_training(self) -> bool: ... + def mpi_param_broadcast(self, root: int = ...) -> None: ... + def __setattr__(self, key, value) -> None: ... + def __getattr__(self, key): ... + def float64(self): ... + def float16(self): ... + def half(self): ... + def float_auto(self): ... + +class Function(Module): + input_mask: Any + output_mask: Any + def __call__(self, *args): ... + def dfs(self, parents, k, callback, callback_leave: Any | None = ...) -> None: ... + @classmethod + def apply(cls, *args, **kw): ... + +class GradHooker(Function): + hook: Any + def __init__(self, hook) -> None: ... + def execute(self, *args): ... + def grad(self, *grad_input): ... + +def grad_hooker(args, hook): ... +def register_hook(v, hook): ... +def make_module(func, exec_n_args: int = ...): ... +def dirty_fix_pytorch_runtime_error() -> None: ... + +class ExitHooks: + exit_code: Any + exception: Any + def __init__(self) -> None: ... + def hook(self) -> None: ... + def exit(self, code: int = ...) -> None: ... + def exc_handler(self, exc_type, exc, *args) -> None: ... + +hooks: Any + +def jittor_exit() -> None: ... +def vtos(v): ... +def size(v, dim: Any | None = ...): ... +def to_int(v): ... +def to_float(v): ... +def to_bool(v): ... +def format(v, spec): ... +def get_len(var): ... +half = float16 + +def is_var(v): ... +from typing import List, Tuple, Callable, overload +import numpy +def ternary(cond: Var, x: Var, y: Var)-> Var: + ... +@overload +def reindex(x: Var, shape: Tuple[int], indexes: List[str], overflow_value: float=0, overflow_conditions: List[str]={}, extras: List[Var]={})-> Var: + '''Document: + * + Reindex Operator is a one-to-many map operator. + It performs equivalent Python-pseudo implementation below:: + + # input is x, output is y + n = len(shape)-1 + m = len(x.shape)-1 + k = len(overflow_conditions)-1 + y = np.zeros(shape, x.dtype) + for i0 in range(shape[0]): # 1-st loop + for i1 in range(shape[1]): # 2-nd loop + ...... # many loops + for in in range(shape[n]) # n+1 -th loop + if is_overflow(i0,i1,...,in): + y[i0,i1,...,in] = overflow_value + else: + # indexes[i] is a c++ style integer expression consisting of i0,i1,...,in + y[i0,i1,...,in] = x[indexes[0],indexes[1],...,indexes[m]] + + # is_overflow is defined as following + def is_overflow(i0,i1,...,in): + return ( + indexes[0] < 0 || indexes[0] >= x.shape[0] || + indexes[1] < 0 || indexes[1] >= x.shape[1] || + ...... + indexes[m] < 0 || indexes[m] >= x.shape[m] || + + # overflow_conditions[i] is a c++ style boolean expression consisting of i0,i1,...,in + overflow_conditions[0] || + overflow_conditions[1] || + ...... + overflow_conditions[k] + ) + ---------------- + * [in] x: A input jittor Var + + * [in] shape: the output shape, a integer array + + * [in] indexes: array of c++ style integer expression, its length should be the same with the number of dimension of x, some buildin variables it can use are:: + + XDIM, xshape0, ..., xshapen, xstride0, ..., xstriden + YDIM, yshape0, ..., yshapem, ystride0, ..., ystridem + i0, i1, ..., in + @e0(...), @e1(...) for extras input index + e0p, e1p , ... for extras input pointer + + * [in] overflow_value: overflow value + + * [in] overflow_conditions: array of c++ style boolean expression, it length can be vary. the buildin variables it can use are the same with indexes + + * [in] extras: extra var used for index + + ---------------- + Example + Convolution implemented by reindex operation:: + + def conv(x, w): + N,H,W,C = x.shape + Kh, Kw, _C, Kc = w.shape + assert C==_C + xx = x.reindex([N,H-Kh+1,W-Kw+1,Kh,Kw,C,Kc], [ + 'i0', # Nid + 'i1+i3', # Hid+Khid + 'i2+i4', # Wid+KWid + 'i5', # Cid + ]) + ww = w.broadcast_var(xx) + yy = xx*ww + y = yy.sum([3,4,5]) # Kh, Kw, C + return y, yy''' + ... +@overload +def reindex(x: Var, indexes: List[Var], overflow_value: float=0, overflow_conditions: List[str]={})-> Var: + '''Document: + * Alias x.reindex([i,j,k]) -> + x.reindex(i.shape, ['@e0(...)','@e1(...)','@e2(...)',], extras=[i,j,k])''' + ... +def reindex_var(x: Var, indexes: List[Var], overflow_value: float=0, overflow_conditions: List[str]={})-> Var: + '''Document: + * Alias x.reindex([i,j,k]) -> + x.reindex(i.shape, ['@e0(...)','@e1(...)','@e2(...)',], extras=[i,j,k])''' + ... +@overload +def index(shape: Tuple[int], dim: int, dtype: str="int32")-> Var: + '''Document: + * + Index Operator generate index of shape. + + It performs equivalent Python-pseudo implementation below:: + + n = len(shape)-1 + x = np.zeros(shape, dtype) + for i0 in range(shape[0]): # 1-st loop + for i1 in range(shape[1]): # 2-nd loop + ...... # many loops + for in in range(shape[n]) # n+1 -th loop + x[i0,i1,...,in] = i@dim + + * [in] shape: the output shape, a integer array + * [in] dim: the dim of the index. + * [in] dtype: the data type string, default int32 + + Example:: + + print(jt.index([2,2], 0)()) + # output: [[0,0],[1,1]] + print(jt.index([2,2], 1)()) + # output: [[0,1],[0,1]]''' + ... +@overload +def index(shape: Tuple[int], dtype: str="int32")-> Tuple[Var]: + '''Document: + * + Index Operator generate index of shape. + + It performs equivalent Python-pseudo implementation below:: + + n = len(shape)-1 + x = np.zeros(shape, dtype) + for i0 in range(shape[0]): # 1-st loop + for i1 in range(shape[1]): # 2-nd loop + ...... # many loops + for in in range(shape[n]) # n+1 -th loop + x[i0,i1,...,in] = i@dim + + * [in] shape: the output shape, a integer array + * [in] dim: the dim of the index. + * [in] dtype: the data type string, default int32 + + Example:: + + print(jt.index([2,2], 0)()) + # output: [[0,0],[1,1]] + print(jt.index([2,2], 1)()) + # output: [[0,1],[0,1]]''' + ... +@overload +def index(a: Var, dim: int, dtype: str="int32")-> Var: + '''Document: + * shape dependency version of index op + jt.index_var(a, 1) similar with jt.index(a.shape, 1)''' + ... +@overload +def index(a: Var, dtype: str="int32")-> Tuple[Var]: + '''Document: + * shape dependency version of index op + jt.index_var(a) similar with jt.index(a.shape)''' + ... +@overload +def index_var(a: Var, dim: int, dtype: str="int32")-> Var: + '''Document: + * shape dependency version of index op + jt.index_var(a, 1) similar with jt.index(a.shape, 1)''' + ... +@overload +def index_var(a: Var, dtype: str="int32")-> Tuple[Var]: + '''Document: + * shape dependency version of index op + jt.index_var(a) similar with jt.index(a.shape)''' + ... +def binary(x: Var, y: Var, p: str)-> Var: + ... +def pow(x: Var, y: Var)-> Var: + '''Document: + * + Computes ``x^y``, element-wise. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... +def maximum(x: Var, y: Var)-> Var: + '''Document: + * + Returns the element-wise maximum of ``x`` and ``y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... +def minimum(x: Var, y: Var)-> Var: + '''Document: + * + Returns the element-wise minimum of ``x`` and ``y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... +def add(x: Var, y: Var)-> Var: + '''Document: + * + Element-wise adds ``x`` and ``y`` and returns a new Var. + + This operation is equivalent to ``x + y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... +def subtract(x: Var, y: Var)-> Var: + '''Document: + * + Element-wise subtract ``y`` from ``x`` and returns a new Var. + + This operation is equivalent to ``x - y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... +def multiply(x: Var, y: Var)-> Var: + '''Document: + * + Element-wise muliplies ``x`` with ``y`` and returns a new Var. + + This operation is equivalent to ``x * y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... +def divide(x: Var, y: Var)-> Var: + '''Document: + * + Element-wise divide ``x`` by ``y`` and returns a new Var. + + This operation is equivalent to ``x / y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var. + + ---------------- + + Example-1:: + >>> a = jt.empty((3,), dtype=jt.int32) + >>> a + jt.Var([707406378 707406378 707406378], dtype=int32) + >>> b = jt.empty((3,), dtype=jt.int32) + >>> b + jt.Var([674510453 171649398 538976288], dtype=int32) + >>> jt.divide(a, b) + jt.Var([1.0487701 4.1212287 1.3125001], dtype=float32) + >>> a / b + jt.Var([1.0487701 4.1212287 1.3125001], dtype=float32) + + .. note :: + returns float value even if the dtype of input Vars are both integers. + @see jt.ops.floor_divide() for floor division.''' + ... +def floor_divide(x: Var, y: Var)-> Var: + '''Document: + * + Element-wise divide ``x`` by ``y`` and returns the floor of the result. + + This operation is equivalent to ``x // y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var. + + ---------------- + + Example-1:: + >>> a = jt.randint(1, 10, (3,), dtype=jt.int32) + >>> a + jt.Var([9 2 7], dtype=int32) + >>> b = jt.randint(1, 10, (3,), dtype=jt.int32) + >>> b + jt.Var([6 4 6], dtype=int32) + >>> jt.floor_divide(a, b) + jt.Var([1 0 1], dtype=int32) + >>> a // b + jt.Var([1 0 1], dtype=int32)''' + ... +def mod(x: Var, y: Var)-> Var: + '''Document: + * + Returns the element-wise remainder of division. + + This operation is equivalent to ``x % y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var. + + ---------------- + + Example-1:: + >>> a = jt.rand(3) + >>> a + jt.Var([0.3989529 0.20159635 0.22973768], dtype=float32) + >>> b = jt.rand(3) + >>> b + jt.Var([0.20121202 0.7704864 0.5654395 ], dtype=float32) + >>> jt.mod(a, b) + jt.Var([0.19774088 0.20159635 0.22973768], dtype=float32) + >>> a % b + jt.Var([0.19774088 0.20159635 0.22973768], dtype=float32)''' + ... +def less(x: Var, y: Var)-> Var: + '''Document: + * + Returns ``x < y`` element-wise. + + This operation is equivalent to ``x < y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... +def less_equal(x: Var, y: Var)-> Var: + '''Document: + * + Returns ``x <= y`` element-wise. + + This operation is equivalent to ``x <= y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... +def greater(x: Var, y: Var)-> Var: + '''Document: + * + Returns ``x > y`` element-wise. + + This operation is equivalent to ``x > y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... +def greater_equal(x: Var, y: Var)-> Var: + '''Document: + * + Returns ``x >= y`` element-wise. + + This operation is equivalent to ``x >= y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... +def equal(x: Var, y: Var)-> Var: + '''Document: + * + Returns ``x == y`` element-wise. + + This operation is equivalent to ``x == y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... +def not_equal(x: Var, y: Var)-> Var: + '''Document: + * + Returns ``x != y`` element-wise. + + This operation is equivalent to ``x != y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... +def left_shift(x: Var, y: Var)-> Var: + '''Document: + * + Shifts the bits of ``x`` to the left by ``y``. + + Bits are shifted to the left by appending ``y`` 0s at the right of ``x``. + This operation is equivalent to ``x << y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var (int32 or int64). + + * [in] y: the second input, a python number or jt.Var (int32 or int64). + + ---------------- + + Example-1:: + >>> a = jt.randint(0, 10, shape=(3,)) + >>> a + jt.Var([7 6 7], dtype=int32) + >>> b = jt.randint(0, 10, shape=(3,)) + >>> b + jt.Var([3 9 8], dtype=int32) + >>> jt.left_shift(a, b) + jt.Var([ 56 3072 1792], dtype=int32) + >>> a << b + jt.Var([ 56 3072 1792], dtype=int32)''' + ... +def right_shift(x: Var, y: Var)-> Var: + '''Document: + * + Shifts the bits of ``x`` to the right by ``y``. + + This operation is equivalent to ``x >> y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var (int32 or int64). + + * [in] y: the second input, a python number or jt.Var (int32 or int64). + + ---------------- + + Example-1:: + >>> a = jt.randint(0, 1024, shape=(3,)) + >>> a + jt.Var([439 113 92], dtype=int32) + >>> b = jt.randint(0, 10, shape=(3,)) + >>> b + jt.Var([6 8 4], dtype=int32) + >>> jt.right_shift(a, b) + jt.Var([6 0 5], dtype=int32)''' + ... +def logical_and(x: Var, y: Var)-> Var: + '''Document: + * + Returns the element-wise logical AND of the inputs. + + ---------------- + + * [in] x: the first input, jt.Var. + + * [in] y: the second input, jt.Var.''' + ... +def logical_or(x: Var, y: Var)-> Var: + '''Document: + * + Returns the element-wise logical OR of the inputs. + + ---------------- + + * [in] x: the first input, jt.Var. + + * [in] y: the second input, jt.Var.''' + ... +def logical_xor(x: Var, y: Var)-> Var: + '''Document: + * + Returns the element-wise logical XOR of the inputs. + + ---------------- + + * [in] x: the first input, jt.Var. + + * [in] y: the second input, jt.Var.''' + ... +def bitwise_and(x: Var, y: Var)-> Var: + '''Document: + * + Computes the bitwise AND of x and y. + + ---------------- + + * [in] x: the first input, jt.Var (integal or boolean). + + * [in] y: the second input, jt.Var (integal or boolean).''' + ... +def bitwise_or(x: Var, y: Var)-> Var: + '''Document: + * + Computes the bitwise OR of x and y. + + ---------------- + + * [in] x: the first input, jt.Var (integal or boolean). + + * [in] y: the second input, jt.Var (integal or boolean).''' + ... +def bitwise_xor(x: Var, y: Var)-> Var: + '''Document: + * + Computes the bitwise XOR of x and y. + + ---------------- + + * [in] x: the first input, jt.Var (integal or boolean). + + * [in] y: the second input, jt.Var (integal or boolean).''' + ... +def tape(x: Var)-> Var: + ... +@overload +def where(cond: Var, dtype: str="int32")-> Tuple[Var]: + '''Document: + * + Where Operator generate index of true condition. + + * [in] cond: condition for index generation + + * [in] dtype: type of return indexes + + * [out] out: return an array of indexes, same length with number of dims of cond + + Example:: + + jt.where([[0,0,1],[1,0,0]]) + # return [jt.Var([0 1], dtype=int32), jt.Var([2 0], dtype=int32)]''' + ... +@overload +def where(cond: Var, x: Var, y: Var)-> Var: + '''Document: + * + * Condition operator, perform cond ? x : y + *''' + ... +def argsort(x: Var, dim: int=-1, descending: bool=False, dtype: str="int32")-> Tuple[Var]: + '''Document: + * + Argsort Operator Perform an indirect sort by given key or compare function. + + x is input, y is output index, satisfy: + + x[y[0]] <= x[y[1]] <= x[y[2]] <= ... <= x[y[n]] + + or + + key(y[0]) <= key(y[1]) <= key(y[2]) <= ... <= key(y[n]) + + or + + compare(y[0], y[1]) && compare(y[1], y[2]) && ... + + * [in] x: input var for sort + + * [in] dim: sort alone which dim + + * [in] descending: the elements are sorted in descending order or not(default False). + + * [in] dtype: type of return indexes + + * [out] index: index have the same size with sorted dim + + * [out] value: sorted value + + + Example:: + + index, value = jt.argsort([11,13,12]) + # return [0 2 1], [11 12 13] + index, value = jt.argsort([11,13,12], descending=True) + # return [1 2 0], [13 12 11] + index, value = jt.argsort([[11,13,12], [12,11,13]]) + # return [[0 2 1],[1 0 2]], [[11 12 13],[11 12 13]] + index, value = jt.argsort([[11,13,12], [12,11,13]], dim=0) + # return [[0 1 0],[1 0 1]], [[11 11 12],[12 13 13]]''' + ... +def fetch(inputs: List[Var], func: Callable)-> Var: + ... +def arg_reduce(x: Var, op: str, dim: int, keepdims: bool)-> Tuple[Var]: + '''Document: + * + Returns the indices of the maximum / minimum of the input across a dimension. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] op: "max" or "min". + + * [in] dim: int. Specifies which dimension to be reduced. + + * [in] keepdims: bool. Whether the output has ``dim`` retained or not. + + ---------------- + + Example-1:: + >>> x = jt.randint(0, 10, shape=(2, 3)) + >>> x + jt.Var([[4 2 5] + [6 7 1]], dtype=int32) + >>> jt.arg_reduce(x, 'max', dim=1, keepdims=False) + [jt.Var([2 1], dtype=int32), jt.Var([5 7], dtype=int32)] + >>> jt.arg_reduce(x, 'min', dim=1, keepdims=False) + [jt.Var([1 2], dtype=int32), jt.Var([2 1], dtype=int32)]''' + ... +def random(shape: Tuple[int], dtype: str="float32", type: str="uniform")-> Var: + ... +@overload +def reduce(x: Var, op: str, dim: int, keepdims: bool=False)-> Var: + ... +@overload +def reduce(x: Var, op: str, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + ... +@overload +def max(x: Var, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Returns the maximum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.max(x) + jt.Var([4], dtype=int32) + >>> x.max() + jt.Var([4], dtype=int32) + >>> x.max(dim=1) + jt.Var([4 4], dtype=int32) + >>> x.max(dim=1, keepdims=True) + jt.Var([[4] + [4]], dtype=int32)''' + ... +@overload +def max(x: Var, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Returns the maximum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.max(x) + jt.Var([4], dtype=int32) + >>> x.max() + jt.Var([4], dtype=int32) + >>> x.max(dim=1) + jt.Var([4 4], dtype=int32) + >>> x.max(dim=1, keepdims=True) + jt.Var([[4] + [4]], dtype=int32)''' + ... +@overload +def max(x: Var, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Returns the maximum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.max(x) + jt.Var([4], dtype=int32) + >>> x.max() + jt.Var([4], dtype=int32) + >>> x.max(dim=1) + jt.Var([4 4], dtype=int32) + >>> x.max(dim=1, keepdims=True) + jt.Var([[4] + [4]], dtype=int32)''' + ... +@overload +def reduce_maximum(x: Var, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Returns the maximum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.max(x) + jt.Var([4], dtype=int32) + >>> x.max() + jt.Var([4], dtype=int32) + >>> x.max(dim=1) + jt.Var([4 4], dtype=int32) + >>> x.max(dim=1, keepdims=True) + jt.Var([[4] + [4]], dtype=int32)''' + ... +@overload +def reduce_maximum(x: Var, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Returns the maximum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.max(x) + jt.Var([4], dtype=int32) + >>> x.max() + jt.Var([4], dtype=int32) + >>> x.max(dim=1) + jt.Var([4 4], dtype=int32) + >>> x.max(dim=1, keepdims=True) + jt.Var([[4] + [4]], dtype=int32)''' + ... +@overload +def reduce_maximum(x: Var, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Returns the maximum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.max(x) + jt.Var([4], dtype=int32) + >>> x.max() + jt.Var([4], dtype=int32) + >>> x.max(dim=1) + jt.Var([4 4], dtype=int32) + >>> x.max(dim=1, keepdims=True) + jt.Var([[4] + [4]], dtype=int32)''' + ... +@overload +def min(x: Var, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Returns the minimum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.min(x) + jt.Var([0], dtype=int32) + >>> x.min() + jt.Var([0], dtype=int32) + >>> x.min(dim=1) + jt.Var([1 0], dtype=int32) + >>> x.min(dim=1, keepdims=True) + jt.Var([[1] + [0]], dtype=int32)''' + ... +@overload +def min(x: Var, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Returns the minimum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.min(x) + jt.Var([0], dtype=int32) + >>> x.min() + jt.Var([0], dtype=int32) + >>> x.min(dim=1) + jt.Var([1 0], dtype=int32) + >>> x.min(dim=1, keepdims=True) + jt.Var([[1] + [0]], dtype=int32)''' + ... +@overload +def min(x: Var, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Returns the minimum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.min(x) + jt.Var([0], dtype=int32) + >>> x.min() + jt.Var([0], dtype=int32) + >>> x.min(dim=1) + jt.Var([1 0], dtype=int32) + >>> x.min(dim=1, keepdims=True) + jt.Var([[1] + [0]], dtype=int32)''' + ... +@overload +def reduce_minimum(x: Var, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Returns the minimum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.min(x) + jt.Var([0], dtype=int32) + >>> x.min() + jt.Var([0], dtype=int32) + >>> x.min(dim=1) + jt.Var([1 0], dtype=int32) + >>> x.min(dim=1, keepdims=True) + jt.Var([[1] + [0]], dtype=int32)''' + ... +@overload +def reduce_minimum(x: Var, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Returns the minimum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.min(x) + jt.Var([0], dtype=int32) + >>> x.min() + jt.Var([0], dtype=int32) + >>> x.min(dim=1) + jt.Var([1 0], dtype=int32) + >>> x.min(dim=1, keepdims=True) + jt.Var([[1] + [0]], dtype=int32)''' + ... +@overload +def reduce_minimum(x: Var, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Returns the minimum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.min(x) + jt.Var([0], dtype=int32) + >>> x.min() + jt.Var([0], dtype=int32) + >>> x.min(dim=1) + jt.Var([1 0], dtype=int32) + >>> x.min(dim=1, keepdims=True) + jt.Var([[1] + [0]], dtype=int32)''' + ... +@overload +def sum(x: Var, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Returns the sum of the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.sum(x) + jt.Var([13], dtype=int32) + >>> x.sum() + jt.Var([13], dtype=int32) + >>> x.sum(dim=1) + jt.Var([7 6], dtype=int32) + >>> x.sum(dim=1, keepdims=True) + jt.Var([[7] + [6]], dtype=int32)''' + ... +@overload +def sum(x: Var, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Returns the sum of the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.sum(x) + jt.Var([13], dtype=int32) + >>> x.sum() + jt.Var([13], dtype=int32) + >>> x.sum(dim=1) + jt.Var([7 6], dtype=int32) + >>> x.sum(dim=1, keepdims=True) + jt.Var([[7] + [6]], dtype=int32)''' + ... +@overload +def sum(x: Var, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Returns the sum of the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.sum(x) + jt.Var([13], dtype=int32) + >>> x.sum() + jt.Var([13], dtype=int32) + >>> x.sum(dim=1) + jt.Var([7 6], dtype=int32) + >>> x.sum(dim=1, keepdims=True) + jt.Var([[7] + [6]], dtype=int32)''' + ... +@overload +def reduce_add(x: Var, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Returns the sum of the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.sum(x) + jt.Var([13], dtype=int32) + >>> x.sum() + jt.Var([13], dtype=int32) + >>> x.sum(dim=1) + jt.Var([7 6], dtype=int32) + >>> x.sum(dim=1, keepdims=True) + jt.Var([[7] + [6]], dtype=int32)''' + ... +@overload +def reduce_add(x: Var, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Returns the sum of the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.sum(x) + jt.Var([13], dtype=int32) + >>> x.sum() + jt.Var([13], dtype=int32) + >>> x.sum(dim=1) + jt.Var([7 6], dtype=int32) + >>> x.sum(dim=1, keepdims=True) + jt.Var([[7] + [6]], dtype=int32)''' + ... +@overload +def reduce_add(x: Var, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Returns the sum of the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.sum(x) + jt.Var([13], dtype=int32) + >>> x.sum() + jt.Var([13], dtype=int32) + >>> x.sum(dim=1) + jt.Var([7 6], dtype=int32) + >>> x.sum(dim=1, keepdims=True) + jt.Var([[7] + [6]], dtype=int32)''' + ... +@overload +def prod(x: Var, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Returns the product of all the elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[7 5 5] + [5 7 5]], dtype=int32) + >>> jt.prod(x) + jt.Var([30625], dtype=int32) + >>> x.prod() + jt.Var([30625], dtype=int32) + >>> x.prod(dim=1) + jt.Var([175 175], dtype=int32) + >>> x.prod(dim=1, keepdims=True) + jt.Var([[175] + [175]], dtype=int32)''' + ... +@overload +def prod(x: Var, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Returns the product of all the elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[7 5 5] + [5 7 5]], dtype=int32) + >>> jt.prod(x) + jt.Var([30625], dtype=int32) + >>> x.prod() + jt.Var([30625], dtype=int32) + >>> x.prod(dim=1) + jt.Var([175 175], dtype=int32) + >>> x.prod(dim=1, keepdims=True) + jt.Var([[175] + [175]], dtype=int32)''' + ... +@overload +def prod(x: Var, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Returns the product of all the elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[7 5 5] + [5 7 5]], dtype=int32) + >>> jt.prod(x) + jt.Var([30625], dtype=int32) + >>> x.prod() + jt.Var([30625], dtype=int32) + >>> x.prod(dim=1) + jt.Var([175 175], dtype=int32) + >>> x.prod(dim=1, keepdims=True) + jt.Var([[175] + [175]], dtype=int32)''' + ... +@overload +def product(x: Var, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Returns the product of all the elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[7 5 5] + [5 7 5]], dtype=int32) + >>> jt.prod(x) + jt.Var([30625], dtype=int32) + >>> x.prod() + jt.Var([30625], dtype=int32) + >>> x.prod(dim=1) + jt.Var([175 175], dtype=int32) + >>> x.prod(dim=1, keepdims=True) + jt.Var([[175] + [175]], dtype=int32)''' + ... +@overload +def product(x: Var, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Returns the product of all the elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[7 5 5] + [5 7 5]], dtype=int32) + >>> jt.prod(x) + jt.Var([30625], dtype=int32) + >>> x.prod() + jt.Var([30625], dtype=int32) + >>> x.prod(dim=1) + jt.Var([175 175], dtype=int32) + >>> x.prod(dim=1, keepdims=True) + jt.Var([[175] + [175]], dtype=int32)''' + ... +@overload +def product(x: Var, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Returns the product of all the elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[7 5 5] + [5 7 5]], dtype=int32) + >>> jt.prod(x) + jt.Var([30625], dtype=int32) + >>> x.prod() + jt.Var([30625], dtype=int32) + >>> x.prod(dim=1) + jt.Var([175 175], dtype=int32) + >>> x.prod(dim=1, keepdims=True) + jt.Var([[175] + [175]], dtype=int32)''' + ... +@overload +def reduce_multiply(x: Var, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Returns the product of all the elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[7 5 5] + [5 7 5]], dtype=int32) + >>> jt.prod(x) + jt.Var([30625], dtype=int32) + >>> x.prod() + jt.Var([30625], dtype=int32) + >>> x.prod(dim=1) + jt.Var([175 175], dtype=int32) + >>> x.prod(dim=1, keepdims=True) + jt.Var([[175] + [175]], dtype=int32)''' + ... +@overload +def reduce_multiply(x: Var, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Returns the product of all the elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[7 5 5] + [5 7 5]], dtype=int32) + >>> jt.prod(x) + jt.Var([30625], dtype=int32) + >>> x.prod() + jt.Var([30625], dtype=int32) + >>> x.prod(dim=1) + jt.Var([175 175], dtype=int32) + >>> x.prod(dim=1, keepdims=True) + jt.Var([[175] + [175]], dtype=int32)''' + ... +@overload +def reduce_multiply(x: Var, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Returns the product of all the elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[7 5 5] + [5 7 5]], dtype=int32) + >>> jt.prod(x) + jt.Var([30625], dtype=int32) + >>> x.prod() + jt.Var([30625], dtype=int32) + >>> x.prod(dim=1) + jt.Var([175 175], dtype=int32) + >>> x.prod(dim=1, keepdims=True) + jt.Var([[175] + [175]], dtype=int32)''' + ... +@overload +def reduce_logical_and(x: Var, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Tests if all elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 1 1] + [0 1 0]], dtype=int32) + >>> jt.all_(x) + jt.Var([False], dtype=int32) + >>> x.all_() + jt.Var([False], dtype=int32) + >>> x.all_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.all_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... +@overload +def reduce_logical_and(x: Var, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Tests if all elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 1 1] + [0 1 0]], dtype=int32) + >>> jt.all_(x) + jt.Var([False], dtype=int32) + >>> x.all_() + jt.Var([False], dtype=int32) + >>> x.all_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.all_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... +@overload +def reduce_logical_and(x: Var, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Tests if all elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 1 1] + [0 1 0]], dtype=int32) + >>> jt.all_(x) + jt.Var([False], dtype=int32) + >>> x.all_() + jt.Var([False], dtype=int32) + >>> x.all_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.all_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... +@overload +def all_(x: Var, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Tests if all elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 1 1] + [0 1 0]], dtype=int32) + >>> jt.all_(x) + jt.Var([False], dtype=int32) + >>> x.all_() + jt.Var([False], dtype=int32) + >>> x.all_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.all_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... +@overload +def all_(x: Var, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Tests if all elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 1 1] + [0 1 0]], dtype=int32) + >>> jt.all_(x) + jt.Var([False], dtype=int32) + >>> x.all_() + jt.Var([False], dtype=int32) + >>> x.all_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.all_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... +@overload +def all_(x: Var, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Tests if all elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 1 1] + [0 1 0]], dtype=int32) + >>> jt.all_(x) + jt.Var([False], dtype=int32) + >>> x.all_() + jt.Var([False], dtype=int32) + >>> x.all_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.all_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... +@overload +def reduce_logical_or(x: Var, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Tests if any elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 0 1] + [0 0 0]], dtype=int32) + >>> jt.any_(x) + jt.Var([True], dtype=int32) + >>> x.any_() + jt.Var([True], dtype=int32) + >>> x.any_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.any_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... +@overload +def reduce_logical_or(x: Var, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Tests if any elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 0 1] + [0 0 0]], dtype=int32) + >>> jt.any_(x) + jt.Var([True], dtype=int32) + >>> x.any_() + jt.Var([True], dtype=int32) + >>> x.any_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.any_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... +@overload +def reduce_logical_or(x: Var, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Tests if any elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 0 1] + [0 0 0]], dtype=int32) + >>> jt.any_(x) + jt.Var([True], dtype=int32) + >>> x.any_() + jt.Var([True], dtype=int32) + >>> x.any_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.any_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... +@overload +def any_(x: Var, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Tests if any elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 0 1] + [0 0 0]], dtype=int32) + >>> jt.any_(x) + jt.Var([True], dtype=int32) + >>> x.any_() + jt.Var([True], dtype=int32) + >>> x.any_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.any_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... +@overload +def any_(x: Var, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Tests if any elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 0 1] + [0 0 0]], dtype=int32) + >>> jt.any_(x) + jt.Var([True], dtype=int32) + >>> x.any_() + jt.Var([True], dtype=int32) + >>> x.any_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.any_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... +@overload +def any_(x: Var, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Tests if any elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 0 1] + [0 0 0]], dtype=int32) + >>> jt.any_(x) + jt.Var([True], dtype=int32) + >>> x.any_() + jt.Var([True], dtype=int32) + >>> x.any_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.any_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... +@overload +def reduce_logical_xor(x: Var, dim: int, keepdims: bool=False)-> Var: + ... +@overload +def reduce_logical_xor(x: Var, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + ... +@overload +def reduce_logical_xor(x: Var, dims_mask: int, keepdims_mask: int)-> Var: + ... +@overload +def reduce_bitwise_and(x: Var, dim: int, keepdims: bool=False)-> Var: + ... +@overload +def reduce_bitwise_and(x: Var, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + ... +@overload +def reduce_bitwise_and(x: Var, dims_mask: int, keepdims_mask: int)-> Var: + ... +@overload +def reduce_bitwise_or(x: Var, dim: int, keepdims: bool=False)-> Var: + ... +@overload +def reduce_bitwise_or(x: Var, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + ... +@overload +def reduce_bitwise_or(x: Var, dims_mask: int, keepdims_mask: int)-> Var: + ... +@overload +def reduce_bitwise_xor(x: Var, dim: int, keepdims: bool=False)-> Var: + ... +@overload +def reduce_bitwise_xor(x: Var, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + ... +@overload +def reduce_bitwise_xor(x: Var, dims_mask: int, keepdims_mask: int)-> Var: + ... +@overload +def mean(x: Var, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Returns the mean value of the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[9 4 4] + [1 9 6]], dtype=int32) + >>> jt.mean(x) + jt.Var([5.5000005], dtype=float32) + >>> x.mean() + jt.Var([5.5000005], dtype=float32) + >>> x.mean(dim=1) + jt.Var([5.666667 5.3333335], dtype=float32) + >>> x.mean(dim=1, keepdims=True) + jt.Var([[5.666667 ] + [5.3333335]], dtype=float32)''' + ... +@overload +def mean(x: Var, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Returns the mean value of the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[9 4 4] + [1 9 6]], dtype=int32) + >>> jt.mean(x) + jt.Var([5.5000005], dtype=float32) + >>> x.mean() + jt.Var([5.5000005], dtype=float32) + >>> x.mean(dim=1) + jt.Var([5.666667 5.3333335], dtype=float32) + >>> x.mean(dim=1, keepdims=True) + jt.Var([[5.666667 ] + [5.3333335]], dtype=float32)''' + ... +@overload +def mean(x: Var, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Returns the mean value of the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[9 4 4] + [1 9 6]], dtype=int32) + >>> jt.mean(x) + jt.Var([5.5000005], dtype=float32) + >>> x.mean() + jt.Var([5.5000005], dtype=float32) + >>> x.mean(dim=1) + jt.Var([5.666667 5.3333335], dtype=float32) + >>> x.mean(dim=1, keepdims=True) + jt.Var([[5.666667 ] + [5.3333335]], dtype=float32)''' + ... +def clone(x: Var)-> Var: + ... +def unary(x: Var, op: str)-> Var: + ... +def cast(x: Var, op: str)-> Var: + ... +def int8(x: Var)-> Var: + '''Document: + * + Returns a copy of the input var, casted to int8. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.int8() + jt.Var([4 2 8], dtype=int8) + >>> jt.int8(x) + jt.Var([4 2 8], dtype=int8)''' + ... +def int16(x: Var)-> Var: + '''Document: + * + Returns a copy of the input var, casted to int16. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.int16() + jt.Var([4 2 8], dtype=int16) + >>> jt.int16(x) + jt.Var([4 2 8], dtype=int16)''' + ... +def int32(x: Var)-> Var: + '''Document: + * + Returns a copy of the input var, casted to int32. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.int() + jt.Var([4 2 8], dtype=int32) + >>> jt.int(x) + jt.Var([4 2 8], dtype=int32) + >>> x.int32() + jt.Var([4 2 8], dtype=int32) + >>> jt.int32(x) + jt.Var([4 2 8], dtype=int32) + >>> x.long() + jt.Var([4 2 8], dtype=int32) + >>> jt.long(x) + jt.Var([4 2 8], dtype=int32)''' + ... +def int64(x: Var)-> Var: + '''Document: + * + Returns a copy of the input var, casted to int64. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.int64() + jt.Var([4 2 8], dtype=int64) + >>> jt.int64(x) + jt.Var([4 2 8], dtype=int64)''' + ... +def uint8(x: Var)-> Var: + '''Document: + * + Returns a copy of the input var, casted to unsigned int8. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.uint8() + jt.Var([4 2 8], dtype=uint8) + >>> jt.uint8(x) + jt.Var([4 2 8], dtype=uint8)''' + ... +def uint16(x: Var)-> Var: + '''Document: + * + Returns a copy of the input var, casted to unsigned int16. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.uint16() + jt.Var([4 2 8], dtype=uint16) + >>> jt.uint16(x) + jt.Var([4 2 8], dtype=uint16)''' + ... +def uint32(x: Var)-> Var: + '''Document: + * + Returns a copy of the input var, casted to unsigned int32. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.uint32() + jt.Var([4 2 8], dtype=uint32) + >>> jt.uint32(x) + jt.Var([4 2 8], dtype=uint32)''' + ... +def uint64(x: Var)-> Var: + '''Document: + * + Returns a copy of the input var, casted to unsigned int64. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.uint64() + jt.Var([4 2 8], dtype=uint64) + >>> jt.uint64(x) + jt.Var([4 2 8], dtype=uint64)''' + ... +def float16(x: Var)-> Var: + '''Document: + * + Returns a copy of the input var, casted to float16 (half-precision float). + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.half() + jt.Var([4.094 2.008 8.48 ], dtype=float16) + >>> jt.half(x) + jt.Var([4.094 2.008 8.48 ], dtype=float16) + >>> x.float16() + jt.Var([4.094 2.008 8.48 ], dtype=float16) + >>> jt.float16(x) + jt.Var([4.094 2.008 8.48 ], dtype=float16)''' + ... +def float32(x: Var)-> Var: + '''Document: + * + Returns a copy of the input var, casted to float32. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.arange(3) + >>> x + jt.Var([0 1 2], dtype=int32) + >>> x.float() + jt.Var([0. 1. 2.], dtype=float32) + >>> jt.float(x) + jt.Var([0. 1. 2.], dtype=float32) + >>> x.float32() + jt.Var([0. 1. 2.], dtype=float32) + >>> jt.float32(x) + jt.Var([0. 1. 2.], dtype=float32)''' + ... +def float64(x: Var)-> Var: + '''Document: + * + Returns a copy of the input var, casted to float64 (double-precision float). + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.arange(3) + >>> x + jt.Var([0 1 2], dtype=int32) + >>> x.double() + jt.Var([0. 1. 2.], dtype=float64) + >>> jt.double(x) + jt.Var([0. 1. 2.], dtype=float64) + >>> x.float64() + jt.Var([0. 1. 2.], dtype=float64) + >>> jt.float64(x) + jt.Var([0. 1. 2.], dtype=float64)''' + ... +def abs(x: Var)-> Var: + '''Document: + * + Returns the absolute value of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> jt.abs(jt.float32([-1, 0, 1])) + jt.Var([1. 0. 1.], dtype=float32)''' + ... +def negative(x: Var)-> Var: + '''Document: + * + Returns the negative value of the input ``x``. + + This operator is equavilant to ``-x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> jt.negative(jt.float32([-1, 0, 1])) + jt.Var([ 1. -0. -1.], dtype=float32)''' + ... +def logical_not(x: Var)-> Var: + '''Document: + * + Returns the logical NOT of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var, integal or boolean. + + ---------------- + + Example-1:: + >>> jt.logical_not(jt.int32([-1, 0, 1])) + jt.Var([False True False], dtype=bool)''' + ... +def bitwise_not(x: Var)-> Var: + '''Document: + * + Returns the bitwise NOT of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var, integal or boolean. + + ---------------- + + Example-1:: + >>> jt.bitwise_not(jt.int32([1, 2, -3])) + jt.Var([-2 -3 2], dtype=int32)''' + ... +def log(x: Var)-> Var: + '''Document: + * + Returns the natural logarithm of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) * 2 + >>> x + jt.Var([0.02863695 1.30122 1.6048753 1.140261 ], dtype=float32) + >>> jt.log(x) + jt.Var([-3.5530574 0.26330233 0.47304606 0.13125724], dtype=float32) + >>> x.log() + jt.Var([-3.5530574 0.26330233 0.47304606 0.13125724], dtype=float32)''' + ... +def exp(x: Var)-> Var: + '''Document: + * + Returns the exponential of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) * 2 + >>> x + jt.Var([1.9841381 1.4103996 0.5855549 1.4212812], dtype=float32) + >>> jt.exp(x) + jt.Var([7.2727766 4.0975924 1.7959872 4.1424246], dtype=float32) + >>> x.exp() + jt.Var([7.2727766 4.0975924 1.7959872 4.1424246], dtype=float32)''' + ... +def sqrt(x: Var)-> Var: + '''Document: + * + Returns the square root of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) * 2 + >>> x + jt.Var([0.81957287 0.5609612 0.07435933 1.7571875 ], dtype=float32) + >>> jt.sqrt(x) + jt.Var([0.90530264 0.7489734 0.27268907 1.3255895 ], dtype=float32) + >>> x.sqrt() + jt.Var([0.90530264 0.7489734 0.27268907 1.3255895 ], dtype=float32)''' + ... +def round(x: Var)-> Var: + '''Document: + * + Returns the closest integer of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 2.101595 0.33055413 -0.44147047 -0.7720668 ], dtype=float32) + >>> jt.round(x) + jt.Var([ 2.0 0.0 0.0 -1.0], dtype=float32) + >>> x.round() + jt.Var([ 2.0 0.0 0.0 -1.0], dtype=float32)''' + ... +def floor(x: Var)-> Var: + '''Document: + * + Returns the largest integer less than or equal to the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-1.0339162 -0.7259972 -0.9220003 -0.8449701], dtype=float32) + >>> jt.floor(x) + jt.Var([-2.0 -1.0 -1.0 -1.0], dtype=float32) + >>> x.floor + jt.Var([-2.0 -1.0 -1.0 -1.0], dtype=float32)''' + ... +def ceil(x: Var)-> Var: + '''Document: + * + Returns the smallest integer greater than or equal to the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-1.0339162 -0.7259972 -0.9220003 -0.8449701], dtype=float32) + >>> jt.ceil(x) + jt.Var([-1.0 0.0 0.0 0.0], dtype=float32) + >>> x.ceil() + jt.Var([-1.0 0.0 0.0 0.0], dtype=float32)''' + ... +def round_int(x: Var)-> Var: + '''Document: + * + Returns the closest integer of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 2.101595 0.33055413 -0.44147047 -0.7720668 ], dtype=float32) + >>> jt.round_int(x) + jt.Var([ 2 0 0 -1], dtype=int32) + >>> x.round_int + jt.Var([ 2 0 0 -1], dtype=int32)''' + ... +def floor_int(x: Var)-> Var: + '''Document: + * + Returns the largest integer less than or equal to the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-1.0339162 -0.7259972 -0.9220003 -0.8449701], dtype=float32) + >>> jt.floor_int(x) + jt.Var([-2 -1 -1 -1], dtype=int32) + >>> x.floor_int + jt.Var([-2 -1 -1 -1], dtype=int32)''' + ... +def ceil_int(x: Var)-> Var: + '''Document: + * + Returns the smallest integer greater than or equal to the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-1.0339162 -0.7259972 -0.9220003 -0.8449701], dtype=float32) + >>> jt.ceil_int(x) + jt.Var([-1 0 0 0], dtype=int32) + >>> x.ceil_int() + jt.Var([-1 0 0 0], dtype=int32)''' + ... +def sin(x: Var)-> Var: + '''Document: + * + Returns the sine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.32893723 -0.7112559 -0.872391 1.8001337 ], dtype=float32) + >>> jt.sin(x) + jt.Var([ 0.32303742 -0.6527857 -0.76586854 0.9738172 ], dtype=float32) + >>> x.sin() + jt.Var([ 0.32303742 -0.6527857 -0.76586854 0.9738172 ], dtype=float32)''' + ... +def asin(x: Var)-> Var: + '''Document: + * + Returns the arcsine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.09342023 -0.42522037 0.9264933 -0.785264 ], dtype=float32) + >>> jt.asin(x) + jt.Var([ 0.09355665 -0.43920535 1.1849847 -0.9031224 ], dtype=float32) + >>> x.asin() + jt.Var([ 0.09355665 -0.43920535 1.1849847 -0.9031224 ], dtype=float32)''' + ... +def arcsin(x: Var)-> Var: + '''Document: + * + Returns the arcsine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.09342023 -0.42522037 0.9264933 -0.785264 ], dtype=float32) + >>> jt.asin(x) + jt.Var([ 0.09355665 -0.43920535 1.1849847 -0.9031224 ], dtype=float32) + >>> x.asin() + jt.Var([ 0.09355665 -0.43920535 1.1849847 -0.9031224 ], dtype=float32)''' + ... +def sinh(x: Var)-> Var: + '''Document: + * + Returns the hyperbolic sine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.32893723 -0.7112559 -0.872391 1.8001337 ], dtype=float32) + >>> jt.sinh(x) + jt.Var([ 0.3349012 -0.77276015 -0.9873369 2.9425898 ], dtype=float32) + >>> x.sinh + jt.Var([ 0.3349012 -0.77276015 -0.9873369 2.9425898 ], dtype=float32)''' + ... +def asinh(x: Var)-> Var: + '''Document: + * + Returns the inverse hyperbolic sine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-1.9749726 -0.52341473 0.8906148 1.0338128 ], dtype=float32) + >>> jt.asinh(x) + jt.Var([-1.4323865 -0.5020559 0.8018747 0.90508187], dtype=float32) + >>> x.asinh() + jt.Var([-1.4323865 -0.5020559 0.8018747 0.90508187], dtype=float32)''' + ... +def arcsinh(x: Var)-> Var: + '''Document: + * + Returns the inverse hyperbolic sine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-1.9749726 -0.52341473 0.8906148 1.0338128 ], dtype=float32) + >>> jt.asinh(x) + jt.Var([-1.4323865 -0.5020559 0.8018747 0.90508187], dtype=float32) + >>> x.asinh() + jt.Var([-1.4323865 -0.5020559 0.8018747 0.90508187], dtype=float32)''' + ... +def tan(x: Var)-> Var: + '''Document: + * + Returns the tangent of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.32893723 -0.7112559 -0.872391 1.8001337 ], dtype=float32) + >>> jt.tan(x) + jt.Var([ 0.34133783 -0.8617148 -1.1910915 -4.283673 ], dtype=float32) + >>> x.tan() + jt.Var([ 0.34133783 -0.8617148 -1.1910915 -4.283673 ], dtype=float32)''' + ... +def atan(x: Var)-> Var: + '''Document: + * + Returns the inverse tangent of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-0.85885596 1.187804 0.47249675 0.95933187], dtype=float32) + >>> jt.atan(x) + jt.Var([-0.70961297 0.87102956 0.44140393 0.76464504], dtype=float32) + >>> x.atan() + jt.Var([-0.70961297 0.87102956 0.44140393 0.76464504], dtype=float32)''' + ... +def arctan(x: Var)-> Var: + '''Document: + * + Returns the inverse tangent of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-0.85885596 1.187804 0.47249675 0.95933187], dtype=float32) + >>> jt.atan(x) + jt.Var([-0.70961297 0.87102956 0.44140393 0.76464504], dtype=float32) + >>> x.atan() + jt.Var([-0.70961297 0.87102956 0.44140393 0.76464504], dtype=float32)''' + ... +def tanh(x: Var)-> Var: + '''Document: + * + Returns the hyperbolic tangent of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-0.85885596 1.187804 0.47249675 0.95933187], dtype=float32) + >>> jt.tanh(x) + jt.Var([-0.6956678 0.82989657 0.4402144 0.7439787 ], dtype=float32) + >>> x.tanh() + jt.Var([-0.6956678 0.82989657 0.4402144 0.7439787 ], dtype=float32)''' + ... +def atanh(x: Var)-> Var: + '''Document: + * + Returns the inverse hyperbolic tangent of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) * 2 - 1 + >>> x + jt.Var([ 0.9062414 -0.799802 -0.27219176 -0.7274077 ], dtype=float32) + >>> jt.atanh(x) + jt.Var([ 1.5060828 -1.0980625 -0.27922946 -0.9231999 ], dtype=float32) + >>> x.atanh() + jt.Var([ 1.5060828 -1.0980625 -0.27922946 -0.9231999 ], dtype=float32)''' + ... +def arctanh(x: Var)-> Var: + '''Document: + * + Returns the inverse hyperbolic tangent of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) * 2 - 1 + >>> x + jt.Var([ 0.9062414 -0.799802 -0.27219176 -0.7274077 ], dtype=float32) + >>> jt.atanh(x) + jt.Var([ 1.5060828 -1.0980625 -0.27922946 -0.9231999 ], dtype=float32) + >>> x.atanh() + jt.Var([ 1.5060828 -1.0980625 -0.27922946 -0.9231999 ], dtype=float32)''' + ... +def cos(x: Var)-> Var: + '''Document: + * + Returns the cosine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.32893723 -0.7112559 -0.872391 1.8001337 ], dtype=float32) + >>> jt.cos(x) + jt.Var([ 0.9463862 0.7575426 0.6429972 -0.2273323], dtype=float32) + >>> x.cos() + jt.Var([ 0.9463862 0.7575426 0.6429972 -0.2273323], dtype=float32)''' + ... +def acos(x: Var)-> Var: + '''Document: + * + Returns the inverse cosine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) * 2 - 1 + >>> x + jt.Var([ 0.5876564 0.740723 -0.667666 0.5371753], dtype=float32) + >>> jt.acos(x) + jt.Var([0.9426371 0.7366504 2.3018656 1.0037117], dtype=float32) + >>> x.acos() + jt.Var([0.9426371 0.7366504 2.3018656 1.0037117], dtype=float32)''' + ... +def arccos(x: Var)-> Var: + '''Document: + * + Returns the inverse cosine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) * 2 - 1 + >>> x + jt.Var([ 0.5876564 0.740723 -0.667666 0.5371753], dtype=float32) + >>> jt.acos(x) + jt.Var([0.9426371 0.7366504 2.3018656 1.0037117], dtype=float32) + >>> x.acos() + jt.Var([0.9426371 0.7366504 2.3018656 1.0037117], dtype=float32)''' + ... +def cosh(x: Var)-> Var: + '''Document: + * + Returns the hyperbolic cosine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.32893723 -0.7112559 -0.872391 1.8001337 ], dtype=float32) + >>> jt.cosh(x) + jt.Var([1.0545894 1.2637873 1.405288 3.1078668], dtype=float32) + >>> x.cosh() + jt.Var([1.0545894 1.2637873 1.405288 3.1078668], dtype=float32)''' + ... +def acosh(x: Var)-> Var: + '''Document: + * + Returns the inverse hyperbolic cosine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) + 1 + >>> x + jt.Var([1.3609099 1.8137748 1.1146184 1.3911307], dtype=float32) + >>> jt.acosh(x) + jt.Var([0.8259237 1.2020639 0.47432774 0.8579033 ], dtype=float32) + >>> x.acosh() + jt.Var([0.8259237 1.2020639 0.47432774 0.8579033 ], dtype=float32)''' + ... +def arccosh(x: Var)-> Var: + '''Document: + * + Returns the inverse hyperbolic cosine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) + 1 + >>> x + jt.Var([1.3609099 1.8137748 1.1146184 1.3911307], dtype=float32) + >>> jt.acosh(x) + jt.Var([0.8259237 1.2020639 0.47432774 0.8579033 ], dtype=float32) + >>> x.acosh() + jt.Var([0.8259237 1.2020639 0.47432774 0.8579033 ], dtype=float32)''' + ... +def sigmoid(x: Var)-> Var: + '''Document: + * + Returns the sigmoid of the input ``x``. + + .. math:: + out_i = \frac{1}{1 + e^{x_i}} + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.49443012 0.4305426 -1.0364404 -1.2628382 ], dtype=float32) + >>> jt.sigmoid(x) + jt.Var([0.62114954 0.6060032 0.2618374 0.2204857 ], dtype=float32) + >>> x.sigmoid() + jt.Var([0.62114954 0.6060032 0.2618374 0.2204857 ], dtype=float32)''' + ... +def erf(x: Var)-> Var: + '''Document: + * + Computes the error function of each element. The error function is defined as follows: + + .. math:: + erf(x) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} dt + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.49443012 0.4305426 -1.0364404 -1.2628382 ], dtype=float32) + >>> jt.erf(x) + jt.Var([ 0.51559156 0.45739546 -0.85728306 -0.9258883 ], dtype=float32) + >>> x.erf() + jt.Var([ 0.51559156 0.45739546 -0.85728306 -0.9258883 ], dtype=float32)''' + ... +def erfinv(x: Var)-> Var: + '''Document: + * + Computes the inverse error function of each element. + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) * 2 - 1 + >>> x + jt.Var([ 0.00277209 -0.26642472 0.7869792 0.5415418 ], dtype=float32) + >>> jt.erfinv(x) + jt.Var([ 0.00245671 -0.24068035 0.8805613 0.5242405 ], dtype=float32) + >>> x.erfinv() + jt.Var([ 0.00245671 -0.24068035 0.8805613 0.5242405 ], dtype=float32)''' + ... +def transpose(x: Var, axes: Tuple[int]=())-> Var: + ... +def fuse_transpose(x: Var, axes: Tuple[int]=())-> Var: + ... +def safe_clip(x: Var, left: float, right: float)-> Var: + '''Document: + * Safe clip value to a range, and keep + the gradient pass thought. + + * [in] x: input value + * [in] left: float64 clip min value. + * [in] right: float64 clip max value.''' + ... +def array_(args: numpy.ndarray)-> Var: + ... +def array(obj: float | int | numpy.ndarray | Var)-> Var: + ... +@overload +def getitem(x: Var, slices: slice)-> Var: + ... +@overload +def getitem(x: Var, slices: slice, _: int)-> Tuple[Var]: + ... +def candidate(x: Var, fail_cond: str, dtype: str="int32")-> Var: + '''Document: + * + Candidate Operator Perform an indirect candidate filter by given a fail condition. + + x is input, y is output index, satisfy:: + + not fail_cond(y[0], y[1]) and + not fail_cond(y[0], y[2]) and not fail_cond(y[1], y[2]) and + ... + ... and not fail_cond(y[m-2], y[m-1]) + + Where m is number of selected candidates. + + Pseudo code:: + + y = [] + for i in range(n): + pass = True + for j in y: + if (@fail_cond): + pass = false + break + if (pass): + y.append(i) + return y + + * [in] x: input var for filter + + * [in] fail_cond: code for fail condition + + * [in] dtype: type of return indexes + + * [out] index: . + + Example:: + + jt.candidate(jt.random(100,2), '(@x(j,0)>@x(i,0))or(@x(j,1)>@x(i,1))') + # return y satisfy: + # x[y[0], 0] <= x[y[1], 0] and x[y[1], 0] <= x[y[2], 0] and ... and x[y[m-2], 0] <= x[y[m-1], 0] and + # x[y[0], 1] <= x[y[1], 1] and x[y[1], 1] <= x[y[2], 1] and ... and x[y[m-2], 1] <= x[y[m-1], 1]''' + ... +@overload +def numpy_code(shape: Tuple[int], dtype: str, inputs: List[Var], forward: Callable, backward: List[Callable])-> Var: + '''Document: + * + Numpy Code Operator for easily customized op. + + ---------------- + + * [in] shape: the output shape, a integer array + + * [in] dtype: the output data type + + * [in] inputs: A list of input jittor Vars + + * [in] forward: function, represents forward python function + + * [in] backward: A list of function, represents gradiant for each input + + ---------------- + + Example-1:: + + def forward_code(np, data): + a = data["inputs"][0] + b = data["outputs"][0] + np.add(a,a,out=b) + + def backward_code(np, data): + dout = data["dout"] + out = data["outputs"][0] + np.copyto(out, dout*2.0) + + a = jt.random((5,1)) + b = jt.numpy_code( + a.shape, + a.dtype, + [a], + forward_code, + [backward_code], + ) + + Example-2:: + + def forward_code(np, data): + a,b = data["inputs"] + c,d = data["outputs"] + np.add(a,b,out=c) + np.subtract(a,b,out=d) + + def backward_code1(np, data): + dout = data["dout"] + out = data["outputs"][0] + np.copyto(out, dout) + + def backward_code2(np, data): + dout = data["dout"] + out_index = data["out_index"] + out = data["outputs"][0] + if out_index==0: + np.copyto(out, dout) + else: + np.negative(dout, out) + + a = jt.random((5,1)) + b = jt.random((5,1)) + c, d = jt.numpy_code( + [a.shape, a.shape], + [a.dtype, a.dtype], + [a, b], + forward_code, + [backward_code1,backward_code2], + )''' + ... +@overload +def numpy_code(shapes: List[Tuple[int]], dtypes: List[str], inputs: List[Var], forward: Callable, backward: List[Callable])-> Tuple[Var]: + '''Document: + * + Numpy Code Operator for easily customized op. + + ---------------- + + * [in] shape: the output shape, a integer array + + * [in] dtype: the output data type + + * [in] inputs: A list of input jittor Vars + + * [in] forward: function, represents forward python function + + * [in] backward: A list of function, represents gradiant for each input + + ---------------- + + Example-1:: + + def forward_code(np, data): + a = data["inputs"][0] + b = data["outputs"][0] + np.add(a,a,out=b) + + def backward_code(np, data): + dout = data["dout"] + out = data["outputs"][0] + np.copyto(out, dout*2.0) + + a = jt.random((5,1)) + b = jt.numpy_code( + a.shape, + a.dtype, + [a], + forward_code, + [backward_code], + ) + + Example-2:: + + def forward_code(np, data): + a,b = data["inputs"] + c,d = data["outputs"] + np.add(a,b,out=c) + np.subtract(a,b,out=d) + + def backward_code1(np, data): + dout = data["dout"] + out = data["outputs"][0] + np.copyto(out, dout) + + def backward_code2(np, data): + dout = data["dout"] + out_index = data["out_index"] + out = data["outputs"][0] + if out_index==0: + np.copyto(out, dout) + else: + np.negative(dout, out) + + a = jt.random((5,1)) + b = jt.random((5,1)) + c, d = jt.numpy_code( + [a.shape, a.shape], + [a.dtype, a.dtype], + [a, b], + forward_code, + [backward_code1,backward_code2], + )''' + ... +@overload +def numpy_code(shape: Tuple[int], dtype: str, inputs: List[Var], forward: Callable)-> Var: + '''Document: + * + Numpy Code Operator for easily customized op. + + ---------------- + + * [in] shape: the output shape, a integer array + + * [in] dtype: the output data type + + * [in] inputs: A list of input jittor Vars + + * [in] forward: function, represents forward python function + + * [in] backward: A list of function, represents gradiant for each input + + ---------------- + + Example-1:: + + def forward_code(np, data): + a = data["inputs"][0] + b = data["outputs"][0] + np.add(a,a,out=b) + + def backward_code(np, data): + dout = data["dout"] + out = data["outputs"][0] + np.copyto(out, dout*2.0) + + a = jt.random((5,1)) + b = jt.numpy_code( + a.shape, + a.dtype, + [a], + forward_code, + [backward_code], + ) + + Example-2:: + + def forward_code(np, data): + a,b = data["inputs"] + c,d = data["outputs"] + np.add(a,b,out=c) + np.subtract(a,b,out=d) + + def backward_code1(np, data): + dout = data["dout"] + out = data["outputs"][0] + np.copyto(out, dout) + + def backward_code2(np, data): + dout = data["dout"] + out_index = data["out_index"] + out = data["outputs"][0] + if out_index==0: + np.copyto(out, dout) + else: + np.negative(dout, out) + + a = jt.random((5,1)) + b = jt.random((5,1)) + c, d = jt.numpy_code( + [a.shape, a.shape], + [a.dtype, a.dtype], + [a, b], + forward_code, + [backward_code1,backward_code2], + )''' + ... +@overload +def numpy_code(shapes: List[Tuple[int]], dtypes: List[str], inputs: List[Var], forward: Callable)-> Tuple[Var]: + '''Document: + * + Numpy Code Operator for easily customized op. + + ---------------- + + * [in] shape: the output shape, a integer array + + * [in] dtype: the output data type + + * [in] inputs: A list of input jittor Vars + + * [in] forward: function, represents forward python function + + * [in] backward: A list of function, represents gradiant for each input + + ---------------- + + Example-1:: + + def forward_code(np, data): + a = data["inputs"][0] + b = data["outputs"][0] + np.add(a,a,out=b) + + def backward_code(np, data): + dout = data["dout"] + out = data["outputs"][0] + np.copyto(out, dout*2.0) + + a = jt.random((5,1)) + b = jt.numpy_code( + a.shape, + a.dtype, + [a], + forward_code, + [backward_code], + ) + + Example-2:: + + def forward_code(np, data): + a,b = data["inputs"] + c,d = data["outputs"] + np.add(a,b,out=c) + np.subtract(a,b,out=d) + + def backward_code1(np, data): + dout = data["dout"] + out = data["outputs"][0] + np.copyto(out, dout) + + def backward_code2(np, data): + dout = data["dout"] + out_index = data["out_index"] + out = data["outputs"][0] + if out_index==0: + np.copyto(out, dout) + else: + np.negative(dout, out) + + a = jt.random((5,1)) + b = jt.random((5,1)) + c, d = jt.numpy_code( + [a.shape, a.shape], + [a.dtype, a.dtype], + [a, b], + forward_code, + [backward_code1,backward_code2], + )''' + ... +@overload +def code(shape: Tuple[int], dtype: str, inputs: List[Var]={}, cpu_src: str="", cpu_grad_src: List[str]={}, cpu_header: str="", cuda_src: str="", cuda_grad_src: List[str]={}, cuda_header: str="")-> Var: + '''Document: + * + Code Operator for easily customized op. + + ---------------- + + * [in] shape: the output shape, a integer array + + * [in] dtype: the output data type + + * [in] inputs: A list of input jittor Vars + + * [in] cpu_src: cpu source code string, buildin value: + + * in{x}, in{x}_shape{y}, in{x}_stride{y}, in{x}_type, in{x}_p, @in0(...) + * out{x}, out{x}_shape{y}, out{x}_stride{y}, out{x}_type, out{x}_p, @out0(...) + * out, out_shape{y}, out_stride{y}, out_type, out_p, @out(...) + + * [in] cpu_header: cpu header code string. + + * [in] cuda_src: cuda source code string. + + * [in] cuda_header: cuda header code string. + + ---------------- + + Example-1:: + + from jittor import Function + import jittor as jt + + class Func(Function): + def execute(self, x): + self.save_vars = x + return jt.code(x.shape, x.dtype, [x], + cpu_src=""" + for (int i=0; i + @alias(a, in0) + @alias(b, out) + """, + cpu_src=""" + for (int i=0; i + using namespace std; + """, + cpu_src=""" + @alias(a, in0) + @alias(b, out0) + @alias(c, out1) + @b(0) = @c(0) = @a(0); + for (int i=0; i0) + @b(num_b++) = @a(i); + else + @c(num_c++) = @a(i); + } + b->set_shape({num_b}); + c->set_shape({num_c}); + """ + ) + assert (b.data == [5,3,1]).all() + assert (c.data == [-4,-2]).all() + + Example-5:: + + # This example shows how to customize code op + # compilation flags, such as add include search + # path, add definitions, or any command line options + + a = jt.random([10]) + b = jt.code(a.shape, a.dtype, [a], + cpu_src=""" + @out0(0) = HAHAHA; + """) + # HAHAHA is defined in flags below + # /any/include/path can be change to any path you want to include + b.compile_options = {"FLAGS: -DHAHAHA=233 -I/any/include/path ": 1} + print(b[0]) + # will output 233 + + + CUDA Example-1:: + + #This example shows how to use CUDA in code op. + import jittor as jt + from jittor import Function + jt.flags.use_cuda = 1 + + class Func(Function): + def execute(self, a, b): + self.save_vars = a, b + return jt.code(a.shape, a.dtype, [a,b], + cuda_src=""" + __global__ static void kernel1(@ARGS_DEF) { + @PRECALC + int i = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (; i>>(@ARGS); + """) + + def grad(self, grad): + a, b = self.save_vars + return jt.code([a.shape, b.shape], [a.dtype, b.dtype], [a, b, grad], + cuda_src=""" + __global__ static void kernel2(@ARGS_DEF) { + @PRECALC + int i = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (; i>>(@ARGS); + """) + + a = jt.random([100000]) + b = jt.random([100000]) + func = Func() + c = func(a,b) + print(c) + print(jt.grad(c, [a, b])) + + CUDA Example-2:: + + #This example shows how to use multi dimension data with CUDA. + import jittor as jt + from jittor import Function + jt.flags.use_cuda = 1 + + class Func(Function): + def execute(self, a, b): + self.save_vars = a, b + return jt.code(a.shape, a.dtype, [a,b], + cuda_src=""" + __global__ static void kernel1(@ARGS_DEF) { + @PRECALC + for (int i=blockIdx.x; i>>(@ARGS); + """) + + def grad(self, grad): + a, b = self.save_vars + return jt.code([a.shape, b.shape], [a.dtype, b.dtype], [a, b, grad], + cuda_src=""" + __global__ static void kernel2(@ARGS_DEF) { + @PRECALC + for (int i=blockIdx.x; i>>(@ARGS); + """) + + a = jt.random((100,100)) + b = jt.random((100,100)) + func = Func() + c = func(a,b) + print(c) + print(jt.grad(c, [a, b]))''' + ... +@overload +def code(shapes: List[Tuple[int]], dtypes: List[str], inputs: List[Var]={}, cpu_src: str="", cpu_grad_src: List[str]={}, cpu_header: str="", cuda_src: str="", cuda_grad_src: List[str]={}, cuda_header: str="")-> Tuple[Var]: + '''Document: + * + Code Operator for easily customized op. + + ---------------- + + * [in] shape: the output shape, a integer array + + * [in] dtype: the output data type + + * [in] inputs: A list of input jittor Vars + + * [in] cpu_src: cpu source code string, buildin value: + + * in{x}, in{x}_shape{y}, in{x}_stride{y}, in{x}_type, in{x}_p, @in0(...) + * out{x}, out{x}_shape{y}, out{x}_stride{y}, out{x}_type, out{x}_p, @out0(...) + * out, out_shape{y}, out_stride{y}, out_type, out_p, @out(...) + + * [in] cpu_header: cpu header code string. + + * [in] cuda_src: cuda source code string. + + * [in] cuda_header: cuda header code string. + + ---------------- + + Example-1:: + + from jittor import Function + import jittor as jt + + class Func(Function): + def execute(self, x): + self.save_vars = x + return jt.code(x.shape, x.dtype, [x], + cpu_src=""" + for (int i=0; i + @alias(a, in0) + @alias(b, out) + """, + cpu_src=""" + for (int i=0; i + using namespace std; + """, + cpu_src=""" + @alias(a, in0) + @alias(b, out0) + @alias(c, out1) + @b(0) = @c(0) = @a(0); + for (int i=0; i0) + @b(num_b++) = @a(i); + else + @c(num_c++) = @a(i); + } + b->set_shape({num_b}); + c->set_shape({num_c}); + """ + ) + assert (b.data == [5,3,1]).all() + assert (c.data == [-4,-2]).all() + + Example-5:: + + # This example shows how to customize code op + # compilation flags, such as add include search + # path, add definitions, or any command line options + + a = jt.random([10]) + b = jt.code(a.shape, a.dtype, [a], + cpu_src=""" + @out0(0) = HAHAHA; + """) + # HAHAHA is defined in flags below + # /any/include/path can be change to any path you want to include + b.compile_options = {"FLAGS: -DHAHAHA=233 -I/any/include/path ": 1} + print(b[0]) + # will output 233 + + + CUDA Example-1:: + + #This example shows how to use CUDA in code op. + import jittor as jt + from jittor import Function + jt.flags.use_cuda = 1 + + class Func(Function): + def execute(self, a, b): + self.save_vars = a, b + return jt.code(a.shape, a.dtype, [a,b], + cuda_src=""" + __global__ static void kernel1(@ARGS_DEF) { + @PRECALC + int i = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (; i>>(@ARGS); + """) + + def grad(self, grad): + a, b = self.save_vars + return jt.code([a.shape, b.shape], [a.dtype, b.dtype], [a, b, grad], + cuda_src=""" + __global__ static void kernel2(@ARGS_DEF) { + @PRECALC + int i = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (; i>>(@ARGS); + """) + + a = jt.random([100000]) + b = jt.random([100000]) + func = Func() + c = func(a,b) + print(c) + print(jt.grad(c, [a, b])) + + CUDA Example-2:: + + #This example shows how to use multi dimension data with CUDA. + import jittor as jt + from jittor import Function + jt.flags.use_cuda = 1 + + class Func(Function): + def execute(self, a, b): + self.save_vars = a, b + return jt.code(a.shape, a.dtype, [a,b], + cuda_src=""" + __global__ static void kernel1(@ARGS_DEF) { + @PRECALC + for (int i=blockIdx.x; i>>(@ARGS); + """) + + def grad(self, grad): + a, b = self.save_vars + return jt.code([a.shape, b.shape], [a.dtype, b.dtype], [a, b, grad], + cuda_src=""" + __global__ static void kernel2(@ARGS_DEF) { + @PRECALC + for (int i=blockIdx.x; i>>(@ARGS); + """) + + a = jt.random((100,100)) + b = jt.random((100,100)) + func = Func() + c = func(a,b) + print(c) + print(jt.grad(c, [a, b]))''' + ... +@overload +def code(inputs: List[Var], outputs: List[Var], cpu_src: str="", cpu_grad_src: List[str]={}, cpu_header: str="", cuda_src: str="", cuda_grad_src: List[str]={}, cuda_header: str="")-> Tuple[Var]: + '''Document: + * + Code Operator for easily customized op. + + ---------------- + + * [in] shape: the output shape, a integer array + + * [in] dtype: the output data type + + * [in] inputs: A list of input jittor Vars + + * [in] cpu_src: cpu source code string, buildin value: + + * in{x}, in{x}_shape{y}, in{x}_stride{y}, in{x}_type, in{x}_p, @in0(...) + * out{x}, out{x}_shape{y}, out{x}_stride{y}, out{x}_type, out{x}_p, @out0(...) + * out, out_shape{y}, out_stride{y}, out_type, out_p, @out(...) + + * [in] cpu_header: cpu header code string. + + * [in] cuda_src: cuda source code string. + + * [in] cuda_header: cuda header code string. + + ---------------- + + Example-1:: + + from jittor import Function + import jittor as jt + + class Func(Function): + def execute(self, x): + self.save_vars = x + return jt.code(x.shape, x.dtype, [x], + cpu_src=""" + for (int i=0; i + @alias(a, in0) + @alias(b, out) + """, + cpu_src=""" + for (int i=0; i + using namespace std; + """, + cpu_src=""" + @alias(a, in0) + @alias(b, out0) + @alias(c, out1) + @b(0) = @c(0) = @a(0); + for (int i=0; i0) + @b(num_b++) = @a(i); + else + @c(num_c++) = @a(i); + } + b->set_shape({num_b}); + c->set_shape({num_c}); + """ + ) + assert (b.data == [5,3,1]).all() + assert (c.data == [-4,-2]).all() + + Example-5:: + + # This example shows how to customize code op + # compilation flags, such as add include search + # path, add definitions, or any command line options + + a = jt.random([10]) + b = jt.code(a.shape, a.dtype, [a], + cpu_src=""" + @out0(0) = HAHAHA; + """) + # HAHAHA is defined in flags below + # /any/include/path can be change to any path you want to include + b.compile_options = {"FLAGS: -DHAHAHA=233 -I/any/include/path ": 1} + print(b[0]) + # will output 233 + + + CUDA Example-1:: + + #This example shows how to use CUDA in code op. + import jittor as jt + from jittor import Function + jt.flags.use_cuda = 1 + + class Func(Function): + def execute(self, a, b): + self.save_vars = a, b + return jt.code(a.shape, a.dtype, [a,b], + cuda_src=""" + __global__ static void kernel1(@ARGS_DEF) { + @PRECALC + int i = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (; i>>(@ARGS); + """) + + def grad(self, grad): + a, b = self.save_vars + return jt.code([a.shape, b.shape], [a.dtype, b.dtype], [a, b, grad], + cuda_src=""" + __global__ static void kernel2(@ARGS_DEF) { + @PRECALC + int i = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (; i>>(@ARGS); + """) + + a = jt.random([100000]) + b = jt.random([100000]) + func = Func() + c = func(a,b) + print(c) + print(jt.grad(c, [a, b])) + + CUDA Example-2:: + + #This example shows how to use multi dimension data with CUDA. + import jittor as jt + from jittor import Function + jt.flags.use_cuda = 1 + + class Func(Function): + def execute(self, a, b): + self.save_vars = a, b + return jt.code(a.shape, a.dtype, [a,b], + cuda_src=""" + __global__ static void kernel1(@ARGS_DEF) { + @PRECALC + for (int i=blockIdx.x; i>>(@ARGS); + """) + + def grad(self, grad): + a, b = self.save_vars + return jt.code([a.shape, b.shape], [a.dtype, b.dtype], [a, b, grad], + cuda_src=""" + __global__ static void kernel2(@ARGS_DEF) { + @PRECALC + for (int i=blockIdx.x; i>>(@ARGS); + """) + + a = jt.random((100,100)) + b = jt.random((100,100)) + func = Func() + c = func(a,b) + print(c) + print(jt.grad(c, [a, b]))''' + ... +def copy(x: Var)-> Var: + ... +def setitem(x: Var, slices: slice, y: Var, op: str="void")-> Var: + ... +@overload +def broadcast(x: Var, shape: Tuple[int], dims: Tuple[int]=())-> Var: + '''Document: + * + Broadcast ``x`` to a given shape. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] shape: the output shape. + + * [in] dims: specifies the new dimension in the output shape, an integer array. + + ---------------- + + Example-1:: + >>> x = jt.randint(0, 10, shape=(2, 2)) + >>> x + jt.Var([[8 1] + [7 6]], dtype=int32) + >>> jt.broadcast(x, shape=(2, 3, 2), dims=[1]) + jt.Var([[[8 1] + [8 1] + [8 1]], + [[7 6] + [7 6] + [7 6]]], dtype=int32)''' + ... +@overload +def broadcast(x: Var, y: Var, dims: Tuple[int]=())-> Var: + '''Document: + * + Broadcast ``x`` to the same shape as ``y``. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] y: the reference jt.Var. + + * [in] dims: specifies the new dimension in the output shape, an integer array. + + ---------------- + + .. note:: + jt.broadcast_var(x, y, dims) is an alias of jt.broadcast(x, y, dims) + + Example-1:: + >>> x = jt.randint(0, 10, shape=(2, 2)) + >>> x + jt.Var([[8 1] + [7 6]], dtype=int32) + >>> y = jt.randint(0, 10, shape=(2, 3, 2)) + >>> jt.broadcast(x, y, dims=[1]) + jt.Var([[[8 1] + [8 1] + [8 1]], + [[7 6] + [7 6] + [7 6]]], dtype=int32) + >>> jt.broadcast_var(x, y, dims=[1]) + jt.Var([[[8 1] + [8 1] + [8 1]], + [[7 6] + [7 6] + [7 6]]], dtype=int32)''' + ... +def broadcast_var(x: Var, y: Var, dims: Tuple[int]=())-> Var: + '''Document: + * + Broadcast ``x`` to the same shape as ``y``. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] y: the reference jt.Var. + + * [in] dims: specifies the new dimension in the output shape, an integer array. + + ---------------- + + .. note:: + jt.broadcast_var(x, y, dims) is an alias of jt.broadcast(x, y, dims) + + Example-1:: + >>> x = jt.randint(0, 10, shape=(2, 2)) + >>> x + jt.Var([[8 1] + [7 6]], dtype=int32) + >>> y = jt.randint(0, 10, shape=(2, 3, 2)) + >>> jt.broadcast(x, y, dims=[1]) + jt.Var([[[8 1] + [8 1] + [8 1]], + [[7 6] + [7 6] + [7 6]]], dtype=int32) + >>> jt.broadcast_var(x, y, dims=[1]) + jt.Var([[[8 1] + [8 1] + [8 1]], + [[7 6] + [7 6] + [7 6]]], dtype=int32)''' + ... +def reshape(x: Var, shape: Tuple[int])-> Var: + '''Document: + * + Returns a tensor with the same data and number of elements as input, but with the specified shape. + + A single dimension may be -1, in which case it's inferred from the remaining dimensions and the number of elements in input. + + ---------------- + + * [in] x: the input jt.Var + + * [in] shape: the output shape, an integer array + + ---------------- + + Example-1:: + >>> a = jt.randint(0, 10, shape=(12,)) + >>> a + jt.Var([4 0 8 4 6 3 1 8 1 1 2 2], dtype=int32) + >>> jt.reshape(a, (3, 4)) + jt.Var([[4 0 8 4] + [6 3 1 8] + [1 1 2 2]], dtype=int32) + >>> jt.reshape(a, (-1, 6)) + jt.Var([[4 0 8 4 6 3] + [1 8 1 1 2 2]], dtype=int32)''' + ... +def empty(shape: Tuple[int], dtype: str="float32")-> Var: + ... +def reindex_reduce(y: Var, op: str, shape: Tuple[int], indexes: List[str], overflow_conditions: List[str]={}, extras: List[Var]={})-> Var: + '''Document: + * + Reindex Reduce Operator is a many-to-one map operator. + It performs equivalent Python-pseudo implementation below:: + + # input is y, output is x + n = len(y.shape)-1 + m = len(shape)-1 + k = len(overflow_conditions)-1 + x = np.zeros(shape, y.dtype) + x[:] = initial_value(op) + for i0 in range(y.shape[0]): # 1-st loop + for i1 in range(y.shape[1]): # 2-nd loop + ...... # many loops + for in in range(y.shape[n]) # n+1 -th loop + # indexes[i] is a c++ style integer expression consisting of i0,i1,...,in + xi0,xi1,...,xim = indexes[0],indexes[1],...,indexes[m] + if not is_overflow(xi0,xi1,...,xim): + x[xi0,xi1,...,xim] = op(x[xi0,xi1,...,xim], y[i0,i1,...,in]) + + # is_overflow is defined as following + def is_overflow(xi0,xi1,...,xim): + return ( + xi0 < 0 || xi0 >= shape[0] || + xi1 < 0 || xi1 >= shape[1] || + ...... + xim < 0 || xim >= shape[m] || + + # overflow_conditions[i] is a c++ style boolean expression consisting of i0,i1,...,in + overflow_conditions[0] || + overflow_conditions[1] || + ...... + overflow_conditions[k] + ) + + * [in] y: A input jittor Var + + * [in] op: a string represent the reduce operation type + + * [in] shape: the output shape, a integer array + + * [in] indexes: array of c++ style integer expression, its length should be the same with length of output shape, some buildin variables it can use are:: + + XDIM, xshape0, ..., xshapem, xstride0, ..., xstridem + YDIM, yshape0, ..., yshapen, ystride0, ..., ystriden + i0, i1, ..., in + @e0(...), @e1(...) for extras input index + e0p, e1p , ... for extras input pointer + + * [in] overflow_conditions: array of c++ style boolean expression, it length can be vary. the buildin variables it can use are the same with indexes. + + * [in] extras: extra var used for index + + Example + + Pooling implemented by reindex operation:: + + def pool(x, size, op): + N,H,W,C = x.shape + h = (H+size-1)//size + w = (W+size-1)//size + return x.reindex_reduce(op, [N,h,w,C], [ + "i0", # Nid + f"i1/{size}", # Hid + f"i2/{size}", # Wid + "i3", # Cid + ])''' + ... +class Var: + '''Variable that stores multi-dimensional data.''' + def ternary(self, x: Var, y: Var)-> Var: ... + @overload + def reindex(self, shape: Tuple[int], indexes: List[str], overflow_value: float=0, overflow_conditions: List[str]={}, extras: List[Var]={})-> Var: + '''Document: + * + Reindex Operator is a one-to-many map operator. + It performs equivalent Python-pseudo implementation below:: + + # input is x, output is y + n = len(shape)-1 + m = len(x.shape)-1 + k = len(overflow_conditions)-1 + y = np.zeros(shape, x.dtype) + for i0 in range(shape[0]): # 1-st loop + for i1 in range(shape[1]): # 2-nd loop + ...... # many loops + for in in range(shape[n]) # n+1 -th loop + if is_overflow(i0,i1,...,in): + y[i0,i1,...,in] = overflow_value + else: + # indexes[i] is a c++ style integer expression consisting of i0,i1,...,in + y[i0,i1,...,in] = x[indexes[0],indexes[1],...,indexes[m]] + + # is_overflow is defined as following + def is_overflow(i0,i1,...,in): + return ( + indexes[0] < 0 || indexes[0] >= x.shape[0] || + indexes[1] < 0 || indexes[1] >= x.shape[1] || + ...... + indexes[m] < 0 || indexes[m] >= x.shape[m] || + + # overflow_conditions[i] is a c++ style boolean expression consisting of i0,i1,...,in + overflow_conditions[0] || + overflow_conditions[1] || + ...... + overflow_conditions[k] + ) + ---------------- + * [in] x: A input jittor Var + + * [in] shape: the output shape, a integer array + + * [in] indexes: array of c++ style integer expression, its length should be the same with the number of dimension of x, some buildin variables it can use are:: + + XDIM, xshape0, ..., xshapen, xstride0, ..., xstriden + YDIM, yshape0, ..., yshapem, ystride0, ..., ystridem + i0, i1, ..., in + @e0(...), @e1(...) for extras input index + e0p, e1p , ... for extras input pointer + + * [in] overflow_value: overflow value + + * [in] overflow_conditions: array of c++ style boolean expression, it length can be vary. the buildin variables it can use are the same with indexes + + * [in] extras: extra var used for index + + ---------------- + Example + Convolution implemented by reindex operation:: + + def conv(x, w): + N,H,W,C = x.shape + Kh, Kw, _C, Kc = w.shape + assert C==_C + xx = x.reindex([N,H-Kh+1,W-Kw+1,Kh,Kw,C,Kc], [ + 'i0', # Nid + 'i1+i3', # Hid+Khid + 'i2+i4', # Wid+KWid + 'i5', # Cid + ]) + ww = w.broadcast_var(xx) + yy = xx*ww + y = yy.sum([3,4,5]) # Kh, Kw, C + return y, yy''' + ... + @overload + def reindex(self, indexes: List[Var], overflow_value: float=0, overflow_conditions: List[str]={})-> Var: + '''Document: + * Alias x.reindex([i,j,k]) -> + x.reindex(i.shape, ['@e0(...)','@e1(...)','@e2(...)',], extras=[i,j,k])''' + ... + def reindex_var(self, indexes: List[Var], overflow_value: float=0, overflow_conditions: List[str]={})-> Var: + '''Document: + * Alias x.reindex([i,j,k]) -> + x.reindex(i.shape, ['@e0(...)','@e1(...)','@e2(...)',], extras=[i,j,k])''' + ... + @overload + def index(self, dim: int, dtype: str="int32")-> Var: + '''Document: + * shape dependency version of index op + jt.index_var(a, 1) similar with jt.index(a.shape, 1)''' + ... + @overload + def index(self, dtype: str="int32")-> Tuple[Var]: + '''Document: + * shape dependency version of index op + jt.index_var(a) similar with jt.index(a.shape)''' + ... + @overload + def index_var(self, dim: int, dtype: str="int32")-> Var: + '''Document: + * shape dependency version of index op + jt.index_var(a, 1) similar with jt.index(a.shape, 1)''' + ... + @overload + def index_var(self, dtype: str="int32")-> Tuple[Var]: + '''Document: + * shape dependency version of index op + jt.index_var(a) similar with jt.index(a.shape)''' + ... + def binary(self, y: Var, p: str)-> Var: ... + def pow(self, y: Var)-> Var: + '''Document: + * + Computes ``x^y``, element-wise. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... + def maximum(self, y: Var)-> Var: + '''Document: + * + Returns the element-wise maximum of ``x`` and ``y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... + def minimum(self, y: Var)-> Var: + '''Document: + * + Returns the element-wise minimum of ``x`` and ``y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... + def add(self, y: Var)-> Var: + '''Document: + * + Element-wise adds ``x`` and ``y`` and returns a new Var. + + This operation is equivalent to ``x + y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... + def subtract(self, y: Var)-> Var: + '''Document: + * + Element-wise subtract ``y`` from ``x`` and returns a new Var. + + This operation is equivalent to ``x - y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... + def multiply(self, y: Var)-> Var: + '''Document: + * + Element-wise muliplies ``x`` with ``y`` and returns a new Var. + + This operation is equivalent to ``x * y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... + def divide(self, y: Var)-> Var: + '''Document: + * + Element-wise divide ``x`` by ``y`` and returns a new Var. + + This operation is equivalent to ``x / y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var. + + ---------------- + + Example-1:: + >>> a = jt.empty((3,), dtype=jt.int32) + >>> a + jt.Var([707406378 707406378 707406378], dtype=int32) + >>> b = jt.empty((3,), dtype=jt.int32) + >>> b + jt.Var([674510453 171649398 538976288], dtype=int32) + >>> jt.divide(a, b) + jt.Var([1.0487701 4.1212287 1.3125001], dtype=float32) + >>> a / b + jt.Var([1.0487701 4.1212287 1.3125001], dtype=float32) + + .. note :: + returns float value even if the dtype of input Vars are both integers. + @see jt.ops.floor_divide() for floor division.''' + ... + def floor_divide(self, y: Var)-> Var: + '''Document: + * + Element-wise divide ``x`` by ``y`` and returns the floor of the result. + + This operation is equivalent to ``x // y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var. + + ---------------- + + Example-1:: + >>> a = jt.randint(1, 10, (3,), dtype=jt.int32) + >>> a + jt.Var([9 2 7], dtype=int32) + >>> b = jt.randint(1, 10, (3,), dtype=jt.int32) + >>> b + jt.Var([6 4 6], dtype=int32) + >>> jt.floor_divide(a, b) + jt.Var([1 0 1], dtype=int32) + >>> a // b + jt.Var([1 0 1], dtype=int32)''' + ... + def mod(self, y: Var)-> Var: + '''Document: + * + Returns the element-wise remainder of division. + + This operation is equivalent to ``x % y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var. + + ---------------- + + Example-1:: + >>> a = jt.rand(3) + >>> a + jt.Var([0.3989529 0.20159635 0.22973768], dtype=float32) + >>> b = jt.rand(3) + >>> b + jt.Var([0.20121202 0.7704864 0.5654395 ], dtype=float32) + >>> jt.mod(a, b) + jt.Var([0.19774088 0.20159635 0.22973768], dtype=float32) + >>> a % b + jt.Var([0.19774088 0.20159635 0.22973768], dtype=float32)''' + ... + def less(self, y: Var)-> Var: + '''Document: + * + Returns ``x < y`` element-wise. + + This operation is equivalent to ``x < y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... + def less_equal(self, y: Var)-> Var: + '''Document: + * + Returns ``x <= y`` element-wise. + + This operation is equivalent to ``x <= y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... + def greater(self, y: Var)-> Var: + '''Document: + * + Returns ``x > y`` element-wise. + + This operation is equivalent to ``x > y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... + def greater_equal(self, y: Var)-> Var: + '''Document: + * + Returns ``x >= y`` element-wise. + + This operation is equivalent to ``x >= y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... + def equal(self, y: Var)-> Var: + '''Document: + * + Returns ``x == y`` element-wise. + + This operation is equivalent to ``x == y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... + def not_equal(self, y: Var)-> Var: + '''Document: + * + Returns ``x != y`` element-wise. + + This operation is equivalent to ``x != y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var.''' + ... + def left_shift(self, y: Var)-> Var: + '''Document: + * + Shifts the bits of ``x`` to the left by ``y``. + + Bits are shifted to the left by appending ``y`` 0s at the right of ``x``. + This operation is equivalent to ``x << y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var (int32 or int64). + + * [in] y: the second input, a python number or jt.Var (int32 or int64). + + ---------------- + + Example-1:: + >>> a = jt.randint(0, 10, shape=(3,)) + >>> a + jt.Var([7 6 7], dtype=int32) + >>> b = jt.randint(0, 10, shape=(3,)) + >>> b + jt.Var([3 9 8], dtype=int32) + >>> jt.left_shift(a, b) + jt.Var([ 56 3072 1792], dtype=int32) + >>> a << b + jt.Var([ 56 3072 1792], dtype=int32)''' + ... + def right_shift(self, y: Var)-> Var: + '''Document: + * + Shifts the bits of ``x`` to the right by ``y``. + + This operation is equivalent to ``x >> y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var (int32 or int64). + + * [in] y: the second input, a python number or jt.Var (int32 or int64). + + ---------------- + + Example-1:: + >>> a = jt.randint(0, 1024, shape=(3,)) + >>> a + jt.Var([439 113 92], dtype=int32) + >>> b = jt.randint(0, 10, shape=(3,)) + >>> b + jt.Var([6 8 4], dtype=int32) + >>> jt.right_shift(a, b) + jt.Var([6 0 5], dtype=int32)''' + ... + def logical_and(self, y: Var)-> Var: + '''Document: + * + Returns the element-wise logical AND of the inputs. + + ---------------- + + * [in] x: the first input, jt.Var. + + * [in] y: the second input, jt.Var.''' + ... + def logical_or(self, y: Var)-> Var: + '''Document: + * + Returns the element-wise logical OR of the inputs. + + ---------------- + + * [in] x: the first input, jt.Var. + + * [in] y: the second input, jt.Var.''' + ... + def logical_xor(self, y: Var)-> Var: + '''Document: + * + Returns the element-wise logical XOR of the inputs. + + ---------------- + + * [in] x: the first input, jt.Var. + + * [in] y: the second input, jt.Var.''' + ... + def bitwise_and(self, y: Var)-> Var: + '''Document: + * + Computes the bitwise AND of x and y. + + ---------------- + + * [in] x: the first input, jt.Var (integal or boolean). + + * [in] y: the second input, jt.Var (integal or boolean).''' + ... + def bitwise_or(self, y: Var)-> Var: + '''Document: + * + Computes the bitwise OR of x and y. + + ---------------- + + * [in] x: the first input, jt.Var (integal or boolean). + + * [in] y: the second input, jt.Var (integal or boolean).''' + ... + def bitwise_xor(self, y: Var)-> Var: + '''Document: + * + Computes the bitwise XOR of x and y. + + ---------------- + + * [in] x: the first input, jt.Var (integal or boolean). + + * [in] y: the second input, jt.Var (integal or boolean).''' + ... + def tape(self)-> Var: ... + @overload + def where(self, dtype: str="int32")-> Tuple[Var]: + '''Document: + * + Where Operator generate index of true condition. + + * [in] cond: condition for index generation + + * [in] dtype: type of return indexes + + * [out] out: return an array of indexes, same length with number of dims of cond + + Example:: + + jt.where([[0,0,1],[1,0,0]]) + # return [jt.Var([0 1], dtype=int32), jt.Var([2 0], dtype=int32)]''' + ... + @overload + def where(self, x: Var, y: Var)-> Var: + '''Document: + * + * Condition operator, perform cond ? x : y + *''' + ... + def argsort(self, dim: int=-1, descending: bool=False, dtype: str="int32")-> Tuple[Var]: + '''Document: + * + Argsort Operator Perform an indirect sort by given key or compare function. + + x is input, y is output index, satisfy: + + x[y[0]] <= x[y[1]] <= x[y[2]] <= ... <= x[y[n]] + + or + + key(y[0]) <= key(y[1]) <= key(y[2]) <= ... <= key(y[n]) + + or + + compare(y[0], y[1]) && compare(y[1], y[2]) && ... + + * [in] x: input var for sort + + * [in] dim: sort alone which dim + + * [in] descending: the elements are sorted in descending order or not(default False). + + * [in] dtype: type of return indexes + + * [out] index: index have the same size with sorted dim + + * [out] value: sorted value + + + Example:: + + index, value = jt.argsort([11,13,12]) + # return [0 2 1], [11 12 13] + index, value = jt.argsort([11,13,12], descending=True) + # return [1 2 0], [13 12 11] + index, value = jt.argsort([[11,13,12], [12,11,13]]) + # return [[0 2 1],[1 0 2]], [[11 12 13],[11 12 13]] + index, value = jt.argsort([[11,13,12], [12,11,13]], dim=0) + # return [[0 1 0],[1 0 1]], [[11 11 12],[12 13 13]]''' + ... + def fetch(self, func: Callable)-> Var: ... + def arg_reduce(self, op: str, dim: int, keepdims: bool)-> Tuple[Var]: + '''Document: + * + Returns the indices of the maximum / minimum of the input across a dimension. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] op: "max" or "min". + + * [in] dim: int. Specifies which dimension to be reduced. + + * [in] keepdims: bool. Whether the output has ``dim`` retained or not. + + ---------------- + + Example-1:: + >>> x = jt.randint(0, 10, shape=(2, 3)) + >>> x + jt.Var([[4 2 5] + [6 7 1]], dtype=int32) + >>> jt.arg_reduce(x, 'max', dim=1, keepdims=False) + [jt.Var([2 1], dtype=int32), jt.Var([5 7], dtype=int32)] + >>> jt.arg_reduce(x, 'min', dim=1, keepdims=False) + [jt.Var([1 2], dtype=int32), jt.Var([2 1], dtype=int32)]''' + ... + @overload + def reduce(self, op: str, dim: int, keepdims: bool=False)-> Var: ... + @overload + def reduce(self, op: str, dims: Tuple[int]=(), keepdims: bool=False)-> Var: ... + @overload + def max(self, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Returns the maximum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.max(x) + jt.Var([4], dtype=int32) + >>> x.max() + jt.Var([4], dtype=int32) + >>> x.max(dim=1) + jt.Var([4 4], dtype=int32) + >>> x.max(dim=1, keepdims=True) + jt.Var([[4] + [4]], dtype=int32)''' + ... + @overload + def max(self, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Returns the maximum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.max(x) + jt.Var([4], dtype=int32) + >>> x.max() + jt.Var([4], dtype=int32) + >>> x.max(dim=1) + jt.Var([4 4], dtype=int32) + >>> x.max(dim=1, keepdims=True) + jt.Var([[4] + [4]], dtype=int32)''' + ... + @overload + def max(self, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Returns the maximum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.max(x) + jt.Var([4], dtype=int32) + >>> x.max() + jt.Var([4], dtype=int32) + >>> x.max(dim=1) + jt.Var([4 4], dtype=int32) + >>> x.max(dim=1, keepdims=True) + jt.Var([[4] + [4]], dtype=int32)''' + ... + @overload + def reduce_maximum(self, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Returns the maximum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.max(x) + jt.Var([4], dtype=int32) + >>> x.max() + jt.Var([4], dtype=int32) + >>> x.max(dim=1) + jt.Var([4 4], dtype=int32) + >>> x.max(dim=1, keepdims=True) + jt.Var([[4] + [4]], dtype=int32)''' + ... + @overload + def reduce_maximum(self, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Returns the maximum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.max(x) + jt.Var([4], dtype=int32) + >>> x.max() + jt.Var([4], dtype=int32) + >>> x.max(dim=1) + jt.Var([4 4], dtype=int32) + >>> x.max(dim=1, keepdims=True) + jt.Var([[4] + [4]], dtype=int32)''' + ... + @overload + def reduce_maximum(self, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Returns the maximum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.max(x) + jt.Var([4], dtype=int32) + >>> x.max() + jt.Var([4], dtype=int32) + >>> x.max(dim=1) + jt.Var([4 4], dtype=int32) + >>> x.max(dim=1, keepdims=True) + jt.Var([[4] + [4]], dtype=int32)''' + ... + @overload + def min(self, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Returns the minimum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.min(x) + jt.Var([0], dtype=int32) + >>> x.min() + jt.Var([0], dtype=int32) + >>> x.min(dim=1) + jt.Var([1 0], dtype=int32) + >>> x.min(dim=1, keepdims=True) + jt.Var([[1] + [0]], dtype=int32)''' + ... + @overload + def min(self, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Returns the minimum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.min(x) + jt.Var([0], dtype=int32) + >>> x.min() + jt.Var([0], dtype=int32) + >>> x.min(dim=1) + jt.Var([1 0], dtype=int32) + >>> x.min(dim=1, keepdims=True) + jt.Var([[1] + [0]], dtype=int32)''' + ... + @overload + def min(self, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Returns the minimum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.min(x) + jt.Var([0], dtype=int32) + >>> x.min() + jt.Var([0], dtype=int32) + >>> x.min(dim=1) + jt.Var([1 0], dtype=int32) + >>> x.min(dim=1, keepdims=True) + jt.Var([[1] + [0]], dtype=int32)''' + ... + @overload + def reduce_minimum(self, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Returns the minimum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.min(x) + jt.Var([0], dtype=int32) + >>> x.min() + jt.Var([0], dtype=int32) + >>> x.min(dim=1) + jt.Var([1 0], dtype=int32) + >>> x.min(dim=1, keepdims=True) + jt.Var([[1] + [0]], dtype=int32)''' + ... + @overload + def reduce_minimum(self, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Returns the minimum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.min(x) + jt.Var([0], dtype=int32) + >>> x.min() + jt.Var([0], dtype=int32) + >>> x.min(dim=1) + jt.Var([1 0], dtype=int32) + >>> x.min(dim=1, keepdims=True) + jt.Var([[1] + [0]], dtype=int32)''' + ... + @overload + def reduce_minimum(self, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Returns the minimum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.min(x) + jt.Var([0], dtype=int32) + >>> x.min() + jt.Var([0], dtype=int32) + >>> x.min(dim=1) + jt.Var([1 0], dtype=int32) + >>> x.min(dim=1, keepdims=True) + jt.Var([[1] + [0]], dtype=int32)''' + ... + @overload + def sum(self, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Returns the sum of the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.sum(x) + jt.Var([13], dtype=int32) + >>> x.sum() + jt.Var([13], dtype=int32) + >>> x.sum(dim=1) + jt.Var([7 6], dtype=int32) + >>> x.sum(dim=1, keepdims=True) + jt.Var([[7] + [6]], dtype=int32)''' + ... + @overload + def sum(self, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Returns the sum of the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.sum(x) + jt.Var([13], dtype=int32) + >>> x.sum() + jt.Var([13], dtype=int32) + >>> x.sum(dim=1) + jt.Var([7 6], dtype=int32) + >>> x.sum(dim=1, keepdims=True) + jt.Var([[7] + [6]], dtype=int32)''' + ... + @overload + def sum(self, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Returns the sum of the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.sum(x) + jt.Var([13], dtype=int32) + >>> x.sum() + jt.Var([13], dtype=int32) + >>> x.sum(dim=1) + jt.Var([7 6], dtype=int32) + >>> x.sum(dim=1, keepdims=True) + jt.Var([[7] + [6]], dtype=int32)''' + ... + @overload + def reduce_add(self, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Returns the sum of the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.sum(x) + jt.Var([13], dtype=int32) + >>> x.sum() + jt.Var([13], dtype=int32) + >>> x.sum(dim=1) + jt.Var([7 6], dtype=int32) + >>> x.sum(dim=1, keepdims=True) + jt.Var([[7] + [6]], dtype=int32)''' + ... + @overload + def reduce_add(self, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Returns the sum of the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.sum(x) + jt.Var([13], dtype=int32) + >>> x.sum() + jt.Var([13], dtype=int32) + >>> x.sum(dim=1) + jt.Var([7 6], dtype=int32) + >>> x.sum(dim=1, keepdims=True) + jt.Var([[7] + [6]], dtype=int32)''' + ... + @overload + def reduce_add(self, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Returns the sum of the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.sum(x) + jt.Var([13], dtype=int32) + >>> x.sum() + jt.Var([13], dtype=int32) + >>> x.sum(dim=1) + jt.Var([7 6], dtype=int32) + >>> x.sum(dim=1, keepdims=True) + jt.Var([[7] + [6]], dtype=int32)''' + ... + @overload + def prod(self, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Returns the product of all the elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[7 5 5] + [5 7 5]], dtype=int32) + >>> jt.prod(x) + jt.Var([30625], dtype=int32) + >>> x.prod() + jt.Var([30625], dtype=int32) + >>> x.prod(dim=1) + jt.Var([175 175], dtype=int32) + >>> x.prod(dim=1, keepdims=True) + jt.Var([[175] + [175]], dtype=int32)''' + ... + @overload + def prod(self, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Returns the product of all the elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[7 5 5] + [5 7 5]], dtype=int32) + >>> jt.prod(x) + jt.Var([30625], dtype=int32) + >>> x.prod() + jt.Var([30625], dtype=int32) + >>> x.prod(dim=1) + jt.Var([175 175], dtype=int32) + >>> x.prod(dim=1, keepdims=True) + jt.Var([[175] + [175]], dtype=int32)''' + ... + @overload + def prod(self, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Returns the product of all the elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[7 5 5] + [5 7 5]], dtype=int32) + >>> jt.prod(x) + jt.Var([30625], dtype=int32) + >>> x.prod() + jt.Var([30625], dtype=int32) + >>> x.prod(dim=1) + jt.Var([175 175], dtype=int32) + >>> x.prod(dim=1, keepdims=True) + jt.Var([[175] + [175]], dtype=int32)''' + ... + @overload + def product(self, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Returns the product of all the elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[7 5 5] + [5 7 5]], dtype=int32) + >>> jt.prod(x) + jt.Var([30625], dtype=int32) + >>> x.prod() + jt.Var([30625], dtype=int32) + >>> x.prod(dim=1) + jt.Var([175 175], dtype=int32) + >>> x.prod(dim=1, keepdims=True) + jt.Var([[175] + [175]], dtype=int32)''' + ... + @overload + def product(self, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Returns the product of all the elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[7 5 5] + [5 7 5]], dtype=int32) + >>> jt.prod(x) + jt.Var([30625], dtype=int32) + >>> x.prod() + jt.Var([30625], dtype=int32) + >>> x.prod(dim=1) + jt.Var([175 175], dtype=int32) + >>> x.prod(dim=1, keepdims=True) + jt.Var([[175] + [175]], dtype=int32)''' + ... + @overload + def product(self, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Returns the product of all the elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[7 5 5] + [5 7 5]], dtype=int32) + >>> jt.prod(x) + jt.Var([30625], dtype=int32) + >>> x.prod() + jt.Var([30625], dtype=int32) + >>> x.prod(dim=1) + jt.Var([175 175], dtype=int32) + >>> x.prod(dim=1, keepdims=True) + jt.Var([[175] + [175]], dtype=int32)''' + ... + @overload + def reduce_multiply(self, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Returns the product of all the elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[7 5 5] + [5 7 5]], dtype=int32) + >>> jt.prod(x) + jt.Var([30625], dtype=int32) + >>> x.prod() + jt.Var([30625], dtype=int32) + >>> x.prod(dim=1) + jt.Var([175 175], dtype=int32) + >>> x.prod(dim=1, keepdims=True) + jt.Var([[175] + [175]], dtype=int32)''' + ... + @overload + def reduce_multiply(self, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Returns the product of all the elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[7 5 5] + [5 7 5]], dtype=int32) + >>> jt.prod(x) + jt.Var([30625], dtype=int32) + >>> x.prod() + jt.Var([30625], dtype=int32) + >>> x.prod(dim=1) + jt.Var([175 175], dtype=int32) + >>> x.prod(dim=1, keepdims=True) + jt.Var([[175] + [175]], dtype=int32)''' + ... + @overload + def reduce_multiply(self, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Returns the product of all the elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[7 5 5] + [5 7 5]], dtype=int32) + >>> jt.prod(x) + jt.Var([30625], dtype=int32) + >>> x.prod() + jt.Var([30625], dtype=int32) + >>> x.prod(dim=1) + jt.Var([175 175], dtype=int32) + >>> x.prod(dim=1, keepdims=True) + jt.Var([[175] + [175]], dtype=int32)''' + ... + @overload + def reduce_logical_and(self, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Tests if all elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 1 1] + [0 1 0]], dtype=int32) + >>> jt.all_(x) + jt.Var([False], dtype=int32) + >>> x.all_() + jt.Var([False], dtype=int32) + >>> x.all_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.all_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... + @overload + def reduce_logical_and(self, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Tests if all elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 1 1] + [0 1 0]], dtype=int32) + >>> jt.all_(x) + jt.Var([False], dtype=int32) + >>> x.all_() + jt.Var([False], dtype=int32) + >>> x.all_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.all_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... + @overload + def reduce_logical_and(self, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Tests if all elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 1 1] + [0 1 0]], dtype=int32) + >>> jt.all_(x) + jt.Var([False], dtype=int32) + >>> x.all_() + jt.Var([False], dtype=int32) + >>> x.all_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.all_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... + @overload + def all_(self, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Tests if all elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 1 1] + [0 1 0]], dtype=int32) + >>> jt.all_(x) + jt.Var([False], dtype=int32) + >>> x.all_() + jt.Var([False], dtype=int32) + >>> x.all_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.all_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... + @overload + def all_(self, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Tests if all elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 1 1] + [0 1 0]], dtype=int32) + >>> jt.all_(x) + jt.Var([False], dtype=int32) + >>> x.all_() + jt.Var([False], dtype=int32) + >>> x.all_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.all_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... + @overload + def all_(self, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Tests if all elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 1 1] + [0 1 0]], dtype=int32) + >>> jt.all_(x) + jt.Var([False], dtype=int32) + >>> x.all_() + jt.Var([False], dtype=int32) + >>> x.all_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.all_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... + @overload + def reduce_logical_or(self, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Tests if any elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 0 1] + [0 0 0]], dtype=int32) + >>> jt.any_(x) + jt.Var([True], dtype=int32) + >>> x.any_() + jt.Var([True], dtype=int32) + >>> x.any_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.any_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... + @overload + def reduce_logical_or(self, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Tests if any elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 0 1] + [0 0 0]], dtype=int32) + >>> jt.any_(x) + jt.Var([True], dtype=int32) + >>> x.any_() + jt.Var([True], dtype=int32) + >>> x.any_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.any_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... + @overload + def reduce_logical_or(self, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Tests if any elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 0 1] + [0 0 0]], dtype=int32) + >>> jt.any_(x) + jt.Var([True], dtype=int32) + >>> x.any_() + jt.Var([True], dtype=int32) + >>> x.any_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.any_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... + @overload + def any_(self, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Tests if any elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 0 1] + [0 0 0]], dtype=int32) + >>> jt.any_(x) + jt.Var([True], dtype=int32) + >>> x.any_() + jt.Var([True], dtype=int32) + >>> x.any_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.any_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... + @overload + def any_(self, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Tests if any elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 0 1] + [0 0 0]], dtype=int32) + >>> jt.any_(x) + jt.Var([True], dtype=int32) + >>> x.any_() + jt.Var([True], dtype=int32) + >>> x.any_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.any_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... + @overload + def any_(self, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Tests if any elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 0 1] + [0 0 0]], dtype=int32) + >>> jt.any_(x) + jt.Var([True], dtype=int32) + >>> x.any_() + jt.Var([True], dtype=int32) + >>> x.any_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.any_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32)''' + ... + @overload + def reduce_logical_xor(self, dim: int, keepdims: bool=False)-> Var: ... + @overload + def reduce_logical_xor(self, dims: Tuple[int]=(), keepdims: bool=False)-> Var: ... + @overload + def reduce_logical_xor(self, dims_mask: int, keepdims_mask: int)-> Var: ... + @overload + def reduce_bitwise_and(self, dim: int, keepdims: bool=False)-> Var: ... + @overload + def reduce_bitwise_and(self, dims: Tuple[int]=(), keepdims: bool=False)-> Var: ... + @overload + def reduce_bitwise_and(self, dims_mask: int, keepdims_mask: int)-> Var: ... + @overload + def reduce_bitwise_or(self, dim: int, keepdims: bool=False)-> Var: ... + @overload + def reduce_bitwise_or(self, dims: Tuple[int]=(), keepdims: bool=False)-> Var: ... + @overload + def reduce_bitwise_or(self, dims_mask: int, keepdims_mask: int)-> Var: ... + @overload + def reduce_bitwise_xor(self, dim: int, keepdims: bool=False)-> Var: ... + @overload + def reduce_bitwise_xor(self, dims: Tuple[int]=(), keepdims: bool=False)-> Var: ... + @overload + def reduce_bitwise_xor(self, dims_mask: int, keepdims_mask: int)-> Var: ... + @overload + def mean(self, dim: int, keepdims: bool=False)-> Var: + '''Document: + * + Returns the mean value of the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[9 4 4] + [1 9 6]], dtype=int32) + >>> jt.mean(x) + jt.Var([5.5000005], dtype=float32) + >>> x.mean() + jt.Var([5.5000005], dtype=float32) + >>> x.mean(dim=1) + jt.Var([5.666667 5.3333335], dtype=float32) + >>> x.mean(dim=1, keepdims=True) + jt.Var([[5.666667 ] + [5.3333335]], dtype=float32)''' + ... + @overload + def mean(self, dims: Tuple[int]=(), keepdims: bool=False)-> Var: + '''Document: + * + Returns the mean value of the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[9 4 4] + [1 9 6]], dtype=int32) + >>> jt.mean(x) + jt.Var([5.5000005], dtype=float32) + >>> x.mean() + jt.Var([5.5000005], dtype=float32) + >>> x.mean(dim=1) + jt.Var([5.666667 5.3333335], dtype=float32) + >>> x.mean(dim=1, keepdims=True) + jt.Var([[5.666667 ] + [5.3333335]], dtype=float32)''' + ... + @overload + def mean(self, dims_mask: int, keepdims_mask: int)-> Var: + '''Document: + * + Returns the mean value of the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[9 4 4] + [1 9 6]], dtype=int32) + >>> jt.mean(x) + jt.Var([5.5000005], dtype=float32) + >>> x.mean() + jt.Var([5.5000005], dtype=float32) + >>> x.mean(dim=1) + jt.Var([5.666667 5.3333335], dtype=float32) + >>> x.mean(dim=1, keepdims=True) + jt.Var([[5.666667 ] + [5.3333335]], dtype=float32)''' + ... + def clone(self)-> Var: ... + def unary(self, op: str)-> Var: ... + def cast(self, op: str)-> Var: ... + def int8(self)-> Var: + '''Document: + * + Returns a copy of the input var, casted to int8. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.int8() + jt.Var([4 2 8], dtype=int8) + >>> jt.int8(x) + jt.Var([4 2 8], dtype=int8)''' + ... + def int16(self)-> Var: + '''Document: + * + Returns a copy of the input var, casted to int16. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.int16() + jt.Var([4 2 8], dtype=int16) + >>> jt.int16(x) + jt.Var([4 2 8], dtype=int16)''' + ... + def int32(self)-> Var: + '''Document: + * + Returns a copy of the input var, casted to int32. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.int() + jt.Var([4 2 8], dtype=int32) + >>> jt.int(x) + jt.Var([4 2 8], dtype=int32) + >>> x.int32() + jt.Var([4 2 8], dtype=int32) + >>> jt.int32(x) + jt.Var([4 2 8], dtype=int32) + >>> x.long() + jt.Var([4 2 8], dtype=int32) + >>> jt.long(x) + jt.Var([4 2 8], dtype=int32)''' + ... + def int64(self)-> Var: + '''Document: + * + Returns a copy of the input var, casted to int64. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.int64() + jt.Var([4 2 8], dtype=int64) + >>> jt.int64(x) + jt.Var([4 2 8], dtype=int64)''' + ... + def uint8(self)-> Var: + '''Document: + * + Returns a copy of the input var, casted to unsigned int8. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.uint8() + jt.Var([4 2 8], dtype=uint8) + >>> jt.uint8(x) + jt.Var([4 2 8], dtype=uint8)''' + ... + def uint16(self)-> Var: + '''Document: + * + Returns a copy of the input var, casted to unsigned int16. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.uint16() + jt.Var([4 2 8], dtype=uint16) + >>> jt.uint16(x) + jt.Var([4 2 8], dtype=uint16)''' + ... + def uint32(self)-> Var: + '''Document: + * + Returns a copy of the input var, casted to unsigned int32. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.uint32() + jt.Var([4 2 8], dtype=uint32) + >>> jt.uint32(x) + jt.Var([4 2 8], dtype=uint32)''' + ... + def uint64(self)-> Var: + '''Document: + * + Returns a copy of the input var, casted to unsigned int64. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.uint64() + jt.Var([4 2 8], dtype=uint64) + >>> jt.uint64(x) + jt.Var([4 2 8], dtype=uint64)''' + ... + def float16(self)-> Var: + '''Document: + * + Returns a copy of the input var, casted to float16 (half-precision float). + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.half() + jt.Var([4.094 2.008 8.48 ], dtype=float16) + >>> jt.half(x) + jt.Var([4.094 2.008 8.48 ], dtype=float16) + >>> x.float16() + jt.Var([4.094 2.008 8.48 ], dtype=float16) + >>> jt.float16(x) + jt.Var([4.094 2.008 8.48 ], dtype=float16)''' + ... + def float32(self)-> Var: + '''Document: + * + Returns a copy of the input var, casted to float32. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.arange(3) + >>> x + jt.Var([0 1 2], dtype=int32) + >>> x.float() + jt.Var([0. 1. 2.], dtype=float32) + >>> jt.float(x) + jt.Var([0. 1. 2.], dtype=float32) + >>> x.float32() + jt.Var([0. 1. 2.], dtype=float32) + >>> jt.float32(x) + jt.Var([0. 1. 2.], dtype=float32)''' + ... + def float64(self)-> Var: + '''Document: + * + Returns a copy of the input var, casted to float64 (double-precision float). + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.arange(3) + >>> x + jt.Var([0 1 2], dtype=int32) + >>> x.double() + jt.Var([0. 1. 2.], dtype=float64) + >>> jt.double(x) + jt.Var([0. 1. 2.], dtype=float64) + >>> x.float64() + jt.Var([0. 1. 2.], dtype=float64) + >>> jt.float64(x) + jt.Var([0. 1. 2.], dtype=float64)''' + ... + def abs(self)-> Var: + '''Document: + * + Returns the absolute value of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> jt.abs(jt.float32([-1, 0, 1])) + jt.Var([1. 0. 1.], dtype=float32)''' + ... + def negative(self)-> Var: + '''Document: + * + Returns the negative value of the input ``x``. + + This operator is equavilant to ``-x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> jt.negative(jt.float32([-1, 0, 1])) + jt.Var([ 1. -0. -1.], dtype=float32)''' + ... + def logical_not(self)-> Var: + '''Document: + * + Returns the logical NOT of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var, integal or boolean. + + ---------------- + + Example-1:: + >>> jt.logical_not(jt.int32([-1, 0, 1])) + jt.Var([False True False], dtype=bool)''' + ... + def bitwise_not(self)-> Var: + '''Document: + * + Returns the bitwise NOT of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var, integal or boolean. + + ---------------- + + Example-1:: + >>> jt.bitwise_not(jt.int32([1, 2, -3])) + jt.Var([-2 -3 2], dtype=int32)''' + ... + def log(self)-> Var: + '''Document: + * + Returns the natural logarithm of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) * 2 + >>> x + jt.Var([0.02863695 1.30122 1.6048753 1.140261 ], dtype=float32) + >>> jt.log(x) + jt.Var([-3.5530574 0.26330233 0.47304606 0.13125724], dtype=float32) + >>> x.log() + jt.Var([-3.5530574 0.26330233 0.47304606 0.13125724], dtype=float32)''' + ... + def exp(self)-> Var: + '''Document: + * + Returns the exponential of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) * 2 + >>> x + jt.Var([1.9841381 1.4103996 0.5855549 1.4212812], dtype=float32) + >>> jt.exp(x) + jt.Var([7.2727766 4.0975924 1.7959872 4.1424246], dtype=float32) + >>> x.exp() + jt.Var([7.2727766 4.0975924 1.7959872 4.1424246], dtype=float32)''' + ... + def sqrt(self)-> Var: + '''Document: + * + Returns the square root of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) * 2 + >>> x + jt.Var([0.81957287 0.5609612 0.07435933 1.7571875 ], dtype=float32) + >>> jt.sqrt(x) + jt.Var([0.90530264 0.7489734 0.27268907 1.3255895 ], dtype=float32) + >>> x.sqrt() + jt.Var([0.90530264 0.7489734 0.27268907 1.3255895 ], dtype=float32)''' + ... + def round(self)-> Var: + '''Document: + * + Returns the closest integer of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 2.101595 0.33055413 -0.44147047 -0.7720668 ], dtype=float32) + >>> jt.round(x) + jt.Var([ 2.0 0.0 0.0 -1.0], dtype=float32) + >>> x.round() + jt.Var([ 2.0 0.0 0.0 -1.0], dtype=float32)''' + ... + def floor(self)-> Var: + '''Document: + * + Returns the largest integer less than or equal to the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-1.0339162 -0.7259972 -0.9220003 -0.8449701], dtype=float32) + >>> jt.floor(x) + jt.Var([-2.0 -1.0 -1.0 -1.0], dtype=float32) + >>> x.floor + jt.Var([-2.0 -1.0 -1.0 -1.0], dtype=float32)''' + ... + def ceil(self)-> Var: + '''Document: + * + Returns the smallest integer greater than or equal to the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-1.0339162 -0.7259972 -0.9220003 -0.8449701], dtype=float32) + >>> jt.ceil(x) + jt.Var([-1.0 0.0 0.0 0.0], dtype=float32) + >>> x.ceil() + jt.Var([-1.0 0.0 0.0 0.0], dtype=float32)''' + ... + def round_int(self)-> Var: + '''Document: + * + Returns the closest integer of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 2.101595 0.33055413 -0.44147047 -0.7720668 ], dtype=float32) + >>> jt.round_int(x) + jt.Var([ 2 0 0 -1], dtype=int32) + >>> x.round_int + jt.Var([ 2 0 0 -1], dtype=int32)''' + ... + def floor_int(self)-> Var: + '''Document: + * + Returns the largest integer less than or equal to the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-1.0339162 -0.7259972 -0.9220003 -0.8449701], dtype=float32) + >>> jt.floor_int(x) + jt.Var([-2 -1 -1 -1], dtype=int32) + >>> x.floor_int + jt.Var([-2 -1 -1 -1], dtype=int32)''' + ... + def ceil_int(self)-> Var: + '''Document: + * + Returns the smallest integer greater than or equal to the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-1.0339162 -0.7259972 -0.9220003 -0.8449701], dtype=float32) + >>> jt.ceil_int(x) + jt.Var([-1 0 0 0], dtype=int32) + >>> x.ceil_int() + jt.Var([-1 0 0 0], dtype=int32)''' + ... + def sin(self)-> Var: + '''Document: + * + Returns the sine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.32893723 -0.7112559 -0.872391 1.8001337 ], dtype=float32) + >>> jt.sin(x) + jt.Var([ 0.32303742 -0.6527857 -0.76586854 0.9738172 ], dtype=float32) + >>> x.sin() + jt.Var([ 0.32303742 -0.6527857 -0.76586854 0.9738172 ], dtype=float32)''' + ... + def asin(self)-> Var: + '''Document: + * + Returns the arcsine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.09342023 -0.42522037 0.9264933 -0.785264 ], dtype=float32) + >>> jt.asin(x) + jt.Var([ 0.09355665 -0.43920535 1.1849847 -0.9031224 ], dtype=float32) + >>> x.asin() + jt.Var([ 0.09355665 -0.43920535 1.1849847 -0.9031224 ], dtype=float32)''' + ... + def arcsin(self)-> Var: + '''Document: + * + Returns the arcsine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.09342023 -0.42522037 0.9264933 -0.785264 ], dtype=float32) + >>> jt.asin(x) + jt.Var([ 0.09355665 -0.43920535 1.1849847 -0.9031224 ], dtype=float32) + >>> x.asin() + jt.Var([ 0.09355665 -0.43920535 1.1849847 -0.9031224 ], dtype=float32)''' + ... + def sinh(self)-> Var: + '''Document: + * + Returns the hyperbolic sine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.32893723 -0.7112559 -0.872391 1.8001337 ], dtype=float32) + >>> jt.sinh(x) + jt.Var([ 0.3349012 -0.77276015 -0.9873369 2.9425898 ], dtype=float32) + >>> x.sinh + jt.Var([ 0.3349012 -0.77276015 -0.9873369 2.9425898 ], dtype=float32)''' + ... + def asinh(self)-> Var: + '''Document: + * + Returns the inverse hyperbolic sine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-1.9749726 -0.52341473 0.8906148 1.0338128 ], dtype=float32) + >>> jt.asinh(x) + jt.Var([-1.4323865 -0.5020559 0.8018747 0.90508187], dtype=float32) + >>> x.asinh() + jt.Var([-1.4323865 -0.5020559 0.8018747 0.90508187], dtype=float32)''' + ... + def arcsinh(self)-> Var: + '''Document: + * + Returns the inverse hyperbolic sine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-1.9749726 -0.52341473 0.8906148 1.0338128 ], dtype=float32) + >>> jt.asinh(x) + jt.Var([-1.4323865 -0.5020559 0.8018747 0.90508187], dtype=float32) + >>> x.asinh() + jt.Var([-1.4323865 -0.5020559 0.8018747 0.90508187], dtype=float32)''' + ... + def tan(self)-> Var: + '''Document: + * + Returns the tangent of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.32893723 -0.7112559 -0.872391 1.8001337 ], dtype=float32) + >>> jt.tan(x) + jt.Var([ 0.34133783 -0.8617148 -1.1910915 -4.283673 ], dtype=float32) + >>> x.tan() + jt.Var([ 0.34133783 -0.8617148 -1.1910915 -4.283673 ], dtype=float32)''' + ... + def atan(self)-> Var: + '''Document: + * + Returns the inverse tangent of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-0.85885596 1.187804 0.47249675 0.95933187], dtype=float32) + >>> jt.atan(x) + jt.Var([-0.70961297 0.87102956 0.44140393 0.76464504], dtype=float32) + >>> x.atan() + jt.Var([-0.70961297 0.87102956 0.44140393 0.76464504], dtype=float32)''' + ... + def arctan(self)-> Var: + '''Document: + * + Returns the inverse tangent of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-0.85885596 1.187804 0.47249675 0.95933187], dtype=float32) + >>> jt.atan(x) + jt.Var([-0.70961297 0.87102956 0.44140393 0.76464504], dtype=float32) + >>> x.atan() + jt.Var([-0.70961297 0.87102956 0.44140393 0.76464504], dtype=float32)''' + ... + def tanh(self)-> Var: + '''Document: + * + Returns the hyperbolic tangent of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-0.85885596 1.187804 0.47249675 0.95933187], dtype=float32) + >>> jt.tanh(x) + jt.Var([-0.6956678 0.82989657 0.4402144 0.7439787 ], dtype=float32) + >>> x.tanh() + jt.Var([-0.6956678 0.82989657 0.4402144 0.7439787 ], dtype=float32)''' + ... + def atanh(self)-> Var: + '''Document: + * + Returns the inverse hyperbolic tangent of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) * 2 - 1 + >>> x + jt.Var([ 0.9062414 -0.799802 -0.27219176 -0.7274077 ], dtype=float32) + >>> jt.atanh(x) + jt.Var([ 1.5060828 -1.0980625 -0.27922946 -0.9231999 ], dtype=float32) + >>> x.atanh() + jt.Var([ 1.5060828 -1.0980625 -0.27922946 -0.9231999 ], dtype=float32)''' + ... + def arctanh(self)-> Var: + '''Document: + * + Returns the inverse hyperbolic tangent of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) * 2 - 1 + >>> x + jt.Var([ 0.9062414 -0.799802 -0.27219176 -0.7274077 ], dtype=float32) + >>> jt.atanh(x) + jt.Var([ 1.5060828 -1.0980625 -0.27922946 -0.9231999 ], dtype=float32) + >>> x.atanh() + jt.Var([ 1.5060828 -1.0980625 -0.27922946 -0.9231999 ], dtype=float32)''' + ... + def cos(self)-> Var: + '''Document: + * + Returns the cosine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.32893723 -0.7112559 -0.872391 1.8001337 ], dtype=float32) + >>> jt.cos(x) + jt.Var([ 0.9463862 0.7575426 0.6429972 -0.2273323], dtype=float32) + >>> x.cos() + jt.Var([ 0.9463862 0.7575426 0.6429972 -0.2273323], dtype=float32)''' + ... + def acos(self)-> Var: + '''Document: + * + Returns the inverse cosine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) * 2 - 1 + >>> x + jt.Var([ 0.5876564 0.740723 -0.667666 0.5371753], dtype=float32) + >>> jt.acos(x) + jt.Var([0.9426371 0.7366504 2.3018656 1.0037117], dtype=float32) + >>> x.acos() + jt.Var([0.9426371 0.7366504 2.3018656 1.0037117], dtype=float32)''' + ... + def arccos(self)-> Var: + '''Document: + * + Returns the inverse cosine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) * 2 - 1 + >>> x + jt.Var([ 0.5876564 0.740723 -0.667666 0.5371753], dtype=float32) + >>> jt.acos(x) + jt.Var([0.9426371 0.7366504 2.3018656 1.0037117], dtype=float32) + >>> x.acos() + jt.Var([0.9426371 0.7366504 2.3018656 1.0037117], dtype=float32)''' + ... + def cosh(self)-> Var: + '''Document: + * + Returns the hyperbolic cosine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.32893723 -0.7112559 -0.872391 1.8001337 ], dtype=float32) + >>> jt.cosh(x) + jt.Var([1.0545894 1.2637873 1.405288 3.1078668], dtype=float32) + >>> x.cosh() + jt.Var([1.0545894 1.2637873 1.405288 3.1078668], dtype=float32)''' + ... + def acosh(self)-> Var: + '''Document: + * + Returns the inverse hyperbolic cosine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) + 1 + >>> x + jt.Var([1.3609099 1.8137748 1.1146184 1.3911307], dtype=float32) + >>> jt.acosh(x) + jt.Var([0.8259237 1.2020639 0.47432774 0.8579033 ], dtype=float32) + >>> x.acosh() + jt.Var([0.8259237 1.2020639 0.47432774 0.8579033 ], dtype=float32)''' + ... + def arccosh(self)-> Var: + '''Document: + * + Returns the inverse hyperbolic cosine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) + 1 + >>> x + jt.Var([1.3609099 1.8137748 1.1146184 1.3911307], dtype=float32) + >>> jt.acosh(x) + jt.Var([0.8259237 1.2020639 0.47432774 0.8579033 ], dtype=float32) + >>> x.acosh() + jt.Var([0.8259237 1.2020639 0.47432774 0.8579033 ], dtype=float32)''' + ... + def sigmoid(self)-> Var: + '''Document: + * + Returns the sigmoid of the input ``x``. + + .. math:: + out_i = \frac{1}{1 + e^{x_i}} + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.49443012 0.4305426 -1.0364404 -1.2628382 ], dtype=float32) + >>> jt.sigmoid(x) + jt.Var([0.62114954 0.6060032 0.2618374 0.2204857 ], dtype=float32) + >>> x.sigmoid() + jt.Var([0.62114954 0.6060032 0.2618374 0.2204857 ], dtype=float32)''' + ... + def erf(self)-> Var: + '''Document: + * + Computes the error function of each element. The error function is defined as follows: + + .. math:: + erf(x) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} dt + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.49443012 0.4305426 -1.0364404 -1.2628382 ], dtype=float32) + >>> jt.erf(x) + jt.Var([ 0.51559156 0.45739546 -0.85728306 -0.9258883 ], dtype=float32) + >>> x.erf() + jt.Var([ 0.51559156 0.45739546 -0.85728306 -0.9258883 ], dtype=float32)''' + ... + def erfinv(self)-> Var: + '''Document: + * + Computes the inverse error function of each element. + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) * 2 - 1 + >>> x + jt.Var([ 0.00277209 -0.26642472 0.7869792 0.5415418 ], dtype=float32) + >>> jt.erfinv(x) + jt.Var([ 0.00245671 -0.24068035 0.8805613 0.5242405 ], dtype=float32) + >>> x.erfinv() + jt.Var([ 0.00245671 -0.24068035 0.8805613 0.5242405 ], dtype=float32)''' + ... + def transpose(self, axes: Tuple[int]=())-> Var: ... + def fuse_transpose(self, axes: Tuple[int]=())-> Var: ... + def safe_clip(self, left: float, right: float)-> Var: + '''Document: + * Safe clip value to a range, and keep + the gradient pass thought. + + * [in] x: input value + * [in] left: float64 clip min value. + * [in] right: float64 clip max value.''' + ... + def array(self)-> Var: ... + @overload + def getitem(self, slices: slice)-> Var: ... + @overload + def getitem(self, slices: slice, _: int)-> Tuple[Var]: ... + def candidate(self, fail_cond: str, dtype: str="int32")-> Var: + '''Document: + * + Candidate Operator Perform an indirect candidate filter by given a fail condition. + + x is input, y is output index, satisfy:: + + not fail_cond(y[0], y[1]) and + not fail_cond(y[0], y[2]) and not fail_cond(y[1], y[2]) and + ... + ... and not fail_cond(y[m-2], y[m-1]) + + Where m is number of selected candidates. + + Pseudo code:: + + y = [] + for i in range(n): + pass = True + for j in y: + if (@fail_cond): + pass = false + break + if (pass): + y.append(i) + return y + + * [in] x: input var for filter + + * [in] fail_cond: code for fail condition + + * [in] dtype: type of return indexes + + * [out] index: . + + Example:: + + jt.candidate(jt.random(100,2), '(@x(j,0)>@x(i,0))or(@x(j,1)>@x(i,1))') + # return y satisfy: + # x[y[0], 0] <= x[y[1], 0] and x[y[1], 0] <= x[y[2], 0] and ... and x[y[m-2], 0] <= x[y[m-1], 0] and + # x[y[0], 1] <= x[y[1], 1] and x[y[1], 1] <= x[y[2], 1] and ... and x[y[m-2], 1] <= x[y[m-1], 1]''' + ... + @overload + def code(self, outputs: List[Var], cpu_src: str="", cpu_grad_src: List[str]={}, cpu_header: str="", cuda_src: str="", cuda_grad_src: List[str]={}, cuda_header: str="")-> Tuple[Var]: + '''Document: + * + Code Operator for easily customized op. + + ---------------- + + * [in] shape: the output shape, a integer array + + * [in] dtype: the output data type + + * [in] inputs: A list of input jittor Vars + + * [in] cpu_src: cpu source code string, buildin value: + + * in{x}, in{x}_shape{y}, in{x}_stride{y}, in{x}_type, in{x}_p, @in0(...) + * out{x}, out{x}_shape{y}, out{x}_stride{y}, out{x}_type, out{x}_p, @out0(...) + * out, out_shape{y}, out_stride{y}, out_type, out_p, @out(...) + + * [in] cpu_header: cpu header code string. + + * [in] cuda_src: cuda source code string. + + * [in] cuda_header: cuda header code string. + + ---------------- + + Example-1:: + + from jittor import Function + import jittor as jt + + class Func(Function): + def execute(self, x): + self.save_vars = x + return jt.code(x.shape, x.dtype, [x], + cpu_src=""" + for (int i=0; i + @alias(a, in0) + @alias(b, out) + """, + cpu_src=""" + for (int i=0; i + using namespace std; + """, + cpu_src=""" + @alias(a, in0) + @alias(b, out0) + @alias(c, out1) + @b(0) = @c(0) = @a(0); + for (int i=0; i0) + @b(num_b++) = @a(i); + else + @c(num_c++) = @a(i); + } + b->set_shape({num_b}); + c->set_shape({num_c}); + """ + ) + assert (b.data == [5,3,1]).all() + assert (c.data == [-4,-2]).all() + + Example-5:: + + # This example shows how to customize code op + # compilation flags, such as add include search + # path, add definitions, or any command line options + + a = jt.random([10]) + b = jt.code(a.shape, a.dtype, [a], + cpu_src=""" + @out0(0) = HAHAHA; + """) + # HAHAHA is defined in flags below + # /any/include/path can be change to any path you want to include + b.compile_options = {"FLAGS: -DHAHAHA=233 -I/any/include/path ": 1} + print(b[0]) + # will output 233 + + + CUDA Example-1:: + + #This example shows how to use CUDA in code op. + import jittor as jt + from jittor import Function + jt.flags.use_cuda = 1 + + class Func(Function): + def execute(self, a, b): + self.save_vars = a, b + return jt.code(a.shape, a.dtype, [a,b], + cuda_src=""" + __global__ static void kernel1(@ARGS_DEF) { + @PRECALC + int i = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (; i>>(@ARGS); + """) + + def grad(self, grad): + a, b = self.save_vars + return jt.code([a.shape, b.shape], [a.dtype, b.dtype], [a, b, grad], + cuda_src=""" + __global__ static void kernel2(@ARGS_DEF) { + @PRECALC + int i = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (; i>>(@ARGS); + """) + + a = jt.random([100000]) + b = jt.random([100000]) + func = Func() + c = func(a,b) + print(c) + print(jt.grad(c, [a, b])) + + CUDA Example-2:: + + #This example shows how to use multi dimension data with CUDA. + import jittor as jt + from jittor import Function + jt.flags.use_cuda = 1 + + class Func(Function): + def execute(self, a, b): + self.save_vars = a, b + return jt.code(a.shape, a.dtype, [a,b], + cuda_src=""" + __global__ static void kernel1(@ARGS_DEF) { + @PRECALC + for (int i=blockIdx.x; i>>(@ARGS); + """) + + def grad(self, grad): + a, b = self.save_vars + return jt.code([a.shape, b.shape], [a.dtype, b.dtype], [a, b, grad], + cuda_src=""" + __global__ static void kernel2(@ARGS_DEF) { + @PRECALC + for (int i=blockIdx.x; i>>(@ARGS); + """) + + a = jt.random((100,100)) + b = jt.random((100,100)) + func = Func() + c = func(a,b) + print(c) + print(jt.grad(c, [a, b]))''' + ... + def copy(self)-> Var: ... + def setitem(self, slices: slice, y: Var, op: str="void")-> Var: ... + @overload + def broadcast(self, shape: Tuple[int], dims: Tuple[int]=())-> Var: + '''Document: + * + Broadcast ``x`` to a given shape. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] shape: the output shape. + + * [in] dims: specifies the new dimension in the output shape, an integer array. + + ---------------- + + Example-1:: + >>> x = jt.randint(0, 10, shape=(2, 2)) + >>> x + jt.Var([[8 1] + [7 6]], dtype=int32) + >>> jt.broadcast(x, shape=(2, 3, 2), dims=[1]) + jt.Var([[[8 1] + [8 1] + [8 1]], + [[7 6] + [7 6] + [7 6]]], dtype=int32)''' + ... + @overload + def broadcast(self, y: Var, dims: Tuple[int]=())-> Var: + '''Document: + * + Broadcast ``x`` to the same shape as ``y``. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] y: the reference jt.Var. + + * [in] dims: specifies the new dimension in the output shape, an integer array. + + ---------------- + + .. note:: + jt.broadcast_var(x, y, dims) is an alias of jt.broadcast(x, y, dims) + + Example-1:: + >>> x = jt.randint(0, 10, shape=(2, 2)) + >>> x + jt.Var([[8 1] + [7 6]], dtype=int32) + >>> y = jt.randint(0, 10, shape=(2, 3, 2)) + >>> jt.broadcast(x, y, dims=[1]) + jt.Var([[[8 1] + [8 1] + [8 1]], + [[7 6] + [7 6] + [7 6]]], dtype=int32) + >>> jt.broadcast_var(x, y, dims=[1]) + jt.Var([[[8 1] + [8 1] + [8 1]], + [[7 6] + [7 6] + [7 6]]], dtype=int32)''' + ... + def broadcast_var(self, y: Var, dims: Tuple[int]=())-> Var: + '''Document: + * + Broadcast ``x`` to the same shape as ``y``. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] y: the reference jt.Var. + + * [in] dims: specifies the new dimension in the output shape, an integer array. + + ---------------- + + .. note:: + jt.broadcast_var(x, y, dims) is an alias of jt.broadcast(x, y, dims) + + Example-1:: + >>> x = jt.randint(0, 10, shape=(2, 2)) + >>> x + jt.Var([[8 1] + [7 6]], dtype=int32) + >>> y = jt.randint(0, 10, shape=(2, 3, 2)) + >>> jt.broadcast(x, y, dims=[1]) + jt.Var([[[8 1] + [8 1] + [8 1]], + [[7 6] + [7 6] + [7 6]]], dtype=int32) + >>> jt.broadcast_var(x, y, dims=[1]) + jt.Var([[[8 1] + [8 1] + [8 1]], + [[7 6] + [7 6] + [7 6]]], dtype=int32)''' + ... + def reshape(self, shape: Tuple[int])-> Var: + '''Document: + * + Returns a tensor with the same data and number of elements as input, but with the specified shape. + + A single dimension may be -1, in which case it's inferred from the remaining dimensions and the number of elements in input. + + ---------------- + + * [in] x: the input jt.Var + + * [in] shape: the output shape, an integer array + + ---------------- + + Example-1:: + >>> a = jt.randint(0, 10, shape=(12,)) + >>> a + jt.Var([4 0 8 4 6 3 1 8 1 1 2 2], dtype=int32) + >>> jt.reshape(a, (3, 4)) + jt.Var([[4 0 8 4] + [6 3 1 8] + [1 1 2 2]], dtype=int32) + >>> jt.reshape(a, (-1, 6)) + jt.Var([[4 0 8 4 6 3] + [1 8 1 1 2 2]], dtype=int32)''' + ... + def reindex_reduce(self, op: str, shape: Tuple[int], indexes: List[str], overflow_conditions: List[str]={}, extras: List[Var]={})-> Var: + '''Document: + * + Reindex Reduce Operator is a many-to-one map operator. + It performs equivalent Python-pseudo implementation below:: + + # input is y, output is x + n = len(y.shape)-1 + m = len(shape)-1 + k = len(overflow_conditions)-1 + x = np.zeros(shape, y.dtype) + x[:] = initial_value(op) + for i0 in range(y.shape[0]): # 1-st loop + for i1 in range(y.shape[1]): # 2-nd loop + ...... # many loops + for in in range(y.shape[n]) # n+1 -th loop + # indexes[i] is a c++ style integer expression consisting of i0,i1,...,in + xi0,xi1,...,xim = indexes[0],indexes[1],...,indexes[m] + if not is_overflow(xi0,xi1,...,xim): + x[xi0,xi1,...,xim] = op(x[xi0,xi1,...,xim], y[i0,i1,...,in]) + + # is_overflow is defined as following + def is_overflow(xi0,xi1,...,xim): + return ( + xi0 < 0 || xi0 >= shape[0] || + xi1 < 0 || xi1 >= shape[1] || + ...... + xim < 0 || xim >= shape[m] || + + # overflow_conditions[i] is a c++ style boolean expression consisting of i0,i1,...,in + overflow_conditions[0] || + overflow_conditions[1] || + ...... + overflow_conditions[k] + ) + + * [in] y: A input jittor Var + + * [in] op: a string represent the reduce operation type + + * [in] shape: the output shape, a integer array + + * [in] indexes: array of c++ style integer expression, its length should be the same with length of output shape, some buildin variables it can use are:: + + XDIM, xshape0, ..., xshapem, xstride0, ..., xstridem + YDIM, yshape0, ..., yshapen, ystride0, ..., ystriden + i0, i1, ..., in + @e0(...), @e1(...) for extras input index + e0p, e1p , ... for extras input pointer + + * [in] overflow_conditions: array of c++ style boolean expression, it length can be vary. the buildin variables it can use are the same with indexes. + + * [in] extras: extra var used for index + + Example + + Pooling implemented by reindex operation:: + + def pool(x, size, op): + N,H,W,C = x.shape + h = (H+size-1)//size + w = (W+size-1)//size + return x.reindex_reduce(op, [N,h,w,C], [ + "i0", # Nid + f"i1/{size}", # Hid + f"i2/{size}", # Wid + "i3", # Cid + ])''' + ... + def sync(self, device_sync: bool=False, weak_sync: bool=True): ... + def fetch_sync(self)-> numpy.ndarray: + '''Document: + * + * Returns a numpy array copy of the Var.''' + ... + def numpy(self)-> numpy.ndarray: + '''Document: + * + * Returns a numpy array copy of the Var.''' + ... + def assign(self, v: Var)-> Var: + '''Document: + * + * assign the data from another Var.''' + ... + def update(self, v: Var)-> Var: + '''Document: + * + * update parameter and global variable, + * different from assign, it will + * stop grad between origin var and assigned var, and + * will update in the background''' + ... + def _update(self, v: Var)-> Var: + '''Document: + * + * update parameter without set attribute.''' + ... + def swap(self, v: Var)-> Var: + '''Document: + * + * swap the data with another Var.''' + ... + @overload + def name(self, s: str)-> Var: + '''Document: + * + * set the name of the Var.''' + ... + @overload + def name(self)-> str: + '''Document: + * + * set the name of the Var.''' + ... + def numel(self)-> int: + '''Document: + * + * return the number of elements in the Var.''' + ... + def stop_grad(self)-> Var: + '''Document: + * + * disable the gradient calculation for the Var.''' + ... + def is_stop_grad(self)-> bool: + '''Document: + * + * return True if the gradient is stopped.''' + ... + def detach(self)-> Var: + '''Document: + detach the grad''' + ... + def stop_fuse(self)-> Var: + '''Document: + * + * stop operator fusion.''' + ... + def is_stop_fuse(self)-> bool: + '''Document: + * + * return True if operator fusion is stopped.''' + ... + def out_hint(self)-> Var: + '''Document: + * + * output hint for training optimization''' + ... + def start_grad(self)-> Var: + '''Document: + * + * enable the gradient calculation for the Var.''' + ... + def item(self)-> float | int | bool: + '''Document: + * + * returns the Python number if the Var contains only one element. + * For other cases, see data().''' + ... + def share_with(self, other: Var)-> Var: ... + def debug_msg(self)-> str: + '''Document: + * + * print the information of the Var to debug.''' + ... + def _input(self, i: int)-> Var: ... + def _add_dependency(self, vars: List[Var])-> Var: + '''Document: + Add dependency, make var computed after vars''' + ... + def compile_options(self): ... + def data(self)-> numpy.ndarray: + '''Document: + * + * get a numpy array which shares the data with the Var.''' + ... + def dtype(self)-> str: + '''Document: + * + * return the data type of the Var.''' + ... + def grad(self)-> int: + '''Document: + Jittor Var doesn't have this interface, please change your code as below:: + + model = Model() + optimizer = SGD(model.parameters()) + ... + optimizer.backward(loss) + + for p in model.parameters(): + # prev code: + # grad = p.grad + + # change to: + grad = p.opt_grad(optimizer)''' + ... + def ndim(self)-> int: + '''Document: + * + * return the number of dimensions.''' + ... + def requires_grad(self)-> bool: + '''Document: + * + * return True if the Var requires gradient calculation. + * @see is_stop_grad''' + ... + def shape(self)-> Tuple[int]: + '''Document: + * + * return the shape of the Var.''' + ... + def uncertain_shape(self)-> Tuple[int]: ... + def view(self, x: Var, shape: Tuple[int])-> Var: + '''Document: + * + Returns a tensor with the same data and number of elements as input, but with the specified shape. + + A single dimension may be -1, in which case it's inferred from the remaining dimensions and the number of elements in input. + + ---------------- + + * [in] x: the input jt.Var + + * [in] shape: the output shape, an integer array + + ---------------- + + Example-1:: + >>> a = jt.randint(0, 10, shape=(12,)) + >>> a + jt.Var([4 0 8 4 6 3 1 8 1 1 2 2], dtype=int32) + >>> jt.reshape(a, (3, 4)) + jt.Var([[4 0 8 4] + [6 3 1 8] + [1 1 2 2]], dtype=int32) + >>> jt.reshape(a, (-1, 6)) + jt.Var([[4 0 8 4 6 3] + [1 8 1 1 2 2]], dtype=int32)''' + ... + def permute(self, x: Var, axes: Tuple[int]=())-> Var: ... + def detach_inplace(self)-> Var: + '''Document: + * + * enable the gradient calculation for the Var.''' + ... + def astype(self, x: Var, op: str)-> Var: ... + def half(self, x: Var)-> Var: + '''Document: + * + Returns a copy of the input var, casted to float16 (half-precision float). + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.half() + jt.Var([4.094 2.008 8.48 ], dtype=float16) + >>> jt.half(x) + jt.Var([4.094 2.008 8.48 ], dtype=float16) + >>> x.float16() + jt.Var([4.094 2.008 8.48 ], dtype=float16) + >>> jt.float16(x) + jt.Var([4.094 2.008 8.48 ], dtype=float16)''' + ... + def expand_as(self, x: Var, y: Var, dims: Tuple[int]=())-> Var: + '''Document: + * + Broadcast ``x`` to the same shape as ``y``. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] y: the reference jt.Var. + + * [in] dims: specifies the new dimension in the output shape, an integer array. + + ---------------- + + .. note:: + jt.broadcast_var(x, y, dims) is an alias of jt.broadcast(x, y, dims) + + Example-1:: + >>> x = jt.randint(0, 10, shape=(2, 2)) + >>> x + jt.Var([[8 1] + [7 6]], dtype=int32) + >>> y = jt.randint(0, 10, shape=(2, 3, 2)) + >>> jt.broadcast(x, y, dims=[1]) + jt.Var([[[8 1] + [8 1] + [8 1]], + [[7 6] + [7 6] + [7 6]]], dtype=int32) + >>> jt.broadcast_var(x, y, dims=[1]) + jt.Var([[[8 1] + [8 1] + [8 1]], + [[7 6] + [7 6] + [7 6]]], dtype=int32)''' + ... +class Flags: + '''A set of flags to configure jittor running behaviors''' + addr2line_path: str + '''Path of addr2line. Default: ""''' + amp_level: int + '''Auto mixed-precision optimization level, 0: not use fp16, 1-3: preserve level, not use fp16 for now; 4: perfer fp16, but some ops use fp32 e.g. sum,exp; 5: simular with 4, and array op will automatically convert to fp16; 6: all ops prefer fp16. Default: 0''' + amp_reg: int + '''Auto mixed-precision control registers, bit 0: prefer 32; bit 1: prefer 16; bit 2: keep reduce type; bit 3 keep white list type; bit 4: array like op prefer too. Default: 0''' + auto_convert_64_to_32: int + '''auto convert 64bit numpy array into 32bit jittor array. Default: 1''' + auto_mixed_precision_level: int + '''Auto mixed-precision optimization level, 0: not use fp16, 1-3: preserve level, not use fp16 for now; 4: perfer fp16, but some ops use fp32 e.g. sum,exp; 5: simular with 4, and array op will automatically convert to fp16; 6: all ops prefer fp16. Default: 0''' + cache_path: str + '''Cache path of jittor. Default: ""''' + cc_flags: str + '''Flags of C++ compiler. Default: ""''' + cc_path: str + '''Path of C++ compiler. Default: ""''' + cc_type: str + '''Type of C++ compiler(clang, icc, g++). Default: ""): Type of C++ compiler(clang, icc, g++''' + check_graph: int + '''Unify graph sanity check. Default: 0''' + compile_options: Any + '''Override the default loop transfrom options. Default: {}''' + disable_lock: bool + '''Disable file lock. Default: 0''' + enable_tuner: int + '''Enable tuner. Default: 1''' + exclude_pass: str + '''Don't run certain pass. Default: ""''' + extra_gdb_cmd: str + '''Extra command pass to GDB, seperate by(;) . Default: ""): Extra command pass to GDB, seperate by(;''' + gdb_attach: int + '''gdb attach self process. Default: 0''' + gdb_path: str + '''Path of GDB. Default: ""''' + gopt_disable: int + '''Disable graph optimizer. Default: 0''' + has_pybt: int + '''GDB has pybt or not. Default: 0''' + jit_search_kernel: int + '''Jit search for the fastest kernel. Default: 0''' + jit_search_rerun: int + '''. Default: 10''' + jit_search_warmup: int + '''. Default: 2''' + jittor_path: str + '''Source path of jittor. Default: ""''' + l1_cache_size: int + '''size of level 1 cache (byte). Default: 32768): size of level 1 cache (byte''' + lazy_execution: int + '''Default enabled, if disable, use immediately eager execution rather than lazy execution, This flag makes error message and traceback infomation better. But this flag will raise memory consumption and lower the performance. Default: 1''' + log_file: str + '''log to file, mpi env will add $OMPI_COMM_WORLD_RANK suffix. Default: ""''' + log_op_hash: str + '''Output compiler pass result of certain hash of op. Default: ""''' + log_silent: int + '''The log will be completely silent. Default: 0''' + log_sync: int + '''Set log printed synchronously. Default: 1''' + log_v: int + '''Verbose level of logging. Default: 0''' + log_vprefix: str + '''Verbose level of logging prefix. Default: ""''' + no_fuse: bool + '''No fusion optimization for all jittor Var creation. Default: 0''' + no_grad: bool + '''No grad for all jittor Var creation. Default: 0''' + node_order: int + '''id prior. Default: 0''' + nvcc_flags: str + '''Flags of CUDA C++ compiler. Default: ""''' + nvcc_path: str + '''Path of CUDA C++ compiler. Default: ""''' + para_opt_level: int + '''para_opt_level. Default: 3''' + profile_memory_enable: int + '''Enable memory profiler. Default: 0''' + profiler_enable: int + '''Enable profiler. Default: 0''' + profiler_hide_relay: int + '''Profiler hide relayed op. Default: 0''' + profiler_record_peek: int + '''Profiler record peek mem bandwidth. Default: 0''' + profiler_record_shape: int + '''Profiler record shape for op. Default: 0''' + profiler_rerun: int + '''Profiler rerun. Default: 0''' + profiler_warmup: int + '''Profiler warmup. Default: 0''' + python_path: str + '''Path of python interpreter. Default: ""''' + reuse_array: int + '''try reuse np.array memory into jt.array. Default: 0''' + rewrite_op: int + '''Rewrite source file of jit operator or not. Default: 1''' + stat_allocator_total_alloc_byte: int + '''Total alloc byte. Default: 0''' + stat_allocator_total_alloc_call: int + '''Number of alloc function call. Default: 0''' + stat_allocator_total_free_byte: int + '''Total alloc byte. Default: 0''' + stat_allocator_total_free_call: int + '''Number of alloc function call. Default: 0''' + th_mode: int + '''th mode. Default: 0''' + trace_depth: int + '''trace depth for GDB. Default: 10''' + trace_py_var: int + '''Trace py stack max depth for debug. Default: 0''' + trace_var_data: int + '''Trace py stack max depth for debug. Default: 0''' + try_use_32bit_index: int + '''If not overflow, try to use 32 bit type as index type. Default: 0''' + use_acl: int + '''Use cuda or not. 1 for trying to use cuda, 2 for forcing to use cuda. Default: 0''' + use_cuda: int + '''Use cuda or not. 1 for trying to use cuda, 2 for forcing to use cuda. Default: 0''' + use_device: int + '''Use cuda or not. 1 for trying to use cuda, 2 for forcing to use cuda. Default: 0''' + use_nfef_allocator: int + '''Enable never free exact fit allocator. Default: 0''' + use_parallel_op_compiler: int + '''Number of threads that parallel op comiler used, default 16, set this value to 0 will disable parallel op compiler. Default: 16''' + use_rocm: int + '''Use cuda or not. 1 for trying to use cuda, 2 for forcing to use cuda. Default: 0''' + use_sfrl_allocator: int + '''Enable sfrl allocator. Default: 1''' + use_stat_allocator: int + '''Enable stat allocator. Default: 0''' + use_temp_allocator: int + '''Enable temp allocator. Default: 1''' + use_tensorcore: int + '''use tensor core. Default: 0''' +flags: Flags +'''Jittor running time flags instance''' diff --git a/python/jittor/attention.py b/python/jittor/attention.py new file mode 100644 index 00000000..a8a486cb --- /dev/null +++ b/python/jittor/attention.py @@ -0,0 +1,176 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** + +import jittor as jt +from jittor import init, Module, nn +import numpy as np +import math + +class MultiheadAttention(Module): + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + q_noise=0.0, + qn_block_size=8, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + assert dropout==0, "TODO: dropout>0" + + self.head_dim = embed_dim // num_heads + assert (self.head_dim * num_heads == self.embed_dim), "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, ("Self-attention requires query, key and " "value to be of the same size") + + #TODO: quant_noise + self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias) + self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + assert not add_bias_kv, "TODO: add_bias_kv=True" + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.reset_parameters() + + self.onnx_trace = False + self.tpu = False + + def reset_parameters(self): + if self.qkv_same_dim: + # Empirically observed the convergence to be much better with + # the scaled initialization + init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) + init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) + init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) + else: + init.xavier_uniform_(self.k_proj.weight) + init.xavier_uniform_(self.v_proj.weight) + init.xavier_uniform_(self.q_proj.weight) + + # init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + init.constant_(self.out_proj.bias, 0.) + if self.bias_k is not None: + init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + init.xavier_normal_(self.bias_v) + + def execute( + self, + query, + key = None, + value = None, + key_padding_mask = None, + incremental_state = None, + need_weights = True, + static_kv = False, + attn_mask = None, + before_softmax = False, + need_head_weights = False, + ): + if need_head_weights: + need_weights = True + + tgt_len, bsz, embed_dim = query.shape + assert embed_dim == self.embed_dim + assert list(query.shape) == [tgt_len, bsz, embed_dim] + + assert incremental_state is None, "TODO: incremental_state is not None" + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + k = self.k_proj(key) + v = self.v_proj(key) + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q = q*self.scaling + + assert self.bias_k is None, "TODO: self.bias_k is not None:" + + q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2) + if k is not None: + k = k.view(-1, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2) + if v is not None: + v = v.view(-1, bsz * self.num_heads, self.head_dim).transpose(1, 0, 2) + + assert saved_state is None, "TODO: saved_state is not None" + assert k is not None + src_len = k.shape[1] + + assert key_padding_mask is None, "TODO: key_padding_mask is not None" + assert not self.add_zero_attn, "TODO: self.add_zero_attn=True" + + attn_weights = nn.bmm(q, k.transpose(0, 2, 1)) + + assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len] + + assert attn_mask is None, "TODO: attn_mask is not None" + assert key_padding_mask is None, "TODO: key_padding_mask is not None" + + if before_softmax: + return attn_weights, v + + attn_weights_float = nn.softmax(attn_weights, dim=-1) + attn_weights = attn_weights_float.type_as(attn_weights) + + assert v is not None + attn = nn.bmm(attn_weights, v) + assert list(attn.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] + if self.onnx_trace and attn.shape[1] == 1: + # when ONNX tracing a single decoder step (sequence length == 1) + # the transpose is a no-op copy before view, thus unnecessary + attn = attn.view(tgt_len, bsz, embed_dim) + else: + attn = attn.transpose(1, 0, 2).view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + attn_weights = None + if need_weights: + attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0, 2, 3) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dims=[0]) + + return attn, attn_weights diff --git a/python/jittor/ccl/__init__.py b/python/jittor/ccl/__init__.py new file mode 100644 index 00000000..23cfbaa7 --- /dev/null +++ b/python/jittor/ccl/__init__.py @@ -0,0 +1,3 @@ +from .ccl_2d import ccl_2d +from .ccl_3d import ccl_3d +from .ccl_link import ccl_link \ No newline at end of file diff --git a/python/jittor/ccl/ccl_2d.py b/python/jittor/ccl/ccl_2d.py new file mode 100644 index 00000000..9bbd5d11 --- /dev/null +++ b/python/jittor/ccl/ccl_2d.py @@ -0,0 +1,177 @@ +import jittor as jt + + +def ccl_2d(data_2d): + ''' + 2D connected component labelling, original code from https://github.com/DanielPlayne/playne-equivalence-algorithm + Args: + [in]param data_2d: binary two-dimensional vector + type data_2d: jittor array + + Returns: + [out]result: labeled two-dimensional vector + + Example: + >>> import jittor as jt + >>> jt.flags.use_cuda = 1 + >>> import cv2 + >>> import numpy as np + >>> img = cv2.imread('testImg.png', 0) + >>> a = img.mean() + >>> img[img <= a] = 0 + >>> img[img > a] = 1 + >>> img = jt.Var(img) + + >>> result = ccl_2d(img) + >>> print(jt.unique(result, return_counts=True, return_inverse=True)[0], jt.unique(result, return_counts=True, return_inverse=True)[2]) + >>> cv2.imwrite('testImg_result.png', result.numpy().astype(np.uint8) * 50) + ''' + + data_2d = data_2d.astype(jt.uint32) + cY = data_2d.shape[0] + cX = data_2d.shape[1] + data_2d_copy = data_2d.clone() + changed = jt.ones([1], dtype=jt.uint32) + data_2d = data_2d.reshape(cX * cY) + result = jt.code(data_2d.shape, + data_2d.dtype, [data_2d, changed], + cuda_header=''' + @alias(g_image, in0) + @alias(g_labels, out) + ''', + cuda_src=r''' + __global__ void init_labels(@ARGS_DEF, const int cX, const int cY) { + @PRECALC + // Calculate index + const unsigned int ix = (blockIdx.x * blockDim.x) + threadIdx.x; + const unsigned int iy = (blockIdx.y * blockDim.y) + threadIdx.y; + @g_labels(iy*cX + ix) = iy*cX + ix; + } + + __device__ __inline__ unsigned int find_root(@ARGS_DEF, unsigned int label) { + // Resolve Label + unsigned int next = @g_labels(label); + + // Follow chain + while(label != next) { + // Move to next + label = next; + next = @g_labels(label); + } + + // Return label + return label; + } + + __global__ void resolve_labels(@ARGS_DEF, const int cX, const int cY) { + @PRECALC + // Calculate index + const unsigned int id = ((blockIdx.y * blockDim.y) + threadIdx.y) * cX + + ((blockIdx.x * blockDim.x) + threadIdx.x); + + // Check Thread Range + if(id < cX*cY) { + // Resolve Label + @g_labels(id) = find_root(@ARGS, @g_labels(id)); + } + } + + __global__ void label_equivalence(@ARGS_DEF, const int cX, const int cY) { + @PRECALC + // Calculate index + const unsigned int ix = (blockIdx.x * blockDim.x) + threadIdx.x; + const unsigned int iy = (blockIdx.y * blockDim.y) + threadIdx.y; + + // Check Thread Range + if((ix < cX) && (iy < cY)) { + // Get image and label values + const unsigned char cyx = @g_image( iy*cX + ix); + + // Get neighbour labels + const unsigned int lym1x = (iy > 0) ? @g_labels((iy-1)*cX + ix) : 0; + const unsigned int lyxm1 = (ix > 0) ? @g_labels(iy *cX + ix-1) : 0; + const unsigned int lyx = @g_labels(iy *cX + ix); + const unsigned int lyxp1 = (ix < cX-1) ? @g_labels(iy *cX + ix+1) : 0; + const unsigned int lyp1x = (iy < cY-1) ? @g_labels((iy+1)*cX + ix) : 0; + + const unsigned int lym1xm1 = (iy > 0 && ix > 0 ) ? @g_labels((iy-1)*cX + ix-1) : 0; + const unsigned int lym1xp1 = (iy > 0 && ix < cX-1) ? @g_labels((iy-1)*cX + ix+1) : 0; + const unsigned int lyp1xm1 = (iy < cY-1 && ix > 0 ) ? @g_labels((iy+1)*cX + ix-1) : 0; + const unsigned int lyp1xp1 = (iy < cY-1 && ix < cX-1) ? @g_labels((iy+1)*cX + ix+1) : 0; + + const bool nym1x = (iy > 0) ? (cyx == (@g_image((iy-1)*cX + ix))) : false; + const bool nyxm1 = (ix > 0) ? (cyx == (@g_image(iy *cX + ix-1))) : false; + const bool nyxp1 = (ix < cX-1) ? (cyx == (@g_image(iy *cX + ix+1))) : false; + const bool nyp1x = (iy > cY-1) ? (cyx == (@g_image((iy+1)*cX + ix))) : false; + + const bool nym1xm1 = (iy > 0 && ix > 0 ) ? (cyx == (@g_image((iy-1)*cX + ix-1))) : false; + const bool nym1xp1 = (iy > 0 && ix < cX-1) ? (cyx == (@g_image((iy-1)*cX + ix+1))) : false; + const bool nyp1xm1 = (iy < cY-1 && ix > 0 ) ? (cyx == (@g_image((iy+1)*cX + ix-1))) : false; + const bool nyp1xp1 = (iy < cY-1 && ix < cX-1) ? (cyx == (@g_image((iy+1)*cX + ix+1))) : false; + + // Lowest label + unsigned int label = lyx; + + // Find lowest neighbouring label + label = ((nym1x) && (lym1x < label)) ? lym1x : label; + label = ((nyxm1) && (lyxm1 < label)) ? lyxm1 : label; + label = ((nyxp1) && (lyxp1 < label)) ? lyxp1 : label; + label = ((nyp1x) && (lyp1x < label)) ? lyp1x : label; + + label = ((nym1xm1) && (lym1xm1 < label)) ? lym1xm1 : label; + label = ((nym1xp1) && (lym1xp1 < label)) ? lym1xp1 : label; + label = ((nyp1xm1) && (lyp1xm1 < label)) ? lyp1xm1 : label; + label = ((nyp1xp1) && (lyp1xp1 < label)) ? lyp1xp1 : label; + + // If labels are different, resolve them + if(label < lyx) { + // Update label + // Nonatomic write may overwrite another label but on average seems to give faster results + @g_labels(lyx) = label; + + // Record the change + @in1(0) = 1; + } + } + } + ''' + f''' + dim3 block(32, 32); + const int cX= {cX}; + const int cY= {cY};''' + ''' + dim3 grid(ceil(cX/(float)block.x), ceil(cY/(float)block.y)); + dim3 resolve_block(32, 32); + dim3 resolve_grid(ceil(cX/(float)resolve_block.x), ceil(cY/(float)resolve_block.y)); + + // Initialise labels + init_labels <<< grid, block >>>(@ARGS, cX, cY); + + // Resolve the labels + resolve_labels <<< resolve_grid, resolve_block >>>(@ARGS, cX, cY); + + // Changed Flag + int32 changed = 1; + + // While labels have changed + while(changed) { + // Copy changed to device + cudaMemsetAsync(in1_p, 0, 4); + + // Label image + label_equivalence <<< grid, block >>>(@ARGS, cX, cY); + + // Copy changed flag to host + cudaMemcpy(&changed, in1_p, sizeof(int32), cudaMemcpyDeviceToHost); + + // Resolve the labels + resolve_labels <<< resolve_grid, resolve_block>>>(@ARGS, cX, cY); + } + ''') + result = result.reshape((cY, cX)) * data_2d_copy + value = jt.unique(result) + value = value[value != 0] + + map_result = jt.zeros((int(value.max().numpy()[0]) + 1), dtype=jt.uint32) + map_result[value] = jt.index(value.shape)[0] + 1 + result = map_result[result] + + return result diff --git a/python/jittor/ccl/ccl_3d.py b/python/jittor/ccl/ccl_3d.py new file mode 100644 index 00000000..e0611d98 --- /dev/null +++ b/python/jittor/ccl/ccl_3d.py @@ -0,0 +1,196 @@ +import jittor as jt + + +def ccl_3d(data_3d): + ''' + 3D connected component labelling, original code from https://github.com/DanielPlayne/playne-equivalence-algorithm + Args: + [in]param data_3d: binary three-dimensional vector + type data_3d: jittor array + + Returns: + [out]result : labeled three-dimensional vector + + Example: + >>> import jittor as jt + >>> jt.flags.use_cuda = 1 + >>> data_3d = jt.zeros((10, 11, 12), dtype=jt.uint32) + >>> data_3d[2:4, :, :] = 1 + >>> data_3d[5:7, :, :] = 1 + >>> result = ccl_3d(data_3d) + >>> print(result[:, 0, 0]) + >>> print( + jt.unique(result, return_counts=True, return_inverse=True)[0], + jt.unique(result, return_counts=True, return_inverse=True)[2]) + ''' + + data_3d = data_3d.astype(jt.uint32) + cX = data_3d.shape[0] + cY = data_3d.shape[1] + cZ = data_3d.shape[2] + changed = jt.ones([1], dtype=jt.uint32) + data_3d_copy = data_3d.copy() + data_3d = data_3d.reshape(cX * cY * cZ) + result = jt.code(data_3d.shape, + data_3d.dtype, [data_3d, changed], + cuda_header=''' + @alias(g_image, in0) + @alias(g_labels, out) + ''', + cuda_src=r''' + __global__ void init_labels(@ARGS_DEF, const int cX, const int cY, const int cZ, const int pX, const int pY) { + @PRECALC + // Calculate index + const unsigned int ix = (blockIdx.x * blockDim.x) + threadIdx.x; + const unsigned int iy = (blockIdx.y * blockDim.y) + threadIdx.y; + const unsigned int iz = (blockIdx.z * blockDim.z) + threadIdx.z; + + if((ix < cX) && (iy < cY) && (iz < cZ)) { + const unsigned char pzyx = @g_image(iz*pY + iy*pX + ix); + + // Neighbour Connections + const bool nzm1yx = (iz > 0) ? (pzyx == @g_image((iz-1)*pY + iy *pX + ix )) : false; + const bool nzym1x = (iy > 0) ? (pzyx == @g_image( iz *pY + (iy-1)*pX + ix )) : false; + const bool nzyxm1 = (ix > 0) ? (pzyx == @g_image( iz *pY + iy *pX + ix-1)) : false; + + // Label + unsigned int label; + + // Initialise Label + label = (nzyxm1) ? ( iz*pY + iy*pX + ix-1) : (iz*pY + iy*pX + ix); + label = (nzym1x) ? ( iz*pY + (iy-1)*pX + ix) : label; + label = (nzm1yx) ? ((iz-1)*pY + iy*pX + ix) : label; + // Write to Global Memory + @g_labels(iz*pY + iy*pX + ix) = label; + } + } + + __device__ __inline__ unsigned int find_root(@ARGS_DEF, unsigned int label) { + // Resolve Label + unsigned int next = @g_labels(label); + + // Follow chain + while(label != next) { + // Move to next + label = next; + next = @g_labels(label); + } + + // Return label + return label; + } + + __global__ void resolve_labels(@ARGS_DEF, const int cX, const int cY, const int cZ, const int pX, const int pY) { + @PRECALC + // Calculate index + const unsigned int id = ((blockIdx.z * blockDim.z) + threadIdx.z) * pY + + ((blockIdx.y * blockDim.y) + threadIdx.y) * pX + + ((blockIdx.x * blockDim.x) + threadIdx.x); + + // Check Thread Range + if(id < cX*cY*cZ) { + // Resolve Label + @g_labels(id) = find_root(@ARGS, @g_labels(id)); + } + } + + __global__ void label_equivalence(@ARGS_DEF, const int cX, const int cY, const int cZ, const int pX, const int pY) { + @PRECALC + // Calculate index + const unsigned int ix = (blockIdx.x * blockDim.x) + threadIdx.x; + const unsigned int iy = (blockIdx.y * blockDim.y) + threadIdx.y; + const unsigned int iz = (blockIdx.z * blockDim.z) + threadIdx.z; + + // Check Thread Range + if((ix < cX) && (iy < cY) && (iz < cZ)) { + // Get image and label values + const unsigned char pzyx = @g_image(iz*pY + iy*pX + ix); + + // Neighbouring indexes + const unsigned int xm1 = ix-1; + const unsigned int xp1 = ix+1; + const unsigned int ym1 = iy-1; + const unsigned int yp1 = iy+1; + const unsigned int zm1 = iz-1; + const unsigned int zp1 = iz+1; + + // Get neighbour labels + const unsigned int lzm1yx = (iz > 0) ? @g_labels(zm1*pY + iy*pX + ix) : 0; + const unsigned int lzym1x = (iy > 0) ? @g_labels( iz*pY + ym1*pX + ix) : 0; + const unsigned int lzyxm1 = (ix > 0) ? @g_labels( iz*pY + iy*pX + xm1) : 0; + const unsigned int lzyx = @g_labels( iz*pY + iy*pX + ix); + const unsigned int lzyxp1 = (ix < cX-1) ? @g_labels( iz*pY + iy*pX + xp1) : 0; + const unsigned int lzyp1x = (iy < cY-1) ? @g_labels( iz*pY + yp1*pX + ix) : 0; + const unsigned int lzp1yx = (iz < cZ-1) ? @g_labels(zp1*pY + iy*pX + ix) : 0; + + const bool nzm1yx = (iz > 0) ? (pzyx == @g_image(zm1*pY + iy*pX + ix)) : false; + const bool nzym1x = (iy > 0) ? (pzyx == @g_image( iz*pY + ym1*pX + ix)) : false; + const bool nzyxm1 = (ix > 0) ? (pzyx == @g_image( iz*pY + iy*pX + xm1)) : false; + const bool nzyxp1 = (ix < cX-1) ? (pzyx == @g_image( iz*pY + iy*pX + xp1)) : false; + const bool nzyp1x = (iy < cY-1) ? (pzyx == @g_image( iz*pY + yp1*pX + ix)) : false; + const bool nzp1yx = (iz < cZ-1) ? (pzyx == @g_image(zp1*pY + iy*pX + ix)) : false; + + // Lowest label + unsigned int label = lzyx; + + // Find lowest neighbouring label + label = ((nzm1yx) && (lzm1yx < label)) ? lzm1yx : label; + label = ((nzym1x) && (lzym1x < label)) ? lzym1x : label; + label = ((nzyxm1) && (lzyxm1 < label)) ? lzyxm1 : label; + label = ((nzyxp1) && (lzyxp1 < label)) ? lzyxp1 : label; + label = ((nzyp1x) && (lzyp1x < label)) ? lzyp1x : label; + label = ((nzp1yx) && (lzp1yx < label)) ? lzp1yx : label; + + // If labels are different, resolve them + if(label < lzyx) { + // Update label + // Nonatomic write may overwrite another label but on average seems to give faster results + @g_labels(lzyx) = label; + + // Record the change + @in1(0) = 1; + } + } + } + ''' + f''' + dim3 block(32, 4, 4); + const int cX= {cX}; + const int cY= {cY}; + const int cZ= {cZ}; + const int pX= cX; + const int pY= cX*cY;''' + ''' + dim3 grid(ceil(cX/(float)block.x), ceil(cY/(float)block.y), ceil(cZ/(float)block.z)); + + // Initialise labels + init_labels <<< grid, block >>>(@ARGS, cX, cY, cZ, pX, pY); + + // Resolve the labels + resolve_labels <<< grid, block >>>(@ARGS, cX, cY, cZ, pX, pY); + + // Changed Flag + int32 changed = 1; + + // While labels have changed + while(changed) { + // Copy changed to device + cudaMemsetAsync(in1_p, 0, 4); + + // Label image + label_equivalence <<< grid, block >>>(@ARGS, cX, cY, cZ, pX, pY); + + // Copy changed flag to host + cudaMemcpy(&changed, in1_p, sizeof(int32), cudaMemcpyDeviceToHost); + + // Resolve the labels + resolve_labels <<< grid, block>>>(@ARGS, cX, cY, cZ, pX, pY); + } + ''') + result = result.reshape((cX, cY, cZ)) * data_3d_copy + value = jt.unique(result) + value = value[value != 0] + + map_result = jt.zeros((int(value.max().numpy()[0]) + 1), dtype=jt.uint32) + map_result[value] = jt.index(value.shape)[0] + 1 + result = map_result[result] + + return result diff --git a/python/jittor/ccl/ccl_link.py b/python/jittor/ccl/ccl_link.py new file mode 100644 index 00000000..0e18cbc8 --- /dev/null +++ b/python/jittor/ccl/ccl_link.py @@ -0,0 +1,195 @@ +import jittor as jt + + +def ccl_link(score_map, link_map, result_comp_area_thresh=6): + """ + Find components in score map and link them with link map, original code from https://github.com/DanielPlayne/playne-equivalence-algorithm. + Args: + [in]param score_map: binary two-dimensional vector + type score_map: jittor array + [in]param link_map: two-dimensional vector with 8 channels + type link_map: jittor array + [in]param result_comp_area_thresh: threshold of component area + type result_comp_area_thresh: int + Returns: + [out]result: labeled two-dimensional vector + Example: + >>> import jittor as jt + >>> jt.flags.use_cuda = 1 + >>> import cv2 + >>> import numpy as np + >>> score_map = jt.Var(np.load("score_map.npy")) + >>> link_map = jt.Var(np.load("link_map.npy")) + >>> score_map = score_map >= 0.5 + >>> link_map = link_map >= 0.8 + >>> for i in range(8): + >>> link_map[:, :, i] = link_map[:, :, i] & score_map + + >>> result = ccl_link(score_map, link_map) + >>> cv2.imwrite('pixellink.png', result.numpy().astype(np.uint8) * 50) + """ + score_map = score_map.astype(jt.uint32) + link_map = link_map.astype(jt.uint32) + cY = score_map.shape[0] + cX = score_map.shape[1] + changed = jt.ones([1], dtype=jt.uint32) + score_map = score_map.reshape(cX * cY) + result = jt.code(score_map.shape, + score_map.dtype, [score_map, link_map, changed], + cuda_header=''' + @alias(score_map, in0) + @alias(link_map, in1) + @alias(g_labels, out) + ''', + cuda_src=r''' + __global__ void init_labels(@ARGS_DEF, const int cX, const int cY) { + @PRECALC + // Calculate index + const unsigned int ix = (blockIdx.x * blockDim.x) + threadIdx.x; + const unsigned int iy = (blockIdx.y * blockDim.y) + threadIdx.y; + @g_labels(iy*cX + ix) = iy*cX + ix; + } + + __device__ __inline__ unsigned int find_root(@ARGS_DEF, unsigned int label) { + // Resolve Label + unsigned int next = @g_labels(label); + + // Follow chain + while(label != next) { + // Move to next + label = next; + next = @g_labels(label); + } + + // Return label + return label; + } + + __global__ void resolve_labels(@ARGS_DEF, const int cX, const int cY) { + @PRECALC + // Calculate index + const unsigned int id = ((blockIdx.y * blockDim.y) + threadIdx.y) * cX + + ((blockIdx.x * blockDim.x) + threadIdx.x); + + // Check Thread Range + if(id < cX*cY) { + // Resolve Label + @g_labels(id) = find_root(@ARGS, @g_labels(id)); + } + } + + __global__ void label_equivalence(@ARGS_DEF, const int cX, const int cY) { + @PRECALC + // Calculate index + const unsigned int ix = (blockIdx.x * blockDim.x) + threadIdx.x; + const unsigned int iy = (blockIdx.y * blockDim.y) + threadIdx.y; + + // Check Thread Range + if((ix < cX) && (iy < cY)) { + // Get image and label values + const unsigned char cyx = @score_map( iy*cX + ix); + + // Get neighbour labels + const unsigned int lym1x = (iy > 0) ? @g_labels((iy-1)*cX + ix) : 0; + const unsigned int lyxm1 = (ix > 0) ? @g_labels(iy *cX + ix-1) : 0; + const unsigned int lyx = @g_labels(iy *cX + ix); + const unsigned int lyxp1 = (ix < cX-1) ? @g_labels(iy *cX + ix+1) : 0; + const unsigned int lyp1x = (iy < cY-1) ? @g_labels((iy+1)*cX + ix) : 0; + + const unsigned int lym1xm1 = (iy > 0 && ix > 0 ) ? @g_labels((iy-1)*cX + ix-1) : 0; + const unsigned int lym1xp1 = (iy > 0 && ix < cX-1) ? @g_labels((iy-1)*cX + ix+1) : 0; + const unsigned int lyp1xm1 = (iy < cY-1 && ix > 0 ) ? @g_labels((iy+1)*cX + ix-1) : 0; + const unsigned int lyp1xp1 = (iy < cY-1 && ix < cX-1) ? @g_labels((iy+1)*cX + ix+1) : 0; + bool nym1x, nyxm1, nyxp1, nyp1x, nym1xm1, nym1xp1, nyp1xm1, nyp1xp1; + if(cyx) { + nym1x = (iy > 0) ? ((cyx == (@score_map((iy-1)*cX + ix))) && (@link_map(iy, ix, 6) || @link_map(iy-1, ix, 7))) : false; // up + nyxm1 = (ix > 0) ? ((cyx == (@score_map(iy *cX + ix-1))) && (@link_map(iy, ix, 0) || @link_map(iy-1, ix-1, 3))) : false; // left + nyxp1 = (ix < cX-1) ? ((cyx == (@score_map(iy *cX + ix+1))) && (@link_map(iy, ix, 3) || @link_map(iy, ix+1, 0))) : false; // right + nyp1x = (iy > cY-1) ? ((cyx == (@score_map((iy+1)*cX + ix))) && (@link_map(iy, ix, 7) || @link_map(iy+1, ix, 6))) : false; // down + + nym1xm1 = (iy > 0 && ix > 0 ) ? ((cyx == (@score_map((iy-1)*cX + ix-1))) && (@link_map(iy, ix, 2) || @link_map(iy-1, ix-1, 4))) : false; // up-left + nym1xp1 = (iy > 0 && ix < cX-1) ? ((cyx == (@score_map((iy-1)*cX + ix+1))) && (@link_map(iy, ix, 5) || @link_map(iy-1, ix+1, 1))) : false; // up-right + nyp1xm1 = (iy < cY-1 && ix > 0 ) ? ((cyx == (@score_map((iy+1)*cX + ix-1))) && (@link_map(iy, ix, 1) || @link_map(iy+1, ix-1, 5))) : false; // down-left + nyp1xp1 = (iy < cY-1 && ix < cX-1) ? ((cyx == (@score_map((iy+1)*cX + ix+1))) && (@link_map(iy, ix, 4) || @link_map(iy+1, ix+1, 2))) : false; // down-right + } + else { + nym1x = (iy > 0) ? (cyx == (@score_map((iy-1)*cX + ix))) : false; // up + nyxm1 = (ix > 0) ? (cyx == (@score_map(iy *cX + ix-1))) : false; // left + nyxp1 = (ix < cX-1) ? (cyx == (@score_map(iy *cX + ix+1))) : false; // right + nyp1x = (iy > cY-1) ? (cyx == (@score_map((iy+1)*cX + ix))) : false; // down + + nym1xm1 = (iy > 0 && ix > 0 ) ? (cyx == (@score_map((iy-1)*cX + ix-1))) : false; // up-left + nym1xp1 = (iy > 0 && ix < cX-1) ? (cyx == (@score_map((iy-1)*cX + ix+1))) : false; // up-right + nyp1xm1 = (iy < cY-1 && ix > 0 ) ? (cyx == (@score_map((iy+1)*cX + ix-1))) : false; // down-left + nyp1xp1 = (iy < cY-1 && ix < cX-1) ? (cyx == (@score_map((iy+1)*cX + ix+1))) : false; // down-right + } + + // Lowest label + unsigned int label = lyx; + + // Find lowest neighbouring label + label = ((nym1x) && (lym1x < label)) ? lym1x : label; + label = ((nyxm1) && (lyxm1 < label)) ? lyxm1 : label; + label = ((nyxp1) && (lyxp1 < label)) ? lyxp1 : label; + label = ((nyp1x) && (lyp1x < label)) ? lyp1x : label; + + label = ((nym1xm1) && (lym1xm1 < label)) ? lym1xm1 : label; + label = ((nym1xp1) && (lym1xp1 < label)) ? lym1xp1 : label; + label = ((nyp1xm1) && (lyp1xm1 < label)) ? lyp1xm1 : label; + label = ((nyp1xp1) && (lyp1xp1 < label)) ? lyp1xp1 : label; + + // If labels are different, resolve them + if(label < lyx) { + // Update label + // Nonatomic write may overwrite another label but on average seems to give faster results + @g_labels(lyx) = label; + + // Record the change + @in2(0) = 1; + } + } + } + ''' + f''' + dim3 block(32, 32); + const int cX= {cX}; + const int cY= {cY};''' + ''' + dim3 grid(ceil(cX/(float)block.x), ceil(cY/(float)block.y)); + dim3 resolve_block(32, 32); + dim3 resolve_grid(ceil(cX/(float)resolve_block.x), ceil(cY/(float)resolve_block.y)); + + // Initialise labels + init_labels <<< grid, block >>>(@ARGS, cX, cY); + + // Resolve the labels + resolve_labels <<< resolve_grid, resolve_block >>>(@ARGS, cX, cY); + + // Changed Flag + int32 changed = 1; + + // While labels have changed + while(changed) { + // Copy changed to device + cudaMemsetAsync(in2_p, 0, 4); + + // Label image + label_equivalence <<< grid, block >>>(@ARGS, cX, cY); + + // Copy changed flag to host + cudaMemcpy(&changed, in2_p, sizeof(int32), cudaMemcpyDeviceToHost); + + // Resolve the labels + resolve_labels <<< resolve_grid, resolve_block >>>(@ARGS, cX, cY); + } + ''') + + result = result.reshape((cY, cX)) + + value, _, cnt = jt.unique(result, return_inverse=True, return_counts=True) + value = (cnt > result_comp_area_thresh) * value + value = value[value != 0] + + map_result = jt.zeros((int(value.max().numpy()[0]) + 1), dtype=jt.uint32) + map_result[value] = jt.index(value.shape)[0] + 1 + result = map_result[result] + + return result diff --git a/python/jittor/compatibility/__init__.py b/python/jittor/compatibility/__init__.py new file mode 100644 index 00000000..94d2e40b --- /dev/null +++ b/python/jittor/compatibility/__init__.py @@ -0,0 +1,430 @@ +# import os +# os.environ["FIX_TORCH_ERROR"] = "0" + +# import jittor as jt +# from jittor import * +# from typing import Tuple + +# org_int = int = type(1) +# org_float = float = type(1.0) +# org_bool = bool = type(True) + +# import jtorch.compiler + +# import jtorch_core +# from jtorch_core import * + +# device.__reduce__ = lambda self: (device, (self.type,)) +# device.__module__ = "jtorch" +# jt.jittor_core.device = device + +# def handle_dtype(args, kw, dtype): +# def convert(x): +# if isinstance(x, jt.Var): +# return x.cast(dtype) +# return x +# if dtype is not None: +# if args is not None: +# if isinstance(args, (tuple,list)): +# args = [ convert(a) for a in args ] +# else: +# args = convert(x) +# if kw is not None: +# kw = { k:convert(v) for k,v in kw.items() } +# return args, kw + +# def get_args_names(func): +# import inspect +# spec = inspect.getfullargspec(func) +# return spec[0] + spec[4] + +# def wrapper(func): +# has_dtype = False +# if hasattr(func, "__code__"): +# has_dtype = "dtype" in get_args_names(func) +# def inner(*args, **kw): +# requires_grad = None +# dtype = None +# if "requires_grad" in kw: +# requires_grad = kw["requires_grad"] +# del kw["requires_grad"] +# if not has_dtype and "dtype" in kw: +# dtype = kw["dtype"] +# del kw["dtype"] +# if "device" in kw: +# del kw["device"] +# if 'pin_memory' in kw: +# del kw['pin_memory'] +# args, kw = handle_dtype(args, kw, dtype) +# ret = func(*args, **kw) +# if isinstance(ret, jt.Var): +# if requires_grad is not None: +# ret.requires_grad = requires_grad +# if dtype is not None: +# ret.astype(dtype) +# return ret +# return inner + + +# import inspect +# _wrapper_keys = set(["shape", "start", "size"]) +# _wrapper_keys.add("x") +# for k,v in list(globals().items()): +# if callable(v) and not isinstance(v, type): +# try: +# spec = inspect.getfullargspec(v) +# args_name = spec[0] +# if len(args_name) and args_name[0] in _wrapper_keys: +# globals()[k] = wrapper(v) +# elif spec.varargs in _wrapper_keys: +# globals()[k] = wrapper(v) +# except: +# pass + +# def empty(*size, dtype=jt.float32, device=None, requires_grad=False): +# if len(size) == 1 and not isinstance(size[0], org_int): +# size = size[0] +# return jt.empty(size, dtype) + +# Tensor = Var + +# Tensor.backward = lambda x: jtorch_core.backward(x) +# Tensor.grad = property(grad_get, grad_set, grad_del) +# Tensor.retains_grad = property(retain_grad_get, retain_grad_set) +# def retain_grad(x:Tensor, value:bool=True): +# x.retains_grad = value +# return value +# Tensor.retain_grad = retain_grad + +# Tensor.dim = lambda self: self.ndim +# Tensor.ndimension = lambda self: self.ndim +# Tensor.nelement = lambda self: self.numel() +# Tensor.cuda = lambda self: self +# def device_get(x:Tensor): +# return device("cpu") if not jt.has_cuda or not jt.flags.use_cuda else device("cuda") +# Tensor.device = property(device_get) + +# def argmax(x: Var, dim=None, keepdim: bool = False): +# return jt.argmax(x, dim, keepdim)[0] +# Tensor.argmax = argmax + +# def tensor_type(x: Var, dtype=None, **kwargs): +# if dtype: +# return x.astype(dtype) +# else: +# return x.dtype +# Tensor.type = tensor_type + +# def is_floating_point(x: Var): +# return "float" in str(x.dtype) +# Tensor.is_floating_point = is_floating_point + +# from . import autograd +# from .autograd import * + +# def tensor(data, *, dtype=None, device=None, requires_grad=False, pin_memory=False): +# if isinstance(data,list): +# data_list = [] +# check = True +# for p in data: +# if isinstance(p, Tensor) and p.numel()==1: +# data_list.append(p.item()) +# elif isinstance(p, (org_int,org_float)): +# data_list.append(p) +# else: +# check = False +# break +# if check: +# data = data_list +# return wrapper(array)(data, dtype=dtype, device=device, requires_grad=requires_grad, pin_memory=pin_memory) + +# # tensor = wrapper(array) +# from_numpy = wrapper(array) +# strided = None + +# def mod_zero_grad(self): +# for p in self.parameters(): +# p.grad = None +# Module.zero_grad = mod_zero_grad + +# class ModuleMisc: +# def parameters(self): +# return iter(super().parameters()) + +# def load_state_dict(self, state_dict, strict=False): +# return super().load_state_dict(state_dict) + +# def to(self, device=None,dtype=None): +# ''' do nothing but return its self''' +# return self +# def register_parameter(self,name,data): +# self.name = data + +# def buffers(self): +# for _, buf in self.named_buffers(): +# yield buf + + +# def make_module(cls): +# class TMod(ModuleMisc, cls): +# def __init__(self, *args, **kw): +# dtype = None +# if "dtype" in kw: +# dtype = kw["dtype"] +# del kw["dtype"] +# self._dtype = dtype +# with jt.flag_scope(th_mode=0): +# if "device" in kw: +# del kw["device"] +# super().__init__(*args, **kw) +# for k,v in self.__dict__.items(): +# if not k.startswith("_") and isinstance(v, Var) \ +# and v.requires_grad: +# v.retain_grad() +# if dtype is not None and isinstance(v, Var): +# v.assign(v.cast(dtype)) +# def __call__(self, *args, **kw): +# args, kw = handle_dtype(args, kw, self._dtype) +# # if forward is override by user, call forward +# if self.__class__.forward is not TMod.forward: +# return self.forward(*args, **kw) +# return self.execute(*args, **kw) +# def forward(self, *args, **kw): +# args, kw = handle_dtype(args, kw, self._dtype) +# return self.execute(*args, **kw) + +# @property +# def training(self): +# if not hasattr(self, "is_train"): +# self.is_train = True +# return self.is_train +# @training.setter +# def training(self, value): +# self.is_train = value + +# TMod.__name__ = cls.__name__ +# return TMod + +# import jtorch.cuda +# import jtorch.nn +# from jtorch.nn import Module, Parameter +# import jtorch.optim + +# from jtorch.utils.dtype import Dtype, get_string_dtype + +# def frombuffer(buffer: bytearray, +# *, +# dtype: Dtype, +# count: int = -1, +# offset: int = 0, +# requires_grad: bool = True) -> Tensor: +# dtype = get_string_dtype(dtype) +# tensor = jt.array(np.frombuffer(buffer, dtype, count=count, offset=offset)) +# if requires_grad and tensor.dtype.is_float(): +# tensor.requires_grad = True +# return tensor + +# def conflict_wrapper(origin_func, new_func): +# def wrapper(*args, **kw): +# if jt.flags.th_mode: +# return new_func(*args, **kw) +# else: +# return origin_func(*args, **kw) +# return wrapper + +# def min(*args, **kw): +# dim = None +# if len(args) >= 2 and isinstance(args[1], org_int): +# dim = args[1] +# elif "dim" in kw and isinstance(kw["dim"], org_int): +# dim = kw["dim"] +# if dim is not None: +# k, v = jt.argmin(*args, **kw) +# return v, k +# elif len(args) == 2 and isinstance(args[1], jt.Var): +# return jt.minimum(args[0], args[1]) +# else: +# return jt.min(*args, **kw) +# Tensor.min = conflict_wrapper(jt.min, min) + +# def max(*args, **kw): +# dim = None +# if "dim" in kw: +# x = kw["dim"] +# if len(args) >= 2 and isinstance(args[1], org_int): +# dim = args[1] +# elif "dim" in kw and isinstance(kw["dim"], org_int): +# dim = kw["dim"] +# if dim is not None: +# k, v = jt.argmax(*args, **kw) +# return v, k +# elif len(args) == 2 and isinstance(args[1], jt.Var): +# return jt.maximum(args[0], args[1]) +# else: +# return jt.max(*args, **kw) +# Tensor.max = conflict_wrapper(jt.max, max) + +# def argsort(*args, **kw): +# k, v = jt.argsort(*args, **kw) +# return k +# Tensor.argsort = conflict_wrapper(jt.argsort, argsort) + +# LongTensor = jt.int64 +# FloatTensor = jt.float +# HalfTensor = jt.float16 +# BoolTensor = jt.bool +# IntTensor = jt.int32 + +# class JDType: +# def __init__(self, func, str): +# self.func = func +# self.str = str +# self.__name__ = str.split(".")[-1] +# def __call__(self, *args, **kw): +# return self.func(*args, **kw) +# def __str__(self): +# return self.str +# def is_floating_point(self): +# return "float" in str(self.str) + +# int8 = JDType(jt.int8, "torch.int8") +# int16 = JDType(jt.int16, "torch.int16") +# int = int32 = JDType(jt.int32, "torch.int32") +# long = int64 = JDType(jt.int64, "torch.int64") + +# half = float16 = JDType(jt.float16, "torch.float16") +# float = float32 = JDType(jt.float32, "torch.float32") +# double = float64 = JDType(jt.float64, "torch.float64") +# bfloat16 = "bfloat16" # TODO +# complex64 = "complex64" # TODO +# complex128 = "complex128" # TODO +# def get_JDtype(dtype): +# if dtype=='float32' or dtype == jt.float32: +# return float32 +# elif dtype=='float64' or dtype == jt.float64: +# return float64 +# elif dtype=='float16' or dtype == jt.float16: +# return float16 +# elif dtype=='int32' or dtype == jt.int32: +# return int32 +# elif dtype=='int64' or dtype == jt.int64: +# return int64 +# elif dtype=='int16' or dtype == jt.int16: +# return int16 +# elif dtype=='int8' or dtype == jt.int8: +# return int8 +# else: +# raise Exception("dtype {} not supported".format(dtype)) + +# def load(path,**kwargs): +# def _to_jittor(data): +# if isinstance(data,dict): +# return {k:_to_jittor(d) for k,d in data.items()} +# if isinstance(data,list): +# return [_to_jittor(d) for d in data] +# if isinstance(data,np.ndarray): +# return jt.array(data) +# return data +# data = jt.load(path) + +# return _to_jittor(data) + +# def is_tensor(x): +# return isinstance(x, Tensor) + +# manual_seed = jt.set_global_seed +# jt.flags.amp_level = 3 +# Size = jt.NanoVector + +# class Generator: +# def __init__(self,*args,**kw) -> None: +# self.seed = None +# def manual_seed(self,seed): +# self.seed = seed + + + +# from . import fx + + +# _default_type = "float32" + +# def get_default_dtype(): +# return _default_type +# def set_default_dtype(dtype): +# global _default_type +# _default_type = dtype + +# dtype = JDType + +# def div(x,y,rounding_mode="floor"): +# assert rounding_mode == "floor" +# z = (x / y) +# if rounding_mode == "floor": +# z = z.floor() +# if x.dtype == "int32" and (isinstance(y,org_int) or y.dtype == "int32"): +# z = z.int32() +# return z + + +# def randn(*args,**kw): +# wrap_randn = wrapper(jt.randn) +# generator = kw.get('generator',None) +# kw.pop('generator',None) +# if 'layout' in kw: +# del kw['layout'] +# if generator is not None and generator.seed is not None: +# jt.set_global_seed(generator.seed) +# return wrap_randn(*args,**kw) + +# def rand(*args,**kw): +# print("rand") +# wrap_rand = wrapper(jt.rand) +# generator = kw.get('generator',None) +# kw.pop('generator',None) +# if 'layout' in kw: +# del kw['layout'] +# if generator is not None and generator.seed is not None: +# jt.set_global_seed(generator.seed) +# return wrap_rand(*args,**kw) + + + +# def set_default_tensor_type(t: type or str): +# if isinstance(t, str): +# info = t.split(".") +# if len(info) == 3 and info[1] == 'cuda': +# jt.flags.use_cuda = 1 +# #TODO: type + + +# def clamp(x, min=None, max=None): +# return jt.clamp(x, min, max) + + +# def to(x,*args,**kw): +# device = None +# if len(args) == 1: +# device = args[0] +# if isinstance(device, jt.NanoString) or callable(device): +# return jt.to(x,*args,**kw) +# if 'cpu' in str(device): +# args = [] +# device = kw.get("device",None) +# if 'cpu' in str(device): +# kw.pop('device',None) +# print("to cpu") +# # print(kw) +# return jt.to(x,*args,**kw) +# Tensor.to = conflict_wrapper(jt.to, to) + +# mm = wrapper(jt.matmul) + +# def _data_get(x): +# return x + +# def _data_set(x, value): +# x.assign(value) + +# Tensor.data = property(_data_get, _data_set) +# Tensor.layout = None \ No newline at end of file diff --git a/python/jittor/compatibility/autograd.py b/python/jittor/compatibility/autograd.py new file mode 100644 index 00000000..5ed88dde --- /dev/null +++ b/python/jittor/compatibility/autograd.py @@ -0,0 +1,134 @@ +import jittor as jt +from jittor import Var +from collections.abc import Sequence, Mapping + +Variable = Var + +class FunctionContext: + def save_for_backward(self, *args): + self.saved_tensors = args + +class Function: + ''' Function Module for customized backward operations + +Example 1 (Function can have multiple input and multiple output, and user +can store value for backward computation):: + + import jtorch + from jtorch import Function + + class MyFunc(Function): + @staticmethod + def forward(self, x, y): + self.x = x + self.y = y + return x*y, x/y + + @staticmethod + def backward(self, grad0, grad1): + return grad0 * self.y, grad1 * self.x + + a = jtorch.array(3.0) + a.requires_grad = True + b = jtorch.array(4.0) + b.requires_grad = True + func = MyFunc.apply + c,d = func(a, b) + (c+d*3).backward() + assert a.grad.data == 4 + assert b.grad.data == 9 + +Example 2(Function can return None for no gradiant, and gradiant +can also be None):: + + import jtorch + from jtorch import Function + + class MyFunc(Function): + @staticmethod + def forward(self, x, y): + self.x = x + self.y = y + return x*y, x/y + + @staticmethod + def backward(self, grad0, grad1): + assert grad1 is None + return grad0 * self.y, None + a = jt.array(3.0) + a.requires_grad = True + b = jt.array(4.0) + b.requires_grad = True + func = MyFunc.apply + c,d = func(a, b) + d.stop_grad() + da, db = jt.grad(c+d*3, [a, b]) + assert da.data == 4 + assert db.data == 0 + + ''' + def __call__(self, *args): + backup = args + args = list(args) + taped_inputs = [] + taped_outputs = [] + input_mask = [-1] * len(args) + for i,v in enumerate(args): + if isinstance(v, Var): + if v.is_stop_grad(): + # -2 in input_mask represents it is stop_grad + input_mask[i] = -2 + continue + v = v.tape() + input_mask[i] = len(taped_inputs) + args[i] = v + taped_inputs.append(v) + ctx = FunctionContext() + ori_res = self.forward(ctx, *args) + # ori_res = self.execute(*args) + if not isinstance(ori_res, Sequence): + res = [ori_res] + else: + res = list(ori_res) + output_mask = [-1] * len(res) + for i,v in enumerate(res): + if isinstance(v, Var): + v = v.tape() + output_mask[i] = len(taped_outputs) + res[i] = v + taped_outputs.append(v) + ctx.input_mask = input_mask + ctx.output_mask = output_mask + # tape output and input together so + # backward treat them as one operator + jt.tape_together(taped_inputs, taped_outputs, + lambda *args: self._grad(ctx, self, *args)) + if isinstance(ori_res, Sequence): + return res + else: + return res[0] + + @staticmethod + def _grad(ctx, func, *args): + new_args = ( (args[i] if i>=0 else None) for i in ctx.output_mask ) + ret = func.backward(ctx, *new_args) + if not isinstance(ret, Sequence): + ret = (ret,) + new_ret = [] + for i, r in enumerate(ret): + j = ctx.input_mask[i] + if j<0: + # -2 in input_mask represents it is stop_grad + assert r is None or j==-2, f"{type(self)}'s {i}-th returned grad should be None, "\ + "because the input value is not jittor variable." + else: + new_ret.append(r) + return new_ret + + def dfs(self, parents, k, callback, callback_leave=None): + pass + + @classmethod + def apply(cls, *args, **kw): + func = cls() + return func(*args, **kw) diff --git a/python/jittor/compatibility/compiler.py b/python/jittor/compatibility/compiler.py new file mode 100644 index 00000000..77bab138 --- /dev/null +++ b/python/jittor/compatibility/compiler.py @@ -0,0 +1,39 @@ +import jittor as jt +import jittor_utils +import glob +import os +from jittor import pyjt_compiler +import sys +from jittor_utils import lock + + +jtorch_path = os.path.dirname(__file__) +cache_path = os.path.join(jt.compiler.cache_path, "jtorch") +# os.makedirs(cache_path, exist_ok=True) +os.makedirs(os.path.join(cache_path, "gen"), exist_ok=True) + +with lock.lock_scope(): + pyjt_gen_src = pyjt_compiler.compile(cache_path, jtorch_path) + +ext_args = 'c[cu]' if jt.has_cuda else 'cc' +files = glob.glob(jtorch_path+"/src/**/*."+ext_args, recursive=True) +files += pyjt_gen_src +cc_flags = " -I\""+os.path.join(jtorch_path, "src")+"\" " +if os.environ.get("use_data_o", "1") == "1": + files += glob.glob(jtorch_path+"/src/**/*.o", recursive=True) + files = [f for f in files if "__data__" not in f] + + +with lock.lock_scope(): + jt.compiler.compile( + jt.compiler.cc_path, + jt.compiler.cc_flags+jt.compiler.opt_flags+ cc_flags, + files, + "jtorch_core"+jt.compiler.extension_suffix, + obj_dirname="jtorch_objs") + + +with jittor_utils.import_scope(jt.compiler.import_flags): + import jtorch_core as core + +jt.flags.th_mode = 1 diff --git a/python/jittor/compatibility/cuda.py b/python/jittor/compatibility/cuda.py new file mode 100644 index 00000000..75665c7c --- /dev/null +++ b/python/jittor/compatibility/cuda.py @@ -0,0 +1,64 @@ +import jittor as jt +import jtorch + +def is_available(): + return jt.has_cuda + +def device_count(): + return int(jt.has_cuda) + +def set_device(device=None): + pass + +def get_rng_state(device=None): + pass + +def current_device(): + return jtorch.device("cuda") + +def mem_get_info(i): + return ("75GB",) + + +class Generator: + def __init__(self): + pass + + def set_state(self, state): + self.state = state + +default_generators = [Generator()] +_lazy_call = lambda func: func() +device = None + +LongTensor = jt.int64 +FloatTensor = jt.float +HalfTensor = jt.float16 +BoolTensor = jt.bool + +manual_seed = jt.set_global_seed +manual_seed_all = jt.set_global_seed + +def synchronize(): + jt.sync_all(True) + +class Event: + pass + +class Stream: + pass + +from typing import Any + +from .gradscaler import GradScaler + +class autocast: + def __init__(self,**kwargs): + pass + + def __enter__(self,): + pass + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): + pass + diff --git a/python/jittor/compatibility/distributed.py b/python/jittor/compatibility/distributed.py new file mode 100644 index 00000000..e39f559a --- /dev/null +++ b/python/jittor/compatibility/distributed.py @@ -0,0 +1,53 @@ +import datetime +from enum import Enum +import jittor as jt + + +class DistributedDataParallel: + def __new__(cls, model): + return model + +def is_initialized(): + return True + +def get_rank(group=None): + return 0 + +def get_world_size(group=None): + return 1 + +def get_backend(group=None): + return "nccl" + +def new_group(ranks=None, timeout=datetime.timedelta(seconds=1800), backend=None, pg_options=None): + return 1 + +def barrier(): + pass + +def is_available(): + return True + +def is_built(): + return True + +class ReduceOp: + SUM = 0 + +class GroupMember: + WORLD = 0 + +class ProcessGroup: + pass + +class Join: + pass + +dist_backend = Enum("dist_backend", ("GLOO", "MPI", "NCCL")) +_backend = dist_backend.NCCL + +def is_mpi_available(): + return jt.in_mpi + +def DistributedDataParallel(model, *args, **kw): + return model diff --git a/python/jittor/compatibility/distributions.py b/python/jittor/compatibility/distributions.py new file mode 100644 index 00000000..a98dfe29 --- /dev/null +++ b/python/jittor/compatibility/distributions.py @@ -0,0 +1,15 @@ +import jittor as jt + +class RelaxedBernoulli: + def __init__(self, temperature, probs=None, logits=None): + self.temperature = temperature + self.probs = probs + self.logits = logits + + def rsample(self): + noise = jt.rand_like(self.logits) + eps = 1e-20 + noise = jt.clamp(noise, eps, 1.0 - eps) + logit_noise = jt.log(noise) - jt.log(1 - noise) + sample = (self.logits + logit_noise) / self.temperature + return jt.sigmoid(sample) diff --git a/python/jittor/compatibility/fft/__init__.py b/python/jittor/compatibility/fft/__init__.py new file mode 100644 index 00000000..7a89fc9c --- /dev/null +++ b/python/jittor/compatibility/fft/__init__.py @@ -0,0 +1,5 @@ +#TODO: Implement FFT and IFFT +fftn = None +fftshift = None +ifftn = None +ifftshift = None \ No newline at end of file diff --git a/python/jittor/compatibility/fx.py b/python/jittor/compatibility/fx.py new file mode 100644 index 00000000..0f0eb4f8 --- /dev/null +++ b/python/jittor/compatibility/fx.py @@ -0,0 +1,2 @@ +class Proxy: + pass \ No newline at end of file diff --git a/python/jittor/compatibility/gradscaler.py b/python/jittor/compatibility/gradscaler.py new file mode 100644 index 00000000..087d6bb2 --- /dev/null +++ b/python/jittor/compatibility/gradscaler.py @@ -0,0 +1,519 @@ +from collections import defaultdict, abc +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, cast +import inspect +import warnings + +import jittor as jt +# import torch + +def _refresh_per_optimizer_state(): + return {} + + +class GradScaler: + _scale: Optional[jt.Var] + _grows_tracker: Optional[jt.Var] + _per_optimizer_states: Dict[int, Dict[str, Any]] + """ + An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling + conveniently. + + * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. + * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. + * ``scaler.update()`` updates ``scaler``'s scale factor. + + Example:: + + # Creates a GradScaler once at the beginning of training. + scaler = GradScaler() + + for epoch in epochs: + for input, target in data: + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + + # Scales loss. Calls backward() on scaled loss to create scaled gradients. + scaler.scale(loss).backward() + + # scaler.step() first unscales gradients of the optimizer's params. + # If gradients don't contain infs/NaNs, optimizer.step() is then called, + # otherwise, optimizer.step() is skipped. + scaler.step(optimizer) + + # Updates the scale for next iteration. + scaler.update() + + See the :ref:`Automatic Mixed Precision examples` for usage + (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, + and multiple losses/optimizers. + + ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow, + a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if + the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used + without incurring inf or NaN gradient values. + ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every + ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). + + * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params + themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. + + * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. + If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by + ``growth_factor``. + + The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its + value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these + iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). + + Args: + init_scale (float, optional, default=2.**16): Initial scale factor. + growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during + :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. + backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during + :meth:`update` if inf/NaN gradients occur in an iteration. + growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients + that must occur for the scale to be multiplied by ``growth_factor``. + enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply + invokes the underlying ``optimizer.step()``, and other methods become no-ops. + Default: ``True`` + """ + def __init__(self, + init_scale=2.**16, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000, + enabled=True): + self._enabled = enabled + + if self._enabled: + assert growth_factor > 1.0, "The growth factor must be > 1.0." + assert backoff_factor < 1.0, "The backoff factor must be < 1.0." + + self._init_scale = init_scale + # self._scale will be lazily initialized during the first call to scale() + self._scale = None + self._growth_factor = growth_factor + self._backoff_factor = backoff_factor + self._growth_interval = growth_interval + self._init_growth_tracker = 0 + # self._growth_tracker will be lazily initialized during the first call to scale() + self._growth_tracker = None + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + def _check_scale_growth_tracker(self, funcname) -> Tuple[jt.Var, jt.Var]: + fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." + assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix + assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix + return (self._scale, self._growth_tracker) + + def _lazy_init_scale_growth_tracker(self): + assert self._growth_tracker is None, "_growth_tracker initialized before _scale" + self._scale = self._init_scale + self._growth_tracker = self._init_growth_tracker + + def scale(self, outputs): + """ + Multiplies ('scales') a tensor or list of tensors by the scale factor. + + Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned + unmodified. + + Args: + outputs (Tensor or iterable of Tensors): Outputs to scale. + """ + if not self._enabled: + return outputs + + + # Short-circuit for the common case. + if isinstance(outputs, jt.Var): + assert jt.flags.use_cuda == 1 + if self._scale is None: + self._lazy_init_scale_growth_tracker() + assert self._scale is not None + return outputs * self._scale + + def apply_scale(val): + if isinstance(val, jt.Var): + assert jt.flags.use_cuda == 1 + if self._scale is None: + self._lazy_init_scale_growth_tracker() + assert self._scale is not None + return val * self._scale + elif isinstance(val, abc.Iterable): + iterable = map(apply_scale, val) + if isinstance(val, (list, tuple)): + return type(val)(iterable) + else: + return iterable + else: + raise ValueError("outputs must be a Tensor or an iterable of Tensors") + + return apply_scale(outputs) + + def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): + with jt.no_grad(): + optimizer.pre_step() + for group in optimizer.param_groups: + for to_unscale in group["grads"]: + if to_unscale is None or isinstance(to_unscale,(int,float)): + continue + if (not allow_fp16) and str(to_unscale.dtype) == "float16": + raise ValueError("Attempting to unscale FP16 gradients.") + + if not (to_unscale.isinf().any()): + if inv_scale != 1.0: + to_unscale.update(to_unscale*inv_scale) + else: + found_inf = 1.0 + + return found_inf + + def unscale_(self, optimizer): + """ + Divides ("unscales") the optimizer's gradient tensors by the scale factor. + + :meth:`unscale_` is optional, serving cases where you need to + :ref:`modify or inspect gradients` + between the backward pass(es) and :meth:`step`. + If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. + + Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: + + ... + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + scaler.step(optimizer) + scaler.update() + + Args: + optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. + + .. note:: + :meth:`unscale_` does not incur a CPU-GPU sync. + + .. warning:: + :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, + and only after all gradients for that optimizer's assigned parameters have been accumulated. + Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. + + .. warning:: + :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. + """ + if not self._enabled: + return + + self._check_scale_growth_tracker("unscale_") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if hasattr(optimizer,"get_find_inf"): + return + # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. + assert self._scale is not None + inv_scale = 1.0 / self._scale + found_inf = 0.0 + optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False) + + + def step(self, optimizer, *args, **kwargs): + """ + :meth:`step` carries out the following two operations: + + 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` + earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. + 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled + gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. + + ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. + + Returns the return value of ``optimizer.step(*args, **kwargs)``. + + Args: + optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. + args: Any arguments. + kwargs: Any keyword arguments. + + .. warning:: + Closure use is not currently supported. + """ + if (not self._enabled): + return optimizer.step(*args, **kwargs) + + if "closure" in kwargs: + raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.") + + self._check_scale_growth_tracker("step") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + retval = None + + if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling): + # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. + # The contract with custom optimizers is that their step() should accept an additional, + # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: + # it can query its own state, invoke unscale_ on itself, etc + # The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument + # to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale` + # and `found_inf` to the passed optimizer so that the optimizer can utilize those + # to skip the parameter updates or unscale gradients before updating parameters in + # the fused kernel, e.g. `FusedAdamMathFunctor`. + # In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`, + # while the method is expected to be called by users side, i.e. their optimizers. + kwargs_ = kwargs + has_grad_scaler_kwarg = "grad_scaler" in inspect.signature(optimizer.step).parameters + if has_grad_scaler_kwarg: + warnings.warn( + "GradScaler is going to stop passing itself as a keyword argument to the passed " + "optimizer. In the near future GradScaler registers `grad_scale: Tensor` and " + "`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.", + FutureWarning) + kwargs_.update({"grad_scaler": self}) + else: + if optimizer_state["stage"] is OptState.READY: + self._check_inf_per_device(optimizer) + scaler = self._get_scale_async() + found_inf = cast( + jt.Var, + sum([ + t for t in optimizer_state["found_inf_per_device"].values() + ]) + ) + optimizer.grad_scale = None if optimizer_state["stage"] == OptState.UNSCALED else scaler + optimizer.found_inf = found_inf + retval = optimizer.step(*args, **kwargs_) + optimizer_state["stage"] = OptState.STEPPED + if not has_grad_scaler_kwarg: + del optimizer.grad_scale + del optimizer.found_inf + return retval + + if hasattr(optimizer,"get_find_inf"): + optimizer.set_grad_scale(self._scale) + optimizer.step() + optimizer_state["found_inf_per_device"] = optimizer.get_find_inf() + return + + retval = None + if not optimizer_state["found_inf_per_device"]: + retval = optimizer.step(*args, **kwargs) + else: + optimizer.post_step() + + return retval + + + def update(self, new_scale=None): + """ + Updates the scale factor. + + If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` + to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, + the scale is multiplied by ``growth_factor`` to increase it. + + Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not + used directly, it's used to fill GradScaler's internal scale tensor. So if + ``new_scale`` was a tensor, later in-place changes to that tensor will not further + affect the scale GradScaler uses internally.) + + Args: + new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. + + .. warning:: + :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has + been invoked for all optimizers used this iteration. + """ + if not self._enabled: + return + + _scale, _growth_tracker = self._check_scale_growth_tracker("update") + + if new_scale is not None: + # Accept a new user-defined scale. + if isinstance(new_scale, float): + self._scale.fill_(new_scale) # type: ignore[union-attr] + else: + reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." + assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined] + assert new_scale.numel() == 1, reason + assert new_scale.requires_grad is False, reason + self._scale.copy_(new_scale) # type: ignore[union-attr] + else: + # Consume shared inf/nan data collected from optimizers to update the scale. + # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. + found_infs = [state["found_inf_per_device"] + for state in self._per_optimizer_states.values() + ] + + assert len(found_infs) > 0, "No inf checks were recorded prior to update." + + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + + + current_scale = _scale + if found_inf_combined: + current_scale *=self._backoff_factor + _growth_tracker = 0 + else: + successful = _growth_tracker+1 + if successful == self._growth_interval: + new_scale = current_scale*self._growth_factor + if new_scale < 1e9: + current_scale = new_scale + _growth_tracker = 0 + else: + _growth_tracker = successful + + self._scale, self._growth_tracker = current_scale,_growth_tracker + + # To prepare for next iteration, clear the data collected from optimizers this iteration. + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + def _get_scale_async(self): + return self._scale + + def get_scale(self): + """ + Returns a Python float containing the current scale, or 1.0 if scaling is disabled. + + .. warning:: + :meth:`get_scale` incurs a CPU-GPU sync. + """ + if self._enabled: + return self._init_scale if self._scale is None else self._get_scale_async() + else: + return 1.0 + + def get_growth_factor(self): + r""" + Returns a Python float containing the scale growth factor. + """ + return self._growth_factor + + def set_growth_factor(self, new_factor): + r""" + Args: + new_scale (float): Value to use as the new scale growth factor. + """ + self._growth_factor = new_factor + + def get_backoff_factor(self): + r""" + Returns a Python float containing the scale backoff factor. + """ + return self._backoff_factor + + def set_backoff_factor(self, new_factor): + r""" + Args: + new_scale (float): Value to use as the new scale backoff factor. + """ + self._backoff_factor = new_factor + + def get_growth_interval(self): + r""" + Returns a Python int containing the growth interval. + """ + return self._growth_interval + + def set_growth_interval(self, new_interval): + r""" + Args: + new_interval (int): Value to use as the new growth interval. + """ + self._growth_interval = new_interval + + def _get_growth_tracker(self): + if self._enabled: + return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item() + else: + return 0 + + def is_enabled(self): + r""" + Returns a bool indicating whether this instance is enabled. + """ + return self._enabled + + def state_dict(self): + r""" + Returns the state of the scaler as a :class:`dict`. It contains five entries: + + * ``"scale"`` - a Python float containing the current scale + * ``"growth_factor"`` - a Python float containing the current growth factor + * ``"backoff_factor"`` - a Python float containing the current backoff factor + * ``"growth_interval"`` - a Python int containing the current growth interval + * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. + + If this instance is not enabled, returns an empty dict. + + .. note:: + If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` + should be called after :meth:`update`. + """ + return {"scale": self.get_scale(), + "growth_factor": self._growth_factor, + "backoff_factor": self._backoff_factor, + "growth_interval": self._growth_interval, + "_growth_tracker": self._get_growth_tracker()} if self._enabled else {} + + def load_state_dict(self, state_dict): + r""" + Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op. + + Args: + state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. + """ + if not self._enabled: + return + + if len(state_dict) == 0: + raise RuntimeError("The source state dict is empty, possibly because it was saved " + "from a disabled instance of GradScaler.") + + self._init_scale = state_dict["scale"] + if self._scale is not None: + self._scale.fill_(state_dict["scale"]) + self._growth_factor = state_dict["growth_factor"] + self._backoff_factor = state_dict["backoff_factor"] + self._growth_interval = state_dict["growth_interval"] + self._init_growth_tracker = state_dict["_growth_tracker"] + if self._growth_tracker is not None: + self._growth_tracker.fill_(state_dict["_growth_tracker"]) + + def __getstate__(self): + state = self.__dict__.copy() + if self._enabled: + assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\ + "of an iteration, or at the end after scaler.update()." + # Pickling _scale and _growth_tracker Tensors directly triggers + # "warnings.warn("pickle support for Storage will be removed in 1.5..." + # so instead, we set the unpickled instance up to reinitialize them lazily. + state['_init_scale'] = self.get_scale() + state['_init_growth_tracker'] = self._get_growth_tracker() + state['_scale'] = None + state['_growth_tracker'] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + def _check_inf_per_device(self, optimizer): + _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") + + dummy_inv_scale = 1.0 + found_inf = 0.0 + + self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ + self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) + + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] + + def _found_inf_per_device(self, optimizer): + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] diff --git a/python/jittor/compatibility/gradscaler_old.py b/python/jittor/compatibility/gradscaler_old.py new file mode 100644 index 00000000..389be2cf --- /dev/null +++ b/python/jittor/compatibility/gradscaler_old.py @@ -0,0 +1,556 @@ +from collections import defaultdict, abc +from enum import Enum +from typing import Any, Dict, List, Optional, Tuple, cast +import inspect +import warnings + +import jittor as jt +# import torch + + +__all__ = ["OptState", "GradScaler"] + + +# Defines default_factory for GradScaler's _per_optimizer_states defaultdict, +# as well as associated "enum" values. Prefers defining these at top level because +# - Lambdas can't be pickled, so we don't want to supply a lambda as the factory. +# - Defining READY, UNSCALED, STEPPED and _refresh_per_optimizer_state within GradScaler +# causes a circular reference, which we'd rather avoid. +class OptState(Enum): + READY = 0 + UNSCALED = 1 + STEPPED = 2 + + +def _refresh_per_optimizer_state(): + return {"stage": OptState.READY, "found_inf_per_device": {}} + + +class GradScaler: + _scale: Optional[jt.Var] + _grows_tracker: Optional[jt.Var] + _per_optimizer_states: Dict[int, Dict[str, Any]] + """ + An instance ``scaler`` of :class:`GradScaler` helps perform the steps of gradient scaling + conveniently. + + * ``scaler.scale(loss)`` multiplies a given loss by ``scaler``'s current scale factor. + * ``scaler.step(optimizer)`` safely unscales gradients and calls ``optimizer.step()``. + * ``scaler.update()`` updates ``scaler``'s scale factor. + + Example:: + + # Creates a GradScaler once at the beginning of training. + scaler = GradScaler() + + for epoch in epochs: + for input, target in data: + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + + # Scales loss. Calls backward() on scaled loss to create scaled gradients. + scaler.scale(loss).backward() + + # scaler.step() first unscales gradients of the optimizer's params. + # If gradients don't contain infs/NaNs, optimizer.step() is then called, + # otherwise, optimizer.step() is skipped. + scaler.step(optimizer) + + # Updates the scale for next iteration. + scaler.update() + + See the :ref:`Automatic Mixed Precision examples` for usage + (along with autocasting) in more complex cases like gradient clipping, gradient accumulation, gradient penalty, + and multiple losses/optimizers. + + ``scaler`` dynamically estimates the scale factor each iteration. To minimize gradient underflow, + a large scale factor should be used. However, ``float16`` values can "overflow" (become inf or NaN) if + the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used + without incurring inf or NaN gradient values. + ``scaler`` approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every + ``scaler.step(optimizer)`` (or optional separate ``scaler.unscale_(optimizer)``, see :meth:`unscale_`). + + * If infs/NaNs are found, ``scaler.step(optimizer)`` skips the underlying ``optimizer.step()`` (so the params + themselves remain uncorrupted) and ``update()`` multiplies the scale by ``backoff_factor``. + + * If no infs/NaNs are found, ``scaler.step(optimizer)`` runs the underlying ``optimizer.step()`` as usual. + If ``growth_interval`` unskipped iterations occur consecutively, ``update()`` multiplies the scale by + ``growth_factor``. + + The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its + value calibrates. ``scaler.step`` will skip the underlying ``optimizer.step()`` for these + iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations). + + Args: + init_scale (float, optional, default=2.**16): Initial scale factor. + growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during + :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. + backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during + :meth:`update` if inf/NaN gradients occur in an iteration. + growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients + that must occur for the scale to be multiplied by ``growth_factor``. + enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply + invokes the underlying ``optimizer.step()``, and other methods become no-ops. + Default: ``True`` + """ + def __init__(self, + init_scale=2.**16, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000, + enabled=True): + self._enabled = enabled + + if self._enabled: + assert growth_factor > 1.0, "The growth factor must be > 1.0." + assert backoff_factor < 1.0, "The backoff factor must be < 1.0." + + self._init_scale = init_scale + # self._scale will be lazily initialized during the first call to scale() + self._scale = None + self._growth_factor = growth_factor + self._backoff_factor = backoff_factor + self._growth_interval = growth_interval + self._init_growth_tracker = 0 + # self._growth_tracker will be lazily initialized during the first call to scale() + self._growth_tracker = None + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + def _check_scale_growth_tracker(self, funcname) -> Tuple[jt.Var, jt.Var]: + fix = "This may indicate your script did not use scaler.scale(loss or outputs) earlier in the iteration." + assert self._scale is not None, "Attempted {} but _scale is None. ".format(funcname) + fix + assert self._growth_tracker is not None, "Attempted {} but _growth_tracker is None. ".format(funcname) + fix + return (self._scale, self._growth_tracker) + + def _lazy_init_scale_growth_tracker(self): + assert self._growth_tracker is None, "_growth_tracker initialized before _scale" + self._scale = self._init_scale + self._growth_tracker = self._init_growth_tracker + + def scale(self, outputs): + """ + Multiplies ('scales') a tensor or list of tensors by the scale factor. + + Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned + unmodified. + + Args: + outputs (Tensor or iterable of Tensors): Outputs to scale. + """ + print("scale") + if not self._enabled: + return outputs + + + # Short-circuit for the common case. + if isinstance(outputs, jt.Var): + assert jt.flags.use_cuda == 1 + if self._scale is None: + self._lazy_init_scale_growth_tracker() + assert self._scale is not None + return outputs * self._scale + + def apply_scale(val): + if isinstance(val, jt.Var): + assert jt.flags.use_cuda == 1 + if self._scale is None: + self._lazy_init_scale_growth_tracker() + assert self._scale is not None + return val * self._scale + elif isinstance(val, abc.Iterable): + iterable = map(apply_scale, val) + if isinstance(val, (list, tuple)): + return type(val)(iterable) + else: + return iterable + else: + raise ValueError("outputs must be a Tensor or an iterable of Tensors") + + return apply_scale(outputs) + + def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): + + # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. + # There could be hundreds of grads, so we'd like to iterate through them just once. + # However, we don't know their devices or dtypes in advance. + + # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict + # Google says mypy struggles with defaultdicts type annotations. + with jt.no_grad(): + optimizer.pre_step() + for group in optimizer.param_groups: + for to_unscale in group["grads"]: + if to_unscale is None or isinstance(to_unscale,(int,float)): + continue + if (not allow_fp16) and str(to_unscale.dtype) == "float16": + raise ValueError("Attempting to unscale FP16 gradients.") + + if not (to_unscale.isinf().any()): + if inv_scale != 1.0: + to_unscale.update(to_unscale*inv_scale) + else: + found_inf = 1.0 + + return found_inf + + def unscale_(self, optimizer): + """ + Divides ("unscales") the optimizer's gradient tensors by the scale factor. + + :meth:`unscale_` is optional, serving cases where you need to + :ref:`modify or inspect gradients` + between the backward pass(es) and :meth:`step`. + If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`. + + Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients:: + + ... + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + scaler.step(optimizer) + scaler.update() + + Args: + optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled. + + .. note:: + :meth:`unscale_` does not incur a CPU-GPU sync. + + .. warning:: + :meth:`unscale_` should only be called once per optimizer per :meth:`step` call, + and only after all gradients for that optimizer's assigned parameters have been accumulated. + Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError. + + .. warning:: + :meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute. + """ + if not self._enabled: + return + + self._check_scale_growth_tracker("unscale_") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.UNSCALED: + raise RuntimeError("unscale_() has already been called on this optimizer since the last update().") + elif optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("unscale_() is being called after step().") + + + # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. + assert self._scale is not None + inv_scale = 1.0 / self._scale + found_inf = 0.0 + optimizer_state["found_inf_per_device"] = self._unscale_grads_(optimizer, inv_scale, found_inf, False) + optimizer_state["stage"] = OptState.UNSCALED + + def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): + retval = None + if not optimizer_state["found_inf_per_device"]: + retval = optimizer.step(*args, **kwargs) + else: + optimizer.post_step() + + return retval + + def step(self, optimizer, *args, **kwargs): + """ + :meth:`step` carries out the following two operations: + + 1. Internally invokes ``unscale_(optimizer)`` (unless :meth:`unscale_` was explicitly called for ``optimizer`` + earlier in the iteration). As part of the :meth:`unscale_`, gradients are checked for infs/NaNs. + 2. If no inf/NaN gradients are found, invokes ``optimizer.step()`` using the unscaled + gradients. Otherwise, ``optimizer.step()`` is skipped to avoid corrupting the params. + + ``*args`` and ``**kwargs`` are forwarded to ``optimizer.step()``. + + Returns the return value of ``optimizer.step(*args, **kwargs)``. + + Args: + optimizer (torch.optim.Optimizer): Optimizer that applies the gradients. + args: Any arguments. + kwargs: Any keyword arguments. + + .. warning:: + Closure use is not currently supported. + """ + if (not self._enabled): + return optimizer.step(*args, **kwargs) + + if "closure" in kwargs: + raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.") + + self._check_scale_growth_tracker("step") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("step() has already been called since the last update().") + + retval = None + + if (hasattr(optimizer, "_step_supports_amp_scaling") and optimizer._step_supports_amp_scaling): + # This optimizer has customized scale-handling logic, so we can call optimizer.step() directly. + # The contract with custom optimizers is that their step() should accept an additional, + # optional grad_scaler kwarg. We append self to the kwargs so the custom optimizer has full information: + # it can query its own state, invoke unscale_ on itself, etc + # The contract above is being deprecated to avoid introducing `grad_scaler: GradScaler` argument + # to `Optimizer.step`. The new behavior is going to add two Tensor attributes of `grad_scale` + # and `found_inf` to the passed optimizer so that the optimizer can utilize those + # to skip the parameter updates or unscale gradients before updating parameters in + # the fused kernel, e.g. `FusedAdamMathFunctor`. + # In this behavior, `GradScaler._check_inf_per_device` is called if `OptState.READY`, + # while the method is expected to be called by users side, i.e. their optimizers. + kwargs_ = kwargs + has_grad_scaler_kwarg = "grad_scaler" in inspect.signature(optimizer.step).parameters + if has_grad_scaler_kwarg: + warnings.warn( + "GradScaler is going to stop passing itself as a keyword argument to the passed " + "optimizer. In the near future GradScaler registers `grad_scale: Tensor` and " + "`found_inf: Tensor` to the passed optimizer and let the optimizer use them directly.", + FutureWarning) + kwargs_.update({"grad_scaler": self}) + else: + if optimizer_state["stage"] is OptState.READY: + self._check_inf_per_device(optimizer) + scaler = self._get_scale_async() + found_inf = cast( + jt.Var, + sum([ + t for t in optimizer_state["found_inf_per_device"].values() + ]) + ) + optimizer.grad_scale = None if optimizer_state["stage"] == OptState.UNSCALED else scaler + optimizer.found_inf = found_inf + retval = optimizer.step(*args, **kwargs_) + optimizer_state["stage"] = OptState.STEPPED + if not has_grad_scaler_kwarg: + del optimizer.grad_scale + del optimizer.found_inf + return retval + + + if optimizer_state["stage"] is OptState.READY: + self.unscale_(optimizer) + + assert "found_inf_per_device" in optimizer_state, "No inf checks were recorded for this optimizer." + + retval = self._maybe_opt_step(optimizer, optimizer_state, *args, **kwargs) + + optimizer_state["stage"] = OptState.STEPPED + + return retval + + def update(self, new_scale=None): + """ + Updates the scale factor. + + If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` + to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, + the scale is multiplied by ``growth_factor`` to increase it. + + Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not + used directly, it's used to fill GradScaler's internal scale tensor. So if + ``new_scale`` was a tensor, later in-place changes to that tensor will not further + affect the scale GradScaler uses internally.) + + Args: + new_scale (float or :class:`torch.cuda.FloatTensor`, optional, default=None): New scale factor. + + .. warning:: + :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has + been invoked for all optimizers used this iteration. + """ + if not self._enabled: + return + + _scale, _growth_tracker = self._check_scale_growth_tracker("update") + + if new_scale is not None: + # Accept a new user-defined scale. + if isinstance(new_scale, float): + self._scale.fill_(new_scale) # type: ignore[union-attr] + else: + reason = "new_scale should be a float or a 1-element torch.cuda.FloatTensor with requires_grad=False." + assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined] + assert new_scale.numel() == 1, reason + assert new_scale.requires_grad is False, reason + self._scale.copy_(new_scale) # type: ignore[union-attr] + else: + # Consume shared inf/nan data collected from optimizers to update the scale. + # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. + found_infs = [state["found_inf_per_device"] + for state in self._per_optimizer_states.values() + ] + + assert len(found_infs) > 0, "No inf checks were recorded prior to update." + + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + + + current_scale = _scale + if found_inf_combined: + current_scale *=self._backoff_factor + _growth_tracker = 0 + else: + successful = _growth_tracker+1 + if successful == self._growth_interval: + new_scale = current_scale*self._growth_factor + if new_scale < 1e9: + current_scale = new_scale + _growth_tracker = 0 + else: + _growth_tracker = successful + + self._scale, self._growth_tracker = current_scale,_growth_tracker + + # To prepare for next iteration, clear the data collected from optimizers this iteration. + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + def _get_scale_async(self): + return self._scale + + def get_scale(self): + """ + Returns a Python float containing the current scale, or 1.0 if scaling is disabled. + + .. warning:: + :meth:`get_scale` incurs a CPU-GPU sync. + """ + if self._enabled: + return self._init_scale if self._scale is None else self._get_scale_async() + else: + return 1.0 + + def get_growth_factor(self): + r""" + Returns a Python float containing the scale growth factor. + """ + return self._growth_factor + + def set_growth_factor(self, new_factor): + r""" + Args: + new_scale (float): Value to use as the new scale growth factor. + """ + self._growth_factor = new_factor + + def get_backoff_factor(self): + r""" + Returns a Python float containing the scale backoff factor. + """ + return self._backoff_factor + + def set_backoff_factor(self, new_factor): + r""" + Args: + new_scale (float): Value to use as the new scale backoff factor. + """ + self._backoff_factor = new_factor + + def get_growth_interval(self): + r""" + Returns a Python int containing the growth interval. + """ + return self._growth_interval + + def set_growth_interval(self, new_interval): + r""" + Args: + new_interval (int): Value to use as the new growth interval. + """ + self._growth_interval = new_interval + + def _get_growth_tracker(self): + if self._enabled: + return self._init_growth_tracker if self._growth_tracker is None else self._growth_tracker.item() + else: + return 0 + + def is_enabled(self): + r""" + Returns a bool indicating whether this instance is enabled. + """ + return self._enabled + + def state_dict(self): + r""" + Returns the state of the scaler as a :class:`dict`. It contains five entries: + + * ``"scale"`` - a Python float containing the current scale + * ``"growth_factor"`` - a Python float containing the current growth factor + * ``"backoff_factor"`` - a Python float containing the current backoff factor + * ``"growth_interval"`` - a Python int containing the current growth interval + * ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps. + + If this instance is not enabled, returns an empty dict. + + .. note:: + If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict` + should be called after :meth:`update`. + """ + return {"scale": self.get_scale(), + "growth_factor": self._growth_factor, + "backoff_factor": self._backoff_factor, + "growth_interval": self._growth_interval, + "_growth_tracker": self._get_growth_tracker()} if self._enabled else {} + + def load_state_dict(self, state_dict): + r""" + Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op. + + Args: + state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`. + """ + if not self._enabled: + return + + if len(state_dict) == 0: + raise RuntimeError("The source state dict is empty, possibly because it was saved " + "from a disabled instance of GradScaler.") + + self._init_scale = state_dict["scale"] + if self._scale is not None: + self._scale.fill_(state_dict["scale"]) + self._growth_factor = state_dict["growth_factor"] + self._backoff_factor = state_dict["backoff_factor"] + self._growth_interval = state_dict["growth_interval"] + self._init_growth_tracker = state_dict["_growth_tracker"] + if self._growth_tracker is not None: + self._growth_tracker.fill_(state_dict["_growth_tracker"]) + + def __getstate__(self): + state = self.__dict__.copy() + if self._enabled: + assert len(self._per_optimizer_states) == 0, "A GradScaler instance may only be pickled at the beginning "\ + "of an iteration, or at the end after scaler.update()." + # Pickling _scale and _growth_tracker Tensors directly triggers + # "warnings.warn("pickle support for Storage will be removed in 1.5..." + # so instead, we set the unpickled instance up to reinitialize them lazily. + state['_init_scale'] = self.get_scale() + state['_init_growth_tracker'] = self._get_growth_tracker() + state['_scale'] = None + state['_growth_tracker'] = None + return state + + def __setstate__(self, state): + self.__dict__.update(state) + + def _check_inf_per_device(self, optimizer): + _scale, _ = self._check_scale_growth_tracker("_check_inf_per_device") + + dummy_inv_scale = 1.0 + found_inf = 0.0 + + self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] = \ + self._unscale_grads_(optimizer, dummy_inv_scale, found_inf, True) + + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] + + def _found_inf_per_device(self, optimizer): + return self._per_optimizer_states[id(optimizer)]["found_inf_per_device"] diff --git a/python/jittor/compatibility/misc.py b/python/jittor/compatibility/misc.py new file mode 100644 index 00000000..8e9ed20d --- /dev/null +++ b/python/jittor/compatibility/misc.py @@ -0,0 +1,12 @@ +import math + +def _jit_set_profiling_mode(x): pass +def _jit_set_profiling_executor(x): pass +def _jit_override_can_fuse_on_cpu(x): pass +def _jit_override_can_fuse_on_gpu(x): pass + +def script(func): + return func + +inf = math.inf +nan = math.nan \ No newline at end of file diff --git a/python/jittor/compatibility/nn/__init__.py b/python/jittor/compatibility/nn/__init__.py new file mode 100644 index 00000000..ae0ff3ae --- /dev/null +++ b/python/jittor/compatibility/nn/__init__.py @@ -0,0 +1,281 @@ +import jtorch +from typing import List, Optional, Tuple, Iterable, Iterator, Mapping, Any, overload, TypeVar, Dict +from typing_extensions import Self +import jittor as jt +from jtorch import make_module, Tensor, ModuleMisc, wrapper +#from . import init +from jittor import Function +import operator +import warnings + +for k,v in jt.nn.__dict__.items(): + if callable(v): + globals()[k] = wrapper(v) + +for k,v in jt.nn.__dict__.items(): + if isinstance(v, type) and issubclass(v, jt.Module): + globals()[k] = make_module(v) + +from collections import OrderedDict +from collections import abc as container_abcs + +class Module(ModuleMisc, jt.Module): + + def __call__(self, *args, **kw): + return self.execute(*args, **kw) + + def execute(self, *args, **kw): + return self.forward(*args, **kw) + + def get_submodule(self, target: str): + if target == "": + return self + + atoms: List[str] = target.split(".") + mod: jt.nn.Module = self + + for item in atoms: + if not hasattr(mod, item): + raise AttributeError(mod._get_name() + " has no " + "attribute `" + item + "`") + + mod = getattr(mod, item) + + if not isinstance(mod, jt.nn.Module): + raise AttributeError("`" + item + "` is not " + "an nn.Module") + return mod + + + +def Parameter(x:Tensor, requires_grad:bool=True) -> Tensor: + x = x.clone() + x.requires_grad = requires_grad + x.retains_grad = requires_grad + return x + +def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False): + return jt.nn.embedding(input, weight) + +def dropout(x, p=0.5, training=False): + return jt.nn.dropout(x, p, training) + + +class Flatten(Module): + ''' Flattens the contiguous range of dimensions in a Var. + :param start_dim: the first dimension to be flattened. Defaults: 1. + :type start_dim: int + :param end_dim: the last dimension to be flattened. Defaults: -1. + :type end_dim: int + ''' + def __init__(self, start_dim=1, end_dim=-1): + self.start_dim = start_dim + self.end_dim = end_dim + + def forward(self, x) -> jt.Var: + return x.flatten(self.start_dim, self.end_dim) + +class _IncompatibleKeys: + def __init__(self, missing_keys, unexpected_keys): + self.missing_keys = missing_keys + self.unexpected_keys = unexpected_keys + +_BatchNorm = None + +#from . import utils +normalize = wrapper(jt.normalize) + +T = TypeVar('T', bound=Module) + +class ModuleDict(Module): + _modules: Dict[str, Module] # type: ignore[assignment] + + def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None: + super().__init__() + if modules is not None: + self.update(modules) + + def __getitem__(self, key: str) -> Module: + return self._modules[key] + + def __setitem__(self, key: str, module: Module) -> None: + self.add_module(key, module) + + def __delitem__(self, key: str) -> None: + del self._modules[key] + + def __len__(self) -> int: + return len(self._modules) + + def __iter__(self) -> Iterator[str]: + return iter(self._modules) + + def __contains__(self, key: str) -> bool: + return key in self._modules + + def clear(self) -> None: + """Remove all items from the ModuleDict.""" + self._modules.clear() + + def pop(self, key: str) -> Module: + r"""Remove key from the ModuleDict and return its module. + + Args: + key (str): key to pop from the ModuleDict + """ + v = self[key] + del self[key] + return v + + def keys(self) -> Iterable[str]: + r"""Return an iterable of the ModuleDict keys.""" + return self._modules.keys() + + def items(self) -> Iterable[Tuple[str, Module]]: + r"""Return an iterable of the ModuleDict key/value pairs.""" + return self._modules.items() + + def values(self) -> Iterable[Module]: + r"""Return an iterable of the ModuleDict values.""" + return self._modules.values() + + def update(self, modules: Mapping[str, Module]) -> None: + r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys. + + .. note:: + If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or + an iterable of key-value pairs, the order of new elements in it is preserved. + + Args: + modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`, + or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`) + """ + if not isinstance(modules, container_abcs.Iterable): + raise TypeError("ModuleDict.update should be called with an " + "iterable of key/value pairs, but got " + + type(modules).__name__) + + if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)): + for key, module in modules.items(): + self[key] = module + else: + # modules here can be a list with two items + for j, m in enumerate(modules): + if not isinstance(m, container_abcs.Iterable): + raise TypeError("ModuleDict update sequence element " + "#" + str(j) + " should be Iterable; is" + + type(m).__name__) + if not len(m) == 2: + raise ValueError("ModuleDict update sequence element " + "#" + str(j) + " has length " + str(len(m)) + + "; 2 is required") + # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)] + # that's too cumbersome to type correctly with overloads, so we add an ignore here + self[m[0]] = m[1] # type: ignore[assignment] + + # remove forward alltogether to fallback on Module's _forward_unimplemented + + +class ParameterList(Module): + + def __init__(self, values: Optional[Iterable[Any]] = None) -> None: + super().__init__() + self._size = 0 + if values is not None: + self += values + + def _get_abs_string_index(self, idx): + """Get the absolute index for the list of modules.""" + idx = operator.index(idx) + if not (-len(self) <= idx < len(self)): + raise IndexError(f'index {idx} is out of range') + if idx < 0: + idx += len(self) + return str(idx) + + @overload + def __getitem__(self, idx: int) -> Any: + ... + + @overload + def __getitem__(self: T, idx: slice) -> T: + ... + + def __getitem__(self, idx): + if isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + out = self.__class__() + for i in range(start, stop, step): + out.append(self[i]) + return out + else: + idx = self._get_abs_string_index(idx) + return getattr(self, str(idx)) + + def __setitem__(self, idx: int, param: Any) -> None: + # Note that all other function that add an entry to the list part of + # the ParameterList end up here. So this is the only place where we need + # to wrap things into Parameter if needed. + # Objects added via setattr() are not in the list part and thus won't + # call into this function. + idx = self._get_abs_string_index(idx) + if isinstance(param, jt.Var) and not isinstance(param, Parameter): + param = Parameter(param) + return setattr(self, str(idx), param) + + def __len__(self) -> int: + return self._size + + def __iter__(self) -> Iterator[Any]: + return iter(self[i] for i in range(len(self))) + + def __iadd__(self, parameters: Iterable[Any]) -> Self: + return self.extend(parameters) + + def __dir__(self): + keys = super().__dir__() + keys = [key for key in keys if not key.isdigit()] + return keys + + def append(self, value: Any) -> 'ParameterList': + """Append a given value at the end of the list. + + Args: + value (Any): value to append + """ + new_idx = len(self) + self._size += 1 + self[new_idx] = value + return self + + def extend(self, values: Iterable[Any]) -> Self: + """Append values from a Python iterable to the end of the list. + + Args: + values (iterable): iterable of values to append + """ + # Tensor is an iterable but we never want to unpack it here + if not isinstance(values, container_abcs.Iterable) or isinstance(values, jt.Var): + raise TypeError("ParameterList.extend should be called with an " + "iterable, but got " + type(values).__name__) + for value in values: + self.append(value) + return self + + def extra_repr(self) -> str: + child_lines = [] + for k, p in enumerate(self): + if isinstance(p, jt.Var): + size_str = 'x'.join(str(size) for size in p.size()) + parastr = '{} containing: [{} of size {}{}]'.format( + "Parameter" if isinstance(p, Parameter) else "Tensor", + p.dtype, size_str, "cuda" if jt.flags.use_cuda else "cpu") + child_lines.append(' (' + str(k) + '): ' + parastr) + else: + child_lines.append(' (' + str(k) + '): Object of type: ' + type(p).__name__) + + tmpstr = '\n'.join(child_lines) + return tmpstr + + def __call__(self, *args, **kwargs): + raise RuntimeError('ParameterList should not be called.') \ No newline at end of file diff --git a/python/jittor/compatibility/nn/init.py b/python/jittor/compatibility/nn/init.py new file mode 100644 index 00000000..3b9f0907 --- /dev/null +++ b/python/jittor/compatibility/nn/init.py @@ -0,0 +1,16 @@ +import jittor as jt + +for k,v in jt.nn.init.__dict__.items(): + if callable(v): + globals()[k] = v + + +normal = gauss +normal_ = gauss_ +xavier_normal = xavier_gauss +xavier_normal_ = xavier_gauss_ +zeros_ = zero_ + + +jt.Var.normal_ = normal_ + diff --git a/python/jittor/compatibility/nn/utils/__init__.py b/python/jittor/compatibility/nn/utils/__init__.py new file mode 100644 index 00000000..83409f5f --- /dev/null +++ b/python/jittor/compatibility/nn/utils/__init__.py @@ -0,0 +1 @@ +from . import rnn \ No newline at end of file diff --git a/python/jittor/compatibility/nn/utils/rnn.py b/python/jittor/compatibility/nn/utils/rnn.py new file mode 100644 index 00000000..b32da8c3 --- /dev/null +++ b/python/jittor/compatibility/nn/utils/rnn.py @@ -0,0 +1,20 @@ +import jittor as jt + +PackedSequence = None + +def pad_sequence(sequences,batch_first=False,padding_value=0.0): + max_f = max([len(s) for s in sequences]) + # max_f = 512 + b = len(sequences) + if batch_first: + ret = sequences[0].new_full([b,max_f,]+list(sequences[0].shape[1:]),padding_value) + for i,s in enumerate(sequences): + ret[i,:len(s)] = s + else: + ret = sequences[0].new_full([max_f,b,]+list(sequences[0].shape[1:]),padding_value) + for i,s in enumerate(sequences): + ret[:len(s),i] = s + # print(ret.shape) + # ret = ret[:,:406] + return ret + \ No newline at end of file diff --git a/python/jittor/compatibility/optim.py b/python/jittor/compatibility/optim.py new file mode 100644 index 00000000..2410917f --- /dev/null +++ b/python/jittor/compatibility/optim.py @@ -0,0 +1,1854 @@ +import jittor as jt +import math +from jittor.optim import * +from functools import partial + +class Optimizer(jt.optim.Optimizer): + def pre_step(self, loss=None, retain_graph=False): + jt.flags.node_order = 1 + params_has_grad = [] + for pg in self.param_groups: + pg["grads"] = [ jt.zeros_like(p) if p.grad is None else p.grad#.float32() + for p in pg["params"] ] + for p in pg["params"]: + if p.requires_grad: + params_has_grad.append(p) + jt.sync(params_has_grad) + self.n_step += 1 + + def zero_grad(self): + for pg in self.param_groups: + pg["grads"] = [ None for p in pg["params"] ] + for p in pg["params"]: p.grad = None + + def post_step(self): + jt.flags.node_order = 0 + + def clip_grad_norm(self, max_norm:float, norm_type:int=2): + r"""Clips gradient norm of this optimizer. + The norm is computed over all gradients together. + + Args: + max_norm (float or int): max norm of the gradients + norm_type (int): 1-norm or 2-norm + + Example:: + + a = jt.ones(2) + opt = jt.optim.SGD([a], 0.1) + + loss = a*a + opt.zero_grad() + opt.backward(loss) + + print(opt.param_groups[0]['grads'][0].norm()) # output: 2.83 + opt.clip_grad_norm(0.01, 2) + print(opt.param_groups[0]['grads'][0].norm()) # output: 0.01 + + opt.step() + + """ + self.pre_step(None) + grads = [] + for pg in self.param_groups: + for p, g in zip(pg["params"], pg["grads"]): + if p.is_stop_grad(): continue + grads.append(g.flatten()) + if len(grads) == 0: return + total_norm = jt.norm(jt.concat(grads), norm_type) + clip_coef = jt.minimum(max_norm / (total_norm + 1e-6), 1.0) + for pg in self.param_groups: + for p, g in zip(pg["params"], pg["grads"]): + if p.is_stop_grad(): continue + g.update(g*clip_coef) + + +class AdamW(Optimizer): + def __init__(self, params, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0,use_fp32=True): + print("lr:", lr) + super().__init__(params, lr) + self.eps = eps + self.betas = betas + self.weight_decay = weight_decay + + self.use_fp32 = use_fp32 + # assert weight_decay==0, "weight_decay is not supported yet" + + # initialize required arguments for each param_groups + for pg in self.param_groups: + values = pg["values"] = [] + m = pg["m"] = [] + mp = pg['masterparams'] = [] + for p in pg["params"]: + values.append(jt.zeros(p.shape, "float32" if self.use_fp32 else p.dtype).stop_grad()) + m.append(jt.zeros(p.shape, "float32" if self.use_fp32 else p.dtype).stop_grad()) + if self.use_fp32: + mp.append(p.detach().clone().stop_grad()) + + def add_param_group(self, group): + values = group["values"] = [] + m = group["m"] = [] + mp = group['masterparams'] = [] + for p in group["params"]: + values.append(jt.zeros(p.shape, "float32" if self.use_fp32 else p.dtype).stop_grad()) + m.append(jt.zeros(p.shape, "float32" if self.use_fp32 else p.dtype).stop_grad()) + if self.use_fp32: + mp.append(p.detach().clone().stop_grad()) + self.param_groups.append(group) + + def step(self, loss=None, retain_graph=False): + self.pre_step(loss, retain_graph) + if loss is None: + self.n_step += 1 + n = float(self.n_step) + for pg in self.param_groups: + # get arguments from each param_groups + lr = pg.get("lr", self.lr) + eps = pg.get("eps", self.eps) + weight_decay = pg.get("weight_decay", self.weight_decay) + b0, b1 = pg.get("betas", self.betas) + + for p, g, v, m,mp in zip(pg["params"], pg["grads"], pg["values"], pg["m"],pg['masterparams']): + if p.is_stop_grad(): continue + #if g.abs().sum().item() < 1e-8: continue + #import pdb; pdb.set_trace() + c_p = (mp * (1 - lr * weight_decay)) + mp.update(c_p) + if self.use_fp32: + g = g.float32() + bias_correction1 = 1 - b0 ** n + bias_correction2 = 1 - b1 ** n + m.update(b0 * m + (1-b0) * g) #exp_avg + v.update(b1 * v + (1-b1) * g * g) #exp_avg_sq + denom = jt.sqrt(v) / jt.sqrt(bias_correction2) + eps + step_size = lr / bias_correction1 + new_p = (mp - step_size * m / denom) + mp.update(new_p) + p.update(mp.cast(p.dtype)) + self.post_step() + +for k,v in jt.optim.__dict__.items(): + if k == "AdamW":continue + if isinstance(v, type) and issubclass(v, jt.optim.Optimizer) and \ + not v is jt.optim.Optimizer: + class OptimWrap(v, Optimizer): + pass + globals()[k] = OptimWrap + + +class Adagrad(Optimizer): + pass + + + +import types +import math +from functools import wraps +import warnings +import weakref +from collections import Counter +from bisect import bisect_right + + +class LRScheduler: + + def __init__(self, optimizer, last_epoch=-1, verbose=False): + + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError('{} is not an Optimizer'.format( + type(optimizer).__name__)) + self.optimizer = optimizer + + # Initialize epoch and base learning rates + if last_epoch == -1: + for group in optimizer.param_groups: + group.setdefault('initial_lr', group.get("lr",optimizer.lr)) + else: + for i, group in enumerate(optimizer.param_groups): + if 'initial_lr' not in group: + raise KeyError("param 'initial_lr' is not specified " + "in param_groups[{}] when resuming an optimizer".format(i)) + self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups] + self.last_epoch = last_epoch + + # Following https://github.com/pytorch/pytorch/issues/20124 + # We would like to ensure that `lr_scheduler.step()` is called after + # `optimizer.step()` + def with_counter(method): + if getattr(method, '_with_counter', False): + # `optimizer.step()` has already been replaced, return. + return method + + # Keep a weak reference to the optimizer instance to prevent + # cyclic references. + instance_ref = weakref.ref(method.__self__) + # Get the unbound method for the same purpose. + func = method.__func__ + cls = instance_ref().__class__ + del method + + @wraps(func) + def wrapper(*args, **kwargs): + instance = instance_ref() + instance._step_count += 1 + wrapped = func.__get__(instance, cls) + return wrapped(*args, **kwargs) + + # Note that the returned function here is no longer a bound method, + # so attributes like `__func__` and `__self__` no longer exist. + wrapper._with_counter = True + return wrapper + + self.optimizer.step = with_counter(self.optimizer.step) + self.verbose = verbose + + self._initial_step() + + def _initial_step(self): + """Initialize step counts and performs a step""" + self.optimizer._step_count = 0 + self._step_count = 0 + self.step() + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + """ + return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + self.__dict__.update(state_dict) + + def get_last_lr(self): + """ Return last computed learning rate by current scheduler. + """ + return self._last_lr + + def get_lr(self): + # Compute learning rate using chainable form of the scheduler + raise NotImplementedError + + def print_lr(self, is_verbose, group, lr, epoch=None): + """Display the current learning rate. + """ + if is_verbose: + if epoch is None: + print('Adjusting learning rate' + ' of group {} to {:.4e}.'.format(group, lr)) + else: + epoch_str = ("%.2f" if isinstance(epoch, float) else + "%.5d") % epoch + print('Epoch {}: adjusting learning rate' + ' of group {} to {:.4e}.'.format(epoch_str, group, lr)) + + + def step(self, epoch=None): + # Raise a warning if old pattern is detected + # https://github.com/pytorch/pytorch/issues/20124 + if self._step_count == 1: + if not hasattr(self.optimizer.step, "_with_counter"): + warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler " + "initialization. Please, make sure to call `optimizer.step()` before " + "`lr_scheduler.step()`. See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) + + # Just check if there were two first lr_scheduler.step() calls before optimizer.step() + elif self.optimizer._step_count < 1: + warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. " + "In PyTorch 1.1.0 and later, you should call them in the opposite order: " + "`optimizer.step()` before `lr_scheduler.step()`. Failure to do this " + "will result in PyTorch skipping the first value of the learning rate schedule. " + "See more details at " + "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning) + self._step_count += 1 + + with _enable_get_lr_call(self): + if epoch is None: + self.last_epoch += 1 + values = self.get_lr() + else: + warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) + self.last_epoch = epoch + if hasattr(self, "_get_closed_form_lr"): + values = self._get_closed_form_lr() + else: + values = self.get_lr() + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group['lr'] = lr + self.print_lr(self.verbose, i, lr, epoch) + + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + +# Including _LRScheduler for backwards compatibility +# Subclass instead of assign because we want __name__ of _LRScheduler to be _LRScheduler (assigning would make it LRScheduler). +class _LRScheduler(LRScheduler): + pass + + +class _enable_get_lr_call: + + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_lr_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_lr_called_within_step = False + + +class LambdaLR(LRScheduler): + """Sets the learning rate of each parameter group to the initial lr + times a given function. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + lr_lambda (function or list): A function which computes a multiplicative + factor given an integer parameter epoch, or a list of such + functions, one for each group in optimizer.param_groups. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer has two groups. + >>> lambda1 = lambda epoch: epoch // 30 + >>> lambda2 = lambda epoch: 0.95 ** epoch + >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False): + self.optimizer = optimizer + + if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): + self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) + else: + if len(lr_lambda) != len(optimizer.param_groups): + raise ValueError("Expected {} lr_lambdas, but got {}".format( + len(optimizer.param_groups), len(lr_lambda))) + self.lr_lambdas = list(lr_lambda) + super().__init__(optimizer, last_epoch, verbose) + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The learning rate lambda functions will only be saved if they are callable objects + and not if they are functions or lambdas. + + When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. + """ + + state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')} + state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas) + + for idx, fn in enumerate(self.lr_lambdas): + if not isinstance(fn, types.FunctionType): + state_dict['lr_lambdas'][idx] = fn.__dict__.copy() + + return state_dict + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + When saving or loading the scheduler, please make sure to also save or load the state of the optimizer. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + + lr_lambdas = state_dict.pop('lr_lambdas') + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict['lr_lambdas'] = lr_lambdas + + for idx, fn in enumerate(lr_lambdas): + if fn is not None: + self.lr_lambdas[idx].__dict__.update(fn) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.") + return [base_lr * lmbda(self.last_epoch) + for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] + + +class MultiplicativeLR(LRScheduler): + """Multiply the learning rate of each parameter group by the factor given + in the specified function. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + lr_lambda (function or list): A function which computes a multiplicative + factor given an integer parameter epoch, or a list of such + functions, one for each group in optimizer.param_groups. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP + >>> lmbda = lambda epoch: 0.95 + >>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False): + self.optimizer = optimizer + + if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): + self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) + else: + if len(lr_lambda) != len(optimizer.param_groups): + raise ValueError("Expected {} lr_lambdas, but got {}".format( + len(optimizer.param_groups), len(lr_lambda))) + self.lr_lambdas = list(lr_lambda) + super().__init__(optimizer, last_epoch, verbose) + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The learning rate lambda functions will only be saved if they are callable objects + and not if they are functions or lambdas. + """ + state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')} + state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas) + + for idx, fn in enumerate(self.lr_lambdas): + if not isinstance(fn, types.FunctionType): + state_dict['lr_lambdas'][idx] = fn.__dict__.copy() + + return state_dict + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + lr_lambdas = state_dict.pop('lr_lambdas') + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict['lr_lambdas'] = lr_lambdas + + for idx, fn in enumerate(lr_lambdas): + if fn is not None: + self.lr_lambdas[idx].__dict__.update(fn) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if self.last_epoch > 0: + return [group['lr'] * lmbda(self.last_epoch) + for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)] + else: + return [group['lr'] for group in self.optimizer.param_groups] + + +class StepLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma every + step_size epochs. Notice that such decay can happen simultaneously with + other changes to the learning rate from outside this scheduler. When + last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + step_size (int): Period of learning rate decay. + gamma (float): Multiplicative factor of learning rate decay. + Default: 0.1. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.05 if epoch < 30 + >>> # lr = 0.005 if 30 <= epoch < 60 + >>> # lr = 0.0005 if 60 <= epoch < 90 + >>> # ... + >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False): + self.step_size = step_size + self.gamma = gamma + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0): + return [group['lr'] for group in self.optimizer.param_groups] + return [group['lr'] * self.gamma + for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [base_lr * self.gamma ** (self.last_epoch // self.step_size) + for base_lr in self.base_lrs] + + +class MultiStepLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma once the + number of epoch reaches one of the milestones. Notice that such decay can + happen simultaneously with other changes to the learning rate from outside + this scheduler. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + milestones (list): List of epoch indices. Must be increasing. + gamma (float): Multiplicative factor of learning rate decay. + Default: 0.1. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.05 if epoch < 30 + >>> # lr = 0.005 if 30 <= epoch < 80 + >>> # lr = 0.0005 if epoch >= 80 + >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False): + self.milestones = Counter(milestones) + self.gamma = gamma + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [group['lr'] * self.gamma ** self.milestones[self.last_epoch] + for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + milestones = sorted(self.milestones.elements()) + return [base_lr * self.gamma ** bisect_right(milestones, self.last_epoch) + for base_lr in self.base_lrs] + + +class ConstantLR(LRScheduler): + """Decays the learning rate of each parameter group by a small constant factor until the + number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can + happen simultaneously with other changes to the learning rate from outside this scheduler. + When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + factor (float): The number we multiply learning rate until the milestone. Default: 1./3. + total_iters (int): The number of steps that the scheduler decays the learning rate. + Default: 5. + last_epoch (int): The index of the last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.025 if epoch == 1 + >>> # lr = 0.025 if epoch == 2 + >>> # lr = 0.025 if epoch == 3 + >>> # lr = 0.05 if epoch >= 4 + >>> scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False): + if factor > 1.0 or factor < 0: + raise ValueError('Constant multiplicative factor expected to be between 0 and 1.') + + self.factor = factor + self.total_iters = total_iters + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if self.last_epoch == 0: + return [group['lr'] * self.factor for group in self.optimizer.param_groups] + + if (self.last_epoch > self.total_iters or + (self.last_epoch != self.total_iters)): + return [group['lr'] for group in self.optimizer.param_groups] + + if (self.last_epoch == self.total_iters): + return [group['lr'] * (1.0 / self.factor) for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor)) + for base_lr in self.base_lrs] + + +class LinearLR(LRScheduler): + """Decays the learning rate of each parameter group by linearly changing small + multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters. + Notice that such decay can happen simultaneously with other changes to the learning rate + from outside this scheduler. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + start_factor (float): The number we multiply learning rate in the first epoch. + The multiplication factor changes towards end_factor in the following epochs. + Default: 1./3. + end_factor (float): The number we multiply learning rate at the end of linear changing + process. Default: 1.0. + total_iters (int): The number of iterations that multiplicative factor reaches to 1. + Default: 5. + last_epoch (int): The index of the last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 0.05 for all groups + >>> # lr = 0.025 if epoch == 0 + >>> # lr = 0.03125 if epoch == 1 + >>> # lr = 0.0375 if epoch == 2 + >>> # lr = 0.04375 if epoch == 3 + >>> # lr = 0.05 if epoch >= 4 + >>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1, + verbose=False): + if start_factor > 1.0 or start_factor <= 0: + raise ValueError('Starting multiplicative factor expected to be greater than 0 and less or equal to 1.') + + if end_factor > 1.0 or end_factor < 0: + raise ValueError('Ending multiplicative factor expected to be between 0 and 1.') + + self.start_factor = start_factor + self.end_factor = end_factor + self.total_iters = total_iters + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if self.last_epoch == 0: + return [group['lr'] * self.start_factor for group in self.optimizer.param_groups] + + if self.last_epoch > self.total_iters: + return [group['lr'] for group in self.optimizer.param_groups] + + return [group['lr'] * (1. + (self.end_factor - self.start_factor) / + (self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor))) + for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [base_lr * (self.start_factor + + (self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters) + for base_lr in self.base_lrs] + + +class ExponentialLR(LRScheduler): + """Decays the learning rate of each parameter group by gamma every epoch. + When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + gamma (float): Multiplicative factor of learning rate decay. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + """ + + def __init__(self, optimizer, gamma, last_epoch=-1, verbose=False): + self.gamma = gamma + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if self.last_epoch == 0: + return [group['lr'] for group in self.optimizer.param_groups] + return [group['lr'] * self.gamma + for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [base_lr * self.gamma ** self.last_epoch + for base_lr in self.base_lrs] + + +class SequentialLR(LRScheduler): + """Receives the list of schedulers that is expected to be called sequentially during + optimization process and milestone points that provides exact intervals to reflect + which scheduler is supposed to be called at a given epoch. + + Args: + optimizer (Optimizer): Wrapped optimizer. + schedulers (list): List of chained schedulers. + milestones (list): List of integers that reflects milestone points. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): Does nothing. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 1. for all groups + >>> # lr = 0.1 if epoch == 0 + >>> # lr = 0.1 if epoch == 1 + >>> # lr = 0.9 if epoch == 2 + >>> # lr = 0.81 if epoch == 3 + >>> # lr = 0.729 if epoch == 4 + >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) + >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) + >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False): + for scheduler_idx in range(len(schedulers)): + if schedulers[scheduler_idx].optimizer != optimizer: + raise ValueError( + "Sequential Schedulers expects all schedulers to belong to the same optimizer, but " + f"got schedulers at index {scheduler_idx} to be different than the optimizer passed in." + ) + + if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): + raise ValueError( + "Sequential Schedulers expects all schedulers to belong to the same optimizer, but " + f"got schedulers at index {0} and {scheduler_idx} to be different." + ) + if (len(milestones) != len(schedulers) - 1): + raise ValueError( + "Sequential Schedulers expects number of schedulers provided to be one more " + "than the number of milestone points, but got number of schedulers {} and the " + "number of milestones to be equal to {}".format(len(schedulers), len(milestones)) + ) + self._schedulers = schedulers + self._milestones = milestones + self.last_epoch = last_epoch + 1 + self.optimizer = optimizer + + # Reset learning rates back to initial values + for group in self.optimizer.param_groups: + group["lr"] = group["initial_lr"] + + # "Undo" the step performed by other schedulers + for scheduler in self._schedulers: + scheduler.last_epoch -= 1 + + # Perform the initial step for only the first scheduler + self._schedulers[0]._initial_step() + + self._last_lr = schedulers[0].get_last_lr() + + def step(self): + self.last_epoch += 1 + idx = bisect_right(self._milestones, self.last_epoch) + scheduler = self._schedulers[idx] + if idx > 0 and self._milestones[idx - 1] == self.last_epoch: + scheduler.step(0) + else: + scheduler.step() + + self._last_lr = scheduler.get_last_lr() + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The wrapped scheduler states will also be saved. + """ + state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')} + state_dict['_schedulers'] = [None] * len(self._schedulers) + + for idx, s in enumerate(self._schedulers): + state_dict['_schedulers'][idx] = s.state_dict() + + return state_dict + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + _schedulers = state_dict.pop('_schedulers') + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict['_schedulers'] = _schedulers + + for idx, s in enumerate(_schedulers): + self._schedulers[idx].load_state_dict(s) + + +class PolynomialLR(LRScheduler): + """Decays the learning rate of each parameter group using a polynomial function + in the given total_iters. When last_epoch=-1, sets initial lr as lr. + + Args: + optimizer (Optimizer): Wrapped optimizer. + total_iters (int): The number of steps that the scheduler decays the learning rate. Default: 5. + power (int): The power of the polynomial. Default: 1.0. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP("undefined vars") + >>> # Assuming optimizer uses lr = 0.001 for all groups + >>> # lr = 0.001 if epoch == 0 + >>> # lr = 0.00075 if epoch == 1 + >>> # lr = 0.00050 if epoch == 2 + >>> # lr = 0.00025 if epoch == 3 + >>> # lr = 0.0 if epoch >= 4 + >>> scheduler = PolynomialLR(self.opt, total_iters=4, power=1.0) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + def __init__(self, optimizer, total_iters=5, power=1.0, last_epoch=-1, verbose=False): + self.total_iters = total_iters + self.power = power + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if self.last_epoch == 0 or self.last_epoch > self.total_iters: + return [group["lr"] for group in self.optimizer.param_groups] + + decay_factor = ((1.0 - self.last_epoch / self.total_iters) / (1.0 - (self.last_epoch - 1) / self.total_iters)) ** self.power + return [group["lr"] * decay_factor for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [ + ( + base_lr * (1.0 - min(self.total_iters, self.last_epoch) / self.total_iters) ** self.power + ) + for base_lr in self.base_lrs + ] + + +class CosineAnnealingLR(LRScheduler): + r"""Set the learning rate of each parameter group using a cosine annealing + schedule, where :math:`\eta_{max}` is set to the initial lr and + :math:`T_{cur}` is the number of epochs since the last restart in SGDR: + + .. math:: + \begin{aligned} + \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right), + & T_{cur} \neq (2k+1)T_{max}; \\ + \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min}) + \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right), + & T_{cur} = (2k+1)T_{max}. + \end{aligned} + + When last_epoch=-1, sets initial lr as lr. Notice that because the schedule + is defined recursively, the learning rate can be simultaneously modified + outside this scheduler by other operators. If the learning rate is set + solely by this scheduler, the learning rate at each step becomes: + + .. math:: + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right) + + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only + implements the cosine annealing part of SGDR, and not the restarts. + + Args: + optimizer (Optimizer): Wrapped optimizer. + T_max (int): Maximum number of iterations. + eta_min (float): Minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + """ + + def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, verbose=False): + self.T_max = T_max + self.eta_min = eta_min + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if self.last_epoch == 0: + return [group['lr'] for group in self.optimizer.param_groups] + elif self._step_count == 1 and self.last_epoch > 0: + return [self.eta_min + (base_lr - self.eta_min) * + (1 + math.cos((self.last_epoch) * math.pi / self.T_max)) / 2 + for base_lr, group in + zip(self.base_lrs, self.optimizer.param_groups)] + elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0: + return [group['lr'] + (base_lr - self.eta_min) * + (1 - math.cos(math.pi / self.T_max)) / 2 + for base_lr, group in + zip(self.base_lrs, self.optimizer.param_groups)] + return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) / + (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * + (group['lr'] - self.eta_min) + self.eta_min + for group in self.optimizer.param_groups] + + def _get_closed_form_lr(self): + return [self.eta_min + (base_lr - self.eta_min) * + (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2 + for base_lr in self.base_lrs] + + +class ChainedScheduler(LRScheduler): + """Chains list of learning rate schedulers. It takes a list of chainable learning + rate schedulers and performs consecutive step() functions belonging to them by just + one call. + + Args: + schedulers (list): List of chained schedulers. + + Example: + >>> # xdoctest: +SKIP + >>> # Assuming optimizer uses lr = 1. for all groups + >>> # lr = 0.09 if epoch == 0 + >>> # lr = 0.081 if epoch == 1 + >>> # lr = 0.729 if epoch == 2 + >>> # lr = 0.6561 if epoch == 3 + >>> # lr = 0.59049 if epoch >= 4 + >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2) + >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9) + >>> scheduler = ChainedScheduler([scheduler1, scheduler2]) + >>> for epoch in range(100): + >>> train(...) + >>> validate(...) + >>> scheduler.step() + """ + + def __init__(self, schedulers): + for scheduler_idx in range(1, len(schedulers)): + if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer): + raise ValueError( + "ChainedScheduler expects all schedulers to belong to the same optimizer, but " + "got schedulers at index {} and {} to be different".format(0, scheduler_idx) + ) + self._schedulers = list(schedulers) + self.optimizer = schedulers[0].optimizer + self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups] + + def step(self): + for scheduler in self._schedulers: + scheduler.step() + self._last_lr = [group['lr'] for group in self._schedulers[-1].optimizer.param_groups] + + def state_dict(self): + """Returns the state of the scheduler as a :class:`dict`. + + It contains an entry for every variable in self.__dict__ which + is not the optimizer. + The wrapped scheduler states will also be saved. + """ + state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')} + state_dict['_schedulers'] = [None] * len(self._schedulers) + + for idx, s in enumerate(self._schedulers): + state_dict['_schedulers'][idx] = s.state_dict() + + return state_dict + + def load_state_dict(self, state_dict): + """Loads the schedulers state. + + Args: + state_dict (dict): scheduler state. Should be an object returned + from a call to :meth:`state_dict`. + """ + _schedulers = state_dict.pop('_schedulers') + self.__dict__.update(state_dict) + # Restore state_dict keys in order to prevent side effects + # https://github.com/pytorch/pytorch/issues/32756 + state_dict['_schedulers'] = _schedulers + + for idx, s in enumerate(_schedulers): + self._schedulers[idx].load_state_dict(s) + + +class ReduceLROnPlateau: + """Reduce learning rate when a metric has stopped improving. + Models often benefit from reducing the learning rate by a factor + of 2-10 once learning stagnates. This scheduler reads a metrics + quantity and if no improvement is seen for a 'patience' number + of epochs, the learning rate is reduced. + + Args: + optimizer (Optimizer): Wrapped optimizer. + mode (str): One of `min`, `max`. In `min` mode, lr will + be reduced when the quantity monitored has stopped + decreasing; in `max` mode it will be reduced when the + quantity monitored has stopped increasing. Default: 'min'. + factor (float): Factor by which the learning rate will be + reduced. new_lr = lr * factor. Default: 0.1. + patience (int): Number of epochs with no improvement after + which learning rate will be reduced. For example, if + `patience = 2`, then we will ignore the first 2 epochs + with no improvement, and will only decrease the LR after the + 3rd epoch if the loss still hasn't improved then. + Default: 10. + threshold (float): Threshold for measuring the new optimum, + to only focus on significant changes. Default: 1e-4. + threshold_mode (str): One of `rel`, `abs`. In `rel` mode, + dynamic_threshold = best * ( 1 + threshold ) in 'max' + mode or best * ( 1 - threshold ) in `min` mode. + In `abs` mode, dynamic_threshold = best + threshold in + `max` mode or best - threshold in `min` mode. Default: 'rel'. + cooldown (int): Number of epochs to wait before resuming + normal operation after lr has been reduced. Default: 0. + min_lr (float or list): A scalar or a list of scalars. A + lower bound on the learning rate of all param groups + or each group respectively. Default: 0. + eps (float): Minimal decay applied to lr. If the difference + between new and old lr is smaller than eps, the update is + ignored. Default: 1e-8. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = ReduceLROnPlateau(optimizer, 'min') + >>> for epoch in range(10): + >>> train(...) + >>> val_loss = validate(...) + >>> # Note that step should be called after validate() + >>> scheduler.step(val_loss) + """ + + def __init__(self, optimizer, mode='min', factor=0.1, patience=10, + threshold=1e-4, threshold_mode='rel', cooldown=0, + min_lr=0, eps=1e-8, verbose=False): + + if factor >= 1.0: + raise ValueError('Factor should be < 1.0.') + self.factor = factor + + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError('{} is not an Optimizer'.format( + type(optimizer).__name__)) + self.optimizer = optimizer + + if isinstance(min_lr, (list, tuple)): + if len(min_lr) != len(optimizer.param_groups): + raise ValueError("expected {} min_lrs, got {}".format( + len(optimizer.param_groups), len(min_lr))) + self.min_lrs = list(min_lr) + else: + self.min_lrs = [min_lr] * len(optimizer.param_groups) + + self.patience = patience + self.verbose = verbose + self.cooldown = cooldown + self.cooldown_counter = 0 + self.mode = mode + self.threshold = threshold + self.threshold_mode = threshold_mode + self.best = None + self.num_bad_epochs = None + self.mode_worse = None # the worse value for the chosen mode + self.eps = eps + self.last_epoch = 0 + self._init_is_better(mode=mode, threshold=threshold, + threshold_mode=threshold_mode) + self._reset() + + def _reset(self): + """Resets num_bad_epochs counter and cooldown counter.""" + self.best = self.mode_worse + self.cooldown_counter = 0 + self.num_bad_epochs = 0 + + def step(self, metrics, epoch=None): + # convert `metrics` to float, in case it's a zero-dim Tensor + current = float(metrics) + if epoch is None: + epoch = self.last_epoch + 1 + else: + warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning) + self.last_epoch = epoch + + if self.is_better(current, self.best): + self.best = current + self.num_bad_epochs = 0 + else: + self.num_bad_epochs += 1 + + if self.in_cooldown: + self.cooldown_counter -= 1 + self.num_bad_epochs = 0 # ignore any bad epochs in cooldown + + if self.num_bad_epochs > self.patience: + self._reduce_lr(epoch) + self.cooldown_counter = self.cooldown + self.num_bad_epochs = 0 + + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + + def _reduce_lr(self, epoch): + for i, param_group in enumerate(self.optimizer.param_groups): + old_lr = float(param_group['lr']) + new_lr = max(old_lr * self.factor, self.min_lrs[i]) + if old_lr - new_lr > self.eps: + param_group['lr'] = new_lr + if self.verbose: + epoch_str = ("%.2f" if isinstance(epoch, float) else + "%.5d") % epoch + print('Epoch {}: reducing learning rate' + ' of group {} to {:.4e}.'.format(epoch_str, i, new_lr)) + + @property + def in_cooldown(self): + return self.cooldown_counter > 0 + + def is_better(self, a, best): + if self.mode == 'min' and self.threshold_mode == 'rel': + rel_epsilon = 1. - self.threshold + return a < best * rel_epsilon + + elif self.mode == 'min' and self.threshold_mode == 'abs': + return a < best - self.threshold + + elif self.mode == 'max' and self.threshold_mode == 'rel': + rel_epsilon = self.threshold + 1. + return a > best * rel_epsilon + + else: # mode == 'max' and epsilon_mode == 'abs': + return a > best + self.threshold + + def _init_is_better(self, mode, threshold, threshold_mode): + if mode not in {'min', 'max'}: + raise ValueError('mode ' + mode + ' is unknown!') + if threshold_mode not in {'rel', 'abs'}: + raise ValueError('threshold mode ' + threshold_mode + ' is unknown!') + + if mode == 'min': + self.mode_worse = inf + else: # mode == 'max': + self.mode_worse = -inf + + self.mode = mode + self.threshold = threshold + self.threshold_mode = threshold_mode + + def state_dict(self): + return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} + + def load_state_dict(self, state_dict): + self.__dict__.update(state_dict) + self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode) + + +class CyclicLR(LRScheduler): + r"""Sets the learning rate of each parameter group according to + cyclical learning rate policy (CLR). The policy cycles the learning + rate between two boundaries with a constant frequency, as detailed in + the paper `Cyclical Learning Rates for Training Neural Networks`_. + The distance between the two boundaries can be scaled on a per-iteration + or per-cycle basis. + + Cyclical learning rate policy changes the learning rate after every batch. + `step` should be called after a batch has been used for training. + + This class has three built-in policies, as put forth in the paper: + + * "triangular": A basic triangular cycle without amplitude scaling. + * "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle. + * "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}` + at each cycle iteration. + + This implementation was adapted from the github repo: `bckenstler/CLR`_ + + Args: + optimizer (Optimizer): Wrapped optimizer. + base_lr (float or list): Initial learning rate which is the + lower boundary in the cycle for each parameter group. + max_lr (float or list): Upper learning rate boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_lr - base_lr). + The lr at any cycle is the sum of base_lr + and some scaling of the amplitude; therefore + max_lr may not actually be reached depending on + scaling function. + step_size_up (int): Number of training iterations in the + increasing half of a cycle. Default: 2000 + step_size_down (int): Number of training iterations in the + decreasing half of a cycle. If step_size_down is None, + it is set to step_size_up. Default: None + mode (str): One of {triangular, triangular2, exp_range}. + Values correspond to policies detailed above. + If scale_fn is not None, this argument is ignored. + Default: 'triangular' + gamma (float): Constant in 'exp_range' scaling function: + gamma**(cycle iterations) + Default: 1.0 + scale_fn (function): Custom scaling policy defined by a single + argument lambda function, where + 0 <= scale_fn(x) <= 1 for all x >= 0. + If specified, then 'mode' is ignored. + Default: None + scale_mode (str): {'cycle', 'iterations'}. + Defines whether scale_fn is evaluated on + cycle number or cycle iterations (training + iterations since start of cycle). + Default: 'cycle' + cycle_momentum (bool): If ``True``, momentum is cycled inversely + to learning rate between 'base_momentum' and 'max_momentum'. + Default: True + base_momentum (float or list): Lower momentum boundaries in the cycle + for each parameter group. Note that momentum is cycled inversely + to learning rate; at the peak of a cycle, momentum is + 'base_momentum' and learning rate is 'max_lr'. + Default: 0.8 + max_momentum (float or list): Upper momentum boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_momentum - base_momentum). + The momentum at any cycle is the difference of max_momentum + and some scaling of the amplitude; therefore + base_momentum may not actually be reached depending on + scaling function. Note that momentum is cycled inversely + to learning rate; at the start of a cycle, momentum is 'max_momentum' + and learning rate is 'base_lr' + Default: 0.9 + last_epoch (int): The index of the last batch. This parameter is used when + resuming a training job. Since `step()` should be invoked after each + batch instead of after each epoch, this number represents the total + number of *batches* computed, not the total number of epochs computed. + When last_epoch=-1, the schedule is started from the beginning. + Default: -1 + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1) + >>> data_loader = torch.utils.data.DataLoader(...) + >>> for epoch in range(10): + >>> for batch in data_loader: + >>> train_batch(...) + >>> scheduler.step() + + + .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 + .. _bckenstler/CLR: https://github.com/bckenstler/CLR + """ + + def __init__(self, + optimizer, + base_lr, + max_lr, + step_size_up=2000, + step_size_down=None, + mode='triangular', + gamma=1., + scale_fn=None, + scale_mode='cycle', + cycle_momentum=True, + base_momentum=0.8, + max_momentum=0.9, + last_epoch=-1, + verbose=False): + + # Attach optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError('{} is not an Optimizer'.format( + type(optimizer).__name__)) + self.optimizer = optimizer + + base_lrs = self._format_param('base_lr', optimizer, base_lr) + if last_epoch == -1: + for lr, group in zip(base_lrs, optimizer.param_groups): + group['lr'] = lr + + self.max_lrs = self._format_param('max_lr', optimizer, max_lr) + + step_size_up = float(step_size_up) + step_size_down = float(step_size_down) if step_size_down is not None else step_size_up + self.total_size = step_size_up + step_size_down + self.step_ratio = step_size_up / self.total_size + + if mode not in ['triangular', 'triangular2', 'exp_range'] \ + and scale_fn is None: + raise ValueError('mode is invalid and scale_fn is None') + + self.mode = mode + self.gamma = gamma + + self._scale_fn_ref = None + self._scale_fn_custom = scale_fn + self.scale_mode = scale_mode + self._init_scale_fn() + + self.cycle_momentum = cycle_momentum + if cycle_momentum: + if 'momentum' not in optimizer.defaults: + raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled') + + base_momentums = self._format_param('base_momentum', optimizer, base_momentum) + if last_epoch == -1: + for momentum, group in zip(base_momentums, optimizer.param_groups): + group['momentum'] = momentum + self.base_momentums = [group['momentum'] for group in optimizer.param_groups] + self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum) + + super().__init__(optimizer, last_epoch, verbose) + self.base_lrs = base_lrs + + def _init_scale_fn(self): + if self._scale_fn_custom is not None: + return + if self.mode == 'triangular': + self._scale_fn_ref = weakref.WeakMethod(self._triangular_scale_fn) + self.scale_mode = 'cycle' + elif self.mode == 'triangular2': + self._scale_fn_ref = weakref.WeakMethod(self._triangular2_scale_fn) + self.scale_mode = 'cycle' + elif self.mode == 'exp_range': + self._scale_fn_ref = weakref.WeakMethod(self._exp_range_scale_fn) + self.scale_mode = 'iterations' + + def _format_param(self, name, optimizer, param): + """Return correctly formatted lr/momentum for each param group.""" + if isinstance(param, (list, tuple)): + if len(param) != len(optimizer.param_groups): + raise ValueError("expected {} values for {}, got {}".format( + len(optimizer.param_groups), name, len(param))) + return param + else: + return [param] * len(optimizer.param_groups) + + def scale_fn(self, x): + if self._scale_fn_custom is not None: + return self._scale_fn_custom(x) + + else: + return self._scale_fn_ref()(x) + + def _triangular_scale_fn(self, x): + return 1. + + def _triangular2_scale_fn(self, x): + return 1 / (2. ** (x - 1)) + + def _exp_range_scale_fn(self, x): + return self.gamma**(x) + + def get_lr(self): + """Calculates the learning rate at batch index. This function treats + `self.last_epoch` as the last batch index. + + If `self.cycle_momentum` is ``True``, this function has a side effect of + updating the optimizer's momentum. + """ + + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + cycle = math.floor(1 + self.last_epoch / self.total_size) + x = 1. + self.last_epoch / self.total_size - cycle + if x <= self.step_ratio: + scale_factor = x / self.step_ratio + else: + scale_factor = (x - 1) / (self.step_ratio - 1) + + lrs = [] + for base_lr, max_lr in zip(self.base_lrs, self.max_lrs): + base_height = (max_lr - base_lr) * scale_factor + if self.scale_mode == 'cycle': + lr = base_lr + base_height * self.scale_fn(cycle) + else: + lr = base_lr + base_height * self.scale_fn(self.last_epoch) + lrs.append(lr) + + if self.cycle_momentum: + momentums = [] + for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums): + base_height = (max_momentum - base_momentum) * scale_factor + if self.scale_mode == 'cycle': + momentum = max_momentum - base_height * self.scale_fn(cycle) + else: + momentum = max_momentum - base_height * self.scale_fn(self.last_epoch) + momentums.append(momentum) + for param_group, momentum in zip(self.optimizer.param_groups, momentums): + param_group['momentum'] = momentum + + return lrs + + def state_dict(self): + state = super().state_dict() + # We are dropping the `_scale_fn_ref` attribute because it is a `weakref.WeakMethod` and can't be pickled + state.pop("_scale_fn_ref") + return state + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + self._init_scale_fn() + + + +class CosineAnnealingWarmRestarts(LRScheduler): + r"""Set the learning rate of each parameter group using a cosine annealing + schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` + is the number of epochs since the last restart and :math:`T_{i}` is the number + of epochs between two warm restarts in SGDR: + + .. math:: + \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 + + \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right) + + When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. + When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`. + + It has been proposed in + `SGDR: Stochastic Gradient Descent with Warm Restarts`_. + + Args: + optimizer (Optimizer): Wrapped optimizer. + T_0 (int): Number of iterations for the first restart. + T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1. + eta_min (float, optional): Minimum learning rate. Default: 0. + last_epoch (int, optional): The index of last epoch. Default: -1. + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: + https://arxiv.org/abs/1608.03983 + """ + + def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False): + if T_0 <= 0 or not isinstance(T_0, int): + raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) + if T_mult < 1 or not isinstance(T_mult, int): + raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult)) + self.T_0 = T_0 + self.T_i = T_0 + self.T_mult = T_mult + self.eta_min = eta_min + self.T_cur = last_epoch + super().__init__(optimizer, last_epoch, verbose) + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2 + for base_lr in self.base_lrs] + + def step(self, epoch=None): + """Step could be called after every batch update + + Example: + >>> # xdoctest: +SKIP("Undefined vars") + >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) + >>> iters = len(dataloader) + >>> for epoch in range(20): + >>> for i, sample in enumerate(dataloader): + >>> inputs, labels = sample['inputs'], sample['labels'] + >>> optimizer.zero_grad() + >>> outputs = net(inputs) + >>> loss = criterion(outputs, labels) + >>> loss.backward() + >>> optimizer.step() + >>> scheduler.step(epoch + i / iters) + + This function can be called in an interleaved way. + + Example: + >>> # xdoctest: +SKIP("Undefined vars") + >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult) + >>> for epoch in range(20): + >>> scheduler.step() + >>> scheduler.step(26) + >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) + """ + + if epoch is None and self.last_epoch < 0: + epoch = 0 + + if epoch is None: + epoch = self.last_epoch + 1 + self.T_cur = self.T_cur + 1 + if self.T_cur >= self.T_i: + self.T_cur = self.T_cur - self.T_i + self.T_i = self.T_i * self.T_mult + else: + if epoch < 0: + raise ValueError("Expected non-negative epoch, but got {}".format(epoch)) + if epoch >= self.T_0: + if self.T_mult == 1: + self.T_cur = epoch % self.T_0 + else: + n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult)) + self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1) + self.T_i = self.T_0 * self.T_mult ** (n) + else: + self.T_i = self.T_0 + self.T_cur = epoch + self.last_epoch = math.floor(epoch) + + class _enable_get_lr_call: + + def __init__(self, o): + self.o = o + + def __enter__(self): + self.o._get_lr_called_within_step = True + return self + + def __exit__(self, type, value, traceback): + self.o._get_lr_called_within_step = False + return self + + with _enable_get_lr_call(self): + for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())): + param_group, lr = data + param_group['lr'] = lr + self.print_lr(self.verbose, i, lr, epoch) + + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + + +class OneCycleLR(LRScheduler): + r"""Sets the learning rate of each parameter group according to the + 1cycle learning rate policy. The 1cycle policy anneals the learning + rate from an initial learning rate to some maximum learning rate and then + from that maximum learning rate to some minimum learning rate much lower + than the initial learning rate. + This policy was initially described in the paper `Super-Convergence: + Very Fast Training of Neural Networks Using Large Learning Rates`_. + + The 1cycle learning rate policy changes the learning rate after every batch. + `step` should be called after a batch has been used for training. + + This scheduler is not chainable. + + Note also that the total number of steps in the cycle can be determined in one + of two ways (listed in order of precedence): + + #. A value for total_steps is explicitly provided. + #. A number of epochs (epochs) and a number of steps per epoch + (steps_per_epoch) are provided. + In this case, the number of total steps is inferred by + total_steps = epochs * steps_per_epoch + + You must either provide a value for total_steps or provide a value for both + epochs and steps_per_epoch. + + The default behaviour of this scheduler follows the fastai implementation of 1cycle, which + claims that "unpublished work has shown even better results by using only two phases". To + mimic the behaviour of the original paper instead, set ``three_phase=True``. + + Args: + optimizer (Optimizer): Wrapped optimizer. + max_lr (float or list): Upper learning rate boundaries in the cycle + for each parameter group. + total_steps (int): The total number of steps in the cycle. Note that + if a value is not provided here, then it must be inferred by providing + a value for epochs and steps_per_epoch. + Default: None + epochs (int): The number of epochs to train for. This is used along + with steps_per_epoch in order to infer the total number of steps in the cycle + if a value for total_steps is not provided. + Default: None + steps_per_epoch (int): The number of steps per epoch to train for. This is + used along with epochs in order to infer the total number of steps in the + cycle if a value for total_steps is not provided. + Default: None + pct_start (float): The percentage of the cycle (in number of steps) spent + increasing the learning rate. + Default: 0.3 + anneal_strategy (str): {'cos', 'linear'} + Specifies the annealing strategy: "cos" for cosine annealing, "linear" for + linear annealing. + Default: 'cos' + cycle_momentum (bool): If ``True``, momentum is cycled inversely + to learning rate between 'base_momentum' and 'max_momentum'. + Default: True + base_momentum (float or list): Lower momentum boundaries in the cycle + for each parameter group. Note that momentum is cycled inversely + to learning rate; at the peak of a cycle, momentum is + 'base_momentum' and learning rate is 'max_lr'. + Default: 0.85 + max_momentum (float or list): Upper momentum boundaries in the cycle + for each parameter group. Functionally, + it defines the cycle amplitude (max_momentum - base_momentum). + Note that momentum is cycled inversely + to learning rate; at the start of a cycle, momentum is 'max_momentum' + and learning rate is 'base_lr' + Default: 0.95 + div_factor (float): Determines the initial learning rate via + initial_lr = max_lr/div_factor + Default: 25 + final_div_factor (float): Determines the minimum learning rate via + min_lr = initial_lr/final_div_factor + Default: 1e4 + three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the + learning rate according to 'final_div_factor' instead of modifying the second + phase (the first two phases will be symmetrical about the step indicated by + 'pct_start'). + last_epoch (int): The index of the last batch. This parameter is used when + resuming a training job. Since `step()` should be invoked after each + batch instead of after each epoch, this number represents the total + number of *batches* computed, not the total number of epochs computed. + When last_epoch=-1, the schedule is started from the beginning. + Default: -1 + verbose (bool): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + Example: + >>> # xdoctest: +SKIP + >>> data_loader = torch.utils.data.DataLoader(...) + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10) + >>> for epoch in range(10): + >>> for batch in data_loader: + >>> train_batch(...) + >>> optimizer.step() + >>> scheduler.step() + + + .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates: + https://arxiv.org/abs/1708.07120 + """ + def __init__(self, + optimizer, + max_lr, + total_steps=None, + epochs=None, + steps_per_epoch=None, + pct_start=0.3, + anneal_strategy='cos', + cycle_momentum=True, + base_momentum=0.85, + max_momentum=0.95, + div_factor=25., + final_div_factor=1e4, + three_phase=False, + last_epoch=-1, + verbose=False): + + # Validate optimizer + if not isinstance(optimizer, Optimizer): + raise TypeError('{} is not an Optimizer'.format( + type(optimizer).__name__)) + self.optimizer = optimizer + + # Validate total_steps + if total_steps is None and epochs is None and steps_per_epoch is None: + raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)") + elif total_steps is not None: + if total_steps <= 0 or not isinstance(total_steps, int): + raise ValueError("Expected positive integer total_steps, but got {}".format(total_steps)) + self.total_steps = total_steps + else: + if epochs <= 0 or not isinstance(epochs, int): + raise ValueError("Expected positive integer epochs, but got {}".format(epochs)) + if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int): + raise ValueError("Expected positive integer steps_per_epoch, but got {}".format(steps_per_epoch)) + self.total_steps = epochs * steps_per_epoch + + if three_phase: + self._schedule_phases = [ + { + 'end_step': float(pct_start * self.total_steps) - 1, + 'start_lr': 'initial_lr', + 'end_lr': 'max_lr', + 'start_momentum': 'max_momentum', + 'end_momentum': 'base_momentum', + }, + { + 'end_step': float(2 * pct_start * self.total_steps) - 2, + 'start_lr': 'max_lr', + 'end_lr': 'initial_lr', + 'start_momentum': 'base_momentum', + 'end_momentum': 'max_momentum', + }, + { + 'end_step': self.total_steps - 1, + 'start_lr': 'initial_lr', + 'end_lr': 'min_lr', + 'start_momentum': 'max_momentum', + 'end_momentum': 'max_momentum', + }, + ] + else: + self._schedule_phases = [ + { + 'end_step': float(pct_start * self.total_steps) - 1, + 'start_lr': 'initial_lr', + 'end_lr': 'max_lr', + 'start_momentum': 'max_momentum', + 'end_momentum': 'base_momentum', + }, + { + 'end_step': self.total_steps - 1, + 'start_lr': 'max_lr', + 'end_lr': 'min_lr', + 'start_momentum': 'base_momentum', + 'end_momentum': 'max_momentum', + }, + ] + + # Validate pct_start + if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float): + raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start)) + + # Validate anneal_strategy + if anneal_strategy not in ['cos', 'linear']: + raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy)) + elif anneal_strategy == 'cos': + self.anneal_func = self._annealing_cos + elif anneal_strategy == 'linear': + self.anneal_func = self._annealing_linear + + # Initialize learning rate variables + max_lrs = self._format_param('max_lr', self.optimizer, max_lr) + if last_epoch == -1: + for idx, group in enumerate(self.optimizer.param_groups): + group['initial_lr'] = max_lrs[idx] / div_factor + group['max_lr'] = max_lrs[idx] + group['min_lr'] = group['initial_lr'] / final_div_factor + + # Initialize momentum variables + self.cycle_momentum = cycle_momentum + if self.cycle_momentum: + if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults: + raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled') + self.use_beta1 = 'betas' in self.optimizer.defaults + max_momentums = self._format_param('max_momentum', optimizer, max_momentum) + base_momentums = self._format_param('base_momentum', optimizer, base_momentum) + if last_epoch == -1: + for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups): + if self.use_beta1: + group['betas'] = (m_momentum, *group['betas'][1:]) + else: + group['momentum'] = m_momentum + group['max_momentum'] = m_momentum + group['base_momentum'] = b_momentum + + super().__init__(optimizer, last_epoch, verbose) + + def _format_param(self, name, optimizer, param): + """Return correctly formatted lr/momentum for each param group.""" + if isinstance(param, (list, tuple)): + if len(param) != len(optimizer.param_groups): + raise ValueError("expected {} values for {}, got {}".format( + len(optimizer.param_groups), name, len(param))) + return param + else: + return [param] * len(optimizer.param_groups) + + def _annealing_cos(self, start, end, pct): + "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0." + cos_out = math.cos(math.pi * pct) + 1 + return end + (start - end) / 2.0 * cos_out + + def _annealing_linear(self, start, end, pct): + "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0." + return (end - start) * pct + start + + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + lrs = [] + step_num = self.last_epoch + + if step_num > self.total_steps: + raise ValueError("Tried to step {} times. The specified number of total steps is {}" + .format(step_num, self.total_steps)) + + for group in self.optimizer.param_groups: + start_step = 0 + for i, phase in enumerate(self._schedule_phases): + end_step = phase['end_step'] + if step_num <= end_step or i == len(self._schedule_phases) - 1: + pct = (step_num - start_step) / (end_step - start_step) + computed_lr = self.anneal_func(group[phase['start_lr']], group[phase['end_lr']], pct) + if self.cycle_momentum: + computed_momentum = self.anneal_func(group[phase['start_momentum']], group[phase['end_momentum']], pct) + break + start_step = phase['end_step'] + + lrs.append(computed_lr) + if self.cycle_momentum: + if self.use_beta1: + group['betas'] = (computed_momentum, *group['betas'][1:]) + else: + group['momentum'] = computed_momentum + + return lrs \ No newline at end of file diff --git a/python/jittor/compatibility/src/jtorch_core.cc b/python/jittor/compatibility/src/jtorch_core.cc new file mode 100644 index 00000000..1102b107 --- /dev/null +++ b/python/jittor/compatibility/src/jtorch_core.cc @@ -0,0 +1,102 @@ + +#include "pyjt/py_obj_holder.h" +#include "utils/str_utils.h" +#include "jtorch_core.h" +#include "graph.h" +#include "grad.h" +#include "ops/op_register.h" + +namespace jittor { + +void pyjt_def_all(PyObject* m); + +EXTERN_LIB void setter_use_cuda(int value); + +Device::Device(const string& name, int ordinal) : name(name) { + if (startswith(name, "cpu")) + setter_use_cuda(0); + else + setter_use_cuda(1); +} + +unordered_map grad_backup; +EXTERN_LIB void (*_var_free_hook)(Var*); +EXTERN_LIB unordered_map* _grad_backup_ptr; + +void jtorch_var_free_hook(Var* v) { + auto iter = grad_backup.find(v->id); + if (iter != grad_backup.end()) { + grad_backup.erase(iter); + } +} + +void jtorch_init() { + _var_free_hook = &jtorch_var_free_hook; + _grad_backup_ptr = &grad_backup; +} + +inline static VarPtr& get_grad(Var* v) { + return grad_backup[v->id]; +} +static auto make_binary = get_op_info("binary") + .get_constructor(); + +inline static void add_grad(VarPtr& a, VarPtr&& b) { + if (!a) a = move(b); + else { + a = make_binary(a, b, ns_add); + } +} + + +void grad_set(VarHolder* x, Maybe v) { + if (!v) { + grad_del(x); + return; + } + grad_backup[x->var->id] = v.ptr->var; +} + +Maybe grad_get(VarHolder* x) { + auto iter = grad_backup.find(x->var->id); + if (iter != grad_backup.end()) { + if (!iter->second.ptr) return nullptr; + return new VarHolder(iter->second.ptr); + } + return nullptr; +} + +void grad_del(VarHolder* x) { + auto iter = grad_backup.find(x->var->id); + if (iter != grad_backup.end()) + grad_backup.erase(iter); +} + +void backward(VarHolder* x) { + vector gnodes({x->var}); + bfs_backward(gnodes, [&](Node* node) { + if (node->is_stop_grad()) + return false; + return true; + }); + vector targets; + for (auto* node : gnodes) { + if (node->is_var() && node->flags.get(NodeFlags::_th_require_grad)) + targets.push_back(node->var()); + } + auto grads = grad(x->var, targets); + for (int i=0; im_doc = "Inner c++ core of jtorch"; + jittor::pyjt_def_all(m); +} +PYJT_MODULE_INIT(jtorch_core); diff --git a/python/jittor/compatibility/src/jtorch_core.h b/python/jittor/compatibility/src/jtorch_core.h new file mode 100644 index 00000000..36de6522 --- /dev/null +++ b/python/jittor/compatibility/src/jtorch_core.h @@ -0,0 +1,40 @@ +#pragma once +#include "common.h" +#include "var_holder.h" +#include "misc/fast_shared_ptr.h" + +namespace jittor { + +// @pyjt(device) +// @attrs(heaptype) +struct Device { + string name; + + // @pyjt(__init__) + Device(const string& name, int ordinal=0); + // @pyjt(__get__type, __str__) + inline string get_type() {return name;} + // @pyjt(__get__index) + inline int index() {return 0;} +}; + +// @pyjt(backward) +void backward(VarHolder* x); + +// @pyjt(grad_set) +void grad_set(VarHolder* x, Maybe v); +// @pyjt(grad_get) +Maybe grad_get(VarHolder* x); +// @pyjt(grad_del) +void grad_del(VarHolder* x); + +// @pyjt(retain_grad_set) +inline void retain_grad_set(VarHolder* x, bool v) { + x->var->flags.set(NodeFlags::_th_require_grad, v); +} +// @pyjt(retain_grad_get) +inline bool retain_grad_get(VarHolder* x) { + return x->var->flags.get(NodeFlags::_th_require_grad); +} + +} \ No newline at end of file diff --git a/python/jittor/compatibility/test/test_conflict_func.py b/python/jittor/compatibility/test/test_conflict_func.py new file mode 100644 index 00000000..97bd7d8f --- /dev/null +++ b/python/jittor/compatibility/test/test_conflict_func.py @@ -0,0 +1,25 @@ +import unittest +import numpy as np +import torch +import jittor as jt + +class TestConflictFunc(unittest.TestCase): + def test_max(self): + a = torch.Tensor([1,4,2]) + assert a.max() == 4 + v, k = a.max(dim=0) + assert v==4 and k==1 + + def test_argsort(self): + a = torch.Tensor([1,4,2]) + k = a.argsort() + assert jt.all_equal(k, [0,2,1]) + + with jt.flag_scope(th_mode=0): + k, v = a.argsort() + assert jt.all_equal(k, [0,2,1]) + + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/compatibility/test/test_function.py b/python/jittor/compatibility/test/test_function.py new file mode 100644 index 00000000..9959dbae --- /dev/null +++ b/python/jittor/compatibility/test/test_function.py @@ -0,0 +1,58 @@ +import unittest +import numpy as np +import torch + +class TestFunction(unittest.TestCase): + def test_example1(self): + import jtorch + from jtorch import Function + + class MyFunc(Function): + @staticmethod + def forward(self, x, y): + self.x = x + self.y = y + return x*y, x/y + + @staticmethod + def backward(self, grad0, grad1): + return grad0 * self.y, grad1 * self.x + + a = jtorch.array(3.0) + a.requires_grad = True + b = jtorch.array(4.0) + b.requires_grad = True + func = MyFunc.apply + c,d = func(a, b) + (c+d*3).backward() + assert a.grad.data == 4 + assert b.grad.data == 9 + + def test_example2(self): + import jtorch as jt + from jtorch import Function + + class MyFunc(Function): + @staticmethod + def forward(self, x, y): + self.x = x + self.y = y + return x*y, x/y + + @staticmethod + def backward(self, grad0, grad1): + assert grad1 is None + return grad0 * self.y, None + a = jt.array(3.0) + a.requires_grad = True + b = jt.array(4.0) + b.requires_grad = True + func = MyFunc.apply + c,d = func(a, b) + d.stop_grad() + da, db = jt.grad(c+d*3, [a, b]) + assert da.data == 4 + assert db.data == 0 + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/compatibility/test/test_misc.py b/python/jittor/compatibility/test/test_misc.py new file mode 100644 index 00000000..00bf1b70 --- /dev/null +++ b/python/jittor/compatibility/test/test_misc.py @@ -0,0 +1,24 @@ +import unittest +import numpy as np +import torch + +class TestMisc(unittest.TestCase): + def test_update_grad(self): + class Net(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.nn.Parameter(torch.Tensor([1.0, 2.0])) + net = Net() + assert(net.a.requires_grad) + net.load_state_dict({"a": torch.Tensor([3.0, 4.0])}) + assert(net.a.requires_grad) + + def test_reshape(self): + a = torch.ones(3,3) + a.requires_grad = True + b = torch.reshape(a, [9]) + assert b.requires_grad == True + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/compatibility/test/test_tutorial.py b/python/jittor/compatibility/test/test_tutorial.py new file mode 100644 index 00000000..92c087c7 --- /dev/null +++ b/python/jittor/compatibility/test/test_tutorial.py @@ -0,0 +1,56 @@ +import unittest +import numpy as np +import os +import subprocess as sp +import sys + +def check_two(cmd, parser=None, checker=None): + jtorch_out = sp.getoutput(cmd) + print("=========JTORCH OUT==========") + print(jtorch_out) + torch_out = sp.getoutput("PYTHONPATH= "+cmd) + print("=========TORCH OUT==========") + print(torch_out) + if parser: + torch_out = parser(torch_out) + jtorch_out = parser(jtorch_out) + if checker: + checker(torch_out, jtorch_out) + else: + assert torch_out == jtorch_out + return jtorch_out, torch_out + +jtorch_path = os.path.join(os.path.dirname(__file__), "..") +# come from https://pytorch.org/tutorials/beginner/pytorch_with_examples.html +class TestTutorial(unittest.TestCase): + def test_auto_grad1(self): + check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad1.py", + parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), + checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) + def test_auto_grad2(self): + check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad2.py", + parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), + checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) + def test_auto_grad3(self): + check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad3.py", + parser=lambda s: np.array(s.split())[[-9,-7,-4,-2]].astype(float), + checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) + def test_auto_grad4(self): + check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad4.py", + parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), + checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) + def test_auto_grad5(self): + check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad5_optim.py", + parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), + checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-2)) + def test_auto_grad6(self): + check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad6_module.py", + parser=lambda s: np.array(s.split())[[-10,-8,-5,-2]].astype(float), + checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-4)) + def test_auto_grad7(self): + check_two(f"{sys.executable} {jtorch_path}/tutorial/auto_grad7_dynet.py", + parser=lambda s: np.array(s.split())[[-13,-10,-7,-3]].astype(float), + checker=lambda a,b: np.testing.assert_allclose(a, b, atol=1e-2)) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad1.py b/python/jittor/compatibility/tutorial/auto_grad1.py new file mode 100644 index 00000000..60a090ad --- /dev/null +++ b/python/jittor/compatibility/tutorial/auto_grad1.py @@ -0,0 +1,44 @@ +import torch +import math + +dtype = torch.float +device = torch.device("cpu") +# device = torch.device("cuda:0") # Uncomment this to run on GPU + +# Create random input and output data +x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) +y = torch.sin(x) + +# Randomly initialize weights +a = torch.randn((), device=device, dtype=dtype) +b = torch.randn((), device=device, dtype=dtype) +c = torch.randn((), device=device, dtype=dtype) +d = torch.randn((), device=device, dtype=dtype) + +learning_rate = 1e-6 +for t in range(20000): + # Forward pass: compute predicted y + y_pred = a + b * x + c * x ** 2 + d * x ** 3 + + # Compute and print loss + loss = (y_pred - y).pow(2).sum().item() + if t % 1000 == 999: + print(t, loss) + + # Backprop to compute gradients of a, b, c, d with respect to loss + grad_y_pred = 2.0 * (y_pred - y) + grad_a = grad_y_pred.sum() + grad_b = (grad_y_pred * x).sum() + grad_c = (grad_y_pred * x ** 2).sum() + grad_d = (grad_y_pred * x ** 3).sum() + + # Update weights using gradient descent + a -= learning_rate * grad_a + b -= learning_rate * grad_b + c -= learning_rate * grad_c + d -= learning_rate * grad_d + # print(t, torch.liveness_info()) + # torch.sync_all() + + +print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad2.py b/python/jittor/compatibility/tutorial/auto_grad2.py new file mode 100644 index 00000000..a3bbc9a8 --- /dev/null +++ b/python/jittor/compatibility/tutorial/auto_grad2.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- +import torch +import math + +dtype = torch.float +device = torch.device("cpu") +# device = torch.device("cuda:0") # Uncomment this to run on GPU + +# Create Tensors to hold input and outputs. +# By default, requires_grad=False, which indicates that we do not need to +# compute gradients with respect to these Tensors during the backward pass. +x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) +y = torch.sin(x) + +# Create random Tensors for weights. For a third order polynomial, we need +# 4 weights: y = a + b x + c x^2 + d x^3 +# Setting requires_grad=True indicates that we want to compute gradients with +# respect to these Tensors during the backward pass. +a = torch.randn((), device=device, dtype=dtype, requires_grad=True) +b = torch.randn((), device=device, dtype=dtype, requires_grad=True) +c = torch.randn((), device=device, dtype=dtype, requires_grad=True) +d = torch.randn((), device=device, dtype=dtype, requires_grad=True) + +learning_rate = 1e-6 +for t in range(20000): + # Forward pass: compute predicted y using operations on Tensors. + y_pred = a + b * x + c * x ** 2 + d * x ** 3 + # print(y_pred.requires_grad) + # y_pred.requires_grad = False + + # Compute and print loss using operations on Tensors. + # Now loss is a Tensor of shape (1,) + # loss.item() gets the scalar value held in the loss. + loss = (y_pred - y).pow(2).sum() + if t % 1000 == 990: + print(t, loss.item()) + + # Use autograd to compute the backward pass. This call will compute the + # gradient of loss with respect to all Tensors with requires_grad=True. + # After this call a.grad, b.grad. c.grad and d.grad will be Tensors holding + # the gradient of the loss with respect to a, b, c, d respectively. + # torch.backward(loss) + loss.backward() + + # Manually update weights using gradient descent. Wrap in torch.no_grad() + # because weights have requires_grad=True, but we don't need to track this + # in autograd. + with torch.no_grad(): + a -= learning_rate * a.grad + b -= learning_rate * b.grad + c -= learning_rate * c.grad + d -= learning_rate * d.grad + + # Manually zero the gradients after updating weights + a.grad = None + b.grad = None + c.grad = None + d.grad = None + +print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad3.py b/python/jittor/compatibility/tutorial/auto_grad3.py new file mode 100644 index 00000000..654ec447 --- /dev/null +++ b/python/jittor/compatibility/tutorial/auto_grad3.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +import torch +import math + + +class LegendrePolynomial3(torch.autograd.Function): + """ + We can implement our own custom autograd Functions by subclassing + torch.autograd.Function and implementing the forward and backward passes + which operate on Tensors. + """ + + @staticmethod + def forward(ctx, input): + """ + In the forward pass we receive a Tensor containing the input and return + a Tensor containing the output. ctx is a context object that can be used + to stash information for backward computation. You can cache arbitrary + objects for use in the backward pass using the ctx.save_for_backward method. + """ + ctx.save_for_backward(input) + return 0.5 * (5 * input ** 3 - 3 * input) + + @staticmethod + def backward(ctx, grad_output): + """ + In the backward pass we receive a Tensor containing the gradient of the loss + with respect to the output, and we need to compute the gradient of the loss + with respect to the input. + """ + input, = ctx.saved_tensors + return grad_output * 1.5 * (5 * input ** 2 - 1) + + +dtype = torch.float +device = torch.device("cpu") +# device = torch.device("cuda:0") # Uncomment this to run on GPU + +# Create Tensors to hold input and outputs. +# By default, requires_grad=False, which indicates that we do not need to +# compute gradients with respect to these Tensors during the backward pass. +x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) +y = torch.sin(x) + +# Create random Tensors for weights. For this example, we need +# 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized +# not too far from the correct result to ensure convergence. +# Setting requires_grad=True indicates that we want to compute gradients with +# respect to these Tensors during the backward pass. +a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True) +b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True) +c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True) +d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True) + +learning_rate = 5e-6 +for t in range(2000): + # To apply our Function, we use Function.apply method. We alias this as 'P3'. + P3 = LegendrePolynomial3.apply + + # Forward pass: compute predicted y using operations; we compute + # P3 using our custom autograd operation. + y_pred = a + b * P3(c + d * x) + + # Compute and print loss + loss = (y_pred - y).pow(2).sum() + if t % 100 == 99: + print(t, loss.item()) + + # Use autograd to compute the backward pass. + loss.backward() + + # Update weights using gradient descent + with torch.no_grad(): + a -= learning_rate * a.grad + b -= learning_rate * b.grad + c -= learning_rate * c.grad + d -= learning_rate * d.grad + + # Manually zero the gradients after updating weights + a.grad = None + b.grad = None + c.grad = None + d.grad = None + +print(f'Result: y = {a.item()} + {b.item()} * P3( {c.item()} + {d.item()} x)') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad4.py b/python/jittor/compatibility/tutorial/auto_grad4.py new file mode 100644 index 00000000..062d0b0e --- /dev/null +++ b/python/jittor/compatibility/tutorial/auto_grad4.py @@ -0,0 +1,71 @@ +# -*- coding: utf-8 -*- +import torch +import math + + +# Create Tensors to hold input and outputs. +x = torch.linspace(-math.pi, math.pi, 2000) +y = torch.sin(x) + +# For this example, the output y is a linear function of (x, x^2, x^3), so +# we can consider it as a linear layer neural network. Let's prepare the +# tensor (x, x^2, x^3). +p = torch.tensor([1, 2, 3]) +xx = x.unsqueeze(-1).pow(p) + +# In the above code, x.unsqueeze(-1) has shape (2000, 1), and p has shape +# (3,), for this case, broadcasting semantics will apply to obtain a tensor +# of shape (2000, 3) + +# Use the nn package to define our model as a sequence of layers. nn.Sequential +# is a Module which contains other Modules, and applies them in sequence to +# produce its output. The Linear Module computes output from input using a +# linear function, and holds internal Tensors for its weight and bias. +# The Flatten layer flatens the output of the linear layer to a 1D tensor, +# to match the shape of `y`. +model = torch.nn.Sequential( + torch.nn.Linear(3, 1), + torch.nn.Flatten(0, 1) +) + +# The nn package also contains definitions of popular loss functions; in this +# case we will use Mean Squared Error (MSE) as our loss function. +loss_fn = torch.nn.MSELoss(reduction='sum') +# print(model[0].weight.requires_grad) + +learning_rate = 1e-6 +for t in range(8000): + + # Forward pass: compute predicted y by passing x to the model. Module objects + # override the __call__ operator so you can call them like functions. When + # doing so you pass a Tensor of input data to the Module and it produces + # a Tensor of output data. + y_pred = model(xx) + + # Compute and print loss. We pass Tensors containing the predicted and true + # values of y, and the loss function returns a Tensor containing the + # loss. + loss = loss_fn(y_pred, y) + if t % 1000 == 999: + print(t, loss.item()) + + # Zero the gradients before running the backward pass. + model.zero_grad() + + # Backward pass: compute gradient of the loss with respect to all the learnable + # parameters of the model. Internally, the parameters of each Module are stored + # in Tensors with requires_grad=True, so this call will compute gradients for + # all learnable parameters in the model. + loss.backward() + + # Update the weights using gradient descent. Each parameter is a Tensor, so + # we can access its gradients like we did before. + with torch.no_grad(): + for param in model.parameters(): + param -= learning_rate * param.grad + +# You can access the first layer of `model` like accessing the first item of a list +linear_layer = model[0] + +# For linear layer, its parameters are stored as `weight` and `bias`. +print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad5_optim.py b/python/jittor/compatibility/tutorial/auto_grad5_optim.py new file mode 100644 index 00000000..04949320 --- /dev/null +++ b/python/jittor/compatibility/tutorial/auto_grad5_optim.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +import torch +import math + + +# Create Tensors to hold input and outputs. +x = torch.linspace(-math.pi, math.pi, 2000) +y = torch.sin(x) + +# Prepare the input tensor (x, x^2, x^3). +p = torch.tensor([1, 2, 3]) +xx = x.unsqueeze(-1).pow(p) + +# Use the nn package to define our model and loss function. +model = torch.nn.Sequential( + torch.nn.Linear(3, 1), + torch.nn.Flatten(0, 1) +) +loss_fn = torch.nn.MSELoss(reduction='sum') + +# Use the optim package to define an Optimizer that will update the weights of +# the model for us. Here we will use RMSprop; the optim package contains many other +# optimization algorithms. The first argument to the RMSprop constructor tells the +# optimizer which Tensors it should update. +learning_rate = 1e-3 +optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate) +for t in range(8000): + # Forward pass: compute predicted y by passing x to the model. + y_pred = model(xx) + + # Compute and print loss. + loss = loss_fn(y_pred, y) + if t % 1000 == 999: + print(t, loss.item()) + + # Before the backward pass, use the optimizer object to zero all of the + # gradients for the variables it will update (which are the learnable + # weights of the model). This is because by default, gradients are + # accumulated in buffers( i.e, not overwritten) whenever .backward() + # is called. Checkout docs of torch.autograd.backward for more details. + optimizer.zero_grad() + + # Backward pass: compute gradient of the loss with respect to model + # parameters + loss.backward() + + # Calling the step function on an Optimizer makes an update to its + # parameters + optimizer.step() + + +linear_layer = model[0] +print(f'Result: y = {linear_layer.bias.item()} + {linear_layer.weight[:, 0].item()} x + {linear_layer.weight[:, 1].item()} x^2 + {linear_layer.weight[:, 2].item()} x^3') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad6_module.py b/python/jittor/compatibility/tutorial/auto_grad6_module.py new file mode 100644 index 00000000..a240e2b5 --- /dev/null +++ b/python/jittor/compatibility/tutorial/auto_grad6_module.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- +import torch +import math + + +class Polynomial3(torch.nn.Module): + def __init__(self): + """ + In the constructor we instantiate four parameters and assign them as + member parameters. + """ + super().__init__() + self.a = torch.nn.Parameter(torch.randn(())) + self.b = torch.nn.Parameter(torch.randn(())) + self.c = torch.nn.Parameter(torch.randn(())) + self.d = torch.nn.Parameter(torch.randn(())) + + def forward(self, x): + """ + In the forward function we accept a Tensor of input data and we must return + a Tensor of output data. We can use Modules defined in the constructor as + well as arbitrary operators on Tensors. + """ + return self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3 + + def string(self): + """ + Just like any class in Python, you can also define custom method on PyTorch modules + """ + return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3' + + +# Create Tensors to hold input and outputs. +x = torch.linspace(-math.pi, math.pi, 2000) +y = torch.sin(x) + +# Construct our model by instantiating the class defined above +model = Polynomial3() + +# Construct our loss function and an Optimizer. The call to model.parameters() +# in the SGD constructor will contain the learnable parameters (defined +# with torch.nn.Parameter) which are members of the model. +criterion = torch.nn.MSELoss(reduction='sum') +optimizer = torch.optim.SGD(model.parameters(), lr=1e-6) +for t in range(8000): + # Forward pass: Compute predicted y by passing x to the model + y_pred = model(x) + + # Compute and print loss + loss = criterion(y_pred, y) + if t % 1000 == 999: + print(t, loss.item()) + + # Zero gradients, perform a backward pass, and update the weights. + optimizer.zero_grad() + loss.backward() + optimizer.step() + +print(f'Result: {model.string()}') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/auto_grad7_dynet.py b/python/jittor/compatibility/tutorial/auto_grad7_dynet.py new file mode 100644 index 00000000..fa954771 --- /dev/null +++ b/python/jittor/compatibility/tutorial/auto_grad7_dynet.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +import random +import torch +import math + + +class DynamicNet(torch.nn.Module): + def __init__(self): + """ + In the constructor we instantiate five parameters and assign them as members. + """ + super().__init__() + self.a = torch.nn.Parameter(torch.randn(())) + self.b = torch.nn.Parameter(torch.randn(())) + self.c = torch.nn.Parameter(torch.randn(())) + self.d = torch.nn.Parameter(torch.randn(())) + self.e = torch.nn.Parameter(torch.randn(())) + + def forward(self, x): + """ + For the forward pass of the model, we randomly choose either 4, 5 + and reuse the e parameter to compute the contribution of these orders. + + Since each forward pass builds a dynamic computation graph, we can use normal + Python control-flow operators like loops or conditional statements when + defining the forward pass of the model. + + Here we also see that it is perfectly safe to reuse the same parameter many + times when defining a computational graph. + """ + y = self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3 + for exp in range(4, random.randint(4, 6)): + y = y + self.e * x ** exp + return y + + def string(self): + """ + Just like any class in Python, you can also define custom method on PyTorch modules + """ + return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3 + {self.e.item()} x^4 ? + {self.e.item()} x^5 ?' + + +# Create Tensors to hold input and outputs. +x = torch.linspace(-math.pi, math.pi, 2000) +y = torch.sin(x) + +# Construct our model by instantiating the class defined above +model = DynamicNet() + +# Construct our loss function and an Optimizer. Training this strange model with +# vanilla stochastic gradient descent is tough, so we use momentum +criterion = torch.nn.MSELoss(reduction='sum') +optimizer = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9) +for t in range(60000): + # Forward pass: Compute predicted y by passing x to the model + y_pred = model(x) + + # Compute and print loss + loss = criterion(y_pred, y) + if t % 2000 == 1999: + print(t, loss.item()) + + # Zero gradients, perform a backward pass, and update the weights. + optimizer.zero_grad() + loss.backward() + optimizer.step() + # print(torch.liveness_info()) + +print(f'Result: {model.string()}') \ No newline at end of file diff --git a/python/jittor/compatibility/tutorial/quickstart.py b/python/jittor/compatibility/tutorial/quickstart.py new file mode 100644 index 00000000..f0401a9b --- /dev/null +++ b/python/jittor/compatibility/tutorial/quickstart.py @@ -0,0 +1,106 @@ +import torch +from torch import nn +# from jtorch.utils import DataLoader +from torch.utils.data import DataLoader +from torchvision import datasets +from torchvision.transforms import ToTensor + +# Download training data from open datasets. +training_data = datasets.FashionMNIST( + root="data", + train=True, + download=True, + transform=ToTensor(), +) + +# Download test data from open datasets. +test_data = datasets.FashionMNIST( + root="data", + train=False, + download=True, + transform=ToTensor(), +) + +batch_size = 64 + +# Create data loaders. +train_dataloader = DataLoader(training_data, batch_size=batch_size) +test_dataloader = DataLoader(test_data, batch_size=batch_size) + +print(len(train_dataloader)) +for X, y in test_dataloader: + print(f"Shape of X [N, C, H, W]: {X.shape}") + print(f"Shape of y: {y.shape} {y.dtype}") + break + +# Get cpu or gpu device for training. +device = "cuda" if torch.cuda.is_available() else "cpu" +print(f"Using {device} device") + +# Define model +class NeuralNetwork(nn.Module): + def __init__(self): + super(NeuralNetwork, self).__init__() + self.flatten = nn.Flatten() + self.linear_relu_stack = nn.Sequential( + nn.Linear(28*28, 512), + nn.ReLU(), + nn.Linear(512, 512), + nn.ReLU(), + nn.Linear(512, 10) + ) + + def forward(self, x): + x = self.flatten(x) + logits = self.linear_relu_stack(x) + return logits + +model = NeuralNetwork().to(device) +print(model) + + +loss_fn = nn.CrossEntropyLoss() +optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + +def train(dataloader, model, loss_fn, optimizer): + size = len(dataloader.dataset) + model.train() + for batch, (X, y) in enumerate(dataloader): + X, y = X.to(device), y.to(device) + + # Compute prediction error + pred = model(X) + loss = loss_fn(pred, y) + + # Backpropagation + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if batch % 100 == 0: + loss, current = loss.item(), batch * len(X) + print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]") + +def test(dataloader, model, loss_fn): + size = len(dataloader.dataset) + num_batches = len(dataloader) + model.eval() + test_loss, correct = 0, 0 + with torch.no_grad(): + for X, y in dataloader: + X, y = X.to(device), y.to(device) + pred = model(X) + test_loss += loss_fn(pred, y).item() + correct += (pred.argmax(1) == y).type(torch.float).sum().item() + test_loss /= num_batches + correct /= size + print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n") + + +epochs = 5 +test(test_dataloader, model, loss_fn) +for t in range(epochs): + print(f"Epoch {t+1}\n-------------------------------") + train(train_dataloader, model, loss_fn, optimizer) + test(test_dataloader, model, loss_fn) +print("Done!") \ No newline at end of file diff --git a/python/jittor/compatibility/utils/__init__.py b/python/jittor/compatibility/utils/__init__.py new file mode 100644 index 00000000..ac2c2bd8 --- /dev/null +++ b/python/jittor/compatibility/utils/__init__.py @@ -0,0 +1,5 @@ +cpp_extension = None +_flatten_dense_tensors = None +_unflatten_dense_tensors = None + +tensorboard = None \ No newline at end of file diff --git a/python/jittor/compatibility/utils/_pytree.py b/python/jittor/compatibility/utils/_pytree.py new file mode 100644 index 00000000..c3118964 --- /dev/null +++ b/python/jittor/compatibility/utils/_pytree.py @@ -0,0 +1,3 @@ +#TODO: Implement this +_register_pytree_node = None +_dict_flatten = None \ No newline at end of file diff --git a/python/jittor/compatibility/utils/checkpoint.py b/python/jittor/compatibility/utils/checkpoint.py new file mode 100644 index 00000000..ba3c3e8e --- /dev/null +++ b/python/jittor/compatibility/utils/checkpoint.py @@ -0,0 +1,8 @@ +detach_variable = None + + +def checkpoint( + *args, + **kwargs +): + pass diff --git a/python/jittor/compatibility/utils/data.py b/python/jittor/compatibility/utils/data.py new file mode 100644 index 00000000..5fcfcaa6 --- /dev/null +++ b/python/jittor/compatibility/utils/data.py @@ -0,0 +1,137 @@ +import jittor as jt +import jittor.dataset +from jittor.dataset import Dataset as JDataset + +from collections import namedtuple +from typing import Any, Callable, Iterable, Optional, Sequence, Union + + +class Dataset: + def __getitem__(self, index): + raise NotImplementedError + +class IterableDataset: + def __iter__(self): + raise NotImplementedError + + +class DataLoader(JDataset): + def __init__(self, dataset, + batch_size: Optional[int] = 1, + shuffle: Optional[bool] = False, + sampler = None, + batch_sampler = None, + num_workers: int = 0, + collate_fn = None, + pin_memory: bool = False, + drop_last: bool = False, + timeout: float = 0, + worker_init_fn = None, + multiprocessing_context=None, + generator=None, + *, prefetch_factor: int = 2, + persistent_workers: bool = False, + pin_memory_device: str = "") -> None: + super().__init__(batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + drop_last=drop_last) + + unsupported_kwargs = { + "batch_sampler": batch_sampler, + "pin_memory": pin_memory, + "timeout": timeout, + "worker_init_fn": worker_init_fn, + "multiprocessing_context": multiprocessing_context, + "generator": generator, + "persistent_workers": persistent_workers, + "pin_memory_device": pin_memory_device + } + for kwarg, value in unsupported_kwargs.items(): + if value: + jt.LOG.w(f"Not implemented Dataloader kwarg: {kwarg}") + + self.dataset = dataset + self.collate_fn = collate_fn + self.sampler = sampler + + if not isinstance(dataset, IterableDataset): + self.total_len = len(dataset) + else: + # TODO: support multiple worker for iterable dataset + assert(num_workers == 0) + + def collate_batch(self, batch): + if self.collate_fn is not None: + return self.collate_fn(batch) + else: + return super().collate_batch(batch) + + def __getitem__(self, i): + return self.dataset[i] + + def __iter__(self): + if isinstance(self.dataset, IterableDataset): + return self.inner_iter() + else: + return super().__iter__() + + def inner_iter(self): + current_batch = [] + + if jt.world_size > 1: + assert self.batch_size % jt.world_size == 0, \ + f"IterableDataset does not support a batch size ({self.batch_size}) that is not evenly divisible by the number of processes f{jt.world_size}" + real_batch_size = int(self.batch_size / jt.world_size) + else: + real_batch_size = self.batch_size + + for element in self.dataset: + current_batch.append(element) + + if len(current_batch) == real_batch_size: + current_batch = self.collate_batch(current_batch) + current_batch = self.to_jittor(current_batch) + yield current_batch + current_batch = [] + + if not self.drop_last and len(current_batch) > 0: + current_batch = self.collate_batch(current_batch) + yield self.to_jittor(current_batch) + +# def get_worker_info(): +# # always return the fake worker info +# return namedtuple('WorkerInfo', 'id num_workers')(0, 1) + +# class RandomSampler(jt.dataset.RandomSampler): +# def __init__(self, dataset, generator=None, **kwargs): +# super().__init__(dataset, **kwargs) + +# def __iter__(self): +# if getattr(self.dataset, "support_random_access", True): +# return super().__iter__() +# else: +# self.dataset.shuffle() +# return iter(range(self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__())) + +# class DistributedSampler(jt.dataset.Sampler): +# def __init__(self, sampler: RandomSampler): +# assert(isinstance(sampler, RandomSampler)) +# self.sampler = sampler + +# def set_epoch(self, epoch: int): +# ### do nothing, let jittor's inner dataset handle +# pass + +# def __iter__(self): +# return self.sampler.__iter__() + +# def __len__(self): +# return self.sampler.__len__() + +# BatchSampler = jt.dataset.BatchSampler +# Sampler = jt.dataset.Sampler +# SequentialSampler = jt.dataset.SequentialSampler +# SubsetRandomSampler = jt.dataset.SubsetRandomSampler + +# TensorDataset = Dataset diff --git a/python/jittor/compatibility/utils/dtype.py b/python/jittor/compatibility/utils/dtype.py new file mode 100644 index 00000000..41728383 --- /dev/null +++ b/python/jittor/compatibility/utils/dtype.py @@ -0,0 +1,9 @@ +from typing import Callable, Union +Dtype = Union[Callable, str] + +def get_string_dtype(dtype): + if callable(dtype): + dtype = dtype.__name__ + if not isinstance(dtype, str): + raise ValueError(f"dtype is expected to be str, python type function, or jittor type function, but got {dtype}.") + return dtype \ No newline at end of file diff --git a/python/jittor/compatibility/utils/hooks.py b/python/jittor/compatibility/utils/hooks.py new file mode 100644 index 00000000..e69de29b diff --git a/python/jittor/compatibility/utils/pip_publish.py b/python/jittor/compatibility/utils/pip_publish.py new file mode 100644 index 00000000..72ff245f --- /dev/null +++ b/python/jittor/compatibility/utils/pip_publish.py @@ -0,0 +1,34 @@ +import os +import glob +import shutil +import sys + +home_path = os.path.join(os.path.dirname(__file__), "..", "..", "..") +home_path = os.path.abspath(home_path) + +def callback(func, path, exc_info): + print(f"remove \"{path}\" failed.") + +def rmtree(path): + if os.path.isdir(path): + print(f"remove \"{path}\" recursive.") + shutil.rmtree(path, onerror=callback) + +def remove_tmpfile(): + dist_file = home_path+"/dist" + egg_file = glob.glob(home_path+"/**/*egg-info") + rmtree(dist_file) + for e in egg_file: + rmtree(e) + +def run_cmd(cmd): + print("[CMD]", cmd) + assert os.system(cmd)==0 + +os.chdir(home_path) +remove_tmpfile() + +run_cmd(f"{sys.executable} ./setup.py sdist") +run_cmd(f"{sys.executable} -m twine upload dist/*") + +remove_tmpfile() \ No newline at end of file diff --git a/python/jittor/compatibility/vision/_internally_replaced_utils.py b/python/jittor/compatibility/vision/_internally_replaced_utils.py new file mode 100644 index 00000000..748fa2ea --- /dev/null +++ b/python/jittor/compatibility/vision/_internally_replaced_utils.py @@ -0,0 +1,46 @@ +import importlib.machinery +import os + + +def _download_file_from_remote_location(fpath: str, url: str) -> None: + pass + + +def _is_remote_location_available() -> bool: + return False + + +def _get_extension_path(lib_name): + + lib_dir = os.path.dirname(__file__) + if os.name == "nt": + # Register the main torchvision library location on the default DLL path + import ctypes + import sys + + kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) + with_load_library_flags = hasattr(kernel32, "AddDllDirectory") + prev_error_mode = kernel32.SetErrorMode(0x0001) + + if with_load_library_flags: + kernel32.AddDllDirectory.restype = ctypes.c_void_p + + if sys.version_info >= (3, 8): + os.add_dll_directory(lib_dir) + elif with_load_library_flags: + res = kernel32.AddDllDirectory(lib_dir) + if res is None: + err = ctypes.WinError(ctypes.get_last_error()) + err.strerror += f' Error adding "{lib_dir}" to the DLL directories.' + raise err + + kernel32.SetErrorMode(prev_error_mode) + + loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES) + + extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) + ext_specs = extfinder.find_spec(lib_name) + if ext_specs is None: + raise ImportError + + return ext_specs.origin \ No newline at end of file diff --git a/python/jittor/compatibility/vision/datasets/__init__.py b/python/jittor/compatibility/vision/datasets/__init__.py new file mode 100644 index 00000000..d04187f1 --- /dev/null +++ b/python/jittor/compatibility/vision/datasets/__init__.py @@ -0,0 +1,9 @@ +from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST + +__all__ = ( + "EMNIST", + "FashionMNIST", + "QMNIST", + "MNIST", + "KMNIST", +) \ No newline at end of file diff --git a/python/jittor/compatibility/vision/datasets/mnist.py b/python/jittor/compatibility/vision/datasets/mnist.py new file mode 100644 index 00000000..dfc3787b --- /dev/null +++ b/python/jittor/compatibility/vision/datasets/mnist.py @@ -0,0 +1,558 @@ +import codecs +import os +import os.path +import shutil +import string +import sys +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple +from urllib.error import URLError + +import numpy as np +import torch +from PIL import Image + +from .utils import check_integrity, download_and_extract_archive, extract_archive, verify_str_arg +from .vision import VisionDataset + + +class MNIST(VisionDataset): + """`MNIST `_ Dataset. + + Args: + root (string): Root directory of dataset where ``MNIST/raw/train-images-idx3-ubyte`` + and ``MNIST/raw/t10k-images-idx3-ubyte`` exist. + train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, + otherwise from ``t10k-images-idx3-ubyte``. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + mirrors = [ + "http://yann.lecun.com/exdb/mnist/", + "https://ossci-datasets.s3.amazonaws.com/mnist/", + ] + + resources = [ + ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), + ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), + ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"), + ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"), + ] + + training_file = "training.pt" + test_file = "test.pt" + classes = [ + "0 - zero", + "1 - one", + "2 - two", + "3 - three", + "4 - four", + "5 - five", + "6 - six", + "7 - seven", + "8 - eight", + "9 - nine", + ] + + @property + def train_labels(self): + warnings.warn("train_labels has been renamed targets") + return self.targets + + @property + def test_labels(self): + warnings.warn("test_labels has been renamed targets") + return self.targets + + @property + def train_data(self): + warnings.warn("train_data has been renamed data") + return self.data + + @property + def test_data(self): + warnings.warn("test_data has been renamed data") + return self.data + + def __init__( + self, + root: str, + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super().__init__(root, transform=transform, target_transform=target_transform) + self.train = train # training set or test set + + if self._check_legacy_exist(): + self.data, self.targets = self._load_legacy_data() + return + + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError("Dataset not found. You can use download=True to download it") + + self.data, self.targets = self._load_data() + + def _check_legacy_exist(self): + processed_folder_exists = os.path.exists(self.processed_folder) + if not processed_folder_exists: + return False + + return all( + check_integrity(os.path.join(self.processed_folder, file)) for file in (self.training_file, self.test_file) + ) + + def _load_legacy_data(self): + # This is for BC only. We no longer cache the data in a custom binary, but simply read from the raw data + # directly. + data_file = self.training_file if self.train else self.test_file + return torch.load(os.path.join(self.processed_folder, data_file)) + + def _load_data(self): + image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte" + data = read_image_file(os.path.join(self.raw_folder, image_file)) + + label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte" + targets = read_label_file(os.path.join(self.raw_folder, label_file)) + + return data, targets + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is index of the target class. + """ + img, target = self.data[index], int(self.targets[index]) + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(img.numpy(), mode="L") + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self) -> int: + return len(self.data) + + @property + def raw_folder(self) -> str: + return os.path.join(self.root, self.__class__.__name__, "raw") + + @property + def processed_folder(self) -> str: + return os.path.join(self.root, self.__class__.__name__, "processed") + + @property + def class_to_idx(self) -> Dict[str, int]: + return {_class: i for i, _class in enumerate(self.classes)} + + def _check_exists(self) -> bool: + return all( + check_integrity(os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0])) + for url, _ in self.resources + ) + + def download(self) -> None: + """Download the MNIST data if it doesn't exist already.""" + + if self._check_exists(): + return + + os.makedirs(self.raw_folder, exist_ok=True) + + # download files + for filename, md5 in self.resources: + for mirror in self.mirrors: + url = f"{mirror}{filename}" + try: + print(f"Downloading {url}") + download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) + except URLError as error: + print(f"Failed to download (trying next):\n{error}") + continue + finally: + print() + break + else: + raise RuntimeError(f"Error downloading {filename}") + + def extra_repr(self) -> str: + split = "Train" if self.train is True else "Test" + return f"Split: {split}" + + +class FashionMNIST(MNIST): + """`Fashion-MNIST `_ Dataset. + + Args: + root (string): Root directory of dataset where ``FashionMNIST/raw/train-images-idx3-ubyte`` + and ``FashionMNIST/raw/t10k-images-idx3-ubyte`` exist. + train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, + otherwise from ``t10k-images-idx3-ubyte``. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + mirrors = ["http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/"] + + resources = [ + ("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"), + ("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"), + ("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"), + ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310"), + ] + classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"] + + +class KMNIST(MNIST): + """`Kuzushiji-MNIST `_ Dataset. + + Args: + root (string): Root directory of dataset where ``KMNIST/raw/train-images-idx3-ubyte`` + and ``KMNIST/raw/t10k-images-idx3-ubyte`` exist. + train (bool, optional): If True, creates dataset from ``train-images-idx3-ubyte``, + otherwise from ``t10k-images-idx3-ubyte``. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + mirrors = ["http://codh.rois.ac.jp/kmnist/dataset/kmnist/"] + + resources = [ + ("train-images-idx3-ubyte.gz", "bdb82020997e1d708af4cf47b453dcf7"), + ("train-labels-idx1-ubyte.gz", "e144d726b3acfaa3e44228e80efcd344"), + ("t10k-images-idx3-ubyte.gz", "5c965bf0a639b31b8f53240b1b52f4d7"), + ("t10k-labels-idx1-ubyte.gz", "7320c461ea6c1c855c0b718fb2a4b134"), + ] + classes = ["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"] + + +class EMNIST(MNIST): + """`EMNIST `_ Dataset. + + Args: + root (string): Root directory of dataset where ``EMNIST/raw/train-images-idx3-ubyte`` + and ``EMNIST/raw/t10k-images-idx3-ubyte`` exist. + split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``, + ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies + which one to use. + train (bool, optional): If True, creates dataset from ``training.pt``, + otherwise from ``test.pt``. + download (bool, optional): If True, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + url = "https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip" + md5 = "58c8d27c78d21e728a6bc7b3cc06412e" + splits = ("byclass", "bymerge", "balanced", "letters", "digits", "mnist") + # Merged Classes assumes Same structure for both uppercase and lowercase version + _merged_classes = {"c", "i", "j", "k", "l", "m", "o", "p", "s", "u", "v", "w", "x", "y", "z"} + _all_classes = set(string.digits + string.ascii_letters) + classes_split_dict = { + "byclass": sorted(list(_all_classes)), + "bymerge": sorted(list(_all_classes - _merged_classes)), + "balanced": sorted(list(_all_classes - _merged_classes)), + "letters": ["N/A"] + list(string.ascii_lowercase), + "digits": list(string.digits), + "mnist": list(string.digits), + } + + def __init__(self, root: str, split: str, **kwargs: Any) -> None: + self.split = verify_str_arg(split, "split", self.splits) + self.training_file = self._training_file(split) + self.test_file = self._test_file(split) + super().__init__(root, **kwargs) + self.classes = self.classes_split_dict[self.split] + + @staticmethod + def _training_file(split) -> str: + return f"training_{split}.pt" + + @staticmethod + def _test_file(split) -> str: + return f"test_{split}.pt" + + @property + def _file_prefix(self) -> str: + return f"emnist-{self.split}-{'train' if self.train else 'test'}" + + @property + def images_file(self) -> str: + return os.path.join(self.raw_folder, f"{self._file_prefix}-images-idx3-ubyte") + + @property + def labels_file(self) -> str: + return os.path.join(self.raw_folder, f"{self._file_prefix}-labels-idx1-ubyte") + + def _load_data(self): + return read_image_file(self.images_file), read_label_file(self.labels_file) + + def _check_exists(self) -> bool: + return all(check_integrity(file) for file in (self.images_file, self.labels_file)) + + def download(self) -> None: + """Download the EMNIST data if it doesn't exist already.""" + + if self._check_exists(): + return + + os.makedirs(self.raw_folder, exist_ok=True) + + download_and_extract_archive(self.url, download_root=self.raw_folder, md5=self.md5) + gzip_folder = os.path.join(self.raw_folder, "gzip") + for gzip_file in os.listdir(gzip_folder): + if gzip_file.endswith(".gz"): + extract_archive(os.path.join(gzip_folder, gzip_file), self.raw_folder) + shutil.rmtree(gzip_folder) + + +class QMNIST(MNIST): + """`QMNIST `_ Dataset. + + Args: + root (string): Root directory of dataset whose ``raw`` + subdir contains binary files of the datasets. + what (string,optional): Can be 'train', 'test', 'test10k', + 'test50k', or 'nist' for respectively the mnist compatible + training set, the 60k qmnist testing set, the 10k qmnist + examples that match the mnist testing set, the 50k + remaining qmnist testing examples, or all the nist + digits. The default is to select 'train' or 'test' + according to the compatibility argument 'train'. + compat (bool,optional): A boolean that says whether the target + for each example is class number (for compatibility with + the MNIST dataloader) or a torch vector containing the + full qmnist information. Default=True. + download (bool, optional): If True, downloads the dataset from + the internet and puts it in root directory. If dataset is + already downloaded, it is not downloaded again. + transform (callable, optional): A function/transform that + takes in an PIL image and returns a transformed + version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform + that takes in the target and transforms it. + train (bool,optional,compatibility): When argument 'what' is + not specified, this boolean decides whether to load the + training set ot the testing set. Default: True. + """ + + subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"} + resources: Dict[str, List[Tuple[str, str]]] = { # type: ignore[assignment] + "train": [ + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz", + "ed72d4157d28c017586c42bc6afe6370", + ), + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz", + "0058f8dd561b90ffdd0f734c6a30e5e4", + ), + ], + "test": [ + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-images-idx3-ubyte.gz", + "1394631089c404de565df7b7aeaf9412", + ), + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-test-labels-idx2-int.gz", + "5b5b05890a5e13444e108efe57b788aa", + ), + ], + "nist": [ + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-images-idx3-ubyte.xz", + "7f124b3b8ab81486c9d8c2749c17f834", + ), + ( + "https://raw.githubusercontent.com/facebookresearch/qmnist/master/xnist-labels-idx2-int.xz", + "5ed0e788978e45d4a8bd4b7caec3d79d", + ), + ], + } + classes = [ + "0 - zero", + "1 - one", + "2 - two", + "3 - three", + "4 - four", + "5 - five", + "6 - six", + "7 - seven", + "8 - eight", + "9 - nine", + ] + + def __init__( + self, root: str, what: Optional[str] = None, compat: bool = True, train: bool = True, **kwargs: Any + ) -> None: + if what is None: + what = "train" if train else "test" + self.what = verify_str_arg(what, "what", tuple(self.subsets.keys())) + self.compat = compat + self.data_file = what + ".pt" + self.training_file = self.data_file + self.test_file = self.data_file + super().__init__(root, train, **kwargs) + + @property + def images_file(self) -> str: + (url, _), _ = self.resources[self.subsets[self.what]] + return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]) + + @property + def labels_file(self) -> str: + _, (url, _) = self.resources[self.subsets[self.what]] + return os.path.join(self.raw_folder, os.path.splitext(os.path.basename(url))[0]) + + def _check_exists(self) -> bool: + return all(check_integrity(file) for file in (self.images_file, self.labels_file)) + + def _load_data(self): + data = read_sn3_pascalvincent_tensor(self.images_file) + if data.dtype != torch.uint8: + raise TypeError(f"data should be of dtype torch.uint8 instead of {data.dtype}") + if data.ndimension() != 3: + raise ValueError("data should have 3 dimensions instead of {data.ndimension()}") + + targets = read_sn3_pascalvincent_tensor(self.labels_file).long() + if targets.ndimension() != 2: + raise ValueError(f"targets should have 2 dimensions instead of {targets.ndimension()}") + + if self.what == "test10k": + data = data[0:10000, :, :].clone() + targets = targets[0:10000, :].clone() + elif self.what == "test50k": + data = data[10000:, :, :].clone() + targets = targets[10000:, :].clone() + + return data, targets + + def download(self) -> None: + """Download the QMNIST data if it doesn't exist already. + Note that we only download what has been asked for (argument 'what'). + """ + if self._check_exists(): + return + + os.makedirs(self.raw_folder, exist_ok=True) + split = self.resources[self.subsets[self.what]] + + for url, md5 in split: + download_and_extract_archive(url, self.raw_folder, md5=md5) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + # redefined to handle the compat flag + img, target = self.data[index], self.targets[index] + img = Image.fromarray(img.numpy(), mode="L") + if self.transform is not None: + img = self.transform(img) + if self.compat: + target = int(target[0]) + if self.target_transform is not None: + target = self.target_transform(target) + return img, target + + def extra_repr(self) -> str: + return f"Split: {self.what}" + + +def get_int(b: bytes) -> int: + return int(codecs.encode(b, "hex"), 16) + + +SN3_PASCALVINCENT_BITSMAP = { + 8: torch.uint8, + 9: torch.int8, + 11: torch.int16, + 12: torch.int32, + 13: torch.float32, + 14: torch.float64, +} + +TORCH_TYPE_BITS = { + torch.uint8: 8, + torch.int8: 8, + torch.int16: 16, + torch.int32: 32, + torch.float32: 32, + torch.float64: 64, +} + + +def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tensor: + """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). + Argument may be a filename, compressed filename, or file object. + """ + # read + with open(path, "rb") as f: + data = f.read() + # parse + magic = get_int(data[0:4]) + nd = magic % 256 + ty = magic // 256 + assert 1 <= nd <= 3 + assert 8 <= ty <= 14 + torch_type = SN3_PASCALVINCENT_BITSMAP[ty] + s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)] + + num_bytes_per_value = TORCH_TYPE_BITS[torch_type] // 8 + # The MNIST format uses the big endian byte order. If the system uses little endian byte order by default, + # we need to reverse the bytes before we can read them with torch.frombuffer(). + needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1 + parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1))) + if needs_byte_reversal: + parsed = parsed.flip(0) + + assert parsed.shape[0] == np.prod(s) or not strict + return parsed.view(*s) + + +def read_label_file(path: str) -> torch.Tensor: + x = read_sn3_pascalvincent_tensor(path, strict=False) + if x.dtype != torch.uint8: + raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}") + if x.ndimension() != 1: + raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}") + return x.long() + + +def read_image_file(path: str) -> torch.Tensor: + x = read_sn3_pascalvincent_tensor(path, strict=False) + if x.dtype != torch.uint8: + raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}") + if x.ndimension() != 3: + raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}") + return x \ No newline at end of file diff --git a/python/jittor/compatibility/vision/datasets/utils.py b/python/jittor/compatibility/vision/datasets/utils.py new file mode 100644 index 00000000..f9ae1a89 --- /dev/null +++ b/python/jittor/compatibility/vision/datasets/utils.py @@ -0,0 +1,522 @@ +import bz2 +import contextlib +import gzip +import hashlib +import itertools +import lzma +import os +import os.path +import pathlib +import re +import sys +import tarfile +import urllib +import urllib.error +import urllib.request +import warnings +import zipfile +from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar +from urllib.parse import urlparse + +import numpy as np +import requests +import torch +from tqdm import tqdm + +from .._internally_replaced_utils import _download_file_from_remote_location, _is_remote_location_available + +USER_AGENT = "pytorch/vision" + + +def _save_response_content( + content: Iterator[bytes], + destination: str, + length: Optional[int] = None, +) -> None: + with open(destination, "wb") as fh, tqdm(total=length) as pbar: + for chunk in content: + # filter out keep-alive new chunks + if not chunk: + continue + + fh.write(chunk) + pbar.update(len(chunk)) + + +def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None: + with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response: + _save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length) + + +def gen_bar_updater() -> Callable[[int, int, int], None]: + warnings.warn("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15.") + pbar = tqdm(total=None) + + def bar_update(count, block_size, total_size): + if pbar.total is None and total_size: + pbar.total = total_size + progress_bytes = count * block_size + pbar.update(progress_bytes - pbar.n) + + return bar_update + + +def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str: + # Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are + # not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without + # it torchvision.datasets is unusable in these environments since we perform a MD5 check everywhere. + if sys.version_info >= (3, 9): + md5 = hashlib.md5(usedforsecurity=False) + else: + md5 = hashlib.md5() + with open(fpath, "rb") as f: + for chunk in iter(lambda: f.read(chunk_size), b""): + md5.update(chunk) + return md5.hexdigest() + + +def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool: + return md5 == calculate_md5(fpath, **kwargs) + + +def check_integrity(fpath: str, md5: Optional[str] = None) -> bool: + if not os.path.isfile(fpath): + return False + if md5 is None: + return True + return check_md5(fpath, md5) + + +def _get_redirect_url(url: str, max_hops: int = 3) -> str: + initial_url = url + headers = {"Method": "HEAD", "User-Agent": USER_AGENT} + + for _ in range(max_hops + 1): + with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response: + if response.url == url or response.url is None: + return url + + url = response.url + else: + raise RecursionError( + f"Request to {initial_url} exceeded {max_hops} redirects. The last redirect points to {url}." + ) + + +def _get_google_drive_file_id(url: str) -> Optional[str]: + parts = urlparse(url) + + if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None: + return None + + match = re.match(r"/file/d/(?P[^/]*)", parts.path) + if match is None: + return None + + return match.group("id") + + +def download_url( + url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None, max_redirect_hops: int = 3 +) -> None: + """Download a file from a url and place it in root. + + Args: + url (str): URL to download file from + root (str): Directory to place downloaded file in + filename (str, optional): Name to save the file under. If None, use the basename of the URL + md5 (str, optional): MD5 checksum of the download. If None, do not check + max_redirect_hops (int, optional): Maximum number of redirect hops allowed + """ + root = os.path.expanduser(root) + if not filename: + filename = os.path.basename(url) + fpath = os.path.join(root, filename) + + os.makedirs(root, exist_ok=True) + + # check if file is already present locally + if check_integrity(fpath, md5): + print("Using downloaded and verified file: " + fpath) + return + + if _is_remote_location_available(): + _download_file_from_remote_location(fpath, url) + else: + # expand redirect chain if needed + url = _get_redirect_url(url, max_hops=max_redirect_hops) + + # check if file is located on Google Drive + file_id = _get_google_drive_file_id(url) + if file_id is not None: + return download_file_from_google_drive(file_id, root, filename, md5) + + # download the file + try: + print("Downloading " + url + " to " + fpath) + _urlretrieve(url, fpath) + except (urllib.error.URLError, OSError) as e: # type: ignore[attr-defined] + if url[:5] == "https": + url = url.replace("https:", "http:") + print("Failed download. Trying https -> http instead. Downloading " + url + " to " + fpath) + _urlretrieve(url, fpath) + else: + raise e + + # check integrity of downloaded file + if not check_integrity(fpath, md5): + raise RuntimeError("File not found or corrupted.") + + +def list_dir(root: str, prefix: bool = False) -> List[str]: + """List all directories at a given root + + Args: + root (str): Path to directory whose folders need to be listed + prefix (bool, optional): If true, prepends the path to each result, otherwise + only returns the name of the directories found + """ + root = os.path.expanduser(root) + directories = [p for p in os.listdir(root) if os.path.isdir(os.path.join(root, p))] + if prefix is True: + directories = [os.path.join(root, d) for d in directories] + return directories + + +def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]: + """List all files ending with a suffix at a given root + + Args: + root (str): Path to directory whose folders need to be listed + suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). + It uses the Python "str.endswith" method and is passed directly + prefix (bool, optional): If true, prepends the path to each result, otherwise + only returns the name of the files found + """ + root = os.path.expanduser(root) + files = [p for p in os.listdir(root) if os.path.isfile(os.path.join(root, p)) and p.endswith(suffix)] + if prefix is True: + files = [os.path.join(root, d) for d in files] + return files + + +def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]: + content = response.iter_content(chunk_size) + first_chunk = None + # filter out keep-alive new chunks + while not first_chunk: + first_chunk = next(content) + content = itertools.chain([first_chunk], content) + + try: + match = re.search("Google Drive - (?P<api_response>.+?)", first_chunk.decode()) + api_response = match["api_response"] if match is not None else None + except UnicodeDecodeError: + api_response = None + return api_response, content + + +def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None): + """Download a Google Drive file from and place it in root. + + Args: + file_id (str): id of file to be downloaded + root (str): Directory to place downloaded file in + filename (str, optional): Name to save the file under. If None, use the id of the file. + md5 (str, optional): MD5 checksum of the download. If None, do not check + """ + # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url + + root = os.path.expanduser(root) + if not filename: + filename = file_id + fpath = os.path.join(root, filename) + + os.makedirs(root, exist_ok=True) + + if check_integrity(fpath, md5): + print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}") + return + + url = "https://drive.google.com/uc" + params = dict(id=file_id, export="download") + with requests.Session() as session: + response = session.get(url, params=params, stream=True) + + for key, value in response.cookies.items(): + if key.startswith("download_warning"): + token = value + break + else: + api_response, content = _extract_gdrive_api_response(response) + token = "t" if api_response == "Virus scan warning" else None + + if token is not None: + response = session.get(url, params=dict(params, confirm=token), stream=True) + api_response, content = _extract_gdrive_api_response(response) + + if api_response == "Quota exceeded": + raise RuntimeError( + f"The daily quota of the file {filename} is exceeded and it " + f"can't be downloaded. This is a limitation of Google Drive " + f"and can only be overcome by trying again later." + ) + + _save_response_content(content, fpath) + + # In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text + if os.stat(fpath).st_size < 10 * 1024: + with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh: + text = fh.read() + # Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604 + if re.search(r"]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text): + warnings.warn( + f"We detected some HTML elements in the downloaded file. " + f"This most likely means that the download triggered an unhandled API response by GDrive. " + f"Please report this to torchvision at https://github.com/pytorch/vision/issues including " + f"the response:\n\n{text}" + ) + + if md5 and not check_md5(fpath, md5): + raise RuntimeError( + f"The MD5 checksum of the download file {fpath} does not match the one on record." + f"Please delete the file and try again. " + f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues." + ) + + +def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None: + with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar: + tar.extractall(to_path) + + +_ZIP_COMPRESSION_MAP: Dict[str, int] = { + ".bz2": zipfile.ZIP_BZIP2, + ".xz": zipfile.ZIP_LZMA, +} + + +def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> None: + with zipfile.ZipFile( + from_path, "r", compression=_ZIP_COMPRESSION_MAP[compression] if compression else zipfile.ZIP_STORED + ) as zip: + zip.extractall(to_path) + + +_ARCHIVE_EXTRACTORS: Dict[str, Callable[[str, str, Optional[str]], None]] = { + ".tar": _extract_tar, + ".zip": _extract_zip, +} +_COMPRESSED_FILE_OPENERS: Dict[str, Callable[..., IO]] = { + ".bz2": bz2.open, + ".gz": gzip.open, + ".xz": lzma.open, +} +_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = { + ".tbz": (".tar", ".bz2"), + ".tbz2": (".tar", ".bz2"), + ".tgz": (".tar", ".gz"), +} + + +def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]: + """Detect the archive type and/or compression of a file. + + Args: + file (str): the filename + + Returns: + (tuple): tuple of suffix, archive type, and compression + + Raises: + RuntimeError: if file has no suffix or suffix is not supported + """ + suffixes = pathlib.Path(file).suffixes + if not suffixes: + raise RuntimeError( + f"File '{file}' has no suffixes that could be used to detect the archive type and compression." + ) + suffix = suffixes[-1] + + # check if the suffix is a known alias + if suffix in _FILE_TYPE_ALIASES: + return (suffix, *_FILE_TYPE_ALIASES[suffix]) + + # check if the suffix is an archive type + if suffix in _ARCHIVE_EXTRACTORS: + return suffix, suffix, None + + # check if the suffix is a compression + if suffix in _COMPRESSED_FILE_OPENERS: + # check for suffix hierarchy + if len(suffixes) > 1: + suffix2 = suffixes[-2] + + # check if the suffix2 is an archive type + if suffix2 in _ARCHIVE_EXTRACTORS: + return suffix2 + suffix, suffix2, suffix + + return suffix, None, suffix + + valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS)) + raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.") + + +def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str: + r"""Decompress a file. + + The compression is automatically detected from the file name. + + Args: + from_path (str): Path to the file to be decompressed. + to_path (str): Path to the decompressed file. If omitted, ``from_path`` without compression extension is used. + remove_finished (bool): If ``True``, remove the file after the extraction. + + Returns: + (str): Path to the decompressed file. + """ + suffix, archive_type, compression = _detect_file_type(from_path) + if not compression: + raise RuntimeError(f"Couldn't detect a compression from suffix {suffix}.") + + if to_path is None: + to_path = from_path.replace(suffix, archive_type if archive_type is not None else "") + + # We don't need to check for a missing key here, since this was already done in _detect_file_type() + compressed_file_opener = _COMPRESSED_FILE_OPENERS[compression] + + with compressed_file_opener(from_path, "rb") as rfh, open(to_path, "wb") as wfh: + wfh.write(rfh.read()) + + if remove_finished: + os.remove(from_path) + + return to_path + + +def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str: + """Extract an archive. + + The archive type and a possible compression is automatically detected from the file name. If the file is compressed + but not an archive the call is dispatched to :func:`decompress`. + + Args: + from_path (str): Path to the file to be extracted. + to_path (str): Path to the directory the file will be extracted to. If omitted, the directory of the file is + used. + remove_finished (bool): If ``True``, remove the file after the extraction. + + Returns: + (str): Path to the directory the file was extracted to. + """ + if to_path is None: + to_path = os.path.dirname(from_path) + + suffix, archive_type, compression = _detect_file_type(from_path) + if not archive_type: + return _decompress( + from_path, + os.path.join(to_path, os.path.basename(from_path).replace(suffix, "")), + remove_finished=remove_finished, + ) + + # We don't need to check for a missing key here, since this was already done in _detect_file_type() + extractor = _ARCHIVE_EXTRACTORS[archive_type] + + extractor(from_path, to_path, compression) + if remove_finished: + os.remove(from_path) + + return to_path + + +def download_and_extract_archive( + url: str, + download_root: str, + extract_root: Optional[str] = None, + filename: Optional[str] = None, + md5: Optional[str] = None, + remove_finished: bool = False, +) -> None: + download_root = os.path.expanduser(download_root) + if extract_root is None: + extract_root = download_root + if not filename: + filename = os.path.basename(url) + + download_url(url, download_root, filename, md5) + + archive = os.path.join(download_root, filename) + print(f"Extracting {archive} to {extract_root}") + extract_archive(archive, extract_root, remove_finished) + + +def iterable_to_str(iterable: Iterable) -> str: + return "'" + "', '".join([str(item) for item in iterable]) + "'" + + +T = TypeVar("T", str, bytes) + + +def verify_str_arg( + value: T, + arg: Optional[str] = None, + valid_values: Optional[Iterable[T]] = None, + custom_msg: Optional[str] = None, +) -> T: + if not isinstance(value, torch._six.string_classes): + if arg is None: + msg = "Expected type str, but got type {type}." + else: + msg = "Expected type str for argument {arg}, but got type {type}." + msg = msg.format(type=type(value), arg=arg) + raise ValueError(msg) + + if valid_values is None: + return value + + if value not in valid_values: + if custom_msg is not None: + msg = custom_msg + else: + msg = "Unknown value '{value}' for argument {arg}. Valid values are {{{valid_values}}}." + msg = msg.format(value=value, arg=arg, valid_values=iterable_to_str(valid_values)) + raise ValueError(msg) + + return value + + +def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray: + """Read file in .pfm format. Might contain either 1 or 3 channels of data. + + Args: + file_name (str): Path to the file. + slice_channels (int): Number of channels to slice out of the file. + Useful for reading different data formats stored in .pfm files: Optical Flows, Stereo Disparity Maps, etc. + """ + + with open(file_name, "rb") as f: + header = f.readline().rstrip() + if header not in [b"PF", b"Pf"]: + raise ValueError("Invalid PFM file") + + dim_match = re.match(rb"^(\d+)\s(\d+)\s$", f.readline()) + if not dim_match: + raise Exception("Malformed PFM header.") + w, h = (int(dim) for dim in dim_match.groups()) + + scale = float(f.readline().rstrip()) + if scale < 0: # little-endian + endian = "<" + scale = -scale + else: + endian = ">" # big-endian + + data = np.fromfile(f, dtype=endian + "f") + + pfm_channels = 3 if header == b"PF" else 1 + + data = data.reshape(h, w, pfm_channels).transpose(2, 0, 1) + data = np.flip(data, axis=1) # flip on h dimension + data = data[:slice_channels, :, :] + return data.astype(np.float32) \ No newline at end of file diff --git a/python/jittor/compatibility/vision/datasets/vision.py b/python/jittor/compatibility/vision/datasets/vision.py new file mode 100644 index 00000000..d71dc2a5 --- /dev/null +++ b/python/jittor/compatibility/vision/datasets/vision.py @@ -0,0 +1,104 @@ +import os +from typing import Any, Callable, List, Optional, Tuple + +import torch +import torch.utils.data as data + +from ..utils import _log_api_usage_once + + +class VisionDataset(data.Dataset): + """ + Base Class For making datasets which are compatible with torchvision. + It is necessary to override the ``__getitem__`` and ``__len__`` method. + Args: + root (string): Root directory of dataset. + transforms (callable, optional): A function/transforms that takes in + an image and a label and returns the transformed versions of both. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + .. note:: + :attr:`transforms` and the combination of :attr:`transform` and :attr:`target_transform` are mutually exclusive. + """ + + _repr_indent = 4 + + def __init__( + self, + root: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + self.root = root + + has_transforms = transforms is not None + has_separate_transform = transform is not None or target_transform is not None + if has_transforms and has_separate_transform: + raise ValueError("Only transforms or transform/target_transform can be passed as argument") + + # for backwards-compatibility + self.transform = transform + self.target_transform = target_transform + + if has_separate_transform: + transforms = StandardTransform(transform, target_transform) + self.transforms = transforms + + def __getitem__(self, index: int) -> Any: + """ + Args: + index (int): Index + Returns: + (Any): Sample and meta data, optionally transformed by the respective transforms. + """ + raise NotImplementedError + + def __len__(self) -> int: + raise NotImplementedError + + def __repr__(self) -> str: + head = "Dataset " + self.__class__.__name__ + body = [f"Number of datapoints: {self.__len__()}"] + if self.root is not None: + body.append(f"Root location: {self.root}") + body += self.extra_repr().splitlines() + if hasattr(self, "transforms") and self.transforms is not None: + body += [repr(self.transforms)] + lines = [head] + [" " * self._repr_indent + line for line in body] + return "\n".join(lines) + + def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: + lines = transform.__repr__().splitlines() + return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]] + + def extra_repr(self) -> str: + return "" + + +class StandardTransform: + def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None: + self.transform = transform + self.target_transform = target_transform + + def __call__(self, input: Any, target: Any) -> Tuple[Any, Any]: + if self.transform is not None: + input = self.transform(input) + if self.target_transform is not None: + target = self.target_transform(target) + return input, target + + def _format_transform_repr(self, transform: Callable, head: str) -> List[str]: + lines = transform.__repr__().splitlines() + return [f"{head}{lines[0]}"] + ["{}{}".format(" " * len(head), line) for line in lines[1:]] + + def __repr__(self) -> str: + body = [self.__class__.__name__] + if self.transform is not None: + body += self._format_transform_repr(self.transform, "Transform: ") + if self.target_transform is not None: + body += self._format_transform_repr(self.target_transform, "Target transform: ") + + return "\n".join(body) \ No newline at end of file diff --git a/python/jittor/compatibility/vision/transforms.py b/python/jittor/compatibility/vision/transforms.py new file mode 100644 index 00000000..416057c7 --- /dev/null +++ b/python/jittor/compatibility/vision/transforms.py @@ -0,0 +1 @@ +from jittor.transform import * \ No newline at end of file diff --git a/python/jittor/compatibility/vision/utils.py b/python/jittor/compatibility/vision/utils.py new file mode 100644 index 00000000..4be36c64 --- /dev/null +++ b/python/jittor/compatibility/vision/utils.py @@ -0,0 +1,582 @@ +import collections +import math +import pathlib +import warnings +from itertools import repeat +from types import FunctionType +from typing import Any, BinaryIO, List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL import Image, ImageColor, ImageDraw, ImageFont + +__all__ = [ + "make_grid", + "save_image", + "draw_bounding_boxes", + "draw_segmentation_masks", + "draw_keypoints", + "flow_to_image", +] + + +@torch.no_grad() +def make_grid( + tensor: Union[torch.Tensor, List[torch.Tensor]], + nrow: int = 8, + padding: int = 2, + normalize: bool = False, + value_range: Optional[Tuple[int, int]] = None, + scale_each: bool = False, + pad_value: float = 0.0, + **kwargs, +) -> torch.Tensor: + """ + Make a grid of images. + + Args: + tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) + or a list of images all of the same size. + nrow (int, optional): Number of images displayed in each row of the grid. + The final grid size is ``(B / nrow, nrow)``. Default: ``8``. + padding (int, optional): amount of padding. Default: ``2``. + normalize (bool, optional): If True, shift the image to the range (0, 1), + by the min and max values specified by ``value_range``. Default: ``False``. + value_range (tuple, optional): tuple (min, max) where min and max are numbers, + then these numbers are used to normalize the image. By default, min and max + are computed from the tensor. + range (tuple. optional): + .. warning:: + This parameter was deprecated in ``0.12`` and will be removed in ``0.14``. Please use ``value_range`` + instead. + scale_each (bool, optional): If ``True``, scale each image in the batch of + images separately rather than the (min, max) over all images. Default: ``False``. + pad_value (float, optional): Value for the padded pixels. Default: ``0``. + + Returns: + grid (Tensor): the tensor containing grid of images. + """ + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(make_grid) + if not torch.is_tensor(tensor): + if isinstance(tensor, list): + for t in tensor: + if not torch.is_tensor(t): + raise TypeError(f"tensor or list of tensors expected, got a list containing {type(t)}") + else: + raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}") + + if "range" in kwargs.keys(): + warnings.warn( + "The parameter 'range' is deprecated since 0.12 and will be removed in 0.14. " + "Please use 'value_range' instead." + ) + value_range = kwargs["range"] + + # if list of tensors, convert to a 4D mini-batch Tensor + if isinstance(tensor, list): + tensor = torch.stack(tensor, dim=0) + + if tensor.dim() == 2: # single image H x W + tensor = tensor.unsqueeze(0) + if tensor.dim() == 3: # single image + if tensor.size(0) == 1: # if single-channel, convert to 3-channel + tensor = torch.cat((tensor, tensor, tensor), 0) + tensor = tensor.unsqueeze(0) + + if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images + tensor = torch.cat((tensor, tensor, tensor), 1) + + if normalize is True: + tensor = tensor.clone() # avoid modifying tensor in-place + if value_range is not None and not isinstance(value_range, tuple): + raise TypeError("value_range has to be a tuple (min, max) if specified. min and max are numbers") + + def norm_ip(img, low, high): + img.clamp_(min=low, max=high) + img.sub_(low).div_(max(high - low, 1e-5)) + + def norm_range(t, value_range): + if value_range is not None: + norm_ip(t, value_range[0], value_range[1]) + else: + norm_ip(t, float(t.min()), float(t.max())) + + if scale_each is True: + for t in tensor: # loop over mini-batch dimension + norm_range(t, value_range) + else: + norm_range(tensor, value_range) + + if not isinstance(tensor, torch.Tensor): + raise TypeError("tensor should be of type torch.Tensor") + if tensor.size(0) == 1: + return tensor.squeeze(0) + + # make the mini-batch of images into a grid + nmaps = tensor.size(0) + xmaps = min(nrow, nmaps) + ymaps = int(math.ceil(float(nmaps) / xmaps)) + height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) + num_channels = tensor.size(1) + grid = tensor.new_full((num_channels, height * ymaps + padding, width * xmaps + padding), pad_value) + k = 0 + for y in range(ymaps): + for x in range(xmaps): + if k >= nmaps: + break + # Tensor.copy_() is a valid method but seems to be missing from the stubs + # https://pytorch.org/docs/stable/tensors.html#torch.Tensor.copy_ + grid.narrow(1, y * height + padding, height - padding).narrow( # type: ignore[attr-defined] + 2, x * width + padding, width - padding + ).copy_(tensor[k]) + k = k + 1 + return grid + + +@torch.no_grad() +def save_image( + tensor: Union[torch.Tensor, List[torch.Tensor]], + fp: Union[str, pathlib.Path, BinaryIO], + format: Optional[str] = None, + **kwargs, +) -> None: + """ + Save a given Tensor into an image file. + + Args: + tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, + saves the tensor as a grid of images by calling ``make_grid``. + fp (string or file object): A filename or a file object + format(Optional): If omitted, the format to use is determined from the filename extension. + If a file object was used instead of a filename, this parameter should always be used. + **kwargs: Other arguments are documented in ``make_grid``. + """ + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(save_image) + grid = make_grid(tensor, **kwargs) + # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer + ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() + im = Image.fromarray(ndarr) + im.save(fp, format=format) + + +@torch.no_grad() +def draw_bounding_boxes( + image: torch.Tensor, + boxes: torch.Tensor, + labels: Optional[List[str]] = None, + colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, + fill: Optional[bool] = False, + width: int = 1, + font: Optional[str] = None, + font_size: Optional[int] = None, +) -> torch.Tensor: + + """ + Draws bounding boxes on given image. + The values of the input image should be uint8 between 0 and 255. + If fill is True, Resulting Tensor should be saved as PNG image. + + Args: + image (Tensor): Tensor of shape (C x H x W) and dtype uint8. + boxes (Tensor): Tensor of size (N, 4) containing bounding boxes in (xmin, ymin, xmax, ymax) format. Note that + the boxes are absolute coordinates with respect to the image. In other words: `0 <= xmin < xmax < W` and + `0 <= ymin < ymax < H`. + labels (List[str]): List containing the labels of bounding boxes. + colors (color or list of colors, optional): List containing the colors + of the boxes or single color for all boxes. The color can be represented as + PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + By default, random colors are generated for boxes. + fill (bool): If `True` fills the bounding box with specified color. + width (int): Width of bounding box. + font (str): A filename containing a TrueType font. If the file is not found in this filename, the loader may + also search in other directories, such as the `fonts/` directory on Windows or `/Library/Fonts/`, + `/System/Library/Fonts/` and `~/Library/Fonts/` on macOS. + font_size (int): The requested font size in points. + + Returns: + img (Tensor[C, H, W]): Image Tensor of dtype uint8 with bounding boxes plotted. + """ + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(draw_bounding_boxes) + if not isinstance(image, torch.Tensor): + raise TypeError(f"Tensor expected, got {type(image)}") + elif image.dtype != torch.uint8: + raise ValueError(f"Tensor uint8 expected, got {image.dtype}") + elif image.dim() != 3: + raise ValueError("Pass individual images, not batches") + elif image.size(0) not in {1, 3}: + raise ValueError("Only grayscale and RGB images are supported") + elif (boxes[:, 0] > boxes[:, 2]).any() or (boxes[:, 1] > boxes[:, 3]).any(): + raise ValueError( + "Boxes need to be in (xmin, ymin, xmax, ymax) format. Use torchvision.ops.box_convert to convert them" + ) + + num_boxes = boxes.shape[0] + + if num_boxes == 0: + warnings.warn("boxes doesn't contain any box. No box was drawn") + return image + + if labels is None: + labels: Union[List[str], List[None]] = [None] * num_boxes # type: ignore[no-redef] + elif len(labels) != num_boxes: + raise ValueError( + f"Number of boxes ({num_boxes}) and labels ({len(labels)}) mismatch. Please specify labels for each box." + ) + + if colors is None: + colors = _generate_color_palette(num_boxes) + elif isinstance(colors, list): + if len(colors) < num_boxes: + raise ValueError(f"Number of colors ({len(colors)}) is less than number of boxes ({num_boxes}). ") + else: # colors specifies a single color for all boxes + colors = [colors] * num_boxes + + colors = [(ImageColor.getrgb(color) if isinstance(color, str) else color) for color in colors] + + if font is None: + if font_size is not None: + warnings.warn("Argument 'font_size' will be ignored since 'font' is not set.") + txt_font = ImageFont.load_default() + else: + txt_font = ImageFont.truetype(font=font, size=font_size or 10) + + # Handle Grayscale images + if image.size(0) == 1: + image = torch.tile(image, (3, 1, 1)) + + ndarr = image.permute(1, 2, 0).cpu().numpy() + img_to_draw = Image.fromarray(ndarr) + img_boxes = boxes.to(torch.int64).tolist() + + if fill: + draw = ImageDraw.Draw(img_to_draw, "RGBA") + else: + draw = ImageDraw.Draw(img_to_draw) + + for bbox, color, label in zip(img_boxes, colors, labels): # type: ignore[arg-type] + if fill: + fill_color = color + (100,) + draw.rectangle(bbox, width=width, outline=color, fill=fill_color) + else: + draw.rectangle(bbox, width=width, outline=color) + + if label is not None: + margin = width + 1 + draw.text((bbox[0] + margin, bbox[1] + margin), label, fill=color, font=txt_font) + + return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) + + +@torch.no_grad() +def draw_segmentation_masks( + image: torch.Tensor, + masks: torch.Tensor, + alpha: float = 0.8, + colors: Optional[Union[List[Union[str, Tuple[int, int, int]]], str, Tuple[int, int, int]]] = None, +) -> torch.Tensor: + + """ + Draws segmentation masks on given RGB image. + The values of the input image should be uint8 between 0 and 255. + + Args: + image (Tensor): Tensor of shape (3, H, W) and dtype uint8. + masks (Tensor): Tensor of shape (num_masks, H, W) or (H, W) and dtype bool. + alpha (float): Float number between 0 and 1 denoting the transparency of the masks. + 0 means full transparency, 1 means no transparency. + colors (color or list of colors, optional): List containing the colors + of the masks or single color for all masks. The color can be represented as + PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + By default, random colors are generated for each mask. + + Returns: + img (Tensor[C, H, W]): Image Tensor, with segmentation masks drawn on top. + """ + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(draw_segmentation_masks) + if not isinstance(image, torch.Tensor): + raise TypeError(f"The image must be a tensor, got {type(image)}") + elif image.dtype != torch.uint8: + raise ValueError(f"The image dtype must be uint8, got {image.dtype}") + elif image.dim() != 3: + raise ValueError("Pass individual images, not batches") + elif image.size()[0] != 3: + raise ValueError("Pass an RGB image. Other Image formats are not supported") + if masks.ndim == 2: + masks = masks[None, :, :] + if masks.ndim != 3: + raise ValueError("masks must be of shape (H, W) or (batch_size, H, W)") + if masks.dtype != torch.bool: + raise ValueError(f"The masks must be of dtype bool. Got {masks.dtype}") + if masks.shape[-2:] != image.shape[-2:]: + raise ValueError("The image and the masks must have the same height and width") + + num_masks = masks.size()[0] + if colors is not None and num_masks > len(colors): + raise ValueError(f"There are more masks ({num_masks}) than colors ({len(colors)})") + + if num_masks == 0: + warnings.warn("masks doesn't contain any mask. No mask was drawn") + return image + + if colors is None: + colors = _generate_color_palette(num_masks) + + if not isinstance(colors, list): + colors = [colors] + if not isinstance(colors[0], (tuple, str)): + raise ValueError("colors must be a tuple or a string, or a list thereof") + if isinstance(colors[0], tuple) and len(colors[0]) != 3: + raise ValueError("It seems that you passed a tuple of colors instead of a list of colors") + + out_dtype = torch.uint8 + + colors_ = [] + for color in colors: + if isinstance(color, str): + color = ImageColor.getrgb(color) + colors_.append(torch.tensor(color, dtype=out_dtype)) + + img_to_draw = image.detach().clone() + # TODO: There might be a way to vectorize this + for mask, color in zip(masks, colors_): + img_to_draw[:, mask] = color[:, None] + + out = image * (1 - alpha) + img_to_draw * alpha + return out.to(out_dtype) + + +@torch.no_grad() +def draw_keypoints( + image: torch.Tensor, + keypoints: torch.Tensor, + connectivity: Optional[List[Tuple[int, int]]] = None, + colors: Optional[Union[str, Tuple[int, int, int]]] = None, + radius: int = 2, + width: int = 3, +) -> torch.Tensor: + + """ + Draws Keypoints on given RGB image. + The values of the input image should be uint8 between 0 and 255. + + Args: + image (Tensor): Tensor of shape (3, H, W) and dtype uint8. + keypoints (Tensor): Tensor of shape (num_instances, K, 2) the K keypoints location for each of the N instances, + in the format [x, y]. + connectivity (List[Tuple[int, int]]]): A List of tuple where, + each tuple contains pair of keypoints to be connected. + colors (str, Tuple): The color can be represented as + PIL strings e.g. "red" or "#FF00FF", or as RGB tuples e.g. ``(240, 10, 157)``. + radius (int): Integer denoting radius of keypoint. + width (int): Integer denoting width of line connecting keypoints. + + Returns: + img (Tensor[C, H, W]): Image Tensor of dtype uint8 with keypoints drawn. + """ + + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + _log_api_usage_once(draw_keypoints) + if not isinstance(image, torch.Tensor): + raise TypeError(f"The image must be a tensor, got {type(image)}") + elif image.dtype != torch.uint8: + raise ValueError(f"The image dtype must be uint8, got {image.dtype}") + elif image.dim() != 3: + raise ValueError("Pass individual images, not batches") + elif image.size()[0] != 3: + raise ValueError("Pass an RGB image. Other Image formats are not supported") + + if keypoints.ndim != 3: + raise ValueError("keypoints must be of shape (num_instances, K, 2)") + + ndarr = image.permute(1, 2, 0).cpu().numpy() + img_to_draw = Image.fromarray(ndarr) + draw = ImageDraw.Draw(img_to_draw) + img_kpts = keypoints.to(torch.int64).tolist() + + for kpt_id, kpt_inst in enumerate(img_kpts): + for inst_id, kpt in enumerate(kpt_inst): + x1 = kpt[0] - radius + x2 = kpt[0] + radius + y1 = kpt[1] - radius + y2 = kpt[1] + radius + draw.ellipse([x1, y1, x2, y2], fill=colors, outline=None, width=0) + + if connectivity: + for connection in connectivity: + start_pt_x = kpt_inst[connection[0]][0] + start_pt_y = kpt_inst[connection[0]][1] + + end_pt_x = kpt_inst[connection[1]][0] + end_pt_y = kpt_inst[connection[1]][1] + + draw.line( + ((start_pt_x, start_pt_y), (end_pt_x, end_pt_y)), + width=width, + ) + + return torch.from_numpy(np.array(img_to_draw)).permute(2, 0, 1).to(dtype=torch.uint8) + + +# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization +@torch.no_grad() +def flow_to_image(flow: torch.Tensor) -> torch.Tensor: + + """ + Converts a flow to an RGB image. + + Args: + flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float. + + Returns: + img (Tensor): Image Tensor of dtype uint8 where each color corresponds + to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input. + """ + + if flow.dtype != torch.float: + raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.") + + orig_shape = flow.shape + if flow.ndim == 3: + flow = flow[None] # Add batch dim + + if flow.ndim != 4 or flow.shape[1] != 2: + raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.") + + max_norm = torch.sum(flow**2, dim=1).sqrt().max() + epsilon = torch.finfo((flow).dtype).eps + normalized_flow = flow / (max_norm + epsilon) + img = _normalized_flow_to_image(normalized_flow) + + if len(orig_shape) == 3: + img = img[0] # Remove batch dim + return img + + +@torch.no_grad() +def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor: + + """ + Converts a batch of normalized flow to an RGB image. + + Args: + normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W) + Returns: + img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8. + """ + + N, _, H, W = normalized_flow.shape + device = normalized_flow.device + flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device) + colorwheel = _make_colorwheel().to(device) # shape [55x3] + num_cols = colorwheel.shape[0] + norm = torch.sum(normalized_flow**2, dim=1).sqrt() + a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi + fk = (a + 1) / 2 * (num_cols - 1) + k0 = torch.floor(fk).to(torch.long) + k1 = k0 + 1 + k1[k1 == num_cols] = 0 + f = fk - k0 + + for c in range(colorwheel.shape[1]): + tmp = colorwheel[:, c] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1 - f) * col0 + f * col1 + col = 1 - norm * (1 - col) + flow_image[:, c, :, :] = torch.floor(255 * col) + return flow_image + + +def _make_colorwheel() -> torch.Tensor: + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf. + + Returns: + colorwheel (Tensor[55, 3]): Colorwheel Tensor. + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = torch.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = torch.floor(255 * torch.arange(0, RY) / RY) + col = col + RY + # YG + colorwheel[col : col + YG, 0] = 255 - torch.floor(255 * torch.arange(0, YG) / YG) + colorwheel[col : col + YG, 1] = 255 + col = col + YG + # GC + colorwheel[col : col + GC, 1] = 255 + colorwheel[col : col + GC, 2] = torch.floor(255 * torch.arange(0, GC) / GC) + col = col + GC + # CB + colorwheel[col : col + CB, 1] = 255 - torch.floor(255 * torch.arange(CB) / CB) + colorwheel[col : col + CB, 2] = 255 + col = col + CB + # BM + colorwheel[col : col + BM, 2] = 255 + colorwheel[col : col + BM, 0] = torch.floor(255 * torch.arange(0, BM) / BM) + col = col + BM + # MR + colorwheel[col : col + MR, 2] = 255 - torch.floor(255 * torch.arange(MR) / MR) + colorwheel[col : col + MR, 0] = 255 + return colorwheel + + +def _generate_color_palette(num_objects: int): + palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1]) + return [tuple((i * palette) % 255) for i in range(num_objects)] + + +def _log_api_usage_once(obj: Any) -> None: + + """ + Logs API usage(module and name) within an organization. + In a large ecosystem, it's often useful to track the PyTorch and + TorchVision APIs usage. This API provides the similar functionality to the + logging module in the Python stdlib. It can be used for debugging purpose + to log which methods are used and by default it is inactive, unless the user + manually subscribes a logger via the `SetAPIUsageLogger method `_. + Please note it is triggered only once for the same API call within a process. + It does not collect any data from open-source users since it is no-op by default. + For more information, please refer to + * PyTorch note: https://pytorch.org/docs/stable/notes/large_scale_deployments.html#api-usage-logging; + * Logging policy: https://github.com/pytorch/vision/issues/5052; + + Args: + obj (class instance or method): an object to extract info from. + """ + pass + + +def _make_ntuple(x: Any, n: int) -> Tuple[Any, ...]: + """ + Make n-tuple from input x. If x is an iterable, then we just convert it to tuple. + Otherwise we will make a tuple of length n, all with value of x. + reference: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/utils.py#L8 + + Args: + x (Any): input value + n (int): length of the resulting tuple + """ + if isinstance(x, collections.abc.Iterable): + return tuple(x) + return tuple(repeat(x, n)) \ No newline at end of file diff --git a/python/jittor/compile_extern.py b/python/jittor/compile_extern.py new file mode 100644 index 00000000..f6c93b54 --- /dev/null +++ b/python/jittor/compile_extern.py @@ -0,0 +1,715 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import os, sys, shutil +import platform +from .compiler import * +from jittor_utils import run_cmd, get_version, get_int_version +from jittor_utils.misc import download_url_to_local +import jittor_utils as jit_utils + +def search_file(dirs, name, prefer_version=()): + if os.name == 'nt': + if name.startswith("lib"): + name = name[3:].replace(".so", "64*.dll") + for d in dirs: + fname = os.path.join(d, name) + if os.name == 'nt': + lname = os.path.join(d, name) + names = glob.glob(lname) + if len(names): + return names[0] + continue + prefer_version = tuple( str(p) for p in prefer_version ) + for i in range(len(prefer_version),-1,-1): + vname = ".".join((fname,)+prefer_version[:i]) + if os.path.isfile(vname): + LOG.v(f"found {vname}") + return vname + LOG.f(f"file {name} not found in {dirs}") + +def install_mkl(root_folder): + # origin url is + # url = "https://github.com/intel/mkl-dnn/releases/download/v1.0.2/mkldnn_lnx_1.0.2_cpu_gomp.tgz" + import platform + url = None + if platform.system()=="Linux": + if platform.machine()=='x86_64': + filename = "dnnl_lnx_2.2.0_cpu_gomp.tgz" + md5 = "35bbbdf550a9d8ad54db798e372000f6" + elif platform.machine()=='aarch64': + filename = "dnnl_lnx_2.2.0_cpu_gomp_aarch64.tgz" + md5 = "72cf9b0b8fd6c3c786d35a9daaee22b8" + else: + raise RuntimeError(f"platform.machine()=={platform.machine()} not support yet," + " Please contact us on https://github.com/jittor/jittor ") + elif os.name == "nt": + # url = "https://github.com/oneapi-src/oneDNN/releases/download/v2.2/dnnl_win_2.2.0_cpu_iomp.zip" + # url = "https://github.com/oneapi-src/oneDNN/releases/download/v2.2/dnnl_win_2.2.0_cpu_vcomp.zip" + filename = "dnnl_win_2.2.0_cpu_vcomp.zip" + md5 = "fa12c693b2ec07700d174e1e99d60a7e" + elif platform.system() == "Darwin": + if platform.machine() == "arm64": + filename = "dnnl_mac_2.2.0_cpu_omp_arm64.tgz" + md5 = "d8fdf56d3cf618685d22d18f08119f88" + else: + filename = "dnnl_mac_2.2.0_cpu_omp_x86_64.tgz" + md5 = "6e2f065d6a589c82081536b684768fe6" + else: + raise RuntimeError(f"platform.machine()=={platform.machine()} not support yet," + " Please contact us on https://github.com/jittor/jittor ") + + if not url: + url = "https://cg.cs.tsinghua.edu.cn/jittor/assets/" + filename + fullname = os.path.join(root_folder, filename) + dirname = os.path.join(root_folder, filename.rsplit(".",1)[0]) + + if not (os.path.isfile(os.path.join(dirname, "lib", "libmkldnn.so")) or + os.path.isfile(os.path.join(dirname, "bin", "dnnl.dll")) or + os.path.isfile(os.path.join(dirname, "lib", "libmkldnn.dylib"))): + LOG.i("Downloading mkl...") + download_url_to_local(url, filename, root_folder, md5) + if fullname.endswith(".zip"): + import zipfile + with zipfile.ZipFile(fullname, "r") as f: + f.extractall(root_folder) + else: + import tarfile + with tarfile.open(fullname, "r") as tar: + tar.extractall(root_folder) + if os.name == 'nt': + # this env is used for execute example/text + bin_path = os.path.join(dirname, "bin") + sys.path.append(bin_path) + os.environ["PATH"] = os.environ.get("PATH", "") + ";" + bin_path + cmd = f"cd /d {dirname}/examples && {cc_path} {dirname}/examples/cnn_inference_f32.cpp -I{dirname}/include -Fe: {dirname}/examples/test.exe {fix_cl_flags(cc_flags).replace('-LD', '')} {dirname}/lib/mkldnn.lib" + + assert 0 == os.system(cmd) + assert 0 == os.system(f"{dirname}/examples/test") + elif platform.system() == "Darwin": + assert 0 == os.system(f"cd {dirname}/examples && " + f"{cc_path} -std=c++14 cnn_inference_f32.cpp -Ofast -lmkldnn -I ../include -L ../lib -o test && DYLD_LIBRARY_PATH=../lib/ ./test") + else: + assert 0 == os.system(f"cd {dirname}/examples && " + f"{cc_path} -std=c++14 cnn_inference_f32.cpp -Ofast -lmkldnn -I ../include -L ../lib -o test && LD_LIBRARY_PATH=../lib/ ./test") + +def setup_mkl(): + global mkl_ops, use_mkl + use_mkl = os.environ.get("use_mkl", "1")=="1" + mkl_ops = None + if not use_mkl: return + + # pytorch mkl is conflict with jittor mkl + # yield error "free: invalide size" or + # "mmap error" + # import pytorch(>1.8) first can fix this problem + # try: + # # jt.dirty_fix_pytorch_runtime_error() + # import torch + # from torch import nn + # except: + # torch = None + + mkl_include_path = os.environ.get("mkl_include_path") + mkl_lib_path = os.environ.get("mkl_lib_path") + + if mkl_lib_path is None or mkl_include_path is None: + LOG.v("setup mkl...") + # mkl_path = os.path.join(cache_path, "mkl") + # mkl_path decouple with cc_path + mkl_path = os.path.join(jit_utils.home(), ".cache", "jittor", "mkl") + + make_cache_dir(mkl_path) + install_mkl(mkl_path) + mkl_home = "" + for name in os.listdir(mkl_path): + if name.startswith("dnnl") and os.path.isdir(os.path.join(mkl_path, name)): + mkl_home = os.path.join(mkl_path, name) + break + assert mkl_home!="" + mkl_include_path = os.path.join(mkl_home, "include") + mkl_lib_path = os.path.join(mkl_home, "lib") + + mkl_lib_name = os.path.join(mkl_lib_path, "libmkldnn.so") + extra_flags = f" -I\"{mkl_include_path}\" -L\"{mkl_lib_path}\" -lmkldnn " + if os.name == 'nt': + mkl_lib_name = os.path.join(mkl_home, 'bin', 'dnnl.dll') + mkl_bin_path = os.path.join(mkl_home, 'bin') + extra_flags = f" -I\"{mkl_include_path}\" -L\"{mkl_lib_path}\" -L\"{mkl_bin_path}\" -ldnnl " + elif platform.system() == "Darwin": + mkl_lib_name = os.path.join(mkl_lib_path, "libmkldnn.dylib") + + assert os.path.isdir(mkl_include_path) + assert os.path.isdir(mkl_lib_path) + assert os.path.isfile(mkl_lib_name) + LOG.v(f"mkl_include_path: {mkl_include_path}") + LOG.v(f"mkl_lib_path: {mkl_lib_path}") + LOG.v(f"mkl_lib_name: {mkl_lib_name}") + # We do not link manualy, link in custom ops + # ctypes.CDLL(mkl_lib_name, dlopen_flags) + + mkl_op_dir = os.path.join(jittor_path, "extern", "mkl", "ops") + mkl_op_files = [os.path.join(mkl_op_dir, name) for name in os.listdir(mkl_op_dir)] + mkl_ops = compile_custom_ops(mkl_op_files, extra_flags=extra_flags) + LOG.vv("Get mkl_ops: "+str(dir(mkl_ops))) + + +def install_cub(root_folder): + url = "https://github.com/NVIDIA/cub/archive/1.11.0.tar.gz" + url = "https://codeload.github.com/NVIDIA/cub/tar.gz/1.11.0" + filename = "cub-1.11.0.tgz" + md5 = "97196a885598e40592100e1caaf3d5ea" + fullname = os.path.join(root_folder, filename) + dirname = os.path.join(root_folder, filename.replace(".tgz","")) + + if not os.path.isfile(os.path.join(dirname, "examples", "device/example_device_radix_sort.cu")): + LOG.i("Downloading cub...") + download_url_to_local(url, filename, root_folder, md5) + import tarfile + + with tarfile.open(fullname, "r") as tar: + tar.extractall(root_folder) + # assert 0 == os.system(f"cd {dirname}/examples && " + # f"{nvcc_path} --cudart=shared -ccbin=\"{cc_path}\" device/example_device_radix_sort.cu -O2 -I.. -std=c++14 -o test") + # if core.get_device_count(): + # assert 0 == os.system(f"cd {dirname}/examples && ./test") + return dirname + +def setup_cub(): + global cub_home + cub_home = "" + cub_path = os.path.join(jit_utils.home(), ".cache", "jittor", "cub") + cuda_version = int(get_version(nvcc_path)[1:-1].split('.')[0]) + extra_flags = "" + if cuda_version < 11: + cub_home = install_cub(cub_path) + extra_flags = f"-I{cub_home}" + cub_home += "/" + setup_cuda_lib("cub", link=False, extra_flags=extra_flags) + +def setup_cuda_extern(): + if not has_cuda: return + def split(a): return a.replace(";",":").split(":") + check_ld_path = split(os.environ.get("LD_LIBRARY_PATH", "")) + \ + split(os.environ.get("PATH", "")) + for cp in check_ld_path: + cp = cp.lower() + if "cuda" in cp and \ + "lib" in cp and \ + "jtcuda" not in cp: + LOG.w(f"CUDA related path found in LD_LIBRARY_PATH or PATH, " + "This path may cause jittor found the wrong libs, " + "please unset LD_LIBRARY_PATH and remove cuda lib path in Path. \n" + "Or you can let jittor install cuda for you: `python3.x -m jittor_utils.install_cuda`") + break + LOG.vv("setup cuda extern...") + cache_path_cuda = os.path.join(cache_path, "cuda") + cuda_include = os.path.join(jittor_path, "extern", "cuda", "inc") + make_cache_dir(cache_path_cuda) + cuda_extern_src = os.path.join(jittor_path, "extern", "cuda", "src") + cuda_extern_files = [os.path.join(cuda_extern_src, name) + for name in os.listdir(cuda_extern_src)] + so_name = os.path.join(cache_path_cuda, "libcuda_extern"+so) + compile(cc_path, cc_flags+f" -I\"{cuda_include}\" ", cuda_extern_files, so_name) + link_cuda_extern = f" -L\"{cache_path_cuda}\" -llibcuda_extern " + ctypes.CDLL(so_name, dlopen_flags) + + try: + setup_cub() + except Exception as e: + import traceback + line = traceback.format_exc() + LOG.w(f"CUDA found but cub is not loaded:\n{line}") + + libs = ["cublas", "cudnn", "curand", "cufft"] + # in cuda 11.4, module memory comsumptions: + # default context: 259 MB + # cublas: 340 MB + # cudnn: 340 MB + if int(os.environ.get("conv_opt", "0")): + libs = ["cublas", "curand"] + for lib_name in libs: + try: + setup_cuda_lib(lib_name, extra_flags=link_cuda_extern) + except Exception as e: + msg = f"CUDA found but {lib_name} is not loaded:\n" + if lib_name == "cudnn": + msg += """Develop version of CUDNN not found, +please refer to CUDA offical tar file installation: +https://docs.nvidia.com/deeplearning/cudnn/install-guide/index.html#installlinux-tar""" + if platform.machine() in ["x86_64", "AMD64"]: + msg += f""" +or you can let jittor install cuda and cudnn for you: +>>> python3.{sys.version_info.minor} -m jittor_utils.install_cuda +""" + LOG.f(msg) + +def setup_cuda_lib(lib_name, link=True, extra_flags=""): + arch_key = "x86_64" + if platform.machine() not in ["x86_64", "AMD64"]: + arch_key = "aarch64" + globals()[lib_name+"_ops"] = None + globals()[lib_name] = None + if not has_cuda: return + LOG.v(f"setup {lib_name}...") + + culib_path = os.path.join(cuda_lib, f"lib{lib_name}.so") + jt_cuda_include = os.path.join(jittor_path, "extern", "cuda", "inc") + jt_culib_include = os.path.join(jittor_path, "extern", "cuda", lib_name, "inc") + + link_flags = "" + if link: + extra_include_path = os.path.abspath(os.path.join(cuda_include, "..", f"targets/{arch_key}-linux/include")) + extra_lib_path = os.path.abspath(os.path.join(cuda_lib, "..", f"targets/{arch_key}-linux/lib")) + cuda_include_name = search_file([cuda_include, extra_include_path, "/usr/include"], lib_name+".h") + # cuda11 prefer cudnn 8 + nvcc_version = get_int_version(nvcc_path) + if has_corex: + nvcc_version = (10,2,89) + prefer_version = () + if nvcc_version[0] == 11: + prefer_version = ("8",) + culib_path = search_file([cuda_bin, cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], f"lib{lib_name}.so", prefer_version) + + if lib_name == "cublas" and nvcc_version[0] >= 10: + # manual link libcublasLt.so + try: + cublas_lt_lib_path = search_file([cuda_bin, cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], f"libcublasLt.so", nvcc_version) + ctypes.CDLL(cublas_lt_lib_path, dlopen_flags) + except: + # some aarch64 os, such as uos with FT2000 cpu, + # it's cuda 10 doesn't have libcublasLt.so + pass + + + + if lib_name == "cudnn": + # cudnn cannot found libcudnn_cnn_train.so.8, we manual link for it. + if nvcc_version >= (11,0,0): + libs = ["libcudnn_ops_infer.so", "libcudnn_ops_train.so", "libcudnn_cnn_infer.so", "libcudnn_cnn_train.so"] + for l in libs: + ex_cudnn_path = search_file([cuda_bin, cuda_lib, extra_lib_path, f"/usr/lib/{arch_key}-linux-gnu", "/usr/lib"], l, prefer_version) + ctypes.CDLL(ex_cudnn_path, dlopen_flags) + + # dynamic link cuda library + # ctypes.CDLL(culib_path, dlopen_flags) + # link_flags = f"-l{lib_name} -L\"{cuda_lib}\"" + link_flags = f"-l{lib_name} -L\"{os.path.dirname(culib_path)}\"" + # print("link_flags", link_flags, culib_path) + + # find all source files + culib_src_dir = os.path.join(jittor_path, "extern", "cuda", lib_name) + culib_src_files = [] + for r, _, f in os.walk(culib_src_dir): + for fname in f: + culib_src_files.append(os.path.join(r, fname)) + if len(culib_src_files) == 0: + return + + # compile and get operators + culib = compile_custom_ops(culib_src_files, return_module=True, + extra_flags=f" -I\"{jt_cuda_include}\" -I\"{jt_culib_include}\" {link_flags} {extra_flags} ") + culib_ops = culib.ops + globals()[lib_name+"_ops"] = culib_ops + globals()[lib_name] = culib + LOG.vv(f"Get {lib_name}_ops: "+str(dir(culib_ops))) + + +def _setup_fake_cuda_lib(lib_name=None, link=True, extra_flags=""): + if lib_name is None: + lib_names = ["cudnn", "cublas", "curand", "cufft", "cub", "cutt", "cutlass"] + for lib_name in lib_names: + _setup_fake_cuda_lib(lib_name, link, extra_flags) + return + arch_key = "x86_64" + if platform.machine() not in ["x86_64", "AMD64"]: + arch_key = "aarch64" + globals()[lib_name+"_ops"] = None + globals()[lib_name] = None + LOG.v(f"setup {lib_name}...") + + jt_cuda_include = os.path.join(jittor_path, "extern", "cuda", "inc") + jt_culib_include = os.path.join(jittor_path, "extern", "cuda", lib_name, "inc") + + # find all source files + culib_src_dir = os.path.join(jittor_path, "extern", "cuda", lib_name, "ops") + culib_src_files = [] + for r, _, f in os.walk(culib_src_dir): + for fname in f: + if fname.endswith("op.cc") or fname.endswith("op.h"): + culib_src_files.append(os.path.join(r, fname)) + if len(culib_src_files) == 0: + return + + # compile and get operators + culib = compile_custom_ops(culib_src_files, return_module=True, + extra_flags=f" -I\"{jt_cuda_include}\" -I\"{jt_culib_include}\" {extra_flags} ") + culib_ops = culib.ops + globals()[lib_name+"_ops"] = culib_ops + globals()[lib_name] = culib + LOG.vv(f"Get {lib_name}_ops: "+str(dir(culib_ops))) + +if setup_fake_cuda_lib: + _setup_fake_cuda_lib() + +def install_cutt(root_folder): + # Modified from: https://github.com/ap-hynninen/cutt + url = "https://codeload.github.com/Jittor/cutt/zip/v1.2" + + filename = "cutt-1.2.zip" + fullname = os.path.join(root_folder, filename) + dirname = os.path.join(root_folder, filename.replace(".zip","")) + true_md5 = "14d0fd1132c8cd657dc3cf29ce4db931" + + if os.path.exists(fullname): + from jittor_utils.misc import calculate_md5 + md5 = calculate_md5(fullname) + if md5 != true_md5: + os.remove(fullname) + shutil.rmtree(dirname) + CUTT_PATH = os.environ.get("CUTT_PATH", "") + if not os.path.isfile(os.path.join(cache_path, "libcutt"+so)) or CUTT_PATH: + if CUTT_PATH: + dirname = CUTT_PATH + else: + LOG.i("Downloading cutt...") + download_url_to_local(url, filename, root_folder, true_md5) + + import zipfile + + zf = zipfile.ZipFile(fullname) + try: + zf.extractall(path=root_folder) + except RuntimeError as e: + print(e) + raise + zf.close() + + LOG.i("installing cutt...") + # -Xptxas -dlcm=ca actually not work + arch_flag = " -Xptxas -dlcm=ca " + if len(flags.cuda_archs): + arch_flag = f" -arch=compute_{min(flags.cuda_archs)} " + arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs)) + cutt_include = f" -I\"{dirname}/include\" -I\"{dirname}/src\" " + files = glob.glob(dirname+"/src/*.c*", recursive=True) + files2 = [] + for f in files: + if f.endswith("cutt_bench.cpp") or \ + f.endswith("cutt_test.cpp"): + continue + files2.append(f) + cutt_flags = cc_flags+opt_flags+cutt_include + compile(cc_path, cutt_flags, files2, cache_path+"/libcutt"+so, cuda_flags=arch_flag) + return dirname + +def setup_cutt(): + global cutt_ops, use_cutt + if not has_cuda: + use_cutt = False + return + use_cutt = os.environ.get("use_cutt", "1")=="1" + cutt_ops = None + if not use_cutt: return + cutt_include_path = os.environ.get("cutt_include_path") + cutt_lib_path = os.environ.get("cutt_lib_path") + + if cutt_lib_path is None or cutt_include_path is None: + LOG.v("setup cutt...") + # cutt_path decouple with cc_path + cutt_path = os.path.join(jit_utils.home(), ".cache", "jittor", "cutt") + + make_cache_dir(cutt_path) + install_cutt(cutt_path) + cutt_home = os.path.join(cutt_path, "cutt-1.2") + cutt_include_path = os.path.join(cutt_home, "src") + cutt_lib_path = cache_path + + cutt_lib_name = os.path.join(cutt_lib_path, "libcutt"+so) + assert os.path.isdir(cutt_include_path) + assert os.path.isdir(cutt_lib_path) + assert os.path.isfile(cutt_lib_name), cutt_lib_name + LOG.v(f"cutt_include_path: {cutt_include_path}") + LOG.v(f"cutt_lib_path: {cutt_lib_path}") + LOG.v(f"cutt_lib_name: {cutt_lib_name}") + # We do not link manualy, link in custom ops + ctypes.CDLL(cutt_lib_name, dlopen_flags) + + cutt_op_dir = os.path.join(jittor_path, "extern", "cuda", "cutt", "ops") + cutt_op_files = [os.path.join(cutt_op_dir, name) for name in os.listdir(cutt_op_dir)] + cutt_ops = compile_custom_ops(cutt_op_files, + extra_flags=f" -I\"{cutt_include_path}\" -L\"{cutt_lib_path}\" -llibcutt ") + LOG.vv("Get cutt_ops: "+str(dir(cutt_ops))) + +def install_cutlass(root_folder): + # Modified from: https://github.com/ap-hynninen/cutlass + url = "https://cloud.tsinghua.edu.cn/f/171e49e5825549548bc4/?dl=1" + + filename = "cutlass.zip" + fullname = os.path.join(root_folder, filename) + dirname = os.path.join(root_folder, filename.replace(".zip","")) + true_md5 = "999ecb7e217e40c497bc3d0ded6643f0" + + if os.path.exists(fullname): + from jittor_utils.misc import calculate_md5 + md5 = calculate_md5(fullname) + if md5 != true_md5: + os.remove(fullname) + shutil.rmtree(dirname) + CUTLASS_PATH = os.environ.get("CUTLASS_PATH", "") + if not os.path.isfile(os.path.join(jit_utils.home(), ".cache/jittor/cutlass/cutlass/include/cutlass/cutlass.h")) or CUTLASS_PATH: + if CUTLASS_PATH: + dirname = CUTLASS_PATH + else: + LOG.i("Downloading cutlass...") + download_url_to_local(url, filename, root_folder, true_md5) + + import zipfile + + zf = zipfile.ZipFile(fullname) + try: + zf.extractall(path=root_folder) + except RuntimeError as e: + print(e) + raise + zf.close() + + # LOG.i("installing cutlass...") + # # -Xptxas -dlcm=ca actually not work + # arch_flag = " -Xptxas -dlcm=ca " + # if len(flags.cuda_archs): + # arch_flag = f" -arch=compute_{min(flags.cuda_archs)} " + # arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs)) + # cutlass_include = f" -I\"{dirname}/include\" -I\"{dirname}/src\" " + # files = glob.glob(dirname+"/src/*.c*", recursive=True) + # files2 = [] + # for f in files: + # if f.endswith("cutlass_bench.cpp") or \ + # f.endswith("cutlass_test.cpp"): + # continue + # files2.append(f) + # cutlass_flags = cc_flags+opt_flags+cutlass_include + # compile(cc_path, cutlass_flags, files2, cache_path+"/libcutlass"+so, cuda_flags=arch_flag) + return dirname + +def setup_cutlass(): + global cutlass_ops, use_cutlass + if not has_cuda: + use_cutlass = False + return + use_cutlass = os.environ.get("use_cutlass", "1")=="1" + cutlass_ops = None + if not use_cutlass: return + cutlass_include_path = os.environ.get("cutlass_include_path") + + if cutlass_include_path is None: + LOG.v("setup cutlass...") + # cutlass_path decouple with cc_path + cutlass_path = os.path.join(jit_utils.home(), ".cache", "jittor", "cutlass") + + make_cache_dir(cutlass_path) + install_cutlass(cutlass_path) + + +def install_nccl(root_folder): + url = "https://github.com/NVIDIA/nccl/archive/v2.8.4-1.tar.gz" + url = "https://codeload.github.com/NVIDIA/nccl/tar.gz/v2.8.4-1" + + filename = "nccl.tgz" + fullname = os.path.join(root_folder, filename) + dirname = os.path.join(root_folder, "nccl-2.8.4-1") + true_md5 = "900666558c5bc43e0a5e84045b88a06f" + + if os.path.exists(fullname): + md5 = run_cmd('md5sum '+fullname).split()[0] + if md5 != true_md5: + os.remove(fullname) + if os.path.isdir(dirname): + shutil.rmtree(dirname) + if not os.path.isfile(os.path.join(dirname, "build", "lib", "libnccl.so")): + if not os.path.isfile(os.path.join(root_folder, filename)): + LOG.i("Downloading nccl...") + download_url_to_local(url, filename, root_folder, true_md5) + + if core.get_device_count() == 0: + return + if not inside_mpi(): + return + + import tarfile + with tarfile.open(fullname, "r") as tar: + tar.extractall(root_folder) + + LOG.i("installing nccl...") + arch_flag = "" + if len(flags.cuda_archs): + arch_flag = f" -arch=compute_{min(flags.cuda_archs)} " + arch_flag += ''.join(map(lambda x:f' -code=sm_{x} ', flags.cuda_archs)) + run_cmd(f"CC=\"{cc_path}\" CXX=\"{cc_path}\" make -j8 src.build CUDA_HOME='{cuda_home}' NVCC_GENCODE='{arch_flag} --cudart=shared ' ", cwd=dirname) + return dirname + +def setup_nccl(): + global nccl, nccl_ops, use_nccl + use_nccl = os.environ.get("use_nccl", "1")=="1" + nccl = None + nccl_ops = None + if not has_cuda or not has_mpi: + use_nccl = False + return + if not use_nccl: return + nccl_include_path = os.environ.get("nccl_include_path") + nccl_lib_path = os.environ.get("nccl_lib_path") + + if nccl_lib_path is None or nccl_include_path is None: + LOG.v("setup nccl...") + # nccl_path decouple with cc_path + nccl_path = os.path.join(jit_utils.home(), ".cache", "jittor", "nccl") + + make_cache_dir(nccl_path) + nccl_home = install_nccl(nccl_path) + if nccl_home is None: return + nccl_include_path = os.path.join(nccl_home, "build", "include") + nccl_lib_path = os.path.join(nccl_home, "build", "lib") + + if not inside_mpi(): + return + + nccl_lib_name = os.path.join(nccl_lib_path, "libnccl.so") + assert os.path.isdir(nccl_include_path) + assert os.path.isdir(nccl_lib_path) + assert os.path.isfile(nccl_lib_name), nccl_lib_name + LOG.v(f"nccl_include_path: {nccl_include_path}") + LOG.v(f"nccl_lib_path: {nccl_lib_path}") + LOG.v(f"nccl_lib_name: {nccl_lib_name}") + # We do not link manualy, link in custom ops + ctypes.CDLL(nccl_lib_name, dlopen_flags) + + nccl_src_dir = os.path.join(jittor_path, "extern", "cuda", "nccl") + nccl_src_files = [] + for r, _, f in os.walk(nccl_src_dir): + for fname in f: + nccl_src_files.append(os.path.join(r, fname)) + + nccl = compile_custom_ops(nccl_src_files, + extra_flags=f" -I\"{nccl_include_path}\" {mpi_compile_flags} ", + return_module=True, dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW, + gen_name_="jittor_nccl_core") + nccl_ops = nccl.ops + LOG.vv("Get nccl_ops: "+str(dir(nccl_ops))) + +def manual_link(flags): + lib_dirs = [] + libs = [] + for f in flags.split(): + if f.startswith("-l"): + libs.append(f[2:]) + elif f.startswith("-L"): + lib_dirs.append(f[2:]) + LOG.v("manual_link:", flags) + LOG.v("lib_dirs:", lib_dirs) + LOG.v("libs:", libs) + for lib in libs: + for d in lib_dirs: + libname = os.path.join(d, f"lib{lib}.so") + if os.path.isfile(libname): + LOG.v("link:", libname) + ctypes.CDLL(libname, dlopen_flags) + break + +def inside_mpi(): + return "OMPI_COMM_WORLD_SIZE" in os.environ + +def setup_mpi(): + global mpi_ops, mpi, use_mpi + global mpicc_path, has_mpi + use_mpi = os.environ.get("use_mpi", "1")=="1" + mpi_ops = None + mpi = None + has_mpi = False + if not use_mpi: return + mpicc_path = env_or_try_find('mpicc_path', 'mpicc') + if mpicc_path == "": + # LOG.i("mpicc not found, distribution disabled.") + use_mpi = False + else: + use_mpi = True + has_mpi = True + if not use_mpi: + return + + global mpi_compile_flags, mpi_link_flags, mpi_flags + mpi_compile_flags = run_cmd(mpicc_path+" --showme:compile") + mpi_link_flags = run_cmd(mpicc_path+" --showme:link") + mpi_flags = mpi_compile_flags + " " + mpi_link_flags + LOG.v("mpi_flags: "+mpi_flags) + + # find all source files + mpi_src_dir = os.path.join(jittor_path, "extern", "mpi") + mpi_src_files = [] + for r, _, f in os.walk(mpi_src_dir): + for fname in f: + mpi_src_files.append(os.path.join(r, fname)) + + # mpi compile flags add for nccl + mpi_compile_flags += f" -I\"{os.path.join(mpi_src_dir, 'inc')}\" " + mpi_compile_flags = mpi_compile_flags.replace("-pthread", "") + + mpi_version = get_version(mpicc_path) + if mpi_version.startswith("(1.") or mpi_version.startswith("(2."): + # mpi version 1.x need to link like this + manual_link(mpi_flags) + # mpi(4.x) cannot use deepbind, it need to + # share the 'environ' symbol. + mpi = compile_custom_ops(mpi_src_files, + extra_flags=f" {mpi_flags} ", return_module=True, + dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW, gen_name_="jittor_mpi_core") + mpi_ops = mpi.ops + LOG.vv("Get mpi: "+str(mpi.__dict__.keys())) + LOG.vv("Get mpi_ops: "+str(mpi_ops.__dict__.keys())) + def wrapper(func): + def inner(self, *args, **kw): + return func(self, *args, **kw) + inner.__doc__ = func.__doc__ + return inner + for k in mpi_ops.__dict__: + if not k.startswith("mpi_"): continue + if k == "mpi_test": continue + setattr(core.Var, k, wrapper(mpi_ops.__dict__[k])) + +in_mpi = inside_mpi() +FIX_TORCH_ERROR = 0 +if os.name != 'nt' and not in_mpi: + FIX_TORCH_ERROR = 1 +if "FIX_TORCH_ERROR" in os.environ: + FIX_TORCH_ERROR = os.environ["FIX_TORCH_ERROR"] != "0" +if FIX_TORCH_ERROR: + try: + import torch + from jittor_utils import dirty_fix_pytorch_runtime_error + dirty_fix_pytorch_runtime_error() + except: + pass + +cudnn = cublas = curand = cufft = None +setup_mpi() +rank = mpi.world_rank() if in_mpi else 0 +world_size = mpi.world_size() if in_mpi else 1 +setup_nccl() + +setup_cutt() +setup_cutlass() + +# try: +setup_mkl() +# except Exception as e: +# LOG.w("MKL install failed, msg:", e) + +setup_cuda_extern() + +# install backend extern library +for mod in jit_utils.backends: + if mod.install_extern(): + break diff --git a/python/jittor/compiler.py b/python/jittor/compiler.py new file mode 100644 index 00000000..decce72d --- /dev/null +++ b/python/jittor/compiler.py @@ -0,0 +1,1431 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import subprocess as sp +import os +import re +import sys +import glob +import inspect +import datetime +import threading +import platform +import ctypes +import platform +from ctypes import cdll +from ctypes.util import find_library + +import jittor_utils as jit_utils +from jittor_utils import LOG, run_cmd, find_exe, cc_path, cc_type, cache_path +from . import pyjt_compiler +from jittor_utils import lock +from jittor_utils import install_cuda +from jittor import __version__ +import hashlib + +def find_jittor_path(): + return os.path.dirname(__file__) + +def make_cache_dir(cache_path): + if not os.path.isdir(cache_path): + LOG.i(f"Create cache dir: {cache_path}") + os.mkdir(cache_path) + +def shsplit(s): + s1 = s.split(' ') + s2 = [] + count = 0 + for s in s1: + nc = s.count('"') + s.count('\'') + if count&1: + count += nc + s2[-1] += " " + s2[-1] += s + else: + count = nc + s2.append(s) + return s2 + + +def remove_flags(flags, rm_flags): + flags = shsplit(flags) + output = [] + for s in flags: + ss = s.replace("\"", "") + for rm in rm_flags: + if ss.startswith(rm) or ss.endswith(rm): + break + else: + output.append(s) + return " ".join(output) + +def moveback_flags(flags, rm_flags): + flags = shsplit(flags) + output = [] + output2 = [] + for s in flags: + ss = s.replace("\"", "") + for rm in rm_flags: + if ss.startswith(rm) or ss.endswith(rm): + output2.append(s) + break + else: + output.append(s) + return " ".join(output+output2) + +def map_flags(flags, func): + flags = shsplit(flags) + output = [] + for s in flags: + output.append(func(s)) + return " ".join(output) + +def compile(compiler, flags, inputs, output, combind_build=False, cuda_flags="", obj_dirname="obj_files"): + def do_compile(cmd): + if jit_utils.cc: + return jit_utils.cc.cache_compile(cmd, cache_path, jittor_path) + else: + run_cmd(cmd) + return True + base_output = os.path.basename(output).split('.')[0] + if os.name == 'nt': + # windows do not combind build, need gen def + combind_build = False + # windows need xxxx.lib + afile = output.rsplit('.', 1)[0] + ".lib" + afile = os.path.join(cache_path, afile) + if cc_type != 'cl': + # initialize order in windows seems reversed + inputs = list(inputs[::-1]) + link = link + f' -Wl,--export-all-symbols,--out-implib,"{afile}" ' + + if not os.path.isabs(output): + output = os.path.join(cache_path, output) + # don't recompile object file in inputs + obj_files = [] + ex_obj_files = [] + new_inputs = [] + obj_dir = os.path.join(cache_path, obj_dirname) + os.makedirs(obj_dir, exist_ok=True) + for name in inputs: + if name[-1] in 'oab': + ex_obj_files.append(name) + else: + new_inputs.append(os.path.join(jittor_path, name)) + obj_files.append(os.path.join( + obj_dir, os.path.basename(name)+".o")) + inputs = new_inputs + cm = lambda s: f"\"{s}\"" + cms = lambda arr: [f"\"{s}\"" for s in arr ] + + if len(inputs) == 1 or combind_build: + cmd = f"\"{compiler}\" {' '.join(cms(inputs))} {flags} -o {cm(output)}" + return do_compile(fix_cl_flags(cmd)) + # split compile object file and link + # remove -l -L flags when compile object files + oflags = remove_flags(flags, ['-l', '-L', '-Wl,', '.lib', '-shared']) + cmds = [] + for input, obj_file in zip(inputs, obj_files): + cc = compiler + nflags = oflags + cmd = f"{cm(input)} {nflags} {lto_flags} -c -o {cm(obj_file)}" + if input.endswith(".cu"): + if has_cuda or has_rocm: + cmd = f"\"{nvcc_path}\" {cuda_flags} {cmd}" + cmd = convert_nvcc_flags(fix_cl_flags(cmd)) + else: + continue + else: + cmd = f"\"{cc}\" {cmd}" + cmd = fix_cl_flags(cmd) + if "nan_checker" in input: + # nan checker needs to disable fast_math + if "--use_fast_math" in cmd: + cmd = cmd.replace("--use_fast_math", "") + if "-Ofast" in cmd: + cmd = cmd.replace("-Ofast", "-O2") + cmds.append(cmd) + jit_utils.run_cmds(cmds, cache_path, jittor_path, "Compiling "+base_output) + obj_files += ex_obj_files + if os.name == 'nt': + dumpdef_path = os.path.join(jittor_path, "utils", "dumpdef.py") + cmd = f"\"{sys.executable}\" \"{dumpdef_path}\" {' '.join(cms(obj_files))} -Fo: \"{output}.def\"" + do_compile(fix_cl_flags(cmd)) + cmd = f"\"{compiler}\" {' '.join(cms(obj_files))} -o {cm(output)} {flags} {lto_flags}" + return do_compile(fix_cl_flags(cmd)) + +def gen_jit_tests(): + all_src = glob.glob(jittor_path+"/src/**/*.cc", recursive=True) + jit_declares = [] + re_def = re.compile("JIT_TEST\\((.*?)\\)") + names = set() + test_defs = [] + + for src_name in all_src: + with open(src_name, 'rb') as f: + src = f.read().decode('utf8') + defs = re_def.findall(src) + for name in defs: + LOG.vv(f"Find test {name} from {src_name}") + assert name not in names, f"Conflict test name {name}" + names.add(name) + jit_declares.append(f"JIT_TEST({name});") + test_defs.append(f""" + /* From {src_name} */ + // @pyjt({name}) + static inline void test_{name}() {{ jit_test_{name}(); }} + """) + + jit_declares = "\n ".join(jit_declares) + jit_src = f""" + #pragma once + #include "common.h" + + void expect_error(std::function func) {{ + try {{ func(); }} + catch (...) {{ return; }} + CHECK(0) << "Missing error"; + }} + + namespace jittor {{ + + {jit_declares} + + // @pyjt(tests) + // @attrs(submodule) + namespace tests {{ + {"".join(test_defs)} + }} + + }} // jittor + """ + LOG.vvvv(jit_src) + with open(os.path.join(cache_path, "gen", "jit_tests.h"), 'w', encoding='utf8') as f: + f.write(jit_src) + +def gen_jit_flags(): + all_src = glob.glob(jittor_path+"/src/**/*.cc", recursive=True) + jit_declares = [] + re_def = re.compile("DEFINE_FLAG(_WITH_SETTER)?\\((.*?)\\);", re.DOTALL) + + flags_defs = [] + visit = {} + + for src_name in all_src: + with open(src_name, 'rb') as f: + src = f.read().decode("utf8") + defs = re_def.findall(src) + for _, args in defs: + args = args.split(",") + type = args[0].strip() + name = args[1].strip() + if not has_cuda and "cuda" in name and name!="use_cuda": + if name != "use_cuda_host_allocator": + continue + default = args[2].strip() + doc = ",".join(args[3:]) + doc = eval(f"({doc})") + LOG.vv(f"Find define {name} from {src_name}") + if name in visit: + continue + visit[name] = 1 + jit_declares.append(f"DECLARE_FLAG({type}, {name});") + alias = [] + if name == "use_cuda": + alias = ["use_device", "use_acl", "use_rocm", "use_corex"] + elif name == "auto_mixed_precision_level": + alias = ["amp_level"] + get_names = ",".join(["__get__"+a for a in [name]+alias]) + set_names = ",".join(["__set__"+a for a in [name]+alias]) + flags_defs.append(f""" + /* {name}(type:{type}, default:{default}): {doc} */ + // @pyjt({get_names}) + {type} _get_{name}() {{ return {name}; }} + // @pyjt({set_names}) + void _set_{name}({type} v) {{ set_{name}(v); }} + {f'''// @pyjt({set_names}) + void _set_{name}(bool v) {{ set_{name}(v); }} + ''' if type=="int" else ""} + """) + + jit_declares = "\n ".join(jit_declares) + jit_src = f""" + #include "utils/flags.h" + + namespace jittor {{ + + {jit_declares} + + // @pyjt(Flags) + struct _Flags {{ + // @pyjt(__init__) + _Flags() {{}} + {"".join(flags_defs)} + }}; + + }} // jittor + """ + LOG.vvvv(jit_src) + with open(os.path.join(cache_path, "gen", "jit_flags.h"), 'w', encoding='utf8') as f: + f.write(jit_src) + +def gen_jit_op_maker(op_headers, export=False, extra_flags=""): + def add_src( + cc_func_name, + cc_args, + op_name, + op_args, + src, + pybind_name, + py_args, + jit_cc_src, + doc_string, + attrs + ): + has_ir = set(["add", "sub", "mul", "matmul", "truediv", "floordiv", "mod", "divmod", "pow", "lshift", "rshift", "and", "xor", "or"]) + pybind_names = [ s.strip() for s in pybind_name.split(",")] + cc_make_args = [ arg.replace("VarHolder*", "Var*") for arg in cc_args ] + op_make_args = [ arg.replace("->var", "") for arg in op_args ] + py_args = [ arg.replace("Var*", "VarHolder*") for arg in py_args ] + op_args = [] + cc_args_with_default = [] + for i, arg in enumerate(cc_args): + pre_arg = arg.split()[-1].split('=')[0] + op_arg = None + if arg.startswith("VarHolder*"): + op_arg = pre_arg+"->var" + elif arg.startswith("vector"): + op_arg = f"convert({pre_arg})" + if "&&" in arg: + if op_arg == None: + op_arg = "move("+pre_arg+")" + op_make_args[i] = "move("+pre_arg+")" + if op_arg==None: op_arg = pre_arg + op_args.append(op_arg) + py_arg = py_args[i] + if "_a=" not in py_arg: + cc_args_with_default.append(arg) + continue + py_arg = py_arg.split("_a=")[1] + cc_args_with_default.append(arg + "=" + py_arg) + cc_args = cc_args_with_default + # steps of Op creation: + # 1. new op + # 2. new output var (create_output in op constructor) + # 3. take over op's output VarPtr from outputs_holder + # 4. set op's output + # 5. set op's input + # 6. infer shape(op->init()) + if "multiple_outputs" not in attrs: + jit_cc_src.append(f""" + VarPtr make_{cc_func_name}({", ".join(cc_make_args)}) {{ + auto _op = new {op_name}({", ".join(op_make_args)}); + if (_op->outputs_holder.size() != 1) {{ + delete _op; + LOGf << "Wrong output size of" << \"{op_name}\"; + }} + if (_op->flags.get(NodeFlags::_forwarded)) {{ + VarPtr _out(move(_op->outputs_holder[0])); + delete _op; + return _out; + }} + _op->outputs_holder[0]->set_inputs({{_op}}); + VarPtr _out(move(_op->outputs_holder[0])); + {src.replace("->var","")}; + _op->init(); + return _out; + }} + """) + else: + jit_cc_src.append(f""" + vector make_{cc_func_name}({", ".join(cc_make_args)}) {{ + auto _op = new {op_name}({", ".join(op_make_args)}); + if (_op->flags.get(NodeFlags::_forwarded)) {{ + vector _outs = move(_op->outputs_holder); + delete _op; + return _outs; + }} + vector _outs = move(_op->outputs_holder); + for (uint i=0; i<_outs.size(); i++) + _outs[i]->set_inputs({{_op}}); + {src.replace("->var","")}; + _op->init(); + return _outs; + }} + """) + if pybind_name == 'None': + return + pyjt_names = [] + for pybind_name in pybind_names: + if pybind_name.startswith("__"): + pyjt_names.append("Var."+pybind_name) + else: + pyjt_names.append(pybind_name) + if len(cc_args)>0 and cc_args[0].startswith("VarHolder* "): + pyjt_names.append("Var."+pybind_name) + if "multiple_outputs" in attrs: + jit_cc_src.append(f""" + /*{doc_string}*/ + // @pyjt({",".join(pyjt_names)}) + vector_to_tuple {cc_func_name}({", ".join(cc_args)}) {{ + { f'return make_vh_vector(make_{cc_func_name}({", ".join(op_args)}));' + if "replace_outputs" not in attrs else + f'''auto rt = make_vh_vector(make_{cc_func_name}({", ".join(op_args)})); + ASSERT(rt.size() == outputs.size()); + for (int i=0; iassign(rt[i]); + return rt; + '''} + }} + """) + else: + jit_cc_src.append(f""" + /*{doc_string}*/ + // @pyjt({",".join(pyjt_names)}) + VarHolder* {cc_func_name}({", ".join(cc_args)}) {{ + return new VarHolder(make_{cc_func_name}({", ".join(op_args)})); + }} + """) + need_ir_define = False + ir_name = None + for pybind_name in pybind_names: + if pybind_name.startswith("__") and pybind_name[2:-2] in has_ir: + need_ir_define = True + assert ir_name is None + ir_name = pybind_name[2:-2] + if need_ir_define: + assert len(cc_args)>0 and cc_args[0].startswith("VarHolder* ") + this = cc_args[0].split()[-1] + jit_cc_src.append(f""" + // @pyjt(Var.__i{ir_name}__) + // @attrs(return_self) + VarHolder* i{cc_func_name}({", ".join(cc_args)}) {{ + *{this} = make_{cc_func_name}({", ".join(op_args)}); + return {this}; + }} + """) + assert len(cc_args)>1 and cc_args[1].startswith("VarHolder* "), cc_args + r_cc_args = [cc_args[1], cc_args[0]] + cc_args[2:] + r_py_args = [py_args[1], py_args[0]] + py_args[2:] + jit_cc_src.append(f""" + VarHolder* r{cc_func_name}({", ".join(r_cc_args)}) {{ + return new VarHolder(make_{cc_func_name}({", ".join(op_args)})); + }} + """) + + jit_cc_src = [] + jit_headers = "" + initer = [] + pybind_reg = '(/\\*(.*?)\\*/\\s*)?(//\\s*@pybind\\(([^\\n]*)\\)\\s*)?' + pybind_attrs_reg = pybind_reg + '(//\\s*@attrs\\(([^\\n]*)\\)\\s*)?' + for header in op_headers: + # xxx_xxx_op + name = os.path.basename(header) + name = os.path.splitext(name)[0] + # xxx_xxx + assert name.endswith("_op") + func_name = name[:-3] + # XxxXxxOp + name2 = map(lambda s:s[:1].upper() + s[1:], name.split('_')) + name2 = "".join(name2) + with open(header, encoding='utf8') as f: + src = f.read() + # XxxXxxOp(args) + res = re.findall(pybind_attrs_reg + '[^~]('+name2+"\\([^\\n]*\\))", src, re.S) + assert len(res) >= 1, "Wrong op args in " + header + # registe op + cc_name = header[:-2] + ".cc" + constructors = [] + for i in range(len(res)): + name = 'make_'+func_name+'_'*i + constructors.append(f"{{ &typeid(&{name}), (void*)&{name} }}") + constructors = ",".join(constructors) + var_member_reg = r"\n\s*Var\b(.*);" + var_member_match = re.findall(var_member_reg, src) + var_member_match = " ".join(var_member_match) + for c in "*,": var_member_match = var_member_match.replace(c, " ") + var_member = var_member_match.split() + LOG.vv("var_member_match "+var_member_match) + LOG.vv("var_member "+str(var_member)) + var_member_src = [ f"VAR_MEMBER_NAME_AND_OFFSET({name}, {name2})" for name in var_member ] + var_member_src = ",".join(var_member_src) + initer.append(f'\n op_registe({{ "{func_name}", R"({cc_name})", extra_flags, {{{constructors}}}, {{{var_member_src}}} }});') + for hid, h_def in enumerate(res): + h_def = list(h_def) + # // @attrs(...) + attrs = {} + if h_def[4] != "": + attrs = pyjt_compiler.parse_attrs(h_def[5]) + del h_def[4:6] + # /* doc_string */ + # // @pybind(bind_name) + # XxxXxxOp(args_def) + doc_string = h_def[1].strip() + h_def = h_def[2:] + args_def = h_def[2][len(name2)+1:-1] + bind_name = h_def[1] + if bind_name == "": + bind_name = func_name + if args_def=="": + args = [] + else: + args = list(map(lambda s: s.split()[-1].split('=')[0], args_def.split(','))) + # py_args: "arg"_a=default + py_args = [] + new_args_def = [] + new_args = [] + # source of convert VarHolder* to Var* + vh2v_src = [] + more_src = [] + for arg, arg_def in zip(args, args_def.split(',')): + py_arg = f'"{arg}"_a' + if '=' in arg_def: + py_arg += "=" + arg_def.split('=')[-1] + arg_def = arg_def.split('=')[0] + py_args.append(py_arg) + arg_type = arg_def[:-(len(arg)+1)].strip() + if arg_type == "Var*": + new_args_def.append("VarHolder* " + arg) + vh2v_src.append(arg + "->var") + new_args.append(arg + "->var") + elif arg_type.startswith("vector"): + new_args_def.append( + arg_type.replace("Var", "VarHolder")+' '+arg) + new_args.append(arg) + more_src.append(f"_op->add_inputs({arg});") + elif arg_type.startswith("VarSlices"): + new_args_def.append(arg_def) + new_args.append(arg) + more_src.append(f""" + vector svars; + for (int i=0; i<_op->vs.n; i++) + if (_op->vs.slices[i].is_var()) + svars.push_back(_op->vs.slices[i].var); + _op->add_inputs(svars);""") + else: + new_args_def.append(arg_def) + new_args.append(arg) + vh2v_src = "_op->set_inputs({" + ", ".join(vh2v_src) + "});" + \ + "".join(more_src) + LOG.vvvv(f"Find op: {name2} args: {new_args}") + # if header.startswith("src/"): + # jit_headers += f"#include \"{header[4:]}\"\n" + # else: + jit_headers += f"#include \"{header}\"\n" + add_src( + func_name+'_'*hid, + new_args_def, + name2, + new_args, + vh2v_src, + bind_name, + py_args, + jit_cc_src, + doc_string, + attrs + ) + if func_name in ["binary", "unary", "reduce"]: + # generate binary op alias + with open(os.path.join(jittor_path, f"src/ops/{func_name}_op.cc"), encoding="utf-8") as f: + src = f.read() + src = src.split(f"unordered_set {func_name}_ops = ""{")[1].split("};")[0] + match_result = re.findall(pybind_reg + "\"([a-z_A-Z0-9]*)\"", src, re.S) + # remove /* doc_string */ pattern + res2 = [ (_[3], _[4]) for _ in match_result ] + LOG.vvvv(f"All supported {func_name} ops: {res2}") + # remove op args + if func_name == "reduce": + args_def = new_args_def[:1] + new_args_def[2:] + py_args_s = py_args[:1] + py_args[2:] + else: + args_def = new_args_def[:-1] + py_args_s = py_args[:-1] + # find the last type id(float64) + # add "_" suffix for all function + if func_name == "unary": + last_tid = res2.index(("","float64")) + # for each functor + for tid, (bind_name, func_name2) in enumerate(res2): + # get certain op doc_string + doc_string2 = match_result[tid][1].strip() + if len(doc_string2) == 0: + doc_string2 = doc_string + # add _ for types + if func_name == "unary" and tid <= last_tid: + func_name3 = func_name2 + "_" + elif func_name == "reduce": + func_name4 = func_name2 + func_name2 = "reduce_" + func_name2 + func_name3 = func_name2 + else: + func_name3 = func_name2 + if len(bind_name) == 0: + bind_name = func_name2 + if func_name == "reduce": + args = new_args[:1] + [f'ns_{func_name4}'] + new_args[2:] + else: + args = new_args[:-1] + [f'ns_{func_name2}'] + add_src( + func_name3+'_'*hid, + args_def, + name2, + args, + vh2v_src, + bind_name, + py_args_s, + jit_cc_src, + doc_string2, + attrs + ) + + jit_src = f""" + #pragma once + #include "pyjt/py_obj_holder.h" + #include "var.h" + #include "var_holder.h" + #include "ops/op_register.h" + {jit_headers} + + namespace jittor {{ + // fix make_array(py::array) undefine reference + #pragma GCC visibility push(default) + #define JIT_NAMESPACE {export+"_maker" if export else "jit_op_maker"} + // @pyjt(ops) + // @attrs(submodule{",core_name="+export if export else ""}) + namespace JIT_NAMESPACE {{ + {"".join(jit_cc_src)} + + void initer() {{ + string extra_flags = R"({extra_flags})"; + {"".join(initer)} + }} + int caller = (initer(), 0); + + }} // JIT_NAMESPACE + }} // jittor + {f''' + namespace jittor {{ + extern void pyjt_def_{export}(PyObject*); + }} + + static void init_module(PyModuleDef* mdef, PyObject* m) {{ + mdef->m_doc = "User defined custom ops"; + jittor::pyjt_def_{export}(m); + }} + PYJT_MODULE_INIT({export}); + + ''' if export else ""} + """ + return jit_src + +@lock.lock_scope() +def compile_custom_op(header, source, op_name, warp=True): + """Compile a single custom op + header: code of op header, not path + source: code of op source, not path + op_name: op_name of this op, it will used for + generation of header and source files, if the + type name of op is XxxXxxOp, op_name should be + xxx_xxx + warp: if true, warp a snippet for header and source + """ + if warp: + header = f""" + #pragma once + #include "op.h" + #include "var.h" + namespace jittor {{ + {header} + }} + """ + source = f""" + #include "{op_name}_op.h" + namespace jittor {{ + {source} + }} + """ + cops_dir = os.path.join(cache_path, "custom_ops") + make_cache_dir(cops_dir) + hname = os.path.join(cops_dir, op_name+"_op.h") + ccname = os.path.join(cops_dir, op_name+"_op.cc") + with open(hname, 'w', encoding='utf8') as f: + f.write(header) + with open(ccname, 'w', encoding='utf8') as f: + f.write(source) + m = compile_custom_ops([hname, ccname]) + return getattr(m, op_name) + +@lock.lock_scope() +def compile_custom_ops( + filenames, + extra_flags="", + return_module=False, + dlopen_flags=None, + gen_name_ = ""): + """Compile custom ops + filenames: path of op source files, filenames must be + pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the + type name of op must be XxxXxxOp. + extra_flags: extra compile flags + return_module: return module rather than ops(default: False) + return: compiled ops + """ + if dlopen_flags is None: + dlopen_flags = os.RTLD_GLOBAL | os.RTLD_NOW + if platform.system() == 'Linux': + dlopen_flags |= os.RTLD_DEEPBIND + + srcs = {} + headers = {} + builds = [] + includes = [] + pyjt_includes = [] + for name in filenames: + name = os.path.realpath(name) + if name.endswith(".cc") or name.endswith(".cpp") or name.endswith(".cu"): + builds.append(name) + if name.endswith(".h"): + dirname = os.path.dirname(name) + if dirname.endswith("inc"): + includes.append(dirname) + with open(name, "r", encoding='utf8') as f: + if "@pyjt" in f.read(): + pyjt_includes.append(name) + bname = os.path.basename(name) + bname = os.path.splitext(bname)[0] + if bname.endswith("_op"): + bname = bname[:-3] + if name.endswith(".cc"): + srcs[bname] = name + elif name.endswith(".h"): + includes.append(os.path.dirname(name)) + headers[bname] = name + assert len(srcs) == len(headers), "Source and header names not match" + for name in srcs: + assert name in headers, f"Header of op {name} not found" + gen_name = "gen_ops_" + "_".join(headers.keys()) + if gen_name_ != "": + gen_name = gen_name_ + if len(gen_name) > 50: + gen_name = gen_name[:50] + "___hash" + hashlib.md5(gen_name.encode()).hexdigest()[:6] + + includes = sorted(list(set(includes))) + includes = "".join(map(lambda x: f" -I\"{x}\" ", includes)) + LOG.vvvv(f"Include flags:{includes}") + + op_extra_flags = includes + extra_flags + + lib_path = os.path.join(cache_path, "custom_ops") + make_cache_dir(lib_path) + gen_src_fname = os.path.join(lib_path, gen_name+".cc") + gen_head_fname = os.path.join(lib_path, gen_name+".h") + gen_lib = os.path.join(lib_path, gen_name+extension_suffix) + libname = gen_name + lib_suffix + op_extra_flags += f" -L\"{lib_path}\" -l\"{libname}\" " + + gen_src = gen_jit_op_maker(headers.values(), export=gen_name, extra_flags=op_extra_flags) + pyjt_compiler.compile_single(gen_head_fname, gen_src_fname, src=gen_src) + # gen src initialize first + builds.insert(0, gen_src_fname) + + def insert_anchor(gen_src, anchor_str, insert_str): + # insert insert_str after anchor_str into gen_src + return gen_src.replace(anchor_str, anchor_str+insert_str, 1) + + for name in pyjt_includes: + LOG.v("handle pyjt_include ", name) + bname = os.path.basename(name).split(".")[0] + gen_src_fname = os.path.join(cache_path, "custom_ops", gen_name+"_"+bname+".cc") + pyjt_compiler.compile_single(name, gen_src_fname) + builds.insert(1, gen_src_fname) + gen_src = insert_anchor(gen_src, + "namespace jittor {", + f"extern void pyjt_def_{bname}(PyObject* m);") + gen_src = insert_anchor(gen_src, + "init_module(PyModuleDef* mdef, PyObject* m) {", + f"jittor::pyjt_def_{bname}(m);") + + with open(gen_head_fname, "w", encoding='utf8') as f: + f.write(gen_src) + + LOG.vvv(f"Build custum ops lib:{gen_lib}") + LOG.vvvv(f"Build sources:{builds}") + compile(cc_path, extra_flags+cc_flags+opt_flags+includes, builds, gen_lib) + + # add python path and import + LOG.vvv(f"Import custum ops lib:{gen_lib}") + lib_path = os.path.join(cache_path, "custom_ops") + if lib_path not in os.sys.path: + os.sys.path.append(lib_path) + # unlock scope when initialize + with lock.unlock_scope(): + with jit_utils.import_scope(dlopen_flags): + exec(f"import {gen_name}") + mod = locals()[gen_name] + if return_module: + return mod + return mod.ops + + +def get_full_path_of_executable(name): + full_path = os.path.abspath(name) + while os.path.islink(full_path): + full_path = os.path.realpath(full_path) + if os.path.isfile(full_path) and os.access(full_path, os.X_OK): + return full_path + return get_full_path_of_executable(find_exe(name)) + +def compile_extern(): + # compile llvm passes + if cc_type != "clang" or platform.system() != 'Linux': + return + global kernel_opt_flags + cache_path_llvm = os.path.join(cache_path, "llvm") + jittor_path_llvm = os.path.join(jittor_path, "extern", "llvm") + clang_dir = os.path.dirname(get_full_path_of_executable(cc_path)) + assert clang_dir.endswith("bin") and "llvm" in clang_dir, f"Wrong clang_dir: {clang_dir}" + llvm_include = os.path.abspath(os.path.join(clang_dir, "..", "include")) + assert os.path.isdir(llvm_include), "LLVM include path not found" + make_cache_dir(cache_path_llvm) + files = os.listdir(jittor_path_llvm) + # test_pass.cc is used for test link problem of llvm pass plugin + test_pass_path = os.path.join(cache_path_llvm, "test_pass.cc") + with open(test_pass_path, 'w', encoding='utf8') as f: + f.write("int main() {return 0;}") + + # -fno-rtti fix link error + + # -Wl,-znodelete fix segfault + # https://github.com/sampsyo/llvm-pass-skeleton/issues/7#issuecomment-401834287 + + # -D_GLIBCXX_USE_CXX11_ABI=0 fix undefined symbol: createPrinterPass + # https://stackoverflow.com/questions/37366291/undefined-symbol-for-self-built-llvm-opt + + # try different flags + try_flags = [ + " -Wl,-znodelete -D_GLIBCXX_USE_CXX11_ABI=0 ", + " -Wl,-znodelete ", + ] + found_flags_id = -1 + for fname in files: + for i, flag in enumerate(try_flags): + if found_flags_id != -1 and found_flags_id != i: + continue + so_name = os.path.join(cache_path_llvm, os.path.splitext(fname)[0]+f".{i}.so") + compile( + cc_path, + f"{cc_flags} {opt_flags} {flag} -I'{llvm_include}'", + [os.path.join(jittor_path_llvm, fname)], + so_name + ) + # if not found available flags, we test it. + if found_flags_id == -1: + try: + s = run_cmd( + f"{cc_path} {cc_flags} -Xclang -load -Xclang '{so_name}' {test_pass_path}", + cache_path_llvm, + print_error=False + ) + except Exception as e: + LOG.v(f"Try flag {flag} failed: {e}") + continue + found_flags_id = i + kernel_opt_flags += f" -Xclang -load -Xclang '{so_name}' " + break + else: + LOG.w("Clang is used, but LLVM pass plugin is unable to link.") + break + LOG.vv(f"Compile extern llvm passes: {str(files)}") + +def check_cuda(): + if not nvcc_path: + return + global cc_flags, has_cuda, is_cuda, core_link_flags, cuda_dir, cuda_lib, cuda_include, cuda_home, cuda_bin + cuda_dir = os.path.dirname(get_full_path_of_executable(nvcc_path)) + cuda_bin = cuda_dir + cuda_home = os.path.abspath(os.path.join(cuda_dir, "..")) + # try default nvidia-cuda-toolkit in Ubuntu 20.04 + # assert cuda_dir.endswith("bin") and "cuda" in cuda_dir.lower(), f"Wrong cuda_dir: {cuda_dir}" + cuda_include = os.path.abspath(os.path.join(cuda_dir, "..", "include")) + cuda_lib = os.path.abspath(os.path.join(cuda_dir, "..", "lib64")) + if nvcc_path == "/usr/bin/nvcc": + # this nvcc is install by package manager + cuda_lib = "/usr/lib/x86_64-linux-gnu" + cuda_include2 = os.path.join(jittor_path, "extern","cuda","inc") + cc_flags += f" -DHAS_CUDA -DIS_CUDA -I\"{cuda_include}\" -I\"{cuda_include2}\" " + if os.name == 'nt': + cuda_lib = os.path.abspath(os.path.join(cuda_dir, "..", "lib", "x64")) + # cc_flags += f" \"{cuda_lib}\\cudart.lib\" " + cuda_lib_path = glob.glob(cuda_bin+"/cudart64*")[0] + cc_flags += f" -lcudart -L\"{cuda_lib}\" -L\"{cuda_bin}\" " + dll = ctypes.CDLL(cuda_lib_path, dlopen_flags) + ret = dll.cudaDeviceSynchronize() + assert ret == 0 + else: + cc_flags += f" -lcudart -L\"{cuda_lib}\" " + # ctypes.CDLL(cuda_lib+"/libcudart.so", import_flags) + ctypes.CDLL(cuda_lib+"/libcudart.so", dlopen_flags) + is_cuda = has_cuda = 1 + +def check_cache_compile(): + files = [ + "src/utils/cache_compile.cc", + "src/utils/log.cc", + "src/utils/tracer.cc", + "src/utils/jit_utils.cc", + "src/utils/str_utils.cc", + ] + if os.name == 'nt': + files = [ x.replace('/', '\\') for x in files ] + global jit_utils_core_files + jit_utils_core_files = files + recompile = compile(cc_path, cc_flags+f" {opt_flags} ", files, jit_utils.cache_path+'/jit_utils_core'+extension_suffix, True) + if recompile and jit_utils.cc: + LOG.e("jit_utils updated, please rerun your command.") + sys.exit(0) + if not jit_utils.cc: + with jit_utils.import_scope(import_flags): + jit_utils.try_import_jit_utils_core() + assert jit_utils.cc + # recompile, generate cache key + compile(cc_path, cc_flags+f" {opt_flags} ", files, jit_utils.cache_path+'/jit_utils_core'+extension_suffix, True) + +def env_or_try_find(name, bname): + if name in os.environ: + path = os.environ[name] + if path != "": + version = jit_utils.get_version(path) + LOG.i(f"Found {bname}{version} at {path}") + return path + return try_find_exe(bname) + +def try_find_exe(*args): + try: + return find_exe(*args) + except: + LOG.v(f"{args[0]} not found.") + return "" + +def check_pybt(gdb_path, python_path): + if gdb_path=='' or python_path=='': + return False + return True + # TODO: prev we use below code to check has py-bt or nor + # but it is too slow, so we comment it, + # find a better way to check py-bt exist + + # ret = sp.getoutput(f"{gdb_path} --batch {python_path} -ex 'help py-bt'") + # if 'python frame' in ret: + # LOG.v("py-bt found in gdb.") + # return True + # return False + +def check_debug_flags(): + global is_debug + is_debug = 0 + if os.environ.get("debug")=="1": + is_debug = 1 + global cc_flags + cc_flags += " -g -DNODE_MEMCHECK " + +cc_flags = " " +# os.RTLD_NOW | os.RTLD_GLOBAL cause segfault when import torch first +import_flags = os.RTLD_NOW | os.RTLD_GLOBAL +if platform.system() == 'Linux': + import_flags |= os.RTLD_DEEPBIND +# if cc_type=="icc": +# # weird link problem, icc omp library may conflict and cause segfault +# import_flags = os.RTLD_NOW | os.RTLD_GLOBAL +dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL +if platform.system() == 'Linux': + import_flags |= os.RTLD_DEEPBIND + +with jit_utils.import_scope(import_flags): + jit_utils.try_import_jit_utils_core() + +jittor_path = find_jittor_path() +if os.name == 'nt': + # prevent windows recompile + jittor_path = jittor_path.lower() +check_debug_flags() + +sys.path.append(cache_path) +LOG.i(f"Jittor({__version__}) src: {jittor_path}") +LOG.i(f"{jit_utils.cc_type} at {jit_utils.cc_path}{jit_utils.get_version(jit_utils.cc_path)}") +LOG.i(f"cache_path: {cache_path}") + +with jit_utils.import_scope(import_flags): + jit_utils.try_import_jit_utils_core() + +python_path = sys.executable +# sometime python do not return the correct sys executable +# this will happend when multiple python version installed +ex_python_path = python_path + '.' + str(sys.version_info.minor) +if os.path.isfile(ex_python_path): + python_path = ex_python_path + +# if jtcuda is already installed +nvcc_path = None +if install_cuda.has_installation() or os.name == 'nt': + nvcc_path = install_cuda.install_cuda() + if nvcc_path: + nvcc_path = try_find_exe(nvcc_path) +# check system installed cuda +if not nvcc_path: + nvcc_path = env_or_try_find('nvcc_path', 'nvcc') or \ + try_find_exe('/usr/local/cuda/bin/nvcc') or \ + try_find_exe('/usr/bin/nvcc') or \ + try_find_exe('/opt/cuda/bin/nvcc') +# if system has no cuda, install jtcuda +if not nvcc_path: + nvcc_path = install_cuda.install_cuda() + if nvcc_path: + nvcc_path = try_find_exe(nvcc_path) +if nvcc_path is None: + nvcc_path = "" +if "nvcc_path" in os.environ: + nvcc_path = os.environ["nvcc_path"] +gdb_path = env_or_try_find('gdb_path', 'gdb') +addr2line_path = try_find_exe('addr2line') +has_pybt = check_pybt(gdb_path, python_path) + +if nvcc_path: + # gen cuda key for cache_path + cu = "cu" + v = jit_utils.get_version(nvcc_path)[1:-1] + nvcc_version = list(map(int,v.split('.'))) + cu += v + try: + r, s = sp.getstatusoutput(f"log_v=0 {sys.executable} -m jittor_utils.query_cuda_cc") + if r==0: + s = sorted(list(set(s.strip().split()))) + if len(s)==0: + LOG.e("No GPU Device Found!") + cu += "_sm_" + "_".join(s) + if "cuda_arch" not in os.environ: + os.environ["cuda_arch"] = " ".join(cu) + cu = cu.replace(":", "").replace(" ", "") + except: + pass + LOG.i("cuda key:", cu) + cache_path = os.path.join(cache_path, cu) + sys.path.append(cache_path) + + +def check_clang_latest_supported_cpu(): + output = run_cmd('clang --print-supported-cpus') + def find_latest_chip_version(pattern_prefix): + apple_cpus = [l.strip() for l in output.split('\n') if pattern_prefix in l] + apple_cpu_id = max([int(cpu[7:]) for cpu in apple_cpus]) + return pattern_prefix + str(apple_cpu_id) + if 'apple-m' in output: + return find_latest_chip_version('apple-m') + else: + return find_latest_chip_version('apple-a') + +# cc_flags += " -Wall -Werror -Wno-unknown-pragmas -std=c++14 -fPIC " +cc_flags += " -Wall -Wno-unknown-pragmas -std=c++14 -fPIC " +# 1. Arch/CPU specific optimization +if platform.machine() in ["x86_64", "AMD64"]: + cc_flags += " -march=native " +elif platform.machine() == 'arm64' and platform.system() == "Darwin": + cc_flags += f" -mcpu={check_clang_latest_supported_cpu()} " +cc_flags += " -fdiagnostics-color=always " +# 2. Non standard include path +if platform.system() == 'Darwin': + # TODO: if not using apple clang, there is no need to add -lomp + cc_flags += " -undefined dynamic_lookup -lomp " + if os.environ.get("CONDA_PREFIX", None): + cc_flags += f" -L{os.path.join(os.environ['CONDA_PREFIX'], 'lib')} " + # if platform.machine() == "arm64": + # cc_flags += " -I/opt/homebrew/include -L/opt/homebrew/lib " + # Homebrew does not symlink the openmp library (libomp >= 15.0.6) into /opt/homebrew/lib + homebrew_openmp_paths = [ + "/opt/homebrew/opt/libomp", + "/usr/local/opt/libomp" + ] + for openmp_path in homebrew_openmp_paths: + if os.path.exists(openmp_path): + cc_flags += f" -I{openmp_path}/include -L{openmp_path}/lib" + +# 3. User specified flags +if "cc_flags" in os.environ: + cc_flags += os.environ["cc_flags"] + ' ' + +cc_flags += " -lstdc++ -ldl -shared " + +opt_flags = "" + +py_include = jit_utils.get_py3_include_path() +LOG.v(f"py_include: {py_include}") +extension_suffix = jit_utils.get_py3_extension_suffix() +lib_suffix = extension_suffix.rsplit(".", 1)[0] +LOG.v(f"extension_suffix: {extension_suffix}") +so = ".so" if os.name != 'nt' else ".dll" + + +kernel_opt_flags = os.environ.get("kernel_flags", "") + opt_flags +if platform.system() == 'Darwin': + # TODO: if not using apple clang, cannot add -Xpreprocessor + kernel_opt_flags += " -Xpreprocessor -fopenmp " +elif cc_type != 'cl': + kernel_opt_flags += " -fopenmp " +def fix_cl_flags(cmd): + output = shsplit(cmd) + output2 = [] + libpaths = [] + for s in output: + if s.startswith("-l") and ("cpython" in s or "lib" in s): + if platform.system() == 'Darwin': + fname = s[2:] + ".so" + for path in reversed(libpaths): + full = os.path.join(path, fname).replace("\"", "") + if os.path.isfile(full): + output2.append(full) + break + else: + output2.append(s) + else: + output2.append(f"-l:{s[2:]}.so") + elif s.startswith("-L"): + libpaths.append(s[2:]) + output2.append(f"{s} -Wl,-rpath,{s[2:]}") + else: + output2.append(s) + return " ".join(output2) + +if os.name == 'nt': + if cc_type == 'g++': + pass + elif cc_type == 'cl': + py3_link_path = jit_utils.get_py3_link_path() + cc_flags = remove_flags(cc_flags, ["-f", "-m"]) + cc_flags = cc_flags.replace("-std=c++14", "-std=c++17") + cc_flags = cc_flags.replace("-lstdc++", "") + cc_flags = cc_flags.replace("-ldl", "") + cc_flags += f" -L\"{py3_link_path}\" -lpython3{sys.version_info.minor} " + cc_flags += " -EHa -MD -utf-8 " + import jittor_utils + if jittor_utils.msvc_path: + mp = jittor_utils.msvc_path + cc_flags += f' -nologo -I"{mp}\\VC\\include" -I"{mp}\\win10_kits\\include\\ucrt" -I"{mp}\\win10_kits\\include\\shared" -I"{mp}\\win10_kits\\include\\um" -DNOMINMAX ' + cc_flags += f' -L"{mp}\\VC\\lib" -L"{mp}\\win10_kits\\lib\\um\\x64" -L"{mp}\\win10_kits\\lib\\ucrt\\x64" ' + win_libpaths = {} + def fix_cl_flags(cmd): + cmd = cmd.replace(".o ", ".obj ") + cmd = cmd.replace(".o\"", ".obj\"") + if cmd.endswith(".o"): cmd += "bj" + if " -o " in cmd: + if " -shared " in cmd: + cmd = cmd.replace(" -o ", " -Fe: ") + output = shsplit(cmd.split("-Fe:")[1].strip())[0] + base_output = os.path.basename(output).split('.')[0] + cmd += f" -DEF:{output}.def -IGNORE:4102 -IGNORE:4197 -IGNORE:4217 " + + elif " -c -o " in cmd: + cmd = cmd.replace(" -c -o ", " -c -Fo: ") + flags = shsplit(cmd) + output = [] + output2 = [] + for f in flags: + if f.startswith("-link"): + pass + elif f.startswith("-l"): + output2.append(f[2:]+".lib") + elif f.startswith("-LIB"): + output2.append(f) + elif f.startswith("-LD"): + output.append(f) + elif f.startswith("-L"): + path = f[2:].replace("\"", "") + if path not in win_libpaths: + win_libpaths[path] = 1 + os.add_dll_directory(path) + os.environ["PATH"] = f";{path};" + os.environ["PATH"] + output2.append("-LIBPATH:"+f[2:]) + elif ".lib" in f: + output2.append(f) + elif f.startswith("-DEF:"): + output2.append(f) + elif f.startswith("-W") or f.startswith("-f"): + pass + elif f.startswith("-std="): + output.append(f.replace("=", ":")) + else: + output.append(f) + cmd = " ".join(output) + if len(output2): + cmd += " -link " + " ".join(output2) + cmd = cmd.replace("-include", "-FI") + cmd = cmd.replace("-shared", "-LD") + return cmd + +if ' -O' not in cc_flags: + if os.environ.get("debug", "0") == "1": + opt_flags += " -O0 " + else: + opt_flags += " -O2 " + kernel_opt_flags += " -Ofast " +lto_flags = "" +if os.environ.get("enable_lto") == "1": + if cc_type == "icc": + lto_flags = " -flto -ipo -ipo-c " + elif cc_type == "g++": + lto_flags = " -flto -fuse-linker-plugin " + else: + lto_flags = " -flto " + +make_cache_dir(cache_path) +make_cache_dir(os.path.join(cache_path, "jit")) +make_cache_dir(os.path.join(cache_path, "obj_files")) +make_cache_dir(os.path.join(cache_path, "gen")) +make_cache_dir(os.path.join(cache_path, "tmp")) +ck_path = os.path.join(cache_path, "checkpoints") +make_cache_dir(ck_path) + + +ascend_toolkit_home = os.getenv('ASCEND_TOOLKIT_HOME') + +# build cache_compile +cc_flags += f" -I\"{os.path.join(jittor_path, 'src')}\" " +cc_flags += f" -I\"{os.path.join(jittor_path, 'extern')}\" " +cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include')}\" " +cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/acl')}\" " +cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/aclnn')}\" " +cc_flags += f" -I\"{os.path.join(ascend_toolkit_home, 'include/aclnnop')}\" " +cc_flags += f" -L\"{os.path.join(ascend_toolkit_home, 'lib64')}\" " +cc_flags += " -llibascendcl " +cc_flags += " -llibnnopbase " +cc_flags += " -llibopapi " + +cc_flags += py_include + +check_cache_compile() +LOG.v(f"Get cache_compile: {jit_utils.cc}") + +# check cuda +is_cuda = has_cuda = 0 +check_cuda() +nvcc_flags = os.environ.get("nvcc_flags", "") +if has_cuda: + nvcc_flags += cc_flags + def convert_nvcc_flags(nvcc_flags): + # nvcc don't support -Wall option + if os.name == 'nt': + nvcc_flags = nvcc_flags.replace("-fp:", "-Xcompiler -fp:") + nvcc_flags = nvcc_flags.replace("-EH", "-Xcompiler -EH") + nvcc_flags = nvcc_flags.replace("-M", "-Xcompiler -M") + nvcc_flags = nvcc_flags.replace("-utf", "-Xcompiler -utf") + nvcc_flags = nvcc_flags.replace("-nologo", "") + nvcc_flags = nvcc_flags.replace("-std:", "-std=") + nvcc_flags = nvcc_flags.replace("-Fo:", "-o") + nvcc_flags = nvcc_flags.replace("-LD", "-shared") + nvcc_flags = nvcc_flags.replace("-LIBPATH:", "-L") + nvcc_flags = nvcc_flags.replace("-link", "") + def func(x): + if ".lib" not in x: return x + x = x.replace("\"", "") + a = os.path.dirname(x) + b = os.path.basename(x) + if not b.endswith(".lib"): + return x + return f"-L\"{a}\" -l{b[:-4]}" + nvcc_flags = map_flags(nvcc_flags, func) + if nvcc_version >= [11,4]: + nvcc_flags = nvcc_flags.replace("-std=c++17", "-std=c++14 -Xcompiler -std:c++14") + else: + nvcc_flags = nvcc_flags.replace("-std=c++17", "") + nvcc_flags = nvcc_flags.replace("-Wall", "") + nvcc_flags = nvcc_flags.replace("-Wno-unknown-pragmas", "") + nvcc_flags = nvcc_flags.replace("-fopenmp", "") + nvcc_flags = nvcc_flags.replace("-march", "-Xcompiler -march") + nvcc_flags = nvcc_flags.replace("-Werror", "") + nvcc_flags = nvcc_flags.replace("-fPIC", "-Xcompiler -fPIC") + nvcc_flags = nvcc_flags.replace("-fdiagnostics", "-Xcompiler -fdiagnostics") + nvcc_flags += f" -x cu --cudart=shared -ccbin=\"{cc_path}\" --use_fast_math " + # nvcc warning is noise + nvcc_flags += " -w " + nvcc_flags += f" -I\"{os.path.join(jittor_path, 'extern/cuda/inc')}\" " + if os.environ.get("cuda_debug", "0") == "1": + nvcc_flags += " -G " + return nvcc_flags + nvcc_flags = convert_nvcc_flags(nvcc_flags) + +extra_core_files = [] +setup_fake_cuda_lib = False +# from .acl_compiler import check_acl +from .extern.acl import acl_compiler +jit_utils.add_backend(acl_compiler) +from .extern.rocm import rocm_compiler +jit_utils.add_backend(rocm_compiler) +from .extern.corex import corex_compiler +jit_utils.add_backend(corex_compiler) + +for mod in jit_utils.backends: + if mod.check(): + break + +if not os.name == 'nt': + is_cuda = os.path.basename(nvcc_path) == "nvcc" +else: + is_cuda = os.path.basename(nvcc_path) == "nvcc.exe" + +# build core +gen_jit_flags() +gen_jit_tests() +op_headers = glob.glob(jittor_path+"/src/ops/**/*op.h", recursive=True) +jit_src = gen_jit_op_maker(op_headers) +LOG.vvvv(jit_src) +with open(os.path.join(cache_path, "gen", "jit_op_maker.h"), 'w', encoding='utf8') as f: + f.write(jit_src) +cc_flags += f' -I\"{cache_path}\" -L\"{cache_path}\" -L\"{jit_utils.cache_path}\" ' + +# gen pyjt +pyjt_gen_src = pyjt_compiler.compile(cache_path, jittor_path) + +# initialize order: +# 1. registers +# 2. generate source +# 3. op_utils +# 4. other +files2 = pyjt_gen_src +ext_args = 'c[cu]' if has_cuda or has_rocm else 'cc' +files4 = glob.glob(jittor_path+"/src/**/*."+ext_args, recursive=True) +files4 = [ f[len(jittor_path)+1:] for f in files4 ] +# files4 = run_cmd('find -L src | grep '+grep_args, jittor_path).splitlines() +at_beginning = [ + "src/ops/op_utils.cc", + "src/ops/op_register.cc", + "src/init.cc", + "src/event_queue.cc", + "src/mem/allocator/sfrl_allocator.cc", + "src/mem/allocator.cc", + "src/misc/nano_string.cc", +] +at_last = [ + "src/profiler/profiler.cc", + "src/executor.cc", +] +if os.name == 'nt': + at_beginning = [ x.replace('/','\\') for x in at_beginning ] + at_last = [ x.replace('/','\\') for x in at_last ] +for i in range(len(at_beginning)): + files4.remove(at_beginning[i]) + files4.insert(i, at_beginning[i]) +for v in at_last: + files4.remove(v) + files4.append(v) +registers = [ name for name in files4 if "register" in name ] +for name in registers: files4.remove(name) +files = registers + files2 + files4 + + +#print(extra_core_files) +#extra_core_files.append("/home/ma-user/work/jittor/python/jittor/extern/acl/aclnn/aclnn.cc") +files += extra_core_files +for file in jit_utils_core_files: + files.remove(file) +LOG.vv("compile order:", files) + +if platform.system() == 'Linux': + libname = {"clang":"omp", "icc":"iomp5", "g++":"gomp"}[cc_type] + libname = ctypes.util.find_library(libname) + assert libname is not None, "openmp library not found" + ctypes.CDLL(libname, os.RTLD_NOW | os.RTLD_GLOBAL) + +if platform.machine()=='sw_64': + import ssl + ssl._create_default_https_context = ssl._create_unverified_context + +data_gz_path = os.path.join(jittor_path, "utils", "data.gz") +use_data_gz = os.path.isfile(data_gz_path) +if os.environ.get("use_data_gz", "1") == "0": + use_data_gz = False +if use_data_gz: + import gzip + with gzip.open(data_gz_path, 'rb') as f: + data = f.read() + md5 = hashlib.md5(data).hexdigest() + target_md5 = None + data_gz_md5_path = os.path.join(cache_path, "data.md5") + if os.path.isfile(data_gz_md5_path): + with open(data_gz_md5_path, 'r') as f: + target_md5 = f.read() + data_o_path = os.path.join(cache_path, "data.o") + if target_md5 != md5: + data_s_path = os.path.join(cache_path, "data.cc") + with open(data_s_path, "w") as f: + f.write(data.decode("utf8")) + dflags = (cc_flags+opt_flags)\ + .replace("-Wall", "") \ + .replace("-Werror", "") \ + .replace("-shared", "") + vdp = os.path.join(jittor_path, "src", "utils", "vdp") + run_cmd(fix_cl_flags(f"\"{cc_path}\" {dflags} -include \"{vdp}\" \"{data_s_path}\" -c -o \"{data_o_path}\"")) + os.remove(data_s_path) + with open(data_gz_md5_path, 'w') as f: + f.write(md5) + files.append(data_o_path) + files = [f for f in files if "__data__" not in f] +else: + files = [f for f in files + if "__data__" not in f or "src" in f.split("__data__")[1]] + + +#print(jittor_path) +#print(cc_flags) +#print(files) +cc_flags += f" -l\"jit_utils_core{lib_suffix}\" " +compile(cc_path, cc_flags+opt_flags, files, 'jittor_core'+extension_suffix) +cc_flags += f" -l\"jittor_core{lib_suffix}\" " + +# TODO: move to compile_extern.py +# compile_extern() + +with jit_utils.import_scope(import_flags): + import jittor_core as core + +flags = core.Flags() + +if has_cuda and is_cuda: + nvcc_flags = " " + os.environ.get("nvcc_flags", "") + " " + nvcc_flags += convert_nvcc_flags(cc_flags) + nvcc_version = list(jit_utils.get_int_version(nvcc_path)) + max_arch = 89 + if nvcc_version < [11,]: + max_arch = 75 + elif nvcc_version < [11,1]: + max_arch = 80 + elif nvcc_version < [11,8]: + max_arch = 86 + if len(flags.cuda_archs): + min_arch = 30 + archs = [] + for arch in flags.cuda_archs: + if archmax_arch: + LOG.w(f"CUDA arch({arch})>{max_arch} will be backward-compatible") + arch = max_arch + archs.append(arch) + flags.cuda_archs = archs + nvcc_flags += f" -arch=compute_{min(archs)} " + nvcc_flags += ''.join(map(lambda x:f' -code=sm_{x} ', archs)) + +flags.cc_path = cc_path +flags.cc_type = cc_type +flags.cc_flags = cc_flags + kernel_opt_flags +flags.nvcc_path = nvcc_path +flags.nvcc_flags = nvcc_flags +flags.python_path = python_path +flags.cache_path = cache_path +flags.jittor_path = jittor_path +flags.gdb_path = gdb_path +flags.addr2line_path = addr2line_path +flags.has_pybt = has_pybt + +core.set_lock_path(lock.lock_path) diff --git a/python/jittor/contrib.py b/python/jittor/contrib.py new file mode 100644 index 00000000..c4026eee --- /dev/null +++ b/python/jittor/contrib.py @@ -0,0 +1,274 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +import numpy as np +from jittor import pool +from collections.abc import Sequence + + +def argmax_pool(x, size, stride, padding=0): + if stride<=0: + raise RuntimeError(f"stride must be > 0, but got {stride}") + return pool.pool(x, size, 'maximum', padding, stride) + +def concat(arr, dim): + '''Concat Operator can concat a list of jt Var at a specfic dimension. + + * [in] x: input var list for concat + + * [in] dim: concat which dim + + * [out] out: concat result + +Example:: + + >>> jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1) + jt.Var([[1 2] + [2 2]], dtype=int32) + ''' + # TODO: low performance when concat lots of vars + total_dim = 0 + if dim < 0: dim += len(arr[0].shape) + for a in arr: + total_dim += a.shape[dim] + cdim = 0 + s = None + indexes = [ f"i{i}" for i in range(len(a.shape)) ] + for a in arr: + shape = list(a.shape) + shape[dim] = total_dim + indexes[dim] = f"i{dim}-{cdim}" + b = a.reindex(shape, indexes) + # ugly fix for preventing large fused op + if len(arr)>=100: + b.stop_fuse() + if s is None: + s = b + else: + s += b + cdim += a.shape[dim] + return s + +def check(bc): + bc = np.array(bc) + if ((bc != 1) * (bc != bc.max(0))).sum() > 0: + raise Exception(f"Shape not match.") + else: + return bc.max(0) + +def slice_var_index(x, slices): + if not isinstance(slices, tuple): + slices = (slices,) + if isinstance(slices[0], jt.Var): + if len(slices) == 1 and slices[0].dtype == "bool": + return slice_var_index(x, tuple(slices[0].where())) + bc = [] + ml = -1 + for idx, s in enumerate(slices): + if isinstance(s, jt.Var): + shape = s.shape + elif isinstance(s, np.ndarray): + shape = list(s.shape) + elif isinstance(s, list): + shape = list(np.array(s).shape) + else: + continue + if len(shape) >= ml: + ml = len(shape) + bc.append(shape) + for idx, shape in enumerate(bc): + if len(shape) < ml: + shape = (ml - len(shape)) * [1] + shape + bc[idx] = shape + if len(bc) >= 1: + bc_shape = check(bc) + ss = [] + for idx, s in enumerate(slices): + if isinstance(s, np.ndarray) or isinstance(s, list): + ss.append(jt.array(s).broadcast(bc_shape.tolist())) + elif isinstance(s, jt.Var): + ss.append(s.broadcast(bc_shape.tolist())) + else: + ss.append(s) + slices = ss + out_shape = [] + out_index = [] + shape = x.shape + cnt_list = 0 + extras_idx = [] + extras = [] + has_ellipse = 0 + ellipse_index = 0 + for s,i in zip(slices,range(len(slices))): + if isinstance(s,type(...)): + has_ellipse+=1 + ellipse_index = i + if has_ellipse>1: + raise Exception(f"There are more than one ...") + elif has_ellipse==1: + slices = list(slices) + del slices[ellipse_index] + while len(slices)=len(slices): + s = slice(None) + else: + s = slices[i] + sp = shape[i] + j = len(out_shape) + if isinstance(s, int): + if s<0: s += sp + out_index.append(str(s)) + elif isinstance(s, slice): + if s == slice(None): + out_shape.append(sp) + out_index.append(f"i{j}") + continue + start = 0 if s.start is None else s.start + stop = sp if s.stop is None else s.stop + step = 1 if s.step is None else s.step + if start<0: start += sp + if stop<0: stop += sp + if stop>sp+1: stop = sp + out_shape.append(1+int(max(0, (stop-start-1)//step))) + out_index.append(f"{start}+i{j}*{step}") + elif isinstance(s, jt.Var): + if cnt_list == 0: + for idx in range(len(bc_shape)): + extras_idx.append(f"i{len(out_shape) + idx}") + out_shape += bc_shape.tolist() + out_index.append(f"@e{cnt_list}("+ ",".join(extras_idx) + ")") + cnt_list += 1 + extras.append(s) + else: + raise Exception(f"Not support slice {s}") + if len(out_shape)==0: + out_shape = [1] + # Stop fuse both input and output, prevent recompile + x.stop_fuse() + return (out_shape, out_index, 0, [], extras) + +def _slice_var_old(x, slices): + reindex_args = slice_var_index(x, slices) + x.stop_fuse() + return x.reindex(*reindex_args).stop_fuse() + +def _setitem_old(x, slices, value): + reindex_args = slice_var_index(x, slices) + reindex_reduce_args = (x.shape, reindex_args[1]) + reindex_args[3:] + xslice = x.stop_fuse().reindex(*reindex_args).stop_fuse() + value = jt.broadcast(value, xslice) + value = value.cast(x.dtype) + one = jt.broadcast(1, xslice) + if not isinstance(reindex_args[0][0], jt.Var): + reindex_args = (x.shape,) + reindex_args[1:] + mask = one.reindex_reduce("add", *reindex_reduce_args) + data = value.reindex_reduce("add", *reindex_reduce_args) + # Stop fuse both input and output, prevent recompile + out = mask.ternary(data, x).stop_fuse() + x.assign(out) + return x + +# PATCH +def getitem(x, slices): + if isinstance(slices, jt.Var) and slices.dtype == "bool": + return getitem(x, slices.where()) + if isinstance(slices, tuple): + ss = [] + for s in slices: + if isinstance(s, jt.Var) and s.dtype == "bool": + ss.extend(s.where()) + else: + ss.append(s) + slices = tuple(ss) + return x.getitem(slices) + +def setitem(x, slices, value): + if isinstance(slices, jt.Var) and slices.dtype == "bool": + if slices.shape == x.shape: + if isinstance(value, (int, float)): + value = jt.array(value).broadcast(x.shape) + return x.assign(slices.ternary(value, x)) + elif isinstance(value, jt.Var) and value.shape == [1,]: + value = jt.broadcast(value, x.shape) + return x.assign(slices.ternary(value, x)) + slices = slices.where() + elif isinstance(slices, tuple): + ss = [] + for s in slices: + if isinstance(s, jt.Var) and s.dtype == "bool": + ss.extend(s.where()) + else: + ss.append(s) + slices = tuple(ss) + return x.check_cascade_setitem(x.setitem(slices, value)) + +jt.Var.__getitem__ = jt.Var.slice_var = getitem +jt.Var.__setitem__ = setitem + + +def _merge_dtypes(dtypes): + dtype = dtypes[0] + for i in range(1, len(dtypes)): + dtype = jt.binary_dtype_infer("add", dtype, dtypes[i]) + return dtype + +@jt.flag_scope(amp_reg=4) # _custom_flag +def concat(arr, dim=0): + '''Concat Operator can concat a list of jt Var at a specfic dimension. + + * [in] x: input var list for concat + + * [in] dim: concat which dim + + * return: concat result + +Example:: + + jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1) + # return jt.Var([[1,2],[2,2]],dtype=int32) + ''' + if not isinstance(arr, Sequence): + raise TypeError("concat arr needs to be a tuple or list") + if len(arr) == 0: + raise ValueError("need at least one array to concat") + total_dim = 0 + base_dim = len(arr[0].shape) + if dim < 0: dim += base_dim + if dim < 0 or dim >= base_dim: + raise IndexError(f"Dimension out of range (expected to be in range of [{-base_dim}, {base_dim-1}], but got {dim})") + dtypes = [] + for a in arr: + if len(a.shape) != base_dim: + raise RuntimeError(f"get different number of dimensions of {base_dim} and {len(a.shape)}") + for i in range(base_dim): + if i != dim and a.shape[i] != arr[0].shape[i]: + raise RuntimeError(f"Sizes of vars must match except in dimension {dim}. Expected size {arr[0].shape[i]} but got size {a.shape[i]} for dimension number {i} in the list.") + total_dim += a.shape[dim] + dtypes.append(str(a.dtype)) + cdim = 0 + shape = list(a.shape) + shape[dim] = total_dim + s = jt.empty(shape, dtype = _merge_dtypes(dtypes)) + slices = [slice(None)]*len(a.shape) + for a in arr: + if a.shape[dim] == 0: + continue + slices[dim] = slice(cdim, cdim+a.shape[dim]) + # print(slices, type(a)) + s = s.setitem(tuple(slices), a) + # s = jt.setitem(s, tuple(slices), a) + cdim += a.shape[dim] + return s + +cat = concat diff --git a/python/jittor/dataset/__init__.py b/python/jittor/dataset/__init__.py new file mode 100644 index 00000000..4537691a --- /dev/null +++ b/python/jittor/dataset/__init__.py @@ -0,0 +1,6 @@ + +from .dataset import Dataset, ImageFolder, dataset_root, TensorDataset, VarDataset, DataLoader +from .mnist import MNIST +from .cifar import CIFAR10, CIFAR100 +from .voc import VOC +from .sampler import * \ No newline at end of file diff --git a/python/jittor/dataset/cifar.py b/python/jittor/dataset/cifar.py new file mode 100644 index 00000000..31d6aadc --- /dev/null +++ b/python/jittor/dataset/cifar.py @@ -0,0 +1,189 @@ + +import os +from jittor_utils.misc import download_and_extract_archive, check_integrity +from PIL import Image +import sys, pickle +import numpy as np +from jittor.dataset import Dataset, dataset_root + +class CIFAR10(Dataset): + """`CIFAR10 `_ Dataset. + + Args: + root (string): Root directory of dataset where directory + ``cifar-10-batches-py`` exists or will be saved to if download is set to True. + train (bool, optional): If True, creates dataset from training set, otherwise + creates from test set. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + + Example:: + + + from jittor.dataset.cifar import CIFAR10 + a = CIFAR10() + a.set_attrs(batch_size=16) + for imgs, labels in a: + print(imgs.shape, labels.shape) + break + + """ + base_folder = 'cifar-10-batches-py' + url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" + filename = "cifar-10-python.tar.gz" + tgz_md5 = 'c58f30108f718f92721af3b95e74349a' + train_list = [ + ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], + ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], + ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], + ['data_batch_4', '634d18415352ddfa80567beed471001a'], + ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], + ] + + test_list = [ + ['test_batch', '40351d587109b95175f43aff81a1287e'], + ] + meta = { + 'filename': 'batches.meta', + 'key': 'label_names', + 'md5': '5ff9c542aee3614f3951f8cda6e48888', + } + + def __init__(self, root=dataset_root+"/cifar_data/", train=True, transform=None, target_transform=None, + download=True): + + super(CIFAR10, self).__init__() + self.root = root + self.transform=transform + self.target_transform=target_transform + + self.train = train # training set or test set + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + if self.train: + downloaded_list = self.train_list + else: + downloaded_list = self.test_list + + self.data = [] + self.targets = [] + + # now load the picked numpy arrays + for file_name, checksum in downloaded_list: + file_path = os.path.join(self.root, self.base_folder, file_name) + with open(file_path, 'rb') as f: + if sys.version_info[0] == 2: + entry = pickle.load(f) + else: + entry = pickle.load(f, encoding='latin1') + self.data.append(entry['data']) + if 'labels' in entry: + self.targets.extend(entry['labels']) + else: + self.targets.extend(entry['fine_labels']) + + self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) + self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC + + self._load_meta() + + def _load_meta(self): + path = os.path.join(self.root, self.base_folder, self.meta['filename']) + if not check_integrity(path, self.meta['md5']): + raise RuntimeError('Dataset metadata file not found or corrupted.' + + ' You can use download=True to download it') + with open(path, 'rb') as infile: + if sys.version_info[0] == 2: + data = pickle.load(infile) + else: + data = pickle.load(infile, encoding='latin1') + self.classes = data[self.meta['key']] + self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is index of the target class. + """ + img, target = self.data[index], self.targets[index] + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(img) + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + return len(self.data) + + def _check_integrity(self): + root = self.root + for fentry in (self.train_list + self.test_list): + filename, md5 = fentry[0], fentry[1] + fpath = os.path.join(root, self.base_folder, filename) + if not check_integrity(fpath, md5): + return False + return True + + def download(self): + if self._check_integrity(): + print('Files already downloaded and verified') + return + download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) + + def extra_repr(self): + return "Split: {}".format("Train" if self.train is True else "Test") + + +class CIFAR100(CIFAR10): + """`CIFAR100 `_ Dataset. + + This is a subclass of the `CIFAR10` Dataset. + + + Example:: + + + from jittor.dataset.cifar import CIFAR100 + a = CIFAR100() + a.set_attrs(batch_size=16) + for imgs, labels in a: + print(imgs.shape, labels.shape) + break + """ + base_folder = 'cifar-100-python' + url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" + filename = "cifar-100-python.tar.gz" + tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' + train_list = [ + ['train', '16019d7e3df5f24257cddd939b257f8d'], + ] + + test_list = [ + ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], + ] + meta = { + 'filename': 'meta', + 'key': 'fine_label_names', + 'md5': '7973b15100ade9c7d40fb424638fde48', + } diff --git a/python/jittor/dataset/dataset.py b/python/jittor/dataset/dataset.py new file mode 100644 index 00000000..ad79675c --- /dev/null +++ b/python/jittor/dataset/dataset.py @@ -0,0 +1,728 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Meng-Hao Guo +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import numpy as np +from urllib import request +import gzip +import pickle +import os +from jittor.dataset.utils import get_random_list, get_order_list, collate_batch, HookTimer +from collections.abc import Sequence, Mapping +import pathlib +from PIL import Image +import multiprocessing as mp +import signal +from jittor_utils import LOG +import jittor as jt +import time +import jittor_utils as jit_utils + +dataset_root = os.path.join(jit_utils.home(), ".cache", "jittor", "dataset") +mp_log_v = os.environ.get("mp_log_v", 0) +mpi = jt.mpi +img_open_hook = HookTimer(Image, "open") +CHECK_MEMORY = int(os.environ.get("CHECK_MEMORY", "0")) + +if os.name == "nt": + from multiprocessing import shared_memory + class RingBuffer: + def __init__(self, size, shm=None): + for i in range(100): + if (1<= size: break + size = 1<= 0 + assert self.batch_size > 0 + if self.drop_last: + return self.total_len // self.batch_size + return (self.total_len-1) // self.batch_size + 1 + + def __len__(self): + return self.__batch_len__() + + def set_attrs(self, **kw): + ''' + You can set attributes of dataset by using set_attrs function, including total_len, batch_size, shuffle, drop_last, num_workers, buffer_size. + + Example:: + + dataset = YourDataset().set_attrs(batch_size=256, shuffle=True) + + Attrs: + + * batch_size(int): batch size, default 16. + * total_len(int): total lenght. + * shuffle(bool): shuffle at each epoch, default False. + * drop_last(bool): if true, the last batch of dataset might smaller than batch_size, default True. + * num_workers: number of workers for loading data + * buffer_size: buffer size for each worker in bytes, default(512MB). + * stop_grad: stop grad for data, default(True). + ''' + for k,v in kw.items(): + assert hasattr(self, k), k + setattr(self, k, v) + self.reset() + return self + + def to_jittor(self, batch): + ''' + Change batch data to jittor array, such as np.ndarray, int, and float. + ''' + if self.keep_numpy_array: return batch + if isinstance(batch, jt.Var): return batch + to_jt = lambda x: jt.array(x).stop_grad() \ + if self.stop_grad else jt.array(x) + if isinstance(batch, np.ndarray): + return to_jt(batch) + if isinstance(batch, dict): + new_batch = {} + for k,v in batch.items(): + new_batch[k] = self.to_jittor(v) + return new_batch + if not isinstance(batch, (list, tuple)): + return batch + new_batch = [] + for a in batch: + if isinstance(a, np.ndarray): + new_batch.append(to_jt(a)) + else: + new_batch.append(self.to_jittor(a)) + return new_batch + + def collate_batch(self, batch): + ''' + Puts each data field into a tensor with outer dimension batch size. + + Args:: + + [in] batch(list): A list of variables, such as jt.var, Image.Image, np.ndarray, int, float, str and so on. + + ''' + return collate_batch(batch) + + def terminate(self): + ''' + Terminate is used to terminate multi-process worker reading data. + ''' + if hasattr(self, "workers"): + for w in self.workers: + w.p.terminate() + + def _worker_main(self, worker_id, buffer, status): + import jittor_utils + jt.flags.use_cuda_host_allocator = 0 + + jittor_utils.cc.init_subprocess() + jt.jt_init_subprocess() + seed = jt.get_seed() + wseed = (seed ^ (worker_id*1167)) ^ 1234 + jt.set_global_seed(wseed) + # parallel_op_compiler still problematic, + # it is not work on ubuntu 16.04. but worked on ubuntu 20.04 + # it seems like the static value of parallel compiler + # is not correctly init. + jt.flags.use_parallel_op_compiler = 0 + import time + try: + gid_obj = self.gid.get_obj() + gid_lock = self.gid.get_lock() + start = time.time() + while True: + # get id + with gid_lock: + while buffer.is_stop() or self.idqueue.is_stop() or \ + gid_obj.value >= self.batch_len: + self.num_idle.value += 1 + self.num_idle_c.notify() + self.gidc.wait() + self.num_idle.value -= 1 + cid = gid_obj.value + batch_index_list = self.index_list_numpy[ + cid*self.real_batch_size: + min(self.real_len, (cid+1)*self.real_batch_size) + ].copy() + gid_obj.value += 1 + with self.idqueue_lock: + self.idqueue.push(worker_id) + now = time.time() + other_time = now - start + start = now + + # load and transform data + batch = [] + if mp_log_v: + print(f"#{worker_id} {os.getpid()} load batch", cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size)) + for i in batch_index_list: + batch.append(self[i]) + batch = self.collate_batch(batch) + now = time.time() + data_time = now - start + start = now + + # send data to main process + if mp_log_v: + print(f"#{worker_id} {os.getpid()} send", type(batch).__name__, [ type(b).__name__ for b in batch ], buffer) + try: + buffer.send(batch) + except: + if buffer.is_stop(): + continue + raise + now = time.time() + send_time = now - start + start = now + status[0], status[1], status[2], status[3], status[4] = \ + other_time, data_time, send_time, \ + other_time + data_time + send_time, \ + img_open_hook.duration + img_open_hook.duration = 0.0 + except: + import traceback + line = traceback.format_exc() + print(line) + os.kill(os.getppid(), signal.SIGINT) + exit(0) + + def display_worker_status(self): + ''' Display dataset worker status, when dataset.num_workers > 0, it will display infomation blow: + +.. code-block:: console + + progress:479/5005 + batch(s): 0.302 wait(s):0.000 + recv(s): 0.069 to_jittor(s):0.021 + recv_raw_call: 6720.0 + last 10 workers: [6, 7, 3, 0, 2, 4, 7, 5, 6, 1] + ID wait(s) load(s) send(s) total + #0 0.000 1.340 2.026 3.366 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) + #1 0.000 1.451 3.607 5.058 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) + #2 0.000 1.278 1.235 2.513 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) + #3 0.000 1.426 1.927 3.353 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) + #4 0.000 1.452 1.074 2.526 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) + #5 0.000 1.422 3.204 4.625 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) + #6 0.000 1.445 1.953 3.398 Buffer(free=0.000% l=462425368 r=462425368 size=536870912) + #7 0.000 1.582 0.507 2.090 Buffer(free=0.000% l=308283552 r=308283552 size=536870912) + +Meaning of the outputs: + +* progress: dataset loading progress (current/total) +* batch: batch time, exclude data loading time +* wait: time of main proc wait worker proc +* recv: time of recv batch data +* to_jittor: time of batch data to jittor variable +* recv_raw_call: total number of underlying recv_raw called +* last 10 workers: id of last 10 workers which main proc load from. +* table meaning + * ID: worker id + * wait: worker wait time + * open: worker image open time + * load: worker load time + * buffer: ring buffer status, such as how many free space, left index, right index, total size(bytes). + +Example:: + + from jittor.dataset import Dataset + class YourDataset(Dataset): + pass + dataset = YourDataset().set_attrs(num_workers=8) + for x, y in dataset: + dataset.display_worker_status() + ''' + if not hasattr(self, "workers"): + return + msg = [""] + msg.append(f"progress:{self.batch_id}/{self.batch_len}") + msg.append(f"batch(s): {self.batch_time:.3f}\twait(s):{self.wait_time:.3f}") + msg.append(f"recv(s): {self.recv_time:.3f}\tto_jittor(s):{self.to_jittor_time:.3f}") + msg.append(f"last 10 workers: {self.last_ids}") + msg.append(f"ID\twait(s)\topen(s)\tload(s)\tsend(s)\ttotal(s)") + for i in range(self.num_workers): + w = self.workers[i] + s = w.status + msg.append(f"#{i}\t{s[0]:.3f}\t{s[4]:.3f}\t{s[1]:.3f}\t{s[2]:.3f}\t{s[3]:.3f}\t{w.buffer}") + LOG.i('\n'.join(msg)) + + def _stop_all_workers(self): + # stop workers + for w in self.workers: + w.buffer.stop() + self.idqueue.stop() + # wait until all workers idle + if self.num_idle.value < self.num_workers: + with self.gid.get_lock(): + self.gid.get_obj().value = self.batch_len + if mp_log_v: + print("idle num", self.num_idle.value) + while self.num_idle.value < self.num_workers: + self.num_idle_c.wait() + if mp_log_v: + print("idle num", self.num_idle.value) + # clean workers' buffer + for w in self.workers: + w.buffer.clear() + self.idqueue.clear() + self.gid.value = 0 + + def _init_workers(self, index_list): + jt.migrate_all_to_cpu() + jt.clean() + jt.gc() + self.index_list = mp.Array('i', self.real_len, lock=False) + workers = [] + # get worker id + self.idqueue = jt.RingBuffer(2048) + self.idqueue_lock = mp.Lock() + # global token index + self.gid = mp.Value('i', self.batch_len) + self.gid.value = 0 + # global token index condition + self.gidc = mp.Condition(self.gid.get_lock()) + # number of idle workers + self.num_idle = mp.Value('i', 0, lock=False) + # number of idle workers condition + self.num_idle_c = mp.Condition(self.gid.get_lock()) + self.index_list_numpy = np.ndarray(dtype='int32', shape=self.real_len, buffer=self.index_list) + self.index_list_numpy[:] = index_list + for i in range(self.num_workers): + w = Worker(target=self._worker_main, args=(i,), + buffer_size=self.buffer_size, + keep_numpy_array=self.keep_numpy_array) + workers.append(w) + self.workers = workers + + def reset(self): + if not hasattr(self, "workers"): + return + self._stop_all_workers() + self.terminate() + del self.index_list + del self.idqueue + del self.idqueue_lock + del self.gid + del self.gidc + del self.num_idle + del self.num_idle_c + del self.workers + del self.index_list_numpy + + def __del__(self): + if mp_log_v: + print("dataset deleted") + try: + self.terminate() + except: + pass + + def __deepcopy__(self, memo=None, _nil=[]): + from copy import deepcopy + if memo is None: + memo = {} + d = id(self) + y = memo.get(d, _nil) + if y is not _nil: + return y + + obj = self.__class__.__new__(self.__class__) + memo[d] = id(obj) + exclude_key = {"index_list", "idqueue", "idqueue_lock", "gid", "gidc", "num_idle", "num_idle_c", "workers", "index_list_numpy", "dataset", "idqueue", "idqueue_lock"} + for k,v in self.__dict__.items(): + if k in exclude_key: continue + obj.__setattr__(k, deepcopy(v)) + obj.dataset = obj + return obj + + def __real_len__(self): + if self.total_len is None: + self.total_len = len(self) + return self.total_len + + def _get_index_list(self): + if self.total_len is None: + self.total_len = len(self) + # maybe rewrite by sampler + total_len = self.total_len + if self.sampler: + index_list = list(self.sampler.__iter__()) + total_len = len(index_list) + # check is not batch sampler + if len(index_list): + assert not isinstance(index_list[0], (list,tuple)), "Batch sampler not support yet." + elif self.shuffle == False: + index_list = get_order_list(self.total_len) + else: + # using _shuffle_rng to generate multiprocess + # consist shuffle list + # index_list = get_random_list(self.total_len) + index_list = self._shuffle_rng.permutation(range(self.total_len)) + + # scatter index_list for all mpi process + # scatter rule: + # batch 1 batch 2 + # [........] [........] ... + # 00011122 00011122 + # if last batch is smaller than world_size + # pad to world_size + # last batch + # [.] -> [012] + if jt.in_mpi: + world_size = mpi.world_size() + world_rank = mpi.world_rank() + index_list = np.int32(index_list) + # TODO: mpi broadcast in subprocess has bug, fix it + # mpi.broadcast(index_list, 0) + + assert self.batch_size >= world_size, \ + f"Batch size({self.batch_size}) is smaller than MPI world_size({world_size})" + real_batch_size = (self.batch_size-1) // world_size + 1 + if real_batch_size * world_size != self.batch_size: + LOG.w("Batch size is not divisible by MPI world size, " + "The distributed version may be different from " + "the single-process version.") + fix_batch = total_len // self.batch_size + last_batch = total_len - fix_batch * self.batch_size + fix_batch_l = index_list[0:fix_batch*self.batch_size] \ + .reshape(-1,self.batch_size) + fix_batch_l = fix_batch_l[ + :,real_batch_size*world_rank:real_batch_size*(world_rank+1)] + real_batch_size = fix_batch_l.shape[1] + fix_batch_l = fix_batch_l.flatten() + if not self.drop_last and last_batch > 0: + last_batch_l = index_list[-last_batch:] + real_last_batch = (last_batch-1)//world_size+1 + l = real_last_batch * world_rank + r = l + real_last_batch + if r > last_batch: + r = last_batch + l = r-real_last_batch + index_list = np.concatenate([fix_batch_l, last_batch_l[l:r]]) + else: + index_list = fix_batch_l + + self.real_len = len(index_list) + self.real_batch_size = real_batch_size + # assert total_len // self.batch_size == \ + # self.real_len // self.real_batch_size, f"Number of batches({total_len // self.batch_size}!={self.real_len // self.real_batch_size}) not match, total_len: {total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}" + + # print(f"Number of batches({total_len // self.batch_size}!={self.real_len // self.real_batch_size}) not match, total_len: {total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}") + # print("mpi dataset init ") + else: + self.real_len = len(index_list) + self.real_batch_size = self.batch_size + + if self.drop_last: + self.batch_len = self.real_len // self.real_batch_size + else: + self.batch_len = (self.real_len-1) // self.real_batch_size + 1 + + return index_list + + def _epochs(self): + if self.endless: + while True: + yield + self.epoch_id += 1 + else: + yield + + def __iter__(self): + if self._disable_workers: + self.num_workers = 0 + index_list = self._get_index_list() + + if not hasattr(self, "workers") and self.num_workers: + self._init_workers(index_list) + self.last_ids = [-1] * 10 + + if self.num_workers: + start = time.time() + self.batch_time = 0 + gid_obj = self.gid.get_obj() + gid_lock = self.gid.get_lock() + + for _ in self._epochs(): + with gid_lock: + if self.num_idle.value: + self.gidc.notify_all() + + for i in range(self.batch_len): + if self.num_idle.value: + with gid_lock: + if self.num_idle.value and \ + gid_obj.value >= self.batch_len: + index_list = self._get_index_list() + self.index_list_numpy[:] = index_list + gid_obj.value = 0 + self.gidc.notify_all() + + # get which worker has this batch + worker_id = self.idqueue.pop() + + now = time.time() + self.wait_time = now - start + start = now + + self.last_ids[i%10] = worker_id + self.batch_id = i + w = self.workers[worker_id] + if mp_log_v: + print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer) + batch = w.buffer.recv() + + now = time.time() + self.recv_time = now - start + start = now + + if mp_log_v: + print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [ type(b).__name__ for b in batch ]) + batch = self.to_jittor(batch) + + now = time.time() + self.to_jittor_time = now - start + start = now + + yield batch + + now = time.time() + self.batch_time = now - start + start = now + + if CHECK_MEMORY and self.batch_id % CHECK_MEMORY == 0: + jt.display_memory_info() + else: + for _ in self._epochs(): + self.batch_id = 0 + batch_data = [] + for idx in index_list: + batch_data.append(self[int(idx)]) + if len(batch_data) == self.real_batch_size: + batch_data = self.collate_batch(batch_data) + tmp = batch_data + batch_data = self.to_jittor(batch_data) + # breakpoint() + yield batch_data + self.batch_id += 1 + if CHECK_MEMORY and self.batch_id % CHECK_MEMORY == 0: + jt.display_memory_info() + batch_data = [] + + # depend on drop_last + if not self.drop_last and len(batch_data) > 0: + batch_data = self.collate_batch(batch_data) + batch_data = self.to_jittor(batch_data) + self.batch_id += 1 + yield batch_data + +def DataLoader(dataset: Dataset, *args, **kargs): + """ Simple dataloader. + + Example:: + + train_dir = './data/celebA_train' + train_dataset = ImageFolder(train_dir) + dataloader = jt.dataset.DataLoader(train_dataset, batch_size=8) + + """ + return dataset.set_attrs(*args, **kargs) + +class ImageFolder(Dataset): + """ + A image classify dataset, load image and label from directory:: + + * root/label1/img1.png + * root/label1/img2.png + * ... + * root/label2/img1.png + * root/label2/img2.png + * ... + + Args:: + + [in] root(string): Root directory path. + + Attributes:: + + * classes(list): List of the class names. + * class_to_idx(dict): map from class_name to class_index. + * imgs(list): List of (image_path, class_index) tuples + + Example:: + + train_dir = './data/celebA_train' + train_loader = ImageFolder(train_dir).set_attrs(batch_size=batch_size, shuffle=True) + for batch_idx, (x_, target) in enumerate(train_loader): + ... + + """ + def __init__(self, root, transform=None): + super().__init__() + self.root = root + self.transform = transform + self.classes = sorted([d.name for d in os.scandir(root) if d.is_dir()]) + self.class_to_idx = {v:k for k,v in enumerate(self.classes)} + self.imgs = [] + image_exts = set(('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')) + + for i, class_name in enumerate(self.classes): + class_dir = os.path.join(root, class_name) + for dname, _, fnames in sorted(os.walk(class_dir, followlinks=True)): + for fname in sorted(fnames): + if os.path.splitext(fname)[-1].lower() in image_exts: + path = os.path.join(class_dir, fname) + self.imgs.append((path, i)) + LOG.i(f"Found {len(self.classes)} classes and {len(self.imgs)} images.") + self.set_attrs(total_len=len(self.imgs)) + + def __getitem__(self, k): + with open(self.imgs[k][0], 'rb') as f: + img = Image.open(f).convert('RGB') + if self.transform: + img = self.transform(img) + return img, self.imgs[k][1] + +class VarDataset(Dataset): + """ Dataset using Var directly, TensorDataset is alias of VarDataset, Example:: + + import jittor as jt + from jittor.dataset import VarDataset + + x = jt.array([1,2,3]) + y = jt.array([4,5,6]) + z = jt.array([7,8,9]) + dataset = VarDataset(x, y, z) + dataset.set_attrs(batch_size=1) + + for a,b,c in dataset: + print(a,b,c) + # will print + # 1,4,7 + # 2,5,8 + # 3,6,9 + + """ + def __init__(self, *args): + super().__init__() + self.args = args + self._disable_workers = True + assert len(args), "At lease one args" + l = len(args[0]) + for a in args: + assert l == len(a), "Len should be the same" + self.set_attrs(total_len=l) + + def __getitem__(self, idx): + return [ a[idx] for a in self.args ] + + + def collate_batch(self, batch): + b = collate_batch(batch) + for i in range(len(self.args)): + x = b[i] + if jt.is_var(self.args[i]) and self.args[i].ndim == 1: + x.assign(x.squeeze(-1)) + return b + +TensorDataset = VarDataset \ No newline at end of file diff --git a/python/jittor/dataset/mnist.py b/python/jittor/dataset/mnist.py new file mode 100644 index 00000000..f0945f94 --- /dev/null +++ b/python/jittor/dataset/mnist.py @@ -0,0 +1,200 @@ +# *************************************************************** +# Copyright(c) 2019 +# Meng-Hao Guo +# Dun Liang . +# All Rights Reserved. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** + +import os +import string +import numpy as np +import gzip +from PIL import Image +# our lib jittor import +from jittor.dataset.dataset import Dataset, dataset_root +from jittor_utils.misc import ensure_dir, download_url_to_local +import jittor as jt +import jittor.transform as trans + +class MNIST(Dataset): + ''' + Jittor's own class for loading MNIST dataset. + + Args:: + + [in] data_root(str): your data root. + [in] train(bool): choose model train or val. + [in] download(bool): Download data automatically if download is True. + [in] batch_size(int): Data batch size. + [in] shuffle(bool): Shuffle data if true. + [in] transform(jittor.transform): transform data. + + Example:: + + from jittor.dataset.mnist import MNIST + train_loader = MNIST(train=True).set_attrs(batch_size=16, shuffle=True) + for i, (imgs, target) in enumerate(train_loader): + ... + ''' + def __init__(self, data_root=dataset_root+"/mnist_data/", + train=True, + download=True, + batch_size = 16, + shuffle = False, + transform=None): + # if you want to test resnet etc you should set input_channel = 3, because the net set 3 as the input dimensions + super().__init__() + self.data_root = data_root + self.is_train = train + self.transform = transform + self.batch_size = batch_size + self.shuffle = shuffle + if download == True: + self.download_url() + + filesname = [ + "train-images-idx3-ubyte.gz", + "t10k-images-idx3-ubyte.gz", + "train-labels-idx1-ubyte.gz", + "t10k-labels-idx1-ubyte.gz" + ] + self.mnist = {} + if self.is_train: + with gzip.open(data_root + filesname[0], 'rb') as f: + self.mnist["images"] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28, 28) + with gzip.open(data_root + filesname[2], 'rb') as f: + self.mnist["labels"] = np.frombuffer(f.read(), np.uint8, offset=8) + else: + with gzip.open(data_root + filesname[1], 'rb') as f: + self.mnist["images"] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28, 28) + with gzip.open(data_root + filesname[3], 'rb') as f: + self.mnist["labels"] = np.frombuffer(f.read(), np.uint8, offset=8) + assert(self.mnist["images"].shape[0] == self.mnist["labels"].shape[0]) + self.total_len = self.mnist["images"].shape[0] + # this function must be called + self.set_attrs(total_len = self.total_len) + + def __getitem__(self, index): + img = Image.fromarray(self.mnist['images'][index]).convert('RGB') + if self.transform: + img = self.transform(img) + return trans.to_tensor(img), self.mnist['labels'][index] + + def download_url(self): + ''' + Download mnist data set function, this function will be called when download is True. + ''' + resources = [ + ("https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), + ("https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), + ("https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"), + ("https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c") + ] + + for url, md5 in resources: + filename = url.rpartition('/')[2] + download_url_to_local(url, filename, self.data_root, md5) + +class EMNIST(Dataset): + ''' + Jittor's own class for loading EMNIST dataset. + + Args:: + + [in] data_root(str): your data root. + [in] split(str): one of 'byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist'. + [in] train(bool): choose model train or val. + [in] download(bool): Download data automatically if download is True. + [in] batch_size(int): Data batch size. + [in] shuffle(bool): Shuffle data if true. + [in] transform(jittor.transform): transform data. + + Example:: + + from jittor.dataset.mnist import EMNIST + train_loader = EMNIST(train=True).set_attrs(batch_size=16, shuffle=True) + for i, (imgs, target) in enumerate(train_loader): + ... + ''' + + _merged_classes = {'c', 'i', 'j', 'k', 'l', 'm', 'o', 'p', 's', 'u', 'v', 'w', 'x', 'y', 'z'} + _all_classes = set(string.digits + string.ascii_letters) + classes_split_dict = { + 'byclass': sorted(list(_all_classes)), + 'bymerge': sorted(list(_all_classes - _merged_classes)), + 'balanced': sorted(list(_all_classes - _merged_classes)), + 'letters': ['N/A'] + list(string.ascii_lowercase), + 'digits': list(string.digits), + 'mnist': list(string.digits), + } + + def __init__(self, data_root=dataset_root+"/emnist_data/", + split='byclass', + train=True, + download=True, + batch_size = 16, + shuffle = False, + transform=None): + # if you want to test resnet etc you should set input_channel = 3, because the net set 3 as the input dimensions + super().__init__() + self.data_root = data_root + self.is_train = train + self.transform = transform + self.batch_size = batch_size + self.shuffle = shuffle + if download == True: + self.download_url() + data_root = os.path.join(data_root, "gzip") + + filesname = [ + f"emnist-{split}-train-images-idx3-ubyte.gz", + f"emnist-{split}-t10k-images-idx3-ubyte.gz", + f"emnist-{split}-train-labels-idx1-ubyte.gz", + f"emnist-{split}-t10k-labels-idx1-ubyte.gz" + ] + for i in range(4): + filesname[i] = os.path.join(data_root, filesname[i]) + self.mnist = {} + if self.is_train: + with gzip.open(filesname[0], 'rb') as f: + self.mnist["images"] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28, 28).transpose(0,2,1) + with gzip.open(filesname[2], 'rb') as f: + self.mnist["labels"] = np.frombuffer(f.read(), np.uint8, offset=8) + else: + with gzip.open(filesname[1], 'rb') as f: + self.mnist["images"] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28, 28).transpose(0,2,1) + with gzip.open(filesname[3], 'rb') as f: + self.mnist["labels"] = np.frombuffer(f.read(), np.uint8, offset=8) + assert(self.mnist["images"].shape[0] == self.mnist["labels"].shape[0]) + self.total_len = self.mnist["images"].shape[0] + # this function must be called + self.set_attrs(total_len = self.total_len) + + def __getitem__(self, index): + img = Image.fromarray(self.mnist['images'][index]).convert('RGB') + if self.transform: + img = self.transform(img) + return trans.to_tensor(img), self.mnist['labels'][index] + + def download_url(self): + ''' + Download mnist data set function, this function will be called when download is True. + ''' + resources = [ + ("https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip", "58c8d27c78d21e728a6bc7b3cc06412e"), + ] + + for url, md5 in resources: + filename = "emnist.zip" + download_url_to_local(url, filename, self.data_root, md5) + import zipfile + zf = zipfile.ZipFile(os.path.join(self.data_root, filename)) + try: + zf.extractall(path=self.data_root) + except RuntimeError as e: + print(e) + raise + zf.close() + diff --git a/python/jittor/dataset/sampler.py b/python/jittor/dataset/sampler.py new file mode 100644 index 00000000..a6cf1269 --- /dev/null +++ b/python/jittor/dataset/sampler.py @@ -0,0 +1,126 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Hao-Yang Peng +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +from .dataset import Dataset +import numpy as np +from PIL import Image + + +class Sampler(): + def __init__(self, dataset): + self.dataset = dataset + # MUST set sampler here + dataset.sampler = self + + def __iter__(self): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + +class SequentialSampler(Sampler): + def __init__(self, dataset): + # MUST set sampler here + dataset.sampler = self + self.dataset = dataset + + def __iter__(self): + return iter(range(self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__())) + + def __len__(self): + return self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__() + + +class RandomSampler(Sampler): + def __init__(self, dataset, replacement=False, num_samples=None): + # MUST set sampler here + dataset.sampler = self + self.dataset = dataset + self.rep = replacement + self._num_samples = num_samples + self._shuffle_rng = np.random.default_rng(1) + + @property + def num_samples(self): + if self._num_samples is None: + return self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__() + return self._num_samples + + def __len__(self): + return self.num_samples + + def __iter__(self): + n = self.dataset.__real_len__() if hasattr(self.dataset,"__real_len__") else self.dataset.__len__() + if self.rep: + return iter(self._shuffle_rng.integers(low=0, high=n, size=(self.num_samples,), dtype=np.int64).tolist()) + return iter(self._shuffle_rng.permutation(n).tolist()) + + +class SkipFirstBatchesSampler(Sampler): + def __init__(self, sampler, num_skip_batches): + # MUST set sampler here + sampler.dataset.sampler = self + self.sampler = sampler + self.num_skip_batches = num_skip_batches + + def __len__(self): + return len(self.sampler) - self.num_skip_batches + + def __iter__(self): + return iter(list(iter(self.sampler))[self.num_skip_batches:]) + + +class SubsetRandomSampler(Sampler): + def __init__(self, dataset, indice): + ''' + testdataset = TestSamplerDataset() + subsetsampler = SubsetRandomSampler(testdataset, (20, 30)) + + for i, data in enumerate(testdataset): + # data between 20 ~ 29 + ...... + + ''' + # MUST set sampler here + dataset.sampler = self + self.dataset = dataset + self.indices = indice + dlen = dataset.__real_len__() if hasattr(dataset,"__real_len__") else dataset.__len__() + assert indice[0] >= 0 and indice[1] < dlen and indice[0] < indice[1] + + def __iter__(self): + return (int(i) + self.indices[0] for i in np.random.permutation(self.indices[1] - self.indices[0])) + + def __len__(self): + return self.indices[1] - self.indices[0] + + +class BatchSampler(Sampler): + def __init__(self, sampler, batch_size, drop_last): + self.sampler = sampler + self.batch_size = batch_size + self.drop_last = drop_last + + def __iter__(self): + batch = [] + for idx in self.sampler: + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + if len(batch) > 0 and not self.drop_last: + yield batch + + def __len__(self): + if self.drop_last: + return len(self.sampler) // self.batch_size + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size diff --git a/python/jittor/dataset/utils.py b/python/jittor/dataset/utils.py new file mode 100644 index 00000000..730bc277 --- /dev/null +++ b/python/jittor/dataset/utils.py @@ -0,0 +1,68 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Meng-Hao Guo +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** + +import jittor as jt +import numpy as np +from collections.abc import Sequence, Mapping +from PIL import Image +import time + +def get_random_list(n): + return list(np.random.permutation(range(n))) + +def get_order_list(n): + return [i for i in range(n)] + + +def collate_batch(batch): + r"""Puts each data field into a tensor with outer dimension batch size""" + real_size = len(batch) + elem = batch[0] + elem_type = type(elem) + if isinstance(elem, jt.Var): + temp_data = jt.stack([data for data in batch], 0) + return temp_data + if elem_type is np.ndarray: + temp_data = np.stack([data for data in batch], 0) + return temp_data + elif np.issubdtype(elem_type, np.integer): + return np.int32(batch) + elif isinstance(elem, int): + return np.int32(batch) + elif isinstance(elem, float): + return np.float32(batch) + elif isinstance(elem, str): + return batch + elif isinstance(elem, Mapping): + return {key: collate_batch([d[key] for d in batch]) for key in elem} + elif isinstance(elem, tuple): + transposed = zip(*batch) + return tuple(collate_batch(samples) for samples in transposed) + elif isinstance(elem, Sequence): + transposed = zip(*batch) + return [collate_batch(samples) for samples in transposed] + elif isinstance(elem, Image.Image): + temp_data = np.stack([np.array(data) for data in batch], 0) + return temp_data + else: + raise TypeError(f"Not support type <{elem_type.__name__}>") + +class HookTimer: + def __init__(self, obj, attr): + self.origin = getattr(obj, attr) + self.duration = 0.0 + setattr(obj, attr, self) + + def __call__(self, *args, **kw): + start = time.time() + rt = self.origin(*args, **kw) + self.duration += time.time() - start + return rt + diff --git a/python/jittor/dataset/voc.py b/python/jittor/dataset/voc.py new file mode 100644 index 00000000..5d05bb14 --- /dev/null +++ b/python/jittor/dataset/voc.py @@ -0,0 +1,70 @@ +# *************************************************************** +# Copyright(c) 2019 +# Meng-Hao Guo +# Dun Liang . +# All Rights Reserved. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** + +import numpy as np +import os +from PIL import Image +from .dataset import Dataset, dataset_root + +class VOC(Dataset): + ''' + Jittor's own class for loading VOC dataset. + + Args:: + + [in] data_root(str): your data root. + [in] split(str): which split you want to use, train or val. + + Attribute:: + + NUM_CLASSES: Number of total categories, default is 21. + + Example:: + + from jittor.dataset.voc import VOC + train_loader = VOC(data_root='...').set_attrs(batch_size=16, shuffle=True) + for i, (imgs, target) in enumerate(train_loader): + ... + ''' + NUM_CLASSES = 21 + def __init__(self, data_root=dataset_root+'/voc/', split='train'): + super().__init__() + ''' total_len , batch_size, shuffle must be set ''' + self.data_root = data_root + self.split = split + + self.image_root = os.path.join(data_root, 'JPEGImages') + self.label_root = os.path.join(data_root, 'SegmentationClass') + + self.data_list_path = os.path.join(self.data_root, 'ImageSets', 'Segmentation', self.split + '.txt') + self.image_path = [] + self.label_path = [] + + with open(self.data_list_path, "r") as f: + lines = f.read().splitlines() + + for idx, line in enumerate(lines): + _img_path = os.path.join(self.image_root, line + '.jpg') + _label_path = os.path.join(self.label_root, line + '.png') + assert os.path.isfile(_img_path) + assert os.path.isfile(_label_path) + self.image_path.append(_img_path) + self.label_path.append(_label_path) + self.set_attrs(total_len = len(self.image_path)) + + def __getitem__(self, index): + _img = Image.open(self.image_path[index]) + _label = Image.open(self.label_path[index]) + _img = _img.resize((513, 513)) + _label = _label.resize((513, 513)) + _img = np.array(_img) + _label = np.array(_label) + _img = _img.transpose(2,0,1) + return _img, _label + diff --git a/python/jittor/demo/simple_cgan.py b/python/jittor/demo/simple_cgan.py new file mode 100644 index 00000000..971e2978 --- /dev/null +++ b/python/jittor/demo/simple_cgan.py @@ -0,0 +1,107 @@ +import jittor as jt +from jittor import nn +import numpy as np +# import pylab as pl + +# 隐空间向量长度 +latent_dim = 100 +# 类别数量 +n_classes = 10 +# 图片大小 +img_size = 32 +# 图片通道数量 +channels = 1 +# 图片张量的形状 +img_shape = (channels, img_size, img_size) + +class Generator(nn.Module): + def __init__(self): + super(Generator, self).__init__() + self.label_emb = nn.Embedding(n_classes, n_classes) + + def block(in_feat, out_feat, normalize=True): + layers = [nn.Linear(in_feat, out_feat)] + if normalize: + layers.append(nn.BatchNorm1d(out_feat, 0.8)) + layers.append(nn.LeakyReLU(0.2)) + return layers + self.model = nn.Sequential( + *block((latent_dim + n_classes), 128, normalize=False), + *block(128, 256), + *block(256, 512), + *block(512, 1024), + nn.Linear(1024, int(np.prod(img_shape))), + nn.Tanh()) + + def execute(self, noise, labels): + gen_input = jt.concat((self.label_emb(labels), noise), dim=1) + img = self.model(gen_input) + img = img.view((img.shape[0], *img_shape)) + return img + +class Discriminator(nn.Module): + def __init__(self): + super(Discriminator, self).__init__() + self.label_embedding = nn.Embedding(n_classes, n_classes) + self.model = nn.Sequential( + nn.Linear((n_classes + int(np.prod(img_shape))), 512), + nn.LeakyReLU(0.2), + nn.Linear(512, 512), + nn.Dropout(0.4), + nn.LeakyReLU(0.2), + nn.Linear(512, 512), + nn.Dropout(0.4), + nn.LeakyReLU(0.2), + nn.Linear(512, 1)) + + def execute(self, img, labels): + d_in = jt.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1) + validity = self.model(d_in) + return validity + + +# 定义模型 +generator = Generator() +discriminator = Discriminator() +generator.eval() +discriminator.eval() + +# 加载参数 +generator.load('https://cg.cs.tsinghua.edu.cn/jittor/assets/build/generator_last.pkl') +discriminator.load('https://cg.cs.tsinghua.edu.cn/jittor/assets/build/discriminator_last.pkl') + + + +def gen_img(number): + print(number, type(number)) + n_row = len(number) + z = jt.array(np.random.normal(0, 1, (n_row, latent_dim))).float32().stop_grad() + labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad() + gen_imgs = generator(z,labels) + gen_imgs = gen_imgs.transpose((1,2,0,3)).reshape(gen_imgs.shape[2], -1) + gen_imgs = gen_imgs[:,:,None].broadcast(gen_imgs.shape+(3,)) # .uint8() + gen_imgs = (gen_imgs - gen_imgs.min()) / (gen_imgs.max() - gen_imgs.min()) * 255 + gen_imgs = gen_imgs.uint8() + # print(gen_imgs.shape, gen_imgs.max(), gen_imgs.min()) + return gen_imgs.numpy() + # gen_imgs = gen_imgs.data.transpose((1,2,0,3))[0].reshape((gen_imgs.shape[2], -1)) + # print(gen_imgs.shape) + return gen_imgs[:,:,None] + +from PIL import Image +import pywebio as pw +# 定义一串数字 +number = "201962517" +# gen_img(number) +Image.fromarray(gen_img(number)) +# pl.imshow() +# pl.show() +# print("done") + + +def web_server(): + pw.pin.put_input("number", label="输入用于生成的数字(由计图框架支持):") + pw.output.put_buttons(['Gen image'], + lambda _: pw.output.put_image(Image.fromarray(gen_img(pw.pin.pin.number)))) + +pw.start_server(web_server, port=8123) \ No newline at end of file diff --git a/python/jittor/depthwise_conv.py b/python/jittor/depthwise_conv.py new file mode 100644 index 00000000..ed9ded14 --- /dev/null +++ b/python/jittor/depthwise_conv.py @@ -0,0 +1,325 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +from jittor import init +from jittor import nn +from jittor import Function + +class DepthwiseConv(Function): + def __init__(self, stride=1, padding=0, dilation=1): + self.stride = stride if isinstance(stride, tuple) else (stride, stride) + self.padding = padding if isinstance(padding, tuple) else (padding, padding) + self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation) + + def execute(self, x, weight): + if not jt.flags.use_cuda or not jt.compiler.is_cuda: + return nn.conv2d(x, weight, None, self.stride, self.padding, self.dilation, x.shape[1]) + self.save_vars = x, weight + N,C,H,W = x.shape + o,i,Kh,Kw = weight.shape + assert(o == C) + oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1 + ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1 + filter_height, filter_width = Kh, Kw + self.Khw = Kh, Kw + assert oh>0 and ow>0 + output = jt.code( + [N, C, oh, ow], + x.dtype, + [x, weight], + cuda_header = """ + template + __global__ void KernelDepthwiseConv( + const T *const input_data, const T *const filter_data, const int batch_size, + const int output_channels, const int output_height, + const int output_width, const int input_channels, + const int input_height, const int input_width, + const int padding_height, const int padding_width, + const int dilate_height, const int dilate_width, T *const output_data) { + const int kWeghtSize = filter_height * filter_width; + T r_weight[kWeghtSize]; + const int batch = blockIdx.y; + const int c_out = blockIdx.x; + const T* weight = filter_data + c_out * filter_height * filter_width; + for (int i = 0; i < filter_height * filter_width; i++) r_weight[i] = weight[i]; + + for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) { + for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) { + const int batch = blockIdx.y; + const int c_out = blockIdx.x; + + const int c_in = c_out; + T value = 0; + const int h_in_start = -padding_height + h_out * stride_height; + const int w_in_start = -padding_width + w_out * stride_width; + const int h_in_end = h_in_start + filter_height * dilate_height; + const int w_in_end = w_in_start + filter_width * dilate_width; + + const int in_offset = + ((batch * input_channels + c_in) * input_height) * input_width; + + const int h_end = h_in_end < input_height ? h_in_end : input_height; + const int w_end = w_in_end < input_width ? w_in_end : input_width; + const int h_start = h_in_start > 0 ? h_in_start : 0; + const int w_start = w_in_start > 0 ? w_in_start : 0; + + for (int h_in = h_in_start, h_f = 0; h_f < filter_height; + h_in += dilate_height, h_f++) { + for (int w_in = w_in_start, w_f = 0; w_f < filter_width; + w_in += dilate_width, w_f++) { + if (h_in >= 0 && h_in < input_height && w_in >= 0 && + w_in < input_width) { + const int offset = in_offset + h_in * input_width + w_in; + value += r_weight[h_f * filter_width + w_f] * input_data[offset]; + } + } + } + int index = + ((batch * gridDim.x + c_out) * output_height + h_out) * output_width + + w_out; + output_data[index] = value; + } + } + } + """, + cuda_src=f""" + @alias(input, in0) + @alias(filter, in1) + @alias(output, out) + + const int batch_size = input_shape0; + const int input_channels = input_shape1; + const int input_height = input_shape2; + const int input_width = input_shape3; + const int output_channels = output_shape1; + const int output_height = output_shape2; + const int output_width = output_shape3; + const int ksize_height = {Kh}; + const int ksize_width = {Kw}; + const int stride_height = {self.stride[0]}; + const int stride_width = {self.stride[1]}; + const int padding_height = {self.padding[0]}; + const int padding_width = {self.padding[1]}; + const int dilate_height = {self.dilation[0]}; + const int dilate_width = {self.dilation[1]}; + + int thread = 512; + if (output_width > 1024 && output_width <= 2048) + thread = (output_width - 1) / 2 + 1; + else if (output_width > 512 && output_width <= 1024) + thread = output_width; + int blocks = std::min(std::max(thread / output_width, 1), output_height); + dim3 threads(std::min(output_width, thread), blocks, 1); + dim3 grid(output_channels, batch_size, 1); + KernelDepthwiseConv< + input_type, ksize_height, ksize_width, + stride_height, stride_width> + <<>>( + input_p, filter_p, batch_size, output_channels, output_height, + output_width, input_channels, input_height, input_width, + padding_height, padding_width, dilate_height, + dilate_width, output_p); + """ + ) + return output + + def grad(self, grad): + x, weight = self.save_vars + Kh, Kw = self.Khw + return jt.code([x.shape, weight.shape], [x.dtype, weight.dtype], [x, weight, grad], + cuda_header = f"#include <{jt.compile_extern.cub_home}cub/cub.cuh>"+""" + template + __device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) { + typedef cub::WarpReduce WarpReduce; + typename WarpReduce::TempStorage temp_storage; + value = WarpReduce(temp_storage).Sum(value); + if (cub::LaneId() == 0) + atomicAdd(sum, value); + } + + // CUDA kernel to compute the depthwise convolution backprop w.r.t input. + template + __global__ void KernelDepthwiseConvInputGradCFilter( + const T *const input_data, const T *const output_grad_data, + const T *const filter_data, const int batch_size, + const int output_channels, const int output_height, + const int output_width, const int input_channels, + const int input_height, const int input_width, + const int padding_height, const int padding_width, + const int dilate_height, const int dilate_width, + T *const input_grad_data) { + const int kWeghtSize = filter_height * filter_width + 1; + T r_weight[kWeghtSize]; + const int batch = blockIdx.y; + const int c_in = blockIdx.x; + + const T* weight = filter_data + c_in * filter_height * filter_width; + for (int i = 0; i < filter_height * filter_width; i++) + r_weight[i] = + weight[filter_height * filter_width - i - 1]; + + for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) { + for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) { + const int batch = blockIdx.y; + const int c_in = blockIdx.x; + + int h_out_start = h_in - (filter_height - 1) * dilate_height + padding_height; + + int w_out_start = w_in - (filter_width - 1) * dilate_width + padding_width; + + T value = 0; + int index = + ((batch * gridDim.x + c_in) * input_height + h_in) * input_width + + w_in; + + for (int h_out = h_out_start, h_f = 0; h_f < filter_height; + h_out += dilate_height, h_f++) { + for (int w_out = w_out_start, w_f = 0; w_f < filter_width; + w_out += dilate_width, w_f++) { + int s_h_out = h_out / stride_height; + int s_w_out = w_out / stride_width; + if (h_out % stride_height == 0 && w_out % stride_width == 0 && + s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 && + s_w_out < output_width) { + const int output_grad_offset = + ((batch * output_channels + c_in) * output_height + + s_h_out) * + output_width + + s_w_out; + value += + output_grad_data[output_grad_offset] * + r_weight[h_f * filter_width + w_f]; + } + } + } + input_grad_data[index] = value; + } + } + } + + // Cuda kernel to compute the depthwise convolution backprop w.r.t. filter. + template + __global__ void KernelDepthwiseConvFilterGrad( + const T* output_grad_data, const T* input_data, const int num, + const int output_channels, const int output_height, const int output_width, + const int input_channels, const int input_height, const int input_width, + const int filter_height, + const int filter_width, const int stride_height, const int stride_width, + const int padding_height, const int padding_width, const int dilate_height, + const int dilate_width, T* filter_grad_data) { + T s = 0; + + int gbid = (((blockIdx.z * blockDim.z + threadIdx.z) * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x; + + for (int image_w = threadIdx.x; image_w < output_width; + image_w += blockDim.x) { + for (int bid = 0; bid < num; bid++) { + //for (int bid = threadIdx.z; bid < num; bid+=blockDim.z) { + for (int image_h = threadIdx.y; image_h < output_height; + image_h += blockDim.y) { + int kernel_id = blockIdx.z; + int kernel_h = blockIdx.y * dilate_height - padding_height; + int kernel_w = blockIdx.x * dilate_width - padding_width; + + int image_hk = image_h * stride_height + kernel_h; + int image_wk = image_w * stride_width + kernel_w; + if (image_hk < 0 || image_hk >= input_height) continue; + if (image_wk < 0 || image_wk >= input_width) continue; + #define gaid(N, C, H, W) \ + ((((N)*gridDim.z + (C)) * output_height + (H)) * output_width + (W)) + int input_id = ((bid * gridDim.z + + kernel_id) * + input_height + + image_hk) * + input_width + + image_wk; + s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] * + input_data[input_id]; + + #undef gaid + } + } + } + CudaAtomicAddWithWarp(&filter_grad_data[gbid], s); + } + """, + cuda_src=f""" + // source for backward to data + @alias(input, in0) + @alias(filter, in1) + @alias(output_grad, in2) + @alias(input_grad, out0) + @alias(filter_grad, out1) + + const int batch_size = input_shape0; + const int input_channels = input_shape1; + const int input_height = input_shape2; + const int input_width = input_shape3; + const int output_channels = output_grad_shape1; + const int output_height = output_grad_shape2; + const int output_width = output_grad_shape3; + const int ksize_height = {Kh}; + const int ksize_width = {Kw}; + const int stride_height = {self.stride[0]}; + const int stride_width = {self.stride[1]}; + const int padding_height = {self.padding[0]}; + const int padding_width = {self.padding[1]}; + const int dilate_height = {self.dilation[0]}; + const int dilate_width = {self.dilation[1]}; + + int thread = 512; + if (input_width > 1024 && input_width <= 2048) + thread = (input_width - 1) / 2 + 1; + else if (input_width > 512 && input_width <= 1024) + thread = input_width; + int blocks = std::min(std::max(thread / input_width, 1), input_height); + dim3 threads(std::min(input_width, thread), blocks, 1); + dim3 grid(input_channels, batch_size, 1); + KernelDepthwiseConvInputGradCFilter< + input_type, ksize_height, ksize_width + , stride_height, stride_width> + <<>>( + input_p, output_grad_p, filter_p, batch_size, + output_channels, output_height, output_width, input_channels, + input_height, input_width, padding_height, + padding_width, dilate_height, dilate_width, input_grad_p); + + // source for backward to filter + + int block_size = 512; + if (output_width > 1024 && output_width <= 2048) + block_size = (output_width - 1) / 2 + 1; + else if (output_width > 512 && output_width <= 1024) + block_size = output_width; + int crop_output_height = + std::min(std::max(block_size / output_width, 1), output_height); + + grid = dim3(ksize_width, ksize_height, output_channels); + threads = dim3(std::min(output_width, block_size), crop_output_height, 1); + cudaMemsetAsync(filter_grad_p, 0, filter_grad->size); + + KernelDepthwiseConvFilterGrad< + input_type><<>>( + output_grad_p, input_p, batch_size, output_channels, + output_height, output_width, input_channels, input_height, + input_width, ksize_height, ksize_width, + stride_height, stride_width, padding_height, padding_width, + dilate_height, dilate_width, filter_grad_p); + """ + ) \ No newline at end of file diff --git a/python/jittor/distributions.py b/python/jittor/distributions.py new file mode 100644 index 00000000..2a5a878e --- /dev/null +++ b/python/jittor/distributions.py @@ -0,0 +1,190 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Haoyang Peng <2247838039@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import math +import os +import numpy as np +import jittor as jt +from jittor import nn +from jittor.nn import binary_cross_entropy_with_logits +from jittor import lgamma, igamma +from jittor.math_util.gamma import gamma_grad, sample_gamma + +def simple_presum(x): + src = ''' +__inline_static__ +@python.jittor.auto_parallel(1) +void kernel(int n0, int i0, in0_type* x, in0_type* out, int nl) { + out[i0*(nl+1)] = 0; + for (int i=0; inum/in0->shape[in0->shape.size()-1], 0, in0_p, out0_p, in0->shape[in0->shape.size()-1]); + ''' + return jt.code(x.shape[:-1]+(x.shape[-1]+1,), x.dtype, [x], + cpu_src=src, cuda_src=src) + + +class OneHotCategorical: + def __init__(self, probs=None, logits=None): + Categorical.__init__(self, probs, logits) + + def sample(self, sample_shape=[]): + shape = sample_shape + self.probs.shape[:-1] + (1,) + rand = jt.rand(shape) + one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r).float() + return one_hot + + def log_prob(self, x): + x = jt.argmax(x, dim=-1)[0] + return Categorical.log_prob(self, x) + + def entropy(self): + p_log_p = self.logits * self.probs + return -p_log_p.sum(-1) + + +class Categorical: + def __init__(self, probs=None, logits=None): + assert not (probs is None and logits is None) + if probs is None: + # cannot align to pytorch + probs = jt.sigmoid(logits) + probs = probs / probs.sum(-1, True) + if logits is None: + logits = jt.safe_log(probs) + with jt.no_grad(): + self.probs = probs + self.logits = logits + self.cum_probs = simple_presum(self.probs) + self.cum_probs_l = self.cum_probs[..., :-1] + self.cum_probs_r = self.cum_probs[..., 1:] + + def sample(self, sample_shape=()): + shape = sample_shape + self.probs.shape[:-1] + (1,) + rand = jt.rand(shape) + one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r) + index = one_hot.index(one_hot.ndim - 1) + return (one_hot * index).sum(-1) + + def log_prob(self, x): + a = self.probs.ndim + b = x.ndim + indexes = tuple( f'i{i}' for i in range(b-a+1, b) ) + indexes = indexes + (x,) + return jt.safe_log(self.probs).getitem(indexes) + + def entropy(self): + p_log_p = self.logits * self.probs + return -p_log_p.sum(-1) + + +class Normal: + def __init__(self, mu, sigma): + self.mu = mu + self.sigma = sigma + + def sample(self, sample_shape=None): + return jt.normal(jt.array(self.mu), jt.array(self.sigma),size=sample_shape) + + def log_prob(self, x): + var = self.sigma**2 + log_scale = jt.safe_log(self.sigma) + return -((x-self.mu)**2) / (2*var) - log_scale-np.log(np.sqrt(2*np.pi)) + + def entropy(self): + return 0.5+0.5*np.log(2*np.pi)+jt.safe_log(self.sigma) + + +class Uniform: + def __init__(self,low,high): + self.low = low + self.high = high + assert high > low + + def sample(self,sample_shape): + return jt.uniform(self.low,self.high,sample_shape) + + def log_prob(self,x): + if x < self.low or x >= self.high: + return math.inf + return -jt.safe_log(self.high - self.low) + + def entropy(self): + return jt.safe_log(self.high - self.low) + + +class Geometric: + def __init__(self,p=None,logits=None): + assert (p is not None) or (logits is not None) + assert 0 < p and p < 1 + if p is None: + self.prob = jt.sigmoid(logits) + self.logits = logits + elif logits is None: + self.prob = p + self.logits = -jt.safe_log(1. / p - 1) + + def sample(self, sample_shape): + u = jt.rand(sample_shape) + return (jt.safe_log(u) / (jt.safe_log(-self.probs+1))).floor_int() + + def log_prob(self, x): + return x*jt.safe_log(-self.prob+1)+jt.safe_log(self.prob) + + def entropy(self): + return binary_cross_entropy_with_logits(jt.array(self.logits),jt.array(self.prob)) / self.prob + + +class GammaDistribution: + ''' + For now only support gamma distribution. + ''' + def __init__(self, concentration, rate): + self.concentration = concentration + self.rate = rate + self.lgamma_alpha = lgamma.apply(jt.array([concentration,])) + + def sample(self, shape): + return sample_gamma(self.concentration, shape) + + def cdf(self, value): + return igamma(self.concentration, value) + + def log_prob(self, value): + return (self.concentration * jt.log(self.rate) + + (self.concentration - 1) * jt.log(value) - + self.rate * value - self.lgamma_alpha) + + def mean(self): + return self.concentration / self.rate + + def mode(self): + return np.minimum((self.concentration - 1) / self.rate, 1) + + def variance(self): + return self.concentration / (self.rate * self.rate) + + +def kl_divergence(cur_dist, old_dist): + assert isinstance(cur_dist, type(old_dist)) + if isinstance(cur_dist, Normal): + vr = (cur_dist.sigma / old_dist.sigma)**2 + t1 = ((cur_dist.mu - old_dist.mu) / old_dist.sigma)**2 + return 0.5*(vr+t1-1-jt.safe_log(vr)) + if isinstance(cur_dist, Categorical) or isinstance(cur_dist,OneHotCategorical): + t = cur_dist.probs * (cur_dist.logits-old_dist.logits) + return t.sum(-1) + if isinstance(cur_dist, Uniform): + res = jt.safe_log((old_dist.high - old_dist.low) / (cur_dist.high - cur_dist.low)) + if old_dist.low > cur_dist.low or old_dist.high < cur_dist.high: + res = math.inf + return res + if isinstance(cur_dist, Geometric): + return -cur_dist.entropy() - jt.safe_log(-old_dist.prob+1) / cur_dist.prob - old_dist.logits diff --git a/python/jittor/einops/__init__.py b/python/jittor/einops/__init__.py new file mode 100644 index 00000000..503dff42 --- /dev/null +++ b/python/jittor/einops/__init__.py @@ -0,0 +1,8 @@ +class EinopsError(RuntimeError): + """ Runtime error thrown by einops """ + pass + + +__all__ = ['rearrange', 'reduce', 'repeat', 'parse_shape', 'asnumpy', 'EinopsError'] + +from jittor.einops.einops import rearrange, reduce, repeat, parse_shape, asnumpy diff --git a/python/jittor/einops/_backends.py b/python/jittor/einops/_backends.py new file mode 100644 index 00000000..2ef16a16 --- /dev/null +++ b/python/jittor/einops/_backends.py @@ -0,0 +1,264 @@ +""" +Backends in `einops` are organized to meet the following requirements +- backends are not imported unless those are actually needed, because + - backends may not be installed + - importing all available backends will drive to significant memory footprint + - backends may by present but installed with errors (but never used), + importing may drive to crashes +- backend should be either symbolic or imperative (tensorflow is for both, but that causes problems) + - this determines which methods (from_numpy/to_numpy or create_symbol/eval_symbol) should be defined +- if backend can't (temporarily) provide symbols for shape dimensions, UnknownSize objects are used +""" + +import sys +import warnings + +__author__ = 'Alex Rogozhnikov, RuiYang Liu' + +_backends = {} +_debug_importing = False + + +def get_backend(tensor) -> 'AbstractBackend': + """ + Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor. + If needed, imports package and creates backend + """ + for framework_name, backend in _backends.items(): + if backend.is_appropriate_type(tensor): + return backend + + # Find backend subclasses recursively + backend_subclasses = [] + backends = AbstractBackend.__subclasses__() + while backends: + backend = backends.pop() + backends += backend.__subclasses__() + backend_subclasses.append(backend) + + for BackendSubclass in backend_subclasses: + if _debug_importing: + print('Testing for subclass of ', BackendSubclass) + if BackendSubclass.framework_name not in _backends: + # check that module was already imported. Otherwise it can't be imported + if BackendSubclass.framework_name in sys.modules: + if _debug_importing: + print('Imported backend for ', BackendSubclass.framework_name) + backend = BackendSubclass() + _backends[backend.framework_name] = backend + if backend.is_appropriate_type(tensor): + return backend + + raise RuntimeError('Tensor type unknown to einops {}'.format(type(tensor))) + + +class AbstractBackend: + """ Base backend class, major part of methods are only for debugging purposes. """ + framework_name = None + + def is_appropriate_type(self, tensor): + """ helper method should recognize tensors it can handle """ + raise NotImplementedError() + + def from_numpy(self, x): + raise NotImplementedError("framework doesn't support imperative execution") + + def to_numpy(self, x): + raise NotImplementedError("framework doesn't support imperative execution") + + def create_symbol(self, shape): + raise NotImplementedError("framework doesn't support symbolic computations") + + def eval_symbol(self, symbol, input_dict): + raise NotImplementedError("framework doesn't support symbolic computations") + + def arange(self, start, stop): + # supplementary method used only in testing, so should implement CPU version + raise NotImplementedError("framework doesn't implement arange") + + def shape(self, x): + """shape should return a tuple with integers or "shape symbols" (which will evaluate to actual size)""" + return x.shape + + def reshape(self, x, shape): + return x.reshape(shape) + + def transpose(self, x, axes): + return x.transpose(axes) + + def reduce(self, x, operation, axes): + return getattr(x, operation)(axis=axes) + + def stack_on_zeroth_dimension(self, tensors: list): + raise NotImplementedError() + + def add_axis(self, x, new_position): + raise NotImplementedError() + + def add_axes(self, x, n_axes, pos2len): + repeats = [1] * n_axes + for axis_position, axis_length in pos2len.items(): + x = self.add_axis(x, axis_position) + repeats[axis_position] = axis_length + return self.tile(x, tuple(repeats)) + + def tile(self, x, repeats): + """repeats is a number of """ + raise NotImplementedError() + + def is_float_type(self, x): + # Decided to drop average for all backends if type is not floating + raise NotImplementedError() + + def layers(self): + raise NotImplementedError("backend does not provide layers") + + def __repr__(self): + return "".format(self.framework_name) + + def einsum(self, pattern, *x): + raise NotImplementedError("backend does not support einsum") + + +class UnknownSize: + """ pseudo-symbol for symbolic frameworks which do not provide symbols for shape elements """ + + def __floordiv__(self, other): + return self + + def __eq__(self, other): + return True # we don't know actual size + + def __mul__(self, other): + return self + + def __rmul__(self, other): + return self + + def __hash__(self): + return None.__hash__() + + +class NumpyBackend(AbstractBackend): + framework_name = 'numpy' + + def __init__(self): + import numpy + self.np = numpy + + def is_appropriate_type(self, tensor): + return isinstance(tensor, self.np.ndarray) + + def from_numpy(self, x): + return x + + def to_numpy(self, x): + return x + + def arange(self, start, stop): + return self.np.arange(start, stop) + + def stack_on_zeroth_dimension(self, tensors: list): + return self.np.stack(tensors) + + def tile(self, x, repeats): + return self.np.tile(x, repeats) + + def is_float_type(self, x): + return x.dtype in ('float16', 'float32', 'float64', 'float128', 'bfloat16') + + def add_axis(self, x, new_position): + return self.np.expand_dims(x, new_position) + + def einsum(self, pattern, *x): + return self.np.einsum(pattern, *x) + + +class HashableTuple: + """Overcomes non-hashability of symbolic elements""" + + def __init__(self, elements: tuple): + self.elements = elements + + def __iter__(self): + for x in self.elements: + yield x + + def __len__(self): + return len(self.elements) + + def __getitem__(self, item): + return self.elements[item] + +class JittorBackend(AbstractBackend): + framework_name = 'jittor' + + def __init__(self): + import jittor + self.jittor = jittor + + def is_appropriate_type(self, tensor): + return isinstance(tensor, self.jittor.Var) + + def from_numpy(self, x): + variable = self.jittor.array(x) + return variable + + def to_numpy(self, x): + return x.detach().numpy() + + def arange(self, start, stop): + return self.jittor.arange(start, stop, dtype='int64') + + def shape(self, x): + return tuple(x.shape) + + def reshape(self, x, shape): + if len(shape) == 0: + return x + return self.jittor.reshape(x, shape) + + def reduce(self, x, operation, reduced_axes): + + if operation == 'prod': + #avoid overflow + return x.prod(reduced_axes) + for axis in sorted(reduced_axes, reverse=True): + if operation == 'min': + x = x.min(dim=axis) + elif operation == 'max': + x = x.max(dim=axis) + elif operation in ['sum', 'mean']: + x = getattr(x, operation)(dim=axis) + else: + raise NotImplementedError('Unknown reduction ', operation) + return x + + def transpose(self, x, axes): + return x.permute(axes) + + def stack_on_zeroth_dimension(self, tensors: list): + return self.jittor.stack(tensors) + + def add_axes(self, x, n_axes, pos2len): + repeats = [-1] * n_axes + for axis_position, axis_length in pos2len.items(): + x = self.add_axis(x, axis_position) + repeats[axis_position] = axis_length + return x.expand(repeats) + + def tile(self, x, repeats): + return x.repeat(repeats) + + def add_axis(self, x, new_position): + return self.jittor.unsqueeze(x, new_position) + + def is_float_type(self, x): + return x.dtype in ["float16", "bfloat16", "float32", "float64"] + + def layers(self): + from jittor.einops.layers import jittor + return jittor + + def einsum(self, pattern, *x): + return self.jittor.linalg.einsum(pattern, *x) \ No newline at end of file diff --git a/python/jittor/einops/einops.py b/python/jittor/einops/einops.py new file mode 100644 index 00000000..da931fd3 --- /dev/null +++ b/python/jittor/einops/einops.py @@ -0,0 +1,782 @@ +import functools +import itertools +import string +import typing +from collections import OrderedDict +from typing import Tuple, List, Dict, Union, Callable, Optional, TypeVar + +if typing.TYPE_CHECKING: + import numpy as np + +from jittor.einops import EinopsError +from jittor.einops._backends import get_backend +from jittor.einops.parsing import ParsedExpression, _ellipsis, AnonymousAxis + +Tensor = TypeVar('Tensor') +ReductionCallable = Callable[[Tensor, List[int]], Tensor] +Reduction = Union[str, ReductionCallable] + +_reductions = ('min', 'max', 'sum', 'mean', 'prod') +_ellipsis_not_in_parenthesis: List[int] = [-999] +_unknown_axis_length = -999999 + + +def is_ellipsis_not_in_parenthesis(group: List[int]) -> bool: + if len(group) != 1: + return False + return group[0] == -999 + + +def _product(sequence: List[int]) -> int: + """ minimalistic product that works both with numbers and symbols. Supports empty lists """ + result = 1 + for element in sequence: + result *= element + return result + + +def _reduce_axes(tensor, reduction_type: Reduction, reduced_axes: List[int], backend): + reduced_axes = tuple(reduced_axes) + if callable(reduction_type): + # custom callable + return reduction_type(tensor, reduced_axes) + else: + # one of built-in operations + if len(reduced_axes) == 0: + return tensor + assert reduction_type in _reductions + if reduction_type == 'mean': + if not backend.is_float_type(tensor): + raise NotImplementedError('reduce_mean is not available for non-floating tensors') + return backend.reduce(tensor, reduction_type, reduced_axes) + + +def _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes): + # 'collapses' neighboring axes if those participate in the result pattern in the same order + # TODO add support for added_axes + assert len(axes_reordering) + len(reduced_axes) == len(init_shapes) + # joining consecutive axes that will be reduced + # possibly we can skip this if all backends can optimize this (not sure) + reduced_axes = tuple(sorted(reduced_axes)) + for i in range(len(reduced_axes) - 1)[::-1]: + if reduced_axes[i] + 1 == reduced_axes[i + 1]: + removed_axis = reduced_axes[i + 1] + removed_length = init_shapes[removed_axis] + init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:] + init_shapes[removed_axis - 1] *= removed_length + reduced_axes = reduced_axes[:i + 1] + tuple(axis - 1 for axis in reduced_axes[i + 2:]) + + # removing axes that are moved together during reshape + def build_mapping(): + init_to_final = {} + for axis in range(len(init_shapes)): + if axis in reduced_axes: + init_to_final[axis] = None + else: + after_reduction = sum(x is not None for x in init_to_final.values()) + init_to_final[axis] = list(axes_reordering).index(after_reduction) + return init_to_final + + init_axis_to_final_axis = build_mapping() + + for init_axis in range(len(init_shapes) - 1)[::-1]: + if init_axis_to_final_axis[init_axis] is None: + continue + if init_axis_to_final_axis[init_axis + 1] is None: + continue + if init_axis_to_final_axis[init_axis] + 1 == init_axis_to_final_axis[init_axis + 1]: + removed_axis = init_axis + 1 + removed_length = init_shapes[removed_axis] + removed_axis_after_reduction = sum(x not in reduced_axes for x in range(removed_axis)) + + reduced_axes = tuple(axis if axis < removed_axis else axis - 1 for axis in reduced_axes) + init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:] + init_shapes[removed_axis - 1] *= removed_length + old_reordering = axes_reordering + axes_reordering = [] + for axis in old_reordering: + if axis == removed_axis_after_reduction: + pass + elif axis < removed_axis_after_reduction: + axes_reordering.append(axis) + else: + axes_reordering.append(axis - 1) + init_axis_to_final_axis = build_mapping() + + return init_shapes, reduced_axes, axes_reordering, final_shapes + + +CookedRecipe = Tuple[List[int], List[int], List[int], Dict[int, int], List[int]] + + +class TransformRecipe: + """ + Recipe describes actual computation pathway. + Recipe can be applied to a tensor or variable. + """ + + # structure is non-mutable. In future, this can be non-mutable dataclass (python 3.7+) + + def __init__(self, + # list of expressions (or just sizes) for elementary axes as they appear in left expression. + # this is what (after computing unknown parts) will be a shape after first transposition. + # If ellipsis is present, it forms one dimension here (in the right position). + elementary_axes_lengths: List[int], + # each dimension in input can help to reconstruct length of one elementary axis + # or verify one of dimensions. Each element points to element of elementary_axes_lengths + input_composite_axes: List[Tuple[List[int], List[int]]], + # indices of axes to be squashed + reduced_elementary_axes: List[int], + # in which order should axes be reshuffled after reduction + axes_permutation: List[int], + # at which positions which of elementary axes should appear + added_axes: Dict[int, int], + # ids of axes as they appear in result, again pointers to elementary_axes_lengths, + # only used to infer result dimensions + output_composite_axes: List[List[int]], + # positions of ellipsis in lhs and rhs of expression + ellipsis_position_in_lhs: Optional[int] = None, + ): + self.elementary_axes_lengths: List[int] = elementary_axes_lengths + self.input_composite_axes: List[Tuple[List[int], List[int]]] = input_composite_axes + self.output_composite_axes: List[List[int]] = output_composite_axes + self.axes_permutation: List[int] = axes_permutation + self.added_axes: Dict[int, int] = added_axes + # This is redundant information, but more convenient to use + self.reduced_elementary_axes: List[int] = reduced_elementary_axes + # setting to a large number to avoid handling Nones in reconstruct_from_shape + self.ellipsis_position_in_lhs: int = ellipsis_position_in_lhs if ellipsis_position_in_lhs is not None else 10000 + + +def _reconstruct_from_shape_uncached(self: TransformRecipe, shape: List[int]) -> CookedRecipe: + """ + Reconstruct all actual parameters using shape. + Shape is a tuple that may contain integers, shape symbols (tf, keras, theano) and UnknownSize (keras, mxnet) + known axes can be integers or symbols, but not Nones. + """ + axes_lengths: List[int] = list(self.elementary_axes_lengths) + if self.ellipsis_position_in_lhs != 10000: + if len(shape) < len(self.input_composite_axes) - 1: + raise EinopsError('Expected at least {} dimensions, got {}'.format( + len(self.input_composite_axes) - 1, len(shape))) + else: + if len(shape) != len(self.input_composite_axes): + raise EinopsError('Expected {} dimensions, got {}'.format(len(self.input_composite_axes), len(shape))) + + ellipsis_shape: List[int] = [] + for input_axis, (known_axes, unknown_axes) in enumerate(self.input_composite_axes): + before_ellipsis = input_axis + after_ellipsis = input_axis + len(shape) - len(self.input_composite_axes) + if input_axis == self.ellipsis_position_in_lhs: + assert len(known_axes) == 0 and len(unknown_axes) == 1 + unknown_axis, = unknown_axes + ellipsis_shape = shape[before_ellipsis:after_ellipsis + 1] + for d in ellipsis_shape: + if d is None: + raise EinopsError("Couldn't infer shape for one or more axes represented by ellipsis") + total_dim_size: int = _product(ellipsis_shape) + axes_lengths[unknown_axis] = total_dim_size + else: + if input_axis < self.ellipsis_position_in_lhs: + length = shape[before_ellipsis] + else: + length = shape[after_ellipsis] + known_product = 1 + for axis in known_axes: + known_product *= axes_lengths[axis] + + if len(unknown_axes) == 0: + if isinstance(length, int) and isinstance(known_product, int) and length != known_product: + raise EinopsError('Shape mismatch, {} != {}'.format(length, known_product)) + # this is enforced when recipe is created + # elif len(unknown_axes) > 1: + # raise EinopsError( + # "Lengths of two or more axes in parenthesis not provided (dim={}), can't infer dimensions". + # format(known_product) + # ) + else: + if isinstance(length, int) and isinstance(known_product, int) and length % known_product != 0: + raise EinopsError("Shape mismatch, can't divide axis of length {} in chunks of {}".format( + length, known_product)) + + unknown_axis: int = unknown_axes[0] + inferred_length: int = length // known_product + axes_lengths[unknown_axis] = inferred_length + + # at this point all axes_lengths are computed (either have values or variables, but not Nones) + + # TODO more readable expression + init_shapes = axes_lengths[:len(axes_lengths) - len(self.added_axes)] + final_shapes: List[int] = [] + for output_axis, grouping in enumerate(self.output_composite_axes): + if is_ellipsis_not_in_parenthesis(grouping): + final_shapes.extend(ellipsis_shape) + else: + lengths = [axes_lengths[elementary_axis] for elementary_axis in grouping] + final_shapes.append(_product(lengths)) + reduced_axes = self.reduced_elementary_axes + axes_reordering = self.axes_permutation + added_axes: Dict[int, int] = { + pos: axes_lengths[pos_in_elementary] for pos, pos_in_elementary in self.added_axes.items()} + # if optimize: + # assert len(self.added_axes) == 0 + # return _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes) + return init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes + + +_reconstruct_from_shape = functools.lru_cache(1024)(_reconstruct_from_shape_uncached) + + +def _apply_recipe(recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction) -> Tensor: + # this method works for all backends but not compilable with + backend = get_backend(tensor) + init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes = \ + _reconstruct_from_shape(recipe, backend.shape(tensor)) + tensor = backend.reshape(tensor, init_shapes) + tensor = _reduce_axes(tensor, reduction_type=reduction_type, reduced_axes=reduced_axes, backend=backend) + tensor = backend.transpose(tensor, axes_reordering) + if len(added_axes) > 0: + tensor = backend.add_axes(tensor, n_axes=len(axes_reordering) + len(added_axes), pos2len=added_axes) + return backend.reshape(tensor, final_shapes) + + +@functools.lru_cache(256) +def _prepare_transformation_recipe(pattern: str, + operation: Reduction, + axes_lengths: Tuple[Tuple, ...]) -> TransformRecipe: + """ Perform initial parsing of pattern and provided supplementary info + axes_lengths is a tuple of tuples (axis_name, axis_length) + """ + left, rght = pattern.split('->') + left = ParsedExpression(left) + rght = ParsedExpression(rght) + + # checking that axes are in agreement - new axes appear only in repeat, while disappear only in reduction + if not left.has_ellipsis and rght.has_ellipsis: + raise EinopsError('Ellipsis found in right side, but not left side of a pattern {}'.format(pattern)) + if left.has_ellipsis and left.has_ellipsis_parenthesized: + raise EinopsError('Ellipsis is parenthesis in the left side is not allowed: {}'.format(pattern)) + if operation == 'rearrange': + difference = set.symmetric_difference(left.identifiers, rght.identifiers) + if left.has_non_unitary_anonymous_axes or rght.has_non_unitary_anonymous_axes: + raise EinopsError('Non-unitary anonymous axes are not supported in rearrange (exception is length 1)') + if len(difference) > 0: + raise EinopsError('Identifiers only on one side of expression (should be on both): {}'.format(difference)) + elif operation == 'repeat': + difference = set.difference(left.identifiers, rght.identifiers) + if len(difference) > 0: + raise EinopsError('Unexpected identifiers on the left side of repeat: {}'.format(difference)) + axes_without_size = set.difference({ax for ax in rght.identifiers if not isinstance(ax, AnonymousAxis)}, + {*left.identifiers, *(ax for ax, _ in axes_lengths)}) + if len(axes_without_size) > 0: + raise EinopsError('Specify sizes for new axes in repeat: {}'.format(axes_without_size)) + elif operation in _reductions or callable(operation): + difference = set.difference(rght.identifiers, left.identifiers) + if len(difference) > 0: + raise EinopsError('Unexpected identifiers on the right side of reduce {}: {}'.format(operation, difference)) + else: + raise EinopsError('Unknown reduction {}. Expect one of {}.'.format(operation, _reductions)) + + # parsing all dimensions to find out lengths + axis_name2known_length = OrderedDict() + for composite_axis in left.composition: + for axis_name in composite_axis: + if isinstance(axis_name, AnonymousAxis): + axis_name2known_length[axis_name] = axis_name.value + else: + axis_name2known_length[axis_name] = _unknown_axis_length + + # axis_ids_after_first_reshape = range(len(axis_name2known_length)) at this point + + repeat_axes_names = [] + for axis_name in rght.identifiers: + if axis_name not in axis_name2known_length: + if isinstance(axis_name, AnonymousAxis): + axis_name2known_length[axis_name] = axis_name.value + else: + axis_name2known_length[axis_name] = _unknown_axis_length + repeat_axes_names.append(axis_name) + + axis_name2position = {name: position for position, name in enumerate(axis_name2known_length)} + reduced_axes: List[int] = [position for axis, position in axis_name2position.items() if + axis not in rght.identifiers] + reduced_axes: List[int] = list(sorted(reduced_axes)) + + for elementary_axis, axis_length in axes_lengths: + if not ParsedExpression.check_axis_name(elementary_axis): + raise EinopsError('Invalid name for an axis', elementary_axis) + if elementary_axis not in axis_name2known_length: + raise EinopsError('Axis {} is not used in transform'.format(elementary_axis)) + axis_name2known_length[elementary_axis] = axis_length + + input_axes_known_unknown = [] + # some of shapes will be inferred later - all information is prepared for faster inference + for composite_axis in left.composition: + known = {axis for axis in composite_axis if axis_name2known_length[axis] != _unknown_axis_length} + unknown = {axis for axis in composite_axis if axis_name2known_length[axis] == _unknown_axis_length} + if len(unknown) > 1: + raise EinopsError('Could not infer sizes for {}'.format(unknown)) + assert len(unknown) + len(known) == len(composite_axis) + input_axes_known_unknown.append( + ([axis_name2position[axis] for axis in known], + [axis_name2position[axis] for axis in unknown]) + ) + + axis_position_after_reduction = {} + for axis_name in itertools.chain(*left.composition): + if axis_name in rght.identifiers: + axis_position_after_reduction[axis_name] = len(axis_position_after_reduction) + + result_axes_grouping: List[List[int]] = [] + for composite_axis in rght.composition: + if composite_axis == _ellipsis: + result_axes_grouping.append(_ellipsis_not_in_parenthesis) + else: + result_axes_grouping.append([axis_name2position[axis] for axis in composite_axis]) + + ordered_axis_right = list(itertools.chain(*rght.composition)) + axes_permutation = [ + axis_position_after_reduction[axis] for axis in ordered_axis_right if axis in left.identifiers] + added_axes = {i: axis_name2position[axis_name] for i, axis_name in enumerate(ordered_axis_right) + if axis_name not in left.identifiers} + + ellipsis_left = None if _ellipsis not in left.composition else left.composition.index(_ellipsis) + + return TransformRecipe( + elementary_axes_lengths=list(axis_name2known_length.values()), + input_composite_axes=input_axes_known_unknown, + reduced_elementary_axes=reduced_axes, + axes_permutation=axes_permutation, + added_axes=added_axes, + output_composite_axes=result_axes_grouping, + ellipsis_position_in_lhs=ellipsis_left, + ) + + +def reduce(tensor: Tensor, pattern: str, reduction: Reduction, **axes_lengths: int) -> Tensor: + """ + einops.reduce provides combination of reordering and reduction using reader-friendly notation. + + Examples for reduce operation: + + ```python + >>> x = np.random.randn(100, 32, 64) + + # perform max-reduction on the first axis + >>> y = reduce(x, 't b c -> b c', 'max') + + # same as previous, but with clearer axes meaning + >>> y = reduce(x, 'time batch channel -> batch channel', 'max') + + >>> x = np.random.randn(10, 20, 30, 40) + + # 2d max-pooling with kernel size = 2 * 2 for image processing + >>> y1 = reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2) + + # if one wants to go back to the original height and width, depth-to-space trick can be applied + >>> y2 = rearrange(y1, 'b (c h2 w2) h1 w1 -> b c (h1 h2) (w1 w2)', h2=2, w2=2) + >>> assert parse_shape(x, 'b _ h w') == parse_shape(y2, 'b _ h w') + + # Adaptive 2d max-pooling to 3 * 4 grid + >>> reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=3, w1=4).shape + (10, 20, 3, 4) + + # Global average pooling + >>> reduce(x, 'b c h w -> b c', 'mean').shape + (10, 20) + + # Subtracting mean over batch for each channel + >>> y = x - reduce(x, 'b c h w -> () c () ()', 'mean') + + # Subtracting per-image mean for each channel + >>> y = x - reduce(x, 'b c h w -> b c () ()', 'mean') + + ``` + + Parameters: + tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.Var). + list of tensors is also accepted, those should be of the same type and shape + pattern: string, reduction pattern + reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive + alternatively, a callable f(tensor, reduced_axes) -> tensor can be provided. + axes_lengths: any additional specifications for dimensions + + Returns: + tensor of the same type as input + """ + try: + hashable_axes_lengths = tuple(sorted(axes_lengths.items())) + recipe = _prepare_transformation_recipe(pattern, reduction, axes_lengths=hashable_axes_lengths) + return _apply_recipe(recipe, tensor, reduction_type=reduction) + except EinopsError as e: + message = ' Error while processing {}-reduction pattern "{}".'.format(reduction, pattern) + if not isinstance(tensor, list): + message += '\n Input tensor shape: {}. '.format(get_backend(tensor).shape(tensor)) + else: + message += '\n Input is list. ' + message += 'Additional info: {}.'.format(axes_lengths) + raise EinopsError(message + '\n {}'.format(e)) + + + +def rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths) -> Tensor: + """ + einops.rearrange is a reader-friendly smart element reordering for multidimensional tensors. + This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze, + stack, concatenate and other operations. + + Examples for rearrange operation: + + ```python + # suppose we have a set of 32 images in "h w c" format (height-width-channel) + >>> images = [np.random.randn(30, 40, 3) for _ in range(32)] + + # stack along first (batch) axis, output is a single array + >>> rearrange(images, 'b h w c -> b h w c').shape + (32, 30, 40, 3) + + # concatenate images along height (vertical axis), 960 = 32 * 30 + >>> rearrange(images, 'b h w c -> (b h) w c').shape + (960, 40, 3) + + # concatenated images along horizontal axis, 1280 = 32 * 40 + >>> rearrange(images, 'b h w c -> h (b w) c').shape + (30, 1280, 3) + + # reordered axes to "b c h w" format for deep learning + >>> rearrange(images, 'b h w c -> b c h w').shape + (32, 3, 30, 40) + + # flattened each image into a vector, 3600 = 30 * 40 * 3 + >>> rearrange(images, 'b h w c -> b (c h w)').shape + (32, 3600) + + # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2 + >>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape + (128, 15, 20, 3) + + # space-to-depth operation + >>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape + (32, 15, 20, 12) + + ``` + + When composing axes, C-order enumeration used (consecutive elements have different last axis) + Find more examples in einops tutorial. + + Parameters: + tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.Var). + list of tensors is also accepted, those should be of the same type and shape + pattern: string, rearrangement pattern + axes_lengths: any additional specifications for dimensions + + Returns: + tensor of the same type as input. If possible, a view to the original tensor is returned. + + """ + if isinstance(tensor, list): + if len(tensor) == 0: + raise TypeError("Rearrange can't be applied to an empty list") + tensor = get_backend(tensor[0]).stack_on_zeroth_dimension(tensor) + return reduce(tensor, pattern, reduction='rearrange', **axes_lengths) + + +def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor: + """ + einops.repeat allows reordering elements and repeating them in arbitrary combinations. + This operation includes functionality of repeat, tile, broadcast functions. + + Examples for repeat operation: + + ```python + # a grayscale image (of shape height x width) + >>> image = np.random.randn(30, 40) + + # change it to RGB format by repeating in each channel + >>> repeat(image, 'h w -> h w c', c=3).shape + (30, 40, 3) + + # repeat image 2 times along height (vertical axis) + >>> repeat(image, 'h w -> (repeat h) w', repeat=2).shape + (60, 40) + + # repeat image 2 time along height and 3 times along width + >>> repeat(image, 'h w -> (h2 h) (w3 w)', h2=2, w3=3).shape + (60, 120) + + # convert each pixel to a small square 2x2. Upsample image by 2x + >>> repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape + (60, 80) + + # pixelate image first by downsampling by 2x, then upsampling + >>> downsampled = reduce(image, '(h h2) (w w2) -> h w', 'mean', h2=2, w2=2) + >>> repeat(downsampled, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape + (30, 40) + + ``` + + When composing axes, C-order enumeration used (consecutive elements have different last axis) + Find more examples in einops tutorial. + + Parameters: + tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.Var). + list of tensors is also accepted, those should be of the same type and shape + pattern: string, rearrangement pattern + axes_lengths: any additional specifications for dimensions + + Returns: + Tensor of the same type as input. If possible, a view to the original tensor is returned. + + """ + return reduce(tensor, pattern, reduction='repeat', **axes_lengths) + + +def parse_shape(x, pattern: str) -> dict: + """ + Parse a tensor shape to dictionary mapping axes names to their lengths. + + ```python + # Use underscore to skip the dimension in parsing. + >>> x = np.zeros([2, 3, 5, 7]) + >>> parse_shape(x, 'batch _ h w') + {'batch': 2, 'h': 5, 'w': 7} + + # `parse_shape` output can be used to specify axes_lengths for other operations: + >>> y = np.zeros([700]) + >>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape + (2, 10, 5, 7) + + ``` + + For symbolic frameworks may return symbols, not integers. + + Parameters: + x: tensor of any of supported frameworks + pattern: str, space separated names for axes, underscore means skip axis + + Returns: + dict, maps axes names to their lengths + """ + exp = ParsedExpression(pattern, allow_underscore=True) + shape = get_backend(x).shape(x) + if exp.has_composed_axes(): + raise RuntimeError("Can't parse shape with composite axes: {pattern} {shape}".format( + pattern=pattern, shape=shape)) + if len(shape) != len(exp.composition): + if exp.has_ellipsis: + if len(shape) < len(exp.composition) - 1: + raise RuntimeError("Can't parse shape with this number of dimensions: {pattern} {shape}".format( + pattern=pattern, shape=shape)) + else: + raise RuntimeError("Can't parse shape with different number of dimensions: {pattern} {shape}".format( + pattern=pattern, shape=shape)) + if exp.has_ellipsis: + ellipsis_idx = exp.composition.index(_ellipsis) + composition = (exp.composition[:ellipsis_idx] + + ['_'] * (len(shape) - len(exp.composition) + 1) + + exp.composition[ellipsis_idx + 1:]) + else: + composition = exp.composition + result = {} + for (axis_name,), axis_length in zip(composition, shape): + if axis_name != '_': + result[axis_name] = axis_length + return result + + +# this one is probably not needed in the public API +def _enumerate_directions(x): + """ + For an n-dimensional tensor, returns tensors to enumerate each axis. + ```python + x = np.zeros([2, 3, 4]) # or any other tensor + i, j, k = _enumerate_directions(x) + result = i + 2*j + 3*k + ``` + + `result[i, j, k] = i + 2j + 3k`, and also has the same shape as result + Works very similarly to numpy.ogrid (open indexing grid) + """ + backend = get_backend(x) + shape = backend.shape(x) + result = [] + for axis_id, axis_length in enumerate(shape): + shape = [1] * len(shape) + shape[axis_id] = axis_length + result.append(backend.reshape(backend.arange(0, axis_length), shape)) + return result + + +def asnumpy(tensor) -> 'numpy.ndarray': + """ + Convert a tensor of an imperative framework (i.e. numpy/jittor.) to `numpy.ndarray` + + Parameters: + tensor: tensor of any of known imperative framework + + Returns: + `numpy.ndarray`, converted to numpy + """ + return get_backend(tensor).to_numpy(tensor) + +def _validate_einsum_axis_name(axis_name): + if len(axis_name) == 0: + raise NotImplementedError("Singleton () axes are not yet supported in einsum.") + if len(axis_name) > 1: + raise NotImplementedError("Shape rearrangement is not yet supported in einsum.") + + axis_name = axis_name[0] + + if isinstance(axis_name, AnonymousAxis): + raise NotImplementedError("Anonymous axes are not yet supported in einsum.") + if len(axis_name) == 0: + raise RuntimeError("Encountered empty axis name in einsum.") + if not isinstance(axis_name, str): + raise RuntimeError("Axis name in einsum must be a string.") + + +@functools.lru_cache(256) +def _compactify_pattern_for_einsum(pattern: str) -> str: + if "->" not in pattern: + # numpy allows this, so make sure users + # don't accidentally do something like this. + raise ValueError("Einsum pattern must contain '->'.") + lefts, right = pattern.split('->') + lefts = lefts.split(',') + + lefts = [ + ParsedExpression(left, allow_underscore=True, allow_duplicates=True) + for left in lefts + ] + + right = ParsedExpression(right, allow_underscore=True) + + # Start from 'a' and go up to 'Z' + output_axis_names = string.ascii_letters + i = 0 + axis_name_mapping = {} + + left_patterns = [] + for left in lefts: + left_pattern = "" + for raw_axis_name in left.composition: + + if raw_axis_name == _ellipsis: + left_pattern += '...' + continue + + _validate_einsum_axis_name(raw_axis_name) + axis_name = raw_axis_name[0] + if axis_name not in axis_name_mapping: + if i >= len(output_axis_names): + raise RuntimeError("Too many axes in einsum.") + axis_name_mapping[axis_name] = output_axis_names[i] + i += 1 + + left_pattern += axis_name_mapping[axis_name] + left_patterns.append(left_pattern) + + compact_pattern = ",".join(left_patterns) + "->" + + for raw_axis_name in right.composition: + if raw_axis_name == _ellipsis: + compact_pattern += '...' + continue + + _validate_einsum_axis_name(raw_axis_name) + axis_name = raw_axis_name[0] + + if axis_name not in axis_name_mapping: + raise EinopsError(f"Unknown axis {axis_name} on right side of einsum {pattern}.") + + compact_pattern += axis_name_mapping[axis_name] + + return compact_pattern + + +@typing.overload +def einsum(tensor: Tensor, pattern: str) -> Tensor: ... +@typing.overload +def einsum(tensor1: Tensor, tensor2: Tensor, pattern: str) -> Tensor: ... +@typing.overload +def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, pattern: str) -> Tensor: ... +@typing.overload +def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, tensor4: Tensor, pattern: str) -> Tensor: ... + + +def einsum(*tensors_and_pattern: List[Union[Tensor, str]]) -> Tensor: + """ + einops.einsum calls einsum operations with einops-style named + axes indexing, computing tensor products with an arbitrary + number of tensors. Unlike typical einsum syntax, here you must + pass tensors first, and then the pattern. + + Also, note that rearrange operations such as `"(batch chan) out"`, + or singleton axes `()`, are not currently supported. + + Examples: + + For a given pattern such as: + ```python + >>> x, y, z = np.random.randn(3, 20, 20, 20) + >>> output = einsum(x, y, z, "a b c, c b d, a g k -> a b k") + + ``` + the following formula is computed: + ```tex + output[a, b, k] = + \sum_{c, d, g} x[a, b, c] * y[c, b, d] * z[a, g, k] + ``` + where the summation over `c`, `d`, and `g` is performed + because those axes names do not appear on the right-hand side. + + Let's see some additional examples: + ```python + # Filter a set of images: + >>> batched_images = np.random.randn(128, 16, 16) + >>> filters = np.random.randn(16, 16, 30) + >>> result = einsum(batched_images, filters, + ... "batch h w, h w channel -> batch channel") + >>> result.shape + (128, 30) + + # Matrix multiplication, with an unknown input shape: + >>> batch_shape = (50, 30) + >>> data = np.random.randn(*batch_shape, 20) + >>> weights = np.random.randn(10, 20) + >>> result = einsum(weights, data, + ... "out_dim in_dim, ... in_dim -> ... out_dim") + >>> result.shape + (50, 30, 10) + + # Matrix trace on a single tensor: + >>> matrix = np.random.randn(10, 10) + >>> result = einsum(matrix, "i i ->") + >>> result.shape + () + + ``` + + Parameters: + tensors: tensors of any supported library (numpy, jittor). + pattern: string, einsum pattern, with commas + separating specifications for each tensor. + + Returns: + Tensor of the same type as input, after processing with einsum. + + """ + if len(tensors_and_pattern) <= 1: + raise ValueError( + "`einops.einsum` takes at minimum two arguments: the tensors (at least one)," + " followed by the pattern." + ) + pattern = tensors_and_pattern[-1] + if not isinstance(pattern, str): + raise ValueError( + "The last argument passed to `einops.einsum` must be a string," + " representing the einsum pattern." + ) + tensors = tensors_and_pattern[:-1] + pattern = _compactify_pattern_for_einsum(pattern) + return get_backend(tensors[0]).einsum(pattern, *tensors) \ No newline at end of file diff --git a/python/jittor/einops/experimental/__init__.py b/python/jittor/einops/experimental/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/jittor/einops/experimental/indexing.py b/python/jittor/einops/experimental/indexing.py new file mode 100644 index 00000000..4ba9e9de --- /dev/null +++ b/python/jittor/einops/experimental/indexing.py @@ -0,0 +1,393 @@ +""" + +Indexing one array with the other(s). + +Concept for discussion. + +Notation targets hard cases, not simple ones, like indexing of 1d-array with another 1d-array +(notation supports that, but you can't simplify arr[ind], and there is no reason to) + +Examples + +1. query for every token in sequence a token in the image. Images and sequences are paired + einindex('b t c <- b h w c, [h, w] b t', arr_bhwc, [h_indices_bt, w_indices_bt]) + + this is equivalent, so you can pass indexers idependently or together + einindex('b t c <- b h w c, [h, w] b t', arr_bhwc, np.asarray([h_indices_bt, w_indices_bt])) + + after some thinking I decided that having first axis for indexing variable is not too restrictive, + but should simplify mapping of such cases. + For this reason [...] part should always go first in indexer. + + This makes the largest difference with einindex https://github.com/malmaud/einindex, + which has almost identical grammar, but puts special dimension last, while we put it first. + This trick allows naturally decomposing multiindex into individual dimensions or visa versa. + + +2. query for every token in the video the most suitable word in a (matching) sentence + einindex('b t h w <- seq b, [seq] t b h w', arr_tbc, [t_indices_bhw]) + + note, that only one indexer is used, but still it has to be enclosed in the list. + That's a price for being generic. Alternatively leading singleton dimension can be added. + + +3. (not supported now, future planning) + for every timeframe in a video, find the token with the highest norm (across h and w), and compose a new stack of them + indices_2bt = argmax(x_bthwc.norm(dim=-1), 'b t h w -> [h, w] b t') + selected_embeddings_btc = einindex('b t c <- b t h w c, [h, w] b t', x_bthwc, indices_2bt) + + while currently question is around 'how do we index', + it is important to pre-align that with a question 'what are natural ways to get indices'. + Most common are min/max. less common options: topk (works here), random sampling. + + + +Some important properties of this notation: +- support for multiple indexers, including using a single tensor to keep multiple indexers +- 'batch' indexing, when some axes of indexer and array should be matched +- universal (one-indexing-to-rule-them-all) +- extensible for (named) ellipses, including variadic number of indexers +- extensible for einops-style compositions and decompositions +- extensible for outer indexing when indexers are not aligned + +Current implementation based on python array api and uses loops, +because no appropriate indexing available in the standard. + +""" + +from typing import List, Union, TypeVar, Tuple + +from jittor.einops import EinopsError + +T = TypeVar('T') + + +class CompositionDecomposition: + def __init__( + self, + decomposed_shape: List[str], + composed_shape: List[List[str]], + ): + flat_shape = [] + for x in composed_shape: + flat_shape.extend(x) + + self.compose_transposition: Tuple[int] = tuple([decomposed_shape.index(x) for x in flat_shape]) + self.decompose_transposition: Tuple[int] = tuple([flat_shape.index(x) for x in decomposed_shape]) + self.composed_shape = composed_shape + self.decomposed_shape = decomposed_shape + + def decompose(self, x, known_axes_lengths: dict[str, int]): + xp = x.__array_namespace__() + shape = x.shape + + flat_shape = [] + + for i, axis_group in enumerate(self.composed_shape): + unknown_axis_name = None + known_sizes_prod = 1 + for axis_name in axis_group: + if axis_name in known_axes_lengths: + known_sizes_prod *= known_axes_lengths[axis_name] + else: + if unknown_axis_name is None: + unknown_axis_name = axis_name + else: + raise EinopsError("Can't infer the size") + + if unknown_axis_name is None: + assert shape[i] == known_sizes_prod + else: + known_axes_lengths[unknown_axis_name] = shape[i] // known_sizes_prod + + for axis in axis_group: + flat_shape.append(known_axes_lengths[axis]) + + x = xp.reshape(x, flat_shape) + return xp.permute_dims(x, self.decompose_transposition) + + def compose(self, x, known_axes_lengths: dict[str, int]): + xp = x.__array_namespace__() + + for axis_len, axis_name in zip(x.shape, self.decomposed_shape): + if axis_name in known_axes_lengths: + assert known_axes_lengths[axis_name] == axis_len + else: + known_axes_lengths[axis_name] = axis_len + + x = xp.permute_dims(x, self.compose_transposition) + new_shape = [] + for axis_group in self.composed_shape: + composed_axis_size = 1 + for axis_name in axis_group: + composed_axis_size *= known_axes_lengths[axis_name] + new_shape.append(composed_axis_size) + + return xp.reshape(x, tuple(new_shape)) + + +def arange_at_position(xp, n_axes, axis, axis_len, device=None): + x = xp.arange(axis_len, dtype=xp.int64, device=device) + shape = [1] * n_axes + shape[axis] = axis_len + x = xp.reshape(x, shape) + return x + + +class IndexingFormula: + + def __init__(self, pattern: str): + """ + :param pattern: example 'b t c <- b hsel wsel c, [hsel, wsel] b t' + """ + self.pattern = pattern + left, right = pattern.split('<-') + arg_split = right.index(',') + arr_pattern, ind_pattern = right[:arg_split], right[arg_split + 1:] + ind_pattern = ind_pattern.strip() + # print( + # arr_pattern, '\n', + # ind_pattern, + # ) + assert ind_pattern.startswith('['), 'composition axis should go first in indexer (second argument) [h w] i j k' + composition_start = ind_pattern.index('[') + composition_end = ind_pattern.index(']') + composition = ind_pattern[composition_start + 1: composition_end] + ind_other_axes = ind_pattern[composition_end + 1:] + + self.result_axes_names = left.split() + self.array_axes_names = arr_pattern.split() + self.indexing_axes_names = [x.strip() for x in composition.split(',')] + self.indexer_other_axes_names = ind_other_axes.split() + + for group_name, group in [ + ('result', self.result_axes_names), + ('array', self.array_axes_names), + ('indexer', self.indexing_axes_names + self.indexer_other_axes_names), + ]: + if len(set(group)) != len(group): + # need more verbosity, which axis, raise + raise EinopsError(f'{group_name} pattern ({group}) contains a duplicated axis') + + axis_groups = [ + self.result_axes_names, + self.array_axes_names, + self.indexing_axes_names, + self.indexer_other_axes_names, + ] + + all_axes = set() + for group in axis_groups: + all_axes.update(group) + + self.indexer_axes = [] + self.batch_axes = [] + self.result_and_index_axes = [] + self.result_and_array_axes = [] + + for axis in all_axes: + presence = tuple(axis in g for g in axis_groups) + # want match-case here. sweet dreams + if presence == (False, True, True, False): + self.indexer_axes.append(axis) + elif presence[2]: + raise EinopsError(f'Wrong usage of indexer variable {axis}') + elif presence == (True, True, False, True): + self.batch_axes.append(axis) + elif presence == (True, False, False, True): + self.result_and_index_axes.append(axis) + elif presence == (True, True, False, False): + self.result_and_array_axes.append(axis) + else: + # TODO better categorization of wrong usage patterns + raise EinopsError(f'{axis} is used incorrectly in {pattern}') + + assert set(self.indexer_axes) == set(self.indexing_axes_names) + # order of these variables matters, since we can't lose mapping here + self.indexer_axes = self.indexing_axes_names + + self.array_composition = CompositionDecomposition( + decomposed_shape=self.array_axes_names, + composed_shape=[self.batch_axes + self.indexer_axes, self.result_and_array_axes], + ) + + self.index_composition = CompositionDecomposition( + decomposed_shape=self.indexer_other_axes_names, + # single axis after composition + composed_shape=[self.batch_axes + self.result_and_index_axes], + ) + + self.result_composition = CompositionDecomposition( + decomposed_shape=self.result_axes_names, + composed_shape=[self.batch_axes + self.result_and_index_axes, self.result_and_array_axes], + ) + + def apply_to_array_api(self, arr: T, ind: Union[T, List[T]]): + known_axes_sizes: dict[str, int] = {} + xp = arr.__array_namespace__() + + if not isinstance(ind, list): + ind = [ind[i, ...] for i in range(ind.shape[0])] + + for indexer in ind: + assert len(indexer.shape) == len(self.indexer_other_axes_names) + + # step 1. transpose, reshapes of arr; learn its dimensions + arr_2d = self.array_composition.compose(arr, known_axes_sizes) + + # step 2. compute shifts and create an actual indexing array + shift = 1 + full_index = xp.zeros([1] * len(ind[0].shape), dtype=xp.int64, device=arr.device) + + # original order: [*batch-like axes, *indexing_axes,] + # now we need to traverse them in the opposite direction + + for axis_name, indexer in list(zip(self.indexing_axes_names, ind))[::-1]: + full_index = full_index + shift * (indexer % known_axes_sizes[axis_name]) + shift *= known_axes_sizes[axis_name] + + for axis_name in self.batch_axes[::-1]: + axis_id = self.indexer_other_axes_names.index(axis_name) + full_index = full_index + arange_at_position( + xp, len(self.indexer_other_axes_names), axis=axis_id, axis_len=known_axes_sizes[axis_name], + device=arr.device, + ) * shift + shift *= known_axes_sizes[axis_name] + + assert shift == arr_2d.shape[0] + + # step 3. Flatten index + full_index = self.index_composition.compose(full_index, known_axes_sizes) + + # step 4. indexing + # python array api lacks any integer indexing, so... I use loops. + # did you know that there is conceptual programming ... just like art? + # result_2d = arr_2d[full_index] + result_2d = xp.stack([arr_2d[full_index[i], :] for i in range(full_index.shape[0])]) + + # step 5. doing resulting + result = self.result_composition.decompose(result_2d, known_axes_sizes) + return result + + +def einindex(pattern: str, arr: T, /, ind: Union[T, List[T]]): + """ + Demonstrates how einindex should work. + Supports data-api compliant arrays. + """ + formula = IndexingFormula(pattern) + return formula.apply_to_array_api(arr, ind) + + +def test_composition_and_decomposition(): + import numpy.array_api as np + x = np.arange(2 * 3 * 5 * 7) + x = np.reshape(x, (2, 3, 5, 7)) + comp = CompositionDecomposition( + decomposed_shape=['a', 'b', 'c', 'd'], + composed_shape=[['a', 'b'], ['c', 'd']], + ) + assert comp.compose(x, known_axes_lengths={}).shape == (2 * 3, 5 * 7) + + y = CompositionDecomposition( + decomposed_shape=['a', 'b', 'c', 'd'], + composed_shape=[['a', 'b'], [], ['c', 'd']], + ).compose(x, {}) + assert y.shape == (2 * 3, 1, 5 * 7) + assert np.all(np.reshape(x, (-1,)) == np.reshape(y, (-1,))) + + comp = CompositionDecomposition( + decomposed_shape=['a', 'b', 'e', 'c', 'd'], + composed_shape=[['e', 'c'], ['b'], ['a', 'd']], + ) + x = np.arange(2 * 3 * 5 * 7 * 3) + x = np.reshape(x, (2, 3, 5, 7, 3)) + + axes = {} + y = comp.compose(x, axes) + x2 = comp.decompose(y, axes) + assert np.all(x == x2) + + +def test_simple_indexing(): + import numpy.array_api as np + + # simple 2d test + arr = np.reshape(np.arange(5 * 7), (5, 7)) + ind = np.arange(7) % 5 + x = einindex('j <- i j, [i] j', arr, [ind]) + for j, i in enumerate(ind): + assert arr[i, j] == x[j] + + y = einindex('j <- j i, [i] j', np.permute_dims(arr, (1, 0)), [ind]) + for j, i in enumerate(ind): + assert arr[i, j] == y[j] + + +def test_multidimensional_indexing(): + import numpy.array_api as np + + embedding_bhwc = ( + + arange_at_position(np, 4, 0, 2) * 1000 + + arange_at_position(np, 4, 1, 3) * 100 + + arange_at_position(np, 4, 2, 5) * 10 + + arange_at_position(np, 4, 3, 7) * 1 + ) + + hindices_bt = np.reshape(np.arange(6), (2, 3)) % 3 + windices_bt = np.reshape(np.arange(6), (2, 3)) % 5 + + # imagine that you have pairs of image <> sentence + # your goal is to get most suitable token from image for every token in sentence + # thus for every token in sentence you compute best k and v + + result = einindex('c t b <- b h w c, [h, w] b t', embedding_bhwc, [hindices_bt, windices_bt]) + # example of using a single array for indexing multiple axes + hw_indices_bt = np.stack([hindices_bt, windices_bt]) + result2 = einindex('c t b <- b h w c, [h, w] b t', embedding_bhwc, hw_indices_bt) + assert np.all(result == result2) + + # check vs manual element computation + result_manual = result * 0 + for b in range(2): + for t in range(3): + for c in range(7): + h = hindices_bt[b, t] + w = windices_bt[b, t] + result_manual[c, t, b] = embedding_bhwc[b, h, w, c] + + assert np.all(result == result_manual) + + +def test_reverse_indexing(): + import numpy.array_api as np + + C, T, B = 2, 3, 5 + # G = GPU, batch-like varaible + G = 4 + H = 7 + W = 9 + + arr_gtbc = ( + + arange_at_position(np, 4, 0, G) * 1000 + + arange_at_position(np, 4, 1, T) * 100 + + arange_at_position(np, 4, 2, B) * 10 + + arange_at_position(np, 4, 3, C) * 1 + ) + + t_indices_gbhw = np.reshape(np.arange(G * B * H * W), (G, B, H, W)) % T + + result = einindex('g b c h w <- g t b c, [t] g b h w', arr_gtbc, [t_indices_gbhw]) + + result_manual = result * 0 + for g in range(G): + for b in range(B): + for c in range(C): + for h in range(H): + for w in range(W): + t = t_indices_gbhw[g, b, h, w] + result_manual[g, b, c, h, w] = arr_gtbc[g, t, b, c] + + assert np.all(result == result_manual) + + diff --git a/python/jittor/einops/layers/__init__.py b/python/jittor/einops/layers/__init__.py new file mode 100644 index 00000000..7b7f43d9 --- /dev/null +++ b/python/jittor/einops/layers/__init__.py @@ -0,0 +1,79 @@ +__author__ = 'Alex Rogozhnikov' + +import functools + +from jittor.einops.einops import _apply_recipe + +from jittor.einops.einops import TransformRecipe, _prepare_transformation_recipe +from jittor.einops import EinopsError + + +class RearrangeMixin: + """ + Rearrange layer behaves identically to einops.rearrange operation. + + :param pattern: str, rearrangement pattern + :param axes_lengths: any additional specification of dimensions + + See einops.rearrange for source_examples. + """ + + def __init__(self, pattern, **axes_lengths): + super().__init__() + self.pattern = pattern + self.axes_lengths = axes_lengths + self._recipe = self.recipe() # checking parameters + + def __repr__(self): + params = repr(self.pattern) + for axis, length in self.axes_lengths.items(): + params += ', {}={}'.format(axis, length) + return '{}({})'.format(self.__class__.__name__, params) + + @functools.lru_cache(maxsize=1024) + def recipe(self) -> TransformRecipe: + try: + hashable_lengths = tuple(sorted(self.axes_lengths.items())) + return _prepare_transformation_recipe(self.pattern, operation='rearrange', axes_lengths=hashable_lengths) + except EinopsError as e: + raise EinopsError(' Error while preparing {!r}\n {}'.format(self, e)) + + def _apply_recipe(self, x): + return _apply_recipe(self._recipe, x, reduction_type='rearrange') + + +class ReduceMixin: + """ + Reduce layer behaves identically to einops.reduce operation. + + :param pattern: str, rearrangement pattern + :param reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive + :param axes_lengths: any additional specification of dimensions + + See einops.reduce for source_examples. + """ + + def __init__(self, pattern, reduction, **axes_lengths): + super().__init__() + self.pattern = pattern + self.reduction = reduction + self.axes_lengths = axes_lengths + self._recipe = self.recipe() # checking parameters + + def __repr__(self): + params = '{!r}, {!r}'.format(self.pattern, self.reduction) + for axis, length in self.axes_lengths.items(): + params += ', {}={}'.format(axis, length) + return '{}({})'.format(self.__class__.__name__, params) + + @functools.lru_cache(maxsize=1024) + def recipe(self) -> TransformRecipe: + try: + hashable_lengths = tuple(sorted(self.axes_lengths.items())) + return _prepare_transformation_recipe( + self.pattern, operation=self.reduction, axes_lengths=hashable_lengths) + except EinopsError as e: + raise EinopsError(' Error while preparing {!r}\n {}'.format(self, e)) + + def _apply_recipe(self, x): + return _apply_recipe(self._recipe, x, reduction_type=self.reduction) diff --git a/python/jittor/einops/layers/_einmix.py b/python/jittor/einops/layers/_einmix.py new file mode 100644 index 00000000..7f5c5c68 --- /dev/null +++ b/python/jittor/einops/layers/_einmix.py @@ -0,0 +1,176 @@ +from typing import Optional, Dict + +from jittor.einops import EinopsError +from jittor.einops.parsing import ParsedExpression +import warnings +import string +from jittor.einops.einops import _product + + +def _report_axes(axes: set, report_message: str): + if len(axes) > 0: + raise EinopsError(report_message.format(axes)) + + +class _EinmixMixin: + def __init__(self, pattern, weight_shape, bias_shape=None, **axes_lengths): + """ + EinMix - Einstein summation with automated tensor management and axis packing/unpacking. + + EinMix is an advanced tool, helpful tutorial: + https://github.com/arogozhnikov/einops/blob/master/docs/3-einmix-layer.ipynb + + Imagine taking einsum with two arguments, one of each input, and one - tensor with weights + >>> einsum('time batch channel_in, channel_in channel_out -> time batch channel_out', input, weight) + + This layer manages weights for you, syntax highlights separate role of weight matrix + >>> EinMix('time batch channel_in -> time batch channel_out', weight_shape='channel_in channel_out') + But otherwise it is the same einsum under the hood. + + Simple linear layer with bias term (you have one like that in your framework) + >>> EinMix('t b cin -> t b cout', weight_shape='cin cout', bias_shape='cout', cin=10, cout=20) + There is restriction to mix the last axis. Let's mix along height + >>> EinMix('h w c-> hout w c', weight_shape='h hout', bias_shape='hout', h=32, hout=32) + Channel-wise multiplication (like one used in normalizations) + >>> EinMix('t b c -> t b c', weight_shape='c', c=128) + Separate dense layer within each head, no connection between different heads + >>> EinMix('t b (head cin) -> t b (head cout)', weight_shape='head cin cout', ...) + + ... ah yes, you need to specify all dimensions of weight shape/bias shape in parameters. + + Use cases: + - when channel dimension is not last, use EinMix, not transposition + - patch/segment embeddings + - when need only within-group connections to reduce number of weights and computations + - perfect as a part of sequential models + - next-gen MLPs (follow tutorial to learn more) + + Uniform He initialization is applied to weight tensor and encounters for number of elements mixed. + + Parameters + :param pattern: transformation pattern, left side - dimensions of input, right side - dimensions of output + :param weight_shape: axes of weight. A tensor of this shape is created, stored, and optimized in a layer + :param bias_shape: axes of bias added to output. Weights of this shape are created and stored. If `None` (the default), no bias is added. + :param axes_lengths: dimensions of weight tensor + """ + super().__init__() + self.pattern = pattern + self.weight_shape = weight_shape + self.bias_shape = bias_shape + self.axes_lengths = axes_lengths + self.initialize_einmix(pattern=pattern, weight_shape=weight_shape, bias_shape=bias_shape, axes_lengths=axes_lengths) + + def initialize_einmix(self, pattern, weight_shape, bias_shape, axes_lengths): + left_pattern, right_pattern = pattern.split('->') + left = ParsedExpression(left_pattern) + right = ParsedExpression(right_pattern) + weight = ParsedExpression(weight_shape) + _report_axes( + set.difference(right.identifiers, {*left.identifiers, *weight.identifiers}), + 'Unrecognized identifiers on the right side of EinMix {}' + ) + + if left.has_ellipsis or right.has_ellipsis or weight.has_ellipsis: + raise EinopsError('Ellipsis is not supported in EinMix (right now)') + if any(x.has_non_unitary_anonymous_axes for x in [left, right, weight]): + raise EinopsError('Anonymous axes (numbers) are not allowed in EinMix') + if '(' in weight_shape or ')' in weight_shape: + raise EinopsError(f'Parenthesis is not allowed in weight shape: {weight_shape}') + + pre_reshape_pattern = None + pre_reshape_lengths = None + post_reshape_pattern = None + if any(len(group) != 1 for group in left.composition): + names = [] + for group in left.composition: + names += group + composition = ' '.join(names) + pre_reshape_pattern = f'{left_pattern}->{composition}' + pre_reshape_lengths = {name: length for name, length in axes_lengths.items() if name in names} + + if any(len(group) != 1 for group in right.composition): + names = [] + for group in right.composition: + names += group + composition = ' '.join(names) + post_reshape_pattern = f'{composition}->{right_pattern}' + + self._create_rearrange_layers(pre_reshape_pattern, pre_reshape_lengths, post_reshape_pattern, {}) + + for axis in weight.identifiers: + if axis not in axes_lengths: + raise EinopsError('Dimension {} of weight should be specified'.format(axis)) + _report_axes( + set.difference(set(axes_lengths), {*left.identifiers, *weight.identifiers}), + 'Axes {} are not used in pattern', + ) + _report_axes( + set.difference(weight.identifiers, {*left.identifiers, *right.identifiers}), + 'Weight axes {} are redundant' + ) + if len(weight.identifiers) == 0: + warnings.warn('EinMix: weight has no dimensions (means multiplication by a number)') + + _weight_shape = [axes_lengths[axis] for axis, in weight.composition] + # single output element is a combination of fan_in input elements + _fan_in = _product([axes_lengths[axis] for axis, in weight.composition if axis not in right.identifiers]) + if bias_shape is not None: + if not isinstance(bias_shape, str): + raise EinopsError('bias shape should be string specifying which axes bias depends on') + bias = ParsedExpression(bias_shape) + _report_axes( + set.difference(bias.identifiers, right.identifiers), + 'Bias axes {} not present in output' + ) + _report_axes( + set.difference(bias.identifiers, set(axes_lengths)), + 'Sizes not provided for bias axes {}', + ) + + _bias_shape = [] + for axes in right.composition: + for axis in axes: + if axis in bias.identifiers: + _bias_shape.append(axes_lengths[axis]) + else: + _bias_shape.append(1) + else: + _bias_shape = None + + weight_bound = (3 / _fan_in) ** 0.5 + bias_bound = (1 / _fan_in) ** 0.5 + self._create_parameters(_weight_shape, weight_bound, _bias_shape, bias_bound) + + # rewrite einsum expression with single-letter latin identifiers so that + # expression will be understood by any framework + mapping2letters = {*left.identifiers, *right.identifiers, *weight.identifiers} + mapping2letters = {k: letter for letter, k in zip(string.ascii_lowercase, mapping2letters)} + + def write_flat(axes: list): + return ''.join(mapping2letters[axis] for axis in axes) + + self.einsum_pattern: str = '{},{}->{}'.format( + write_flat(left.flat_axes_order()), + write_flat(weight.flat_axes_order()), + write_flat(right.flat_axes_order()), + ) + + def _create_rearrange_layers(self, + pre_reshape_pattern: Optional[str], + pre_reshape_lengths: Optional[Dict], + post_reshape_pattern: Optional[str], + post_reshape_lengths: Optional[Dict]): + raise NotImplementedError('Should be defined in framework implementations') + + def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound): + """ Shape and implementations """ + raise NotImplementedError('Should be defined in framework implementations') + + def __repr__(self): + params = repr(self.pattern) + params += f", '{self.weight_shape}'" + if self.bias_shape is not None: + params += f", '{self.bias_shape}'" + for axis, length in self.axes_lengths.items(): + params += ', {}={}'.format(axis, length) + return '{}({})'.format(self.__class__.__name__, params) diff --git a/python/jittor/einops/layers/jittor.py b/python/jittor/einops/layers/jittor.py new file mode 100644 index 00000000..e2696b87 --- /dev/null +++ b/python/jittor/einops/layers/jittor.py @@ -0,0 +1,55 @@ +from typing import Optional, Dict + +import jittor as jt +from jittor import nn +import numpy as np + +from jittor.einops.layers import RearrangeMixin, ReduceMixin +from jittor.einops.layers._einmix import _EinmixMixin + +__author__ = 'Ruiyang Liu' + + +class Rearrange(RearrangeMixin, jt.nn.Module): + def execute(self, input): + return self._apply_recipe(input) + + +class Reduce(ReduceMixin, jt.nn.Module): + def execute(self, input): + return self._apply_recipe(input) + + +class EinMix(_EinmixMixin, jt.nn.Module): + def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound): + self.weight = jt.zeros(weight_shape) + nn.init.uniform_(self.weight, low = -weight_bound, high = weight_bound) + if bias_shape is not None: + self.bias = jt.zeros(bias_shape) + nn.init.uniform_(self.bias, low = -bias_bound, high = bias_bound) + else: + self.bias = None + + def _create_rearrange_layers(self, + pre_reshape_pattern: Optional[str], + pre_reshape_lengths: Optional[Dict], + post_reshape_pattern: Optional[str], + post_reshape_lengths: Optional[Dict], + ): + self.pre_rearrange = None + if pre_reshape_pattern is not None: + self.pre_rearrange = Rearrange(pre_reshape_pattern, **pre_reshape_lengths) + + self.post_rearrange = None + if post_reshape_pattern is not None: + self.post_rearrange = Rearrange(post_reshape_pattern, **post_reshape_lengths) + + def execute(self, input): + if self.pre_rearrange is not None: + input = self.pre_rearrange(input) + result = jt.linalg.einsum(self.einsum_pattern, input, self.weight) + if self.bias is not None: + result += self.bias + if self.post_rearrange is not None: + result = self.post_rearrange(result) + return result diff --git a/python/jittor/einops/parsing.py b/python/jittor/einops/parsing.py new file mode 100644 index 00000000..e298d6b3 --- /dev/null +++ b/python/jittor/einops/parsing.py @@ -0,0 +1,147 @@ +from jittor.einops import EinopsError +import keyword +import warnings +from typing import List, Optional, Set, Tuple + +_ellipsis: str = '…' # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated + + +class AnonymousAxis(object): + """Important thing: all instances of this class are not equal to each other """ + + def __init__(self, value: str): + self.value = int(value) + if self.value <= 1: + if self.value == 1: + raise EinopsError('No need to create anonymous axis of length 1. Report this as an issue') + else: + raise EinopsError('Anonymous axis should have positive length, not {}'.format(self.value)) + + def __repr__(self): + return "{}-axis".format(str(self.value)) + + +class ParsedExpression: + """ + non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)') + and keeps some information important for downstream + """ + def __init__(self, expression, *, allow_underscore: bool = False, allow_duplicates: bool = False): + self.has_ellipsis: bool = False + self.has_ellipsis_parenthesized: Optional[bool] = None + self.identifiers: Set[str] = set() + # that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition + self.has_non_unitary_anonymous_axes: bool = False + # composition keeps structure of composite axes, see how different corner cases are handled in tests + self.composition = [] + if '.' in expression: + if '...' not in expression: + raise EinopsError('Expression may contain dots only inside ellipsis (...)') + if str.count(expression, '...') != 1 or str.count(expression, '.') != 3: + raise EinopsError( + 'Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor ') + expression = expression.replace('...', _ellipsis) + self.has_ellipsis = True + + bracket_group = None + + def add_axis_name(x): + if x is not None: + if x in self.identifiers: + if not (allow_underscore and x == "_") and not allow_duplicates: + raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x)) + if x == _ellipsis: + self.identifiers.add(_ellipsis) + if bracket_group is None: + self.composition.append(_ellipsis) + self.has_ellipsis_parenthesized = False + else: + bracket_group.append(_ellipsis) + self.has_ellipsis_parenthesized = True + else: + is_number = str.isdecimal(x) + if is_number and int(x) == 1: + # handling the case of anonymous axis of length 1 + if bracket_group is None: + self.composition.append([]) + else: + pass # no need to think about 1s inside parenthesis + return + is_axis_name, reason = self.check_axis_name_return_reason(x, allow_underscore=allow_underscore) + if not (is_number or is_axis_name): + raise EinopsError('Invalid axis identifier: {}\n{}'.format(x, reason)) + if is_number: + x = AnonymousAxis(x) + self.identifiers.add(x) + if is_number: + self.has_non_unitary_anonymous_axes = True + if bracket_group is None: + self.composition.append([x]) + else: + bracket_group.append(x) + + current_identifier = None + for char in expression: + if char in '() ': + add_axis_name(current_identifier) + current_identifier = None + if char == '(': + if bracket_group is not None: + raise EinopsError("Axis composition is one-level (brackets inside brackets not allowed)") + bracket_group = [] + elif char == ')': + if bracket_group is None: + raise EinopsError('Brackets are not balanced') + self.composition.append(bracket_group) + bracket_group = None + elif str.isalnum(char) or char in ['_', _ellipsis]: + if current_identifier is None: + current_identifier = char + else: + current_identifier += char + else: + raise EinopsError("Unknown character '{}'".format(char)) + + if bracket_group is not None: + raise EinopsError('Imbalanced parentheses in expression: "{}"'.format(expression)) + add_axis_name(current_identifier) + + def flat_axes_order(self) -> List: + result = [] + for composed_axis in self.composition: + assert isinstance(composed_axis, list), 'does not work with ellipsis' + for axis in composed_axis: + result.append(axis) + return result + + def has_composed_axes(self) -> bool: + # this will ignore 1 inside brackets + for axes in self.composition: + if isinstance(axes, list) and len(axes) > 1: + return True + return False + + @staticmethod + def check_axis_name_return_reason(name: str, allow_underscore: bool = False) -> Tuple[bool, str]: + if not str.isidentifier(name): + return False, 'not a valid python identifier' + elif name[0] == '_' or name[-1] == '_': + if name == '_' and allow_underscore: + return True, '' + return False, 'axis name should should not start or end with underscore' + else: + if keyword.iskeyword(name): + warnings.warn("It is discouraged to use axes names that are keywords: {}".format(name), RuntimeWarning) + if name in ['axis']: + warnings.warn("It is discouraged to use 'axis' as an axis name " + "and will raise an error in future", FutureWarning) + return True, '' + + @staticmethod + def check_axis_name(name: str) -> bool: + """ + Valid axes names are python identifiers except keywords, + and additionally should not start or end with underscore + """ + is_valid, _reason = ParsedExpression.check_axis_name_return_reason(name) + return is_valid diff --git a/python/jittor/extern/acl/acl_compiler copy.py b/python/jittor/extern/acl/acl_compiler copy.py new file mode 100644 index 00000000..3e99eb58 --- /dev/null +++ b/python/jittor/extern/acl/acl_compiler copy.py @@ -0,0 +1,1155 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import pdb + +has_acl = 0 +cc_flags = "" +tikcc_path = env_or_try_find('tikcc_path', 'ccec') +dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL +compiler.has_acl = has_acl + +# export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/tools/aoe/lib64:/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64:/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64/plugin/opskernel:/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64/plugin/nnengine:/usr/local/Ascend/ascend-toolkit/latest/runtime/lib64:/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64/stub:/usr/local/Ascend/ascend-toolkit/latest/tools/tikicpulib/lib/Ascend910A:/usr/local/Ascend/ascend-toolkit/latest/toolkit/tools/simulator/Ascend910A/lib:/opt/AXESMI/lib64:/usr/local/Ascend/driver/lib64/driver/ +# export PYTHONPATH=/home/cjld/new_jittor/jittor/python +# export tikcc_path=g++ + +# conda activate cann +# source /usr/local/Ascend/ascend-toolkit/set_env.sh +# export PYTHONPATH=/home/cjld/new_jittor/jittor/python:/home/cjld/new_jittor/jittor/my/jtorch/python:$PYTHONPATH +# export TASK_QUEUE_ENABLE=0 +# python3 -m jittor.test.test_acl -k array +# jittor: conda activate cann && source /usr/local/Ascend/ascend-toolkit/set_env.sh && PYTHONPATH=/home/cjld/new_jittor/jittor/python:/home/cjld/new_jittor/jittor/my/jtorch/python:$PYTHONPATH && cd /home/cjld/new_jittor/jittor/my/mm_benchmark +# python3 -m jittor.test.test_acl -k test_sum +# export ASCEND_SLOG_PRINT_TO_STDOUT=0 +# ASCEND_GLOBAL_LOG_LEVEL +# export DUMP_GE_GRAPH=1 +# export DUMP_GRAPH_LEVEL=1 + +# build pytorch-npu +# bash ./ci/build.sh +# python3 -m pip install ./dist/torch_npu-1.11.0.post1-cp37-cp37m-linux_x86_64.whl --force-reinstall +# pytorch: conda activate cann && source /usr/local/Ascend/ascend-toolkit/set_env.sh && export TASK_QUEUE_ENABLE=0 && cd /home/cjld/new_jittor/jittor/my/mm_benchmark +# python3 ./mm_bench_pt_npu.py + + +def install(): + import jittor.compiler as compiler + global has_acl, cc_flags + acl_compiler_home = os.path.dirname(__file__) + cc_files = sorted(glob.glob(acl_compiler_home + "/**/*.cc", + recursive=True)) + cc_files2 = [] + for name in cc_files: + if "acl_op_exec" in name: + compiler.extra_core_files.append(name) + else: + cc_files2.append(name) + cc_files = cc_files2 + ascend_toolkit_home = os.getenv('ASCEND_TOOLKIT_HOME') + + #print(ascend_toolkit_home) + #print(acl_compiler_home) + cc_flags += f" -DHAS_CUDA -DIS_ACL \ + -I{ascend_toolkit_home}/include/ \ + -I{ascend_toolkit_home}/include/acl/ \ + -I{ascend_toolkit_home}/include/aclnn/ \ + -I{ascend_toolkit_home}/include/aclnnop/ \ + -I{acl_compiler_home} -lascendcl -lacl_op_compiler \ + -I{acl_compiler_home}/aclnn \ + -L{ascend_toolkit_home}/lib64/" + + cc_flags += " -llibascendcl " + cc_flags += " -llibnnopbase " + cc_flags += " -llibopapi " + + #pdb.set_trace() + ctypes.CDLL("libascendcl.so", dlopen_flags) + f''' + -ltikc_runtime + -I/usr/local/Ascend/driver/include/ \ + -L{ascend_toolkit_home}/compiler/lib64/ \ + -L{ascend_toolkit_home}/runtime/lib64/ \ + ''' + jittor_utils.LOG.i("ACL detected") + + global mod + mod = jittor_utils.compile_module( + ''' +#include "common.h" +namespace jittor { +// @pyjt(process) +string process_acl(const string& src, const string& name, const map& kargs); +// @pyjt(init_acl_ops) +void init_acl_ops(); +}''', compiler.cc_flags + " " + " ".join(cc_files) + cc_flags) + jittor_utils.process_jittor_source("acl", mod.process) + + has_acl = 1 + os.environ["use_mkl"] = "0" + compiler.setup_fake_cuda_lib = True + + +def install_extern(): + return False + + +def check(): + import jittor.compiler as compiler + global has_acl, cc_flags + if tikcc_path: + try: + install() + except Exception as e: + jittor_utils.LOG.w(f"load ACL failed, exception: {e}") + has_acl = 0 + compiler.has_acl = has_acl + compiler.tikcc_path = tikcc_path + if not has_acl: return False + compiler.cc_flags += cc_flags + compiler.nvcc_path = tikcc_path + compiler.nvcc_flags = compiler.cc_flags.replace("-std=c++14", "") + return True + + +def post_process(): + if has_acl: + from jittor import pool + pool.pool_use_code_op = False + import jittor as jt + jt.flags.use_cuda_host_allocator = 1 + jt.flags.use_parallel_op_compiler = 0 + jt.flags.amp_reg |= 32 + 4 # 32 keep float16, 4 keep reduce type + mod.init_acl_ops() + + +def acl_cmd(name: str, inputs: list, output_dtypes: list, output_shapes: list, + attr: dict): + nchw_op = ['MaxPoolWithArgmaxV1', 'MaxPoolGradWithArgmaxV1', 'AvgPoolV2'] + attr_op = [ + 'MaxPoolWithArgmaxV1', 'MaxPoolGradWithArgmaxV1', 'AvgPoolV2', + 'AdaptiveAvgPool2d', 'AdaptiveAvgPool2dGrad', 'ReverseV2' + ] + + input_code = '' + for i in range(len(inputs)): + if name in nchw_op: + input_code += f"op.add(in{i}, true, ACL_FORMAT_NCHW);\n" + else: + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(output_dtypes)): + if name in nchw_op: + output_code += f"op.add(out{i}, false, ACL_FORMAT_NCHW);\n" + else: + output_code += f"op.add(out{i}, false);\n" + + # add attr to op + attr_code = '' + if name in attr_op: + for k, v in attr.items(): + if isinstance(v, bool): + if v == True: + attr_code += f"op.set_attr(\"{k}\", 1, 1);\n" + else: + attr_code += f"op.set_attr(\"{k}\", 1, 0);\n" + elif isinstance(v, str): + attr_code += f"op.set_attr(\"{k}\", \"{v}\");\n" + elif k == 'divisor_override_value': + attr_code += f"op.set_attr(\"{k}\", int64_t({v}), 0);\n" + else: + v = str(v).replace('[', '{').replace(']', '}') + attr_code += f"op.set_attr(\"{k}\", vector{v});\n" + else: + for k, v in attr.items(): + if isinstance(v, bool): + if v == True: + attr_code += f"op.set_attr(\"{k}\", 1, 1);\n" + else: + attr_code += f"op.set_attr(\"{k}\", 1, 0);\n" + elif isinstance(v, str): + attr_code += f"op.set_attr(\"{k}\", \"{v}\");\n" + else: + attr_code += f"op.set_attr(\"{k}\", int({v}));\n" + + #print("input_code",input_code) + #print("attr_code",attr_code) + # read the tmp_file.cpp to the cuda_header + with open( + "/home/ma-user/work/zy/jittor/python/jittor/extern/acl/tmp_file.cpp", + "r") as f: + cuda_header = f.read() + import jittor as jt + return jt.code(output_shapes, + output_dtypes, + inputs, + cuda_header=cuda_header, + cuda_src=f""" + // aclop + AclOpRunner op("{name}"); + {input_code} + {output_code} + {attr_code} + op.run();""") + + +def change_function(): + import jittor as jt + from jittor import Function + + class IndexACL(Function): + + def __init__(self): + super(IndexACL, self).__init__() + + def execute(self, inshape: list, dim, dtype="int32"): + # zeros a tensor, shape is inshape, dtype is dtype + dim_input = dim + if dim == None: + dim = [i for i in range(len(inshape))] + elif type(dim) == int: + dim = [dim] + results = [] + for d in dim: + max_len = inshape[d] + tmp = jt.zeros(max_len, dtype=dtype) + result = acl_cmd( + "Range", [jt.Var(0), jt.Var(max_len), + jt.Var(1)], + output_dtypes=[tmp.dtype], + output_shapes=[tmp.shape], + attr={})[0] + broadcast_dim = [] + for i in range(len(inshape)): + if i != d: + broadcast_dim.append(i) + result = jt.broadcast(result, + shape=inshape, + dims=broadcast_dim) + results.append(result) + if len(results) != 1 or dim_input == None: + return tuple(results) + else: + return results[0] + + def grad(self, grad_output): + return grad_output + + class PoolACL(Function): + + def get_paddings(self): + pad_top = self.padding[0] + pad_left = self.padding[1] + H = self.input.shape[-2] + W = self.input.shape[-1] + + totalH = H + 2 * self.padding[0] - self.kernel_size[0] + totalW = W + 2 * self.padding[1] - self.kernel_size[1] + + kH = (totalH + self.stride[0] - + 1) // self.stride[0] + 1 if self.attr[ + 'ceil_mode'] else totalH // self.stride[0] + 1 + kW = (totalW + self.stride[1] - + 1) // self.stride[1] + 1 if self.attr[ + 'ceil_mode'] else totalW // self.stride[1] + 1 + + if self.attr['ceil_mode']: + if (kH - 1) * self.stride[0] >= H + self.padding[0]: + kH -= 1 + need_pad_h = (kH - + 1) * self.stride[0] + self.kernel_size[0] - H + pad_top = need_pad_h - self.padding[0] + if (kW - 1) * self.stride[1] >= W + self.padding[1]: + kW -= 1 + need_pad_w = (kW - + 1) * self.stride[1] + self.kernel_size[1] - W + pad_left = need_pad_w - self.padding[1] + + pads = [self.padding[0], pad_top, self.padding[1], pad_left] + return pads + + def __init__(self, + kernel_size, + stride=None, + padding=0, + dilation=None, + return_indices=None, + ceil_mode=False, + count_include_pad=True, + op='maximum'): + super(PoolACL, self).__init__() + # set attr + self.kernel_size = kernel_size if isinstance( + kernel_size, tuple) else (kernel_size, kernel_size) + stride = stride if stride else kernel_size + self.stride = stride if isinstance(stride, tuple) else (stride, + stride) + self.padding = padding if isinstance(padding, tuple) else (padding, + padding) + dilation = dilation if dilation else 1 + self.dilation = dilation if isinstance( + dilation, tuple) else (dilation, dilation) + attr = {} + + self.return_indices = return_indices + self.uint16 = jt.Var(1).int32().dtype + self.op = op + + if op == 'mean': + attr['exclusive'] = not count_include_pad + attr['global_pooling'] = False + attr['divisor_override_value'] = 0 + attr['ksize'] = [ + 1, 1, self.kernel_size[0], self.kernel_size[1] + ] + attr['strides'] = [1, 1, self.stride[0], self.stride[1]] + attr['ceil_mode'] = ceil_mode + attr['padding_mode'] = 'CALCULATED' + attr['data_format'] = 'NCHW' + elif op == 'maximum': + attr['ksize'] = [ + 1, self.kernel_size[0], self.kernel_size[1], 1 + ] + attr['strides'] = [1, self.stride[0], self.stride[1], 1] + attr['pads'] = [1, self.padding[0], self.padding[1], 1] + attr['dilation'] = [1, self.dilation[0], self.dilation[1], 1] + # attr['ceil_mode'] = ceil_mode + + self.attr = attr + + def execute(self, input): + + # create input + input_shape = input.shape + input_dtype = input.dtype + + self.input = input + # create output + output_shape = [ + input_shape[0], input_shape[1], + (input_shape[2] + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) // self.stride[0] + 1, + (input_shape[3] + 2 * self.padding[1] - self.dilation[1] * + (self.kernel_size[1] - 1) - 1) // self.stride[1] + 1 + ] + output_dtype = input_dtype + + if self.op == 'mean': + self.attr['pads'] = self.get_paddings() + result = acl_cmd("AvgPoolV2", [input], + output_dtypes=[output_dtype], + output_shapes=[output_shape], + attr=self.attr) + elif self.op == 'maximum': + result = acl_cmd("MaxPoolWithArgmaxV1", [input], + output_dtypes=[output_dtype, self.uint16], + output_shapes=[output_shape, output_shape], + attr=self.attr) + else: + raise ValueError('no this type pool') + + if self.op == 'maximum': + self.index = result[1] + + if self.return_indices: + return result[0], result[1] + else: + return result[0] + + def grad(self, grad_output): + if self.op == 'maximum': + grad_input = acl_cmd("MaxPoolGradWithArgmaxV1", + [self.input, grad_output, self.index], + output_dtypes=[grad_output.dtype], + output_shapes=[self.input.shape], + attr=self.attr)[0] + elif self.op == 'mean': + grad_input = acl_cmd("AvgPoolV2", + [self.input, grad_output, self.index], + output_dtypes=[grad_output.dtype], + output_shapes=[self.input.shape], + attr=self.attr)[0] + else: + grad_input = None + return grad_input + + class BmmACL(Function): + + def __init__(self, adj_x1=False, adj_x2=False): + super(BmmACL, self).__init__() + self.adj_x1 = adj_x1 + self.adj_x2 = adj_x2 + + def execute(self, x1, x2): + self.input = [x1, x2] + result = acl_cmd("BatchMatMul", [x1, x2], + output_dtypes=[x1.dtype], + output_shapes=[x1.shape[:-1] + x2.shape[-1:]], + attr={})[0] + return result + + def grad(self, grad_output): + x1, x2 = self.input + grad_x1 = acl_cmd( + "BatchMatMul", [grad_output, x2.transpose(-2, -1)], + output_dtypes=[x1.dtype], + output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]], + attr={})[0] + grad_x2 = acl_cmd( + "BatchMatMul", [x1.transpose(-2, -1), grad_output], + output_dtypes=[x2.dtype], + output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]], + attr={})[0] + return grad_x1, grad_x2 + + class MatmulACL(Function): + + def __init__(self, adj_x1=False, adj_x2=False): + super(MatmulACL, self).__init__() + self.adj_x1 = adj_x1 + self.adj_x2 = adj_x2 + + def execute(self, x1, x2): + self.input = [x1, x2] + if len(x1.shape) > 2 or len(x2.shape) > 2: + result = acl_cmd("BatchMatMul", [x1, x2], + output_dtypes=[x1.dtype], + output_shapes=[x1.shape[:-1] + x2.shape[-1:]], + attr={})[0] + else: + result = acl_cmd("MatMul", [x1, x2], + output_dtypes=[x1.dtype], + output_shapes=[x1.shape[:-1] + x2.shape[-1:]], + attr={})[0] + return result + + def grad(self, grad_output): + x1, x2 = self.input + if len(x1.shape) > 2 or len(x2.shape) > 2: + grad_x1 = acl_cmd( + "BatchMatMul", + [grad_output, x2.transpose(-2, -1)], + output_dtypes=[x1.dtype], + output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]], + attr={})[0] + grad_x2 = acl_cmd( + "BatchMatMul", [x1.transpose(-2, -1), grad_output], + output_dtypes=[x2.dtype], + output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]], + attr={})[0] + else: + grad_x1 = acl_cmd( + "MatMul", [grad_output, x2.transpose(-2, -1)], + output_dtypes=[x1.dtype], + output_shapes=[grad_output.shape[:-1] + x1.shape[-1:]], + attr={})[0] + grad_x2 = acl_cmd( + "MatMul", [x1.transpose(-2, -1), grad_output], + output_dtypes=[x2.dtype], + output_shapes=[x2.shape[:-1] + grad_output.shape[-1:]], + attr={})[0] + return grad_x1, grad_x2 + + class GetItem(Function): + + def __init__(self): + super(GetItem, self).__init__() + self.type_ = 'index' + + def stride(self, x, dim): + stride = 1 + for i in range(dim + 1, len(x.shape)): + stride *= x.shape[i] + return stride + + def execute(self, x, slices, return_x=None): + if isinstance(slices, jt.Var) or isinstance(slices, tuple): + if isinstance(slices, jt.Var): + slices = (slices, ) + if isinstance(slices[0], jt.Var): + slices_len = len(slices) + masks = jt.ones(slices_len, dtype=jt.int64) + output = slices[0].shape + output += x.shape[slices_len:] + input_ = [x, masks, jt.Var(list(output)).int64()] + for i in range(slices_len): + input_.append(slices[i].int32()) + result = acl_cmd("Index", + input_, + output_dtypes=[x.dtype], + output_shapes=[output], + attr={})[0] + self.shape = x.shape + self.sizes = list(output) + self.type_ = 'index' + self.slices = slices + # self.strides + return result + + # use AsStrided operator to implement the getitem function + # get the shape and stride of the input tensor + x_dim = len(x.shape) + # int type + if not isinstance(slices, tuple): + slices = (slices, ) + + if len(slices) < x_dim: + slices += (slice(None, None, None), ) * (x_dim - len(slices)) + + self.inputs = [x, slices] + + sizes = [] + strides = [] + offset = 0 + + for dim, s in enumerate(slices): + if isinstance(s, int): + if s < 0: # Handle negative indices. + s += x.shape[dim] + offset += s * self.stride(x, dim) + elif isinstance(s, slice): + # Unpack the slice + start, stop, step = s.indices(x.size(dim)) + size = (stop - start - 1) // step + 1 + stride = self.stride(x, dim) * step + offset += start * self.stride(x, dim) + sizes.append(size) + strides.append(stride) + else: + raise ValueError("Invalid slice type") + + if not sizes: + sizes = [1] + strides = [0] + # AsStrided same with as_strided of pytorch + self.sizes = sizes + self.strides = strides + self.offset = offset + self.shape = x.shape + self.type_ = 'as_strided' + result = acl_cmd( + "AsStrided", + [x, jt.Var(sizes), + jt.Var(strides), + jt.Var(offset)], + output_dtypes=[x.dtype], + output_shapes=[jt.empty(sizes).shape], + attr={})[0] + return result + + def grad(self, grad_output): + if self.type_ == 'as_strided': + result = jt.zeros(self.shape, dtype=grad_output.dtype) + sizes = list(grad_output.shape) + strides = [ + self.stride(grad_output, dim) + for dim in range(len(grad_output.shape)) + ] + result = acl_cmd("ViewCopy", [ + result, + jt.Var(self.sizes), + jt.Var(self.strides), + jt.Var(self.offset), grad_output, + jt.Var(sizes), + jt.Var(strides), + jt.Var(0) + ], + output_dtypes=[result.dtype], + output_shapes=[result.shape], + attr={})[0] + elif self.type_ == 'index': + #TODO: use IndexPutV2 to implement the grad function + assert len(self.slices) == 1 + index = self.slices[0] + input = jt.zeros(self.shape, dtype=grad_output.dtype) + input_flatten = input.reshape(input.shape[0], -1) + index_flatten = index.reshape(-1).unsqueeze(-1).repeat( + 1, input_flatten.shape[1]) + grad_output_flatten = grad_output.reshape(index.numel(), -1) + result = acl_cmd( + "ScatterElements", + [input_flatten, index_flatten, grad_output_flatten], + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr={ + 'axis': 0, + 'reduction': 'add' + })[0] + result = result.reshape(self.shape) + # result = jt.zeros(self.shape, dtype=grad_output.dtype) + # # masks = jt.ones(len(self.slices), dtype=jt.int64) + # masks = jt.array([1,1], dtype=jt.int64) + # expand_masks = jt.array([1,1], dtype=jt.int64) + # inputs_ = [result,grad_output,masks,expand_masks] + # slices_len = len(self.slices) + # for i in range(slices_len): + # inputs_.append(self.slices[i].int64()) + # # breakpoint() + # jt.sync_all(True) + # print(inputs_) + # result_ = acl_cmd("IndexPutV2", inputs_, + # output_dtypes=[result.dtype], + # output_shapes=[result.shape], + # attr={"accumulate":True})[0] + # result = result_ + else: + raise ValueError("Invalid slice type") + result.sync() + return result, None + + class ConcatACL(Function): + + def __init__(self): + super(ConcatACL, self).__init__() + + def execute(self, input_tensors, dim=0): + self.input = input_tensors + for i in range(len(input_tensors)): + if input_tensors[i].dtype != input_tensors[0].dtype: + raise ValueError( + "All input tensors must have the same dtype") + if input_tensors[i].shape[:dim] != input_tensors[ + 0].shape[:dim] or input_tensors[i].shape[ + dim + 1:] != input_tensors[0].shape[dim + 1:]: + raise ValueError( + "All input tensors must have the same shape") + result = acl_cmd( + "ConcatD", + input_tensors, + output_dtypes=[input_tensors[0].dtype], + output_shapes=[ + jt.empty(self.calculate_output_shape(input_tensors, + dim)).shape + ], + attr={ + "N": len(input_tensors), + "concat_dim": dim + })[0] + return result + + def grad(self, grad_output): + grad_inputs = self.split_grad(grad_output, self.input, self.axis) + return grad_inputs + + def calculate_output_shape(self, input_tensors, axis): + shape = list(input_tensors[0].shape) + for tensor in input_tensors[1:]: + shape[axis] += tensor.shape[axis] + return tuple(shape) + + def split_grad(self, grad_output, input_tensors, axis): + offset = 0 + grad_inputs = [] + for tensor in input_tensors: + grad_input = acl_cmd("Slice", [ + grad_output, [0] * axis + [offset] + [0] * + (len(tensor.shape) - axis - 1), tensor.shape + ]) + grad_inputs.append(grad_input) + offset += tensor.shape[axis] + return grad_inputs + + class SetItemACL(Function): + + def __init__(self): + super(SetItemACL, self).__init__() + + def stride(self, x, dim): + # 计算给定维度的步长 + stride = 1 + for i in range(dim + 1, len(x.shape)): + stride *= x.shape[i] + return stride + + def execute(self, x, slices, value, reduce='void'): + self.is_tensor = type(value) == jt.Var + if type(value) != jt.Var: + value = jt.array(value) + x_dim = len(x.shape) + + # 确保slices是一个元组 + if not isinstance(slices, tuple): + slices = (slices, ) + + # 补齐slices使其长度等于x的维度 + if len(slices) < x_dim: + slices += (slice(None, None, None), ) * (x_dim - len(slices)) + + self.inputs = [x, slices, value] + + target_sizes = [] + target_strides = [] + offset = 0 + + for dim, s in enumerate(slices): + if isinstance(s, int): + if s < 0: + s += x.shape[dim] + s = slice(s, s + 1, None) + if isinstance(s, slice): + # 解包切片 + start, stop, step = s.indices(x.shape[dim]) + size = (stop - start - 1) // step + 1 + stride = self.stride(x, dim) * step + offset += start * self.stride(x, dim) + target_sizes.append(size) + target_strides.append(stride) + else: + print("slices: ", s, type(s)) + raise ValueError("Invalid slice type") + + # 计算value的size、stride和offset + value_sizes = list(value.shape) + value_strides = [ + self.stride(value, dim) for dim in range(len(value.shape)) + ] + + self.target_sizes = target_sizes + self.target_strides = target_strides + self.offset = offset + self.value_sizes = value_sizes + self.value_strides = value_strides + + #import pdb; pdb.set_trace() + result = acl_cmd("ViewCopy", [ + x, + jt.Var(target_sizes), + jt.Var(target_strides), + jt.Var(offset), value, + jt.Var(value_sizes), + jt.Var(value_strides), + jt.Var(0) + ], + output_dtypes=[x.dtype], + output_shapes=[x.shape], + attr={})[0] + result.sync() + return result + + def grad(self, grad_output): + result = acl_cmd("AsStrided", [ + grad_output, + jt.Var(self.target_sizes), + jt.Var(self.target_strides), + jt.Var(self.offset) + ], + output_dtypes=[grad_output.dtype], + output_shapes=[jt.empty(self.target_sizes).shape], + attr={})[0] + # copy grad_output to new_grad_output + new_grad_output = acl_cmd("Copy", [grad_output], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={"N": 1})[0] + new_grad_output = acl_cmd("ViewCopy", [ + new_grad_output, + jt.Var(self.target_sizes), + jt.Var(self.target_strides), + jt.Var(self.offset), + jt.zeros(self.value_sizes, dtype=grad_output.dtype), + jt.Var(self.value_sizes), + jt.Var(self.value_strides), + jt.Var(0) + ], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={})[0] + new_grad_output.sync() + return new_grad_output, None, result if self.is_tensor else None + + class TriuACL(Function): + + def __init__(self): + super(TriuACL, self).__init__() + + def execute(self, input, k): + self.input = input + result = acl_cmd("Triu", [input], + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr={'diagonal': k})[0] + return result + + def grad(self, grad_output): + return grad_output + + class TransposeACL(Function): + + def __init__(self): + super(TransposeACL, self).__init__() + + def execute(self, input, perm): + self.input = input + + output_shape = input.shape[perm[0]:perm[0] + 1] + + for i in range(1, len(perm)): + output_shape += input.shape[perm[i]:perm[i] + 1] + result = acl_cmd("Transpose", [input, jt.Var(perm)], + output_dtypes=[input.dtype], + output_shapes=[output_shape], + attr={})[0] + return result + + def grad(self, grad_output): + return grad_output + + class AdaptiveMaxPool2dACL(Function): + + def __init__( + self, + output_size, + return_indices=False, + ): + super(AdaptiveMaxPool2dACL, self).__init__() + self.output_size = (output_size, output_size) if isinstance( + output_size, int) else output_size + + self.return_indices = return_indices + self.uint16 = jt.Var(1).int32().dtype + + attr = {} + attr['ceil_mode'] = False + attr['dilations'] = [1, 1, 1, 1] + self.attr = attr + + def execute(self, input): + input_shape = input.shape + input_dtype = input.dtype + + output_shape = [ + input_shape[0], input_shape[1], self.output_size[0], + self.output_size[1] + ] + output_dtype = input_dtype + self.input = input + + stride_h = input_shape[2] // output_shape[2] + stride_w = input_shape[3] // output_shape[3] + kernel_size_h = input_shape[2] - (output_shape[2] - 1) * stride_h + kernel_size_w = input_shape[3] - (output_shape[3] - 1) * stride_w + + stride = [0, 0] + kernel_size = [0, 0] + padding = [0, 0] + + stride[0] = stride_h + stride[1] = stride_w + kernel_size[0] = kernel_size_h + kernel_size[1] = kernel_size_w + padding[0] = padding[1] = 0 + kernel_sizes = [1, kernel_size[0], kernel_size[1], 1] + strides_size = [1, stride[0], stride[1], 1] + paddings = [1, padding[0], padding[1], 1] + + self.attr['ksize'] = kernel_sizes + self.attr['strides'] = strides_size + self.attr['pads'] = paddings + + result = acl_cmd("MaxPoolWithArgmaxV1", [input], + output_dtypes=[output_dtype, self.uint16], + output_shapes=[output_shape, output_shape], + attr=self.attr) + + self.index = result[1] + + if self.return_indices: + return result[0], result[1] + else: + return result[0] + + def grad(self, grad_output): + grad_input = acl_cmd("MaxPoolGradWithArgmaxV1", + [self.input, grad_output, self.index], + output_dtypes=[grad_output.dtype], + output_shapes=[self.input.shape], + attr=self.attr)[0] + return grad_input + + class AdaptiveAvgPool2dACL(Function): + + def __init__(self, output_size): + super(AdaptiveAvgPool2dACL, self).__init__() + self.output_size = (output_size, output_size) if isinstance( + output_size, int) else output_size + + attr = {} + if isinstance(output_size, tuple): + output_size = [output_size[0], output_size[1]] + attr['output_size'] = output_size + self.attr = attr + + def execute(self, input): + input_shape = input.shape + input_dtype = input.dtype + + self.original_shape = input_shape + + output_shape = [ + input_shape[0], input_shape[1], self.attr['output_size'][0], + self.attr['output_size'][1] + ] + output_dtype = input_dtype + self.input = input + + result = acl_cmd("AdaptiveAvgPool2d", [input], + output_dtypes=[output_dtype], + output_shapes=[output_shape], + attr=self.attr) + + return result[0] + + def grad(self, grad_output): + attr = {} + attr['orig_input_shape'] = list(self.original_shape) + grad_input = acl_cmd("AdaptiveAvgPool2dGrad", [grad_output], + output_dtypes=[grad_output.dtype], + output_shapes=[self.original_shape], + attr=attr)[0] + return grad_input + + class CumsumACL(Function): + + def __init__(self): + super(CumsumACL, self).__init__() + + def execute(self, input, dim=-1): + self.input = input + self.dim = dim + result = acl_cmd("Cumsum", [input, jt.Var(dim)], + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr={})[0] + return result + + def grad(self, grad_output): + flipped_grad_output = acl_cmd( + "ReverseV2", [grad_output, jt.Var([self.dim])], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={})[0] + cumulative_grad = acl_cmd( + "Cumsum", + [flipped_grad_output, jt.Var(self.dim)], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={})[0] + grad_input = acl_cmd( + "ReverseV2", + [cumulative_grad, jt.Var([self.dim])], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={})[0] + return grad_input + + class GatherACL(Function): + + def __init__(self): + super(GatherACL, self).__init__() + + def execute(self, input, dim, index): + self.input = input + self.dim = dim + self.index = index + + result = acl_cmd("GatherElements", [input, index], + output_dtypes=[input.dtype], + output_shapes=[index.shape], + attr={'dim': dim})[0] + return result + + def grad(self, grad_output): + tmp = jt.zeros(self.index.shape, dtype=grad_output.dtype) + grad_input = acl_cmd("ScatterElements", + [tmp, self.index, grad_output], + output_dtypes=[grad_output.dtype], + output_shapes=[tmp.shape], + attr={ + 'axis': self.dim, + 'reduction': "add" + })[0] + return grad_input + + class ScatterACL(Function): + + def __init__(self): + super(ScatterACL, self).__init__() + + def execute(self, input, dim, index, src, reduce='void'): + self.input = input + self.dim = dim + self.index = index + self.reduce = reduce + result = acl_cmd("ScatterElements", [input, self.index, src], + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr={ + 'axis': self.dim, + 'reduction': reduce + })[0] + return result + + def grad(self, grad_output): + grad_input = acl_cmd("GatherElements", [grad_output, self.index], + output_dtypes=[grad_output.dtype], + output_shapes=[self.index.shape], + attr={'dim': self.dim})[0] + return grad_output, None, None, grad_input + + class WhereACL(Function): + + def __init__(self): + super(WhereACL, self).__init__() + + def execute(self, condition, x, y): + self.condition = condition + + if x.dtype != y.dtype: + if x.dtype == jt.float32: + y = y.float32() + elif y.dtype == jt.float32: + x = x.float32() + else: + x = x.to(y.dtype) + + self.x = x + self.y = y + + result = acl_cmd("Select", [condition, x, y], + output_dtypes=[x.dtype], + output_shapes=[x.shape], + attr={})[0] + return result + + def grad(self, grad_output): + tmp = jt.zeros(grad_output.shape, dtype=grad_output.dtype) + grad_x = acl_cmd("Select", [self.condition, grad_output, tmp], + output_dtypes=[self.x.dtype], + output_shapes=[self.x.shape], + attr={})[0] + + grad_y = acl_cmd("Select", [self.condition, tmp, grad_output], + output_dtypes=[self.y.dtype], + output_shapes=[self.y.shape], + attr={})[0] + return grad_output, grad_x, grad_y + + class FlipACL(Function): + + def __init__(self): + super(FlipACL, self).__init__() + + def execute(self, input, dim): + self.input = input + #if isinstance(dim_vector, tuple): + dim_vector = jt.Var(list(dim)) + #print(dim_vector.dtype) + self.dim_vector = dim_vector + #print(input, dim_vector) + result = acl_cmd("ReverseV2", [input, dim_vector], + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr={})[0] + return result + + def grad(self, grad_output): + #print(grad_output) + grad_input = acl_cmd("ReverseV2", [grad_output, self.dim_vector], + output_dtypes=[grad_output.dtype], + output_shapes=[grad_output.shape], + attr={})[0] + return grad_input + + class FloorIntACL(Function): + + def __init__(self): + super(FloorIntACL, self).__init__() + + def execute(self, input): + self.input = input + self.shape = input.shape + result = acl_cmd("Floor", [input], + output_dtypes=[jt.int], + output_shapes=[input.shape], + attr={})[0] + return result + + def grad(self, grad_output): + return jt.zeros(self.shape, dtype=grad_output.dtype) + + def warp(origin_func, new_func): + + def warpper(*args, **kwargs): + if origin_func == jt.index: + if len(args) == 2 and args[1] == None: + args = tuple(list(args[0:1])) + if jt.flags.use_acl: + if isinstance(new_func, IndexACL): + if len(args) == 1: + args = (args[0], None) + if isinstance(new_func, CumsumACL): + args = (args[0], kwargs.get('dim', -1)) + kwargs = {} + if isinstance(new_func, + ScatterACL) and kwargs.get('reduce') is not None: + args = (args[0], args[1], args[2], args[3], + kwargs.get('reduce', 'void')) + kwargs = {} + + return new_func(*args, **kwargs) + return origin_func(*args, **kwargs) + + return warpper + + # jt.index = warp(jt.index, IndexACL()) + # jt.Var.index = lambda x, dim=None: warp(jt.index, IndexACL())(x.shape, dim) + # jt.nn.Pool = warp(jt.nn.Pool, PoolACL) + # jt.nn.AdaptiveMaxPool2d = warp(jt.nn.AdaptiveMaxPool2d, + # AdaptiveMaxPool2dACL) + # jt.nn.AdaptiveAvgPool2d = warp(jt.nn.AdaptiveAvgPool2d, + # AdaptiveAvgPool2dACL) + + jt.triu = warp(jt.triu, TriuACL()) + jt.triu_ = warp(jt.triu, TriuACL()) + jt.Var.triu = lambda x: warp(jt.Var.triu, TriuACL())(x) + jt.Var.triu_ = lambda x: warp(jt.Var.triu_, TriuACL())(x) + + # jt.getitem = warp(jt.getitem, GetItem()) + # jt.Var.getitem = lambda x, slices, return_x=None: warp( + # jt.getitem, GetItem())(x, slices) + + # jt.setitem = warp(jt.setitem, SetItemACL()) + # jt.Var.setitem = lambda x, slices, value, reduce='void': warp( + # jt.setitem, SetItemACL())(x, slices, value, reduce) + + # jt.misc.flip = warp(jt.misc.flip, FlipACL()) + # jt.Var.flip = lambda x, dim_vector: warp(jt.misc.flip, FlipACL())( + # x, dim_vector) + # jt.cumsum = warp(jt.cumsum, CumsumACL()) + # jt.gather = warp(jt.gather, GatherACL()) + # jt.Var.gather = lambda x, dim, index: warp(jt.gather, GatherACL())(x, dim, + # index) + # jt.scatter = warp(jt.scatter, ScatterACL()) + # jt.Var.scatter = lambda x, dim, index, src, reduce="void": warp( + # jt.scatter, ScatterACL())(x, dim, index, src, reduce) + # jt.where = warp(jt.where, WhereACL()) + # jt.floor_int = warp(jt.floor_int, FloorIntACL()) + # jt.Var.floor_int = lambda x: warp(jt.floor_int, FloorIntACL())(x) + + # jt.nn.bmm = warp(jt.nn.bmm, BmmACL()) + # jt.bmm = warp(jt.bmm, BmmACL()) + # jt.nn.matmul = warp(jt.matmul, MatmulACL()) + # jt.matmul = warp(jt.matmul, MatmulACL()) + # jt.transpose = warp(jt.transpose, TransposeACL()) + # jt.Var.transpose = lambda x, perm: warp(jt.transpose, TransposeACL())(x, perm) + # jt.concat = warp(jt.concat, ConcatACL()) diff --git a/python/jittor/extern/acl/acl_compiler.py b/python/jittor/extern/acl/acl_compiler.py new file mode 100644 index 00000000..5aa2371f --- /dev/null +++ b/python/jittor/extern/acl/acl_compiler.py @@ -0,0 +1,416 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler +import jittor as jt +import math + +from collections.abc import Sequence, Iterable + + +def _ntuple(n): + + def parse(x): + if isinstance(x, Iterable): + return x + return tuple([x] * n) + + return parse + + +_pair = _ntuple(2) + +has_acl = 0 +cc_flags = "" +tikcc_path = env_or_try_find('tikcc_path', 'ccec') +dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL +compiler.has_acl = has_acl + +# export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/tools/aoe/lib64:/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64:/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64/plugin/opskernel:/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64/plugin/nnengine:/usr/local/Ascend/ascend-toolkit/latest/runtime/lib64:/usr/local/Ascend/ascend-toolkit/latest/compiler/lib64/stub:/usr/local/Ascend/ascend-toolkit/latest/tools/tikicpulib/lib/Ascend910A:/usr/local/Ascend/ascend-toolkit/latest/toolkit/tools/simulator/Ascend910A/lib:/opt/AXESMI/lib64:/usr/local/Ascend/driver/lib64/driver/ +# export PYTHONPATH=/home/cjld/new_jittor/jittor/python +# export tikcc_path=g++ + +# conda activate cann +# source /usr/local/Ascend/ascend-toolkit/set_env.sh +# export PYTHONPATH=/home/cjld/new_jittor/jittor/python:/home/cjld/new_jittor/jittor/my/jtorch/python:$PYTHONPATH +# export TASK_QUEUE_ENABLE=0 +# python3 -m jittor.test.test_acl -k array +# jittor: conda activate cann && source /usr/local/Ascend/ascend-toolkit/set_env.sh && PYTHONPATH=/home/cjld/new_jittor/jittor/python:/home/cjld/new_jittor/jittor/my/jtorch/python:$PYTHONPATH && cd /home/cjld/new_jittor/jittor/my/mm_benchmark +# python3 -m jittor.test.test_acl -k test_sum +# export ASCEND_SLOG_PRINT_TO_STDOUT=0 +# ASCEND_GLOBAL_LOG_LEVEL +# export DUMP_GE_GRAPH=1 +# export DUMP_GRAPH_LEVEL=1 + +# build pytorch-npu +# bash ./ci/build.sh +# python3 -m pip install ./dist/torch_npu-1.11.0.post1-cp37-cp37m-linux_x86_64.whl --force-reinstall +# pytorch: conda activate cann && source /usr/local/Ascend/ascend-toolkit/set_env.sh && export TASK_QUEUE_ENABLE=0 && cd /home/cjld/new_jittor/jittor/my/mm_benchmark +# python3 ./mm_bench_pt_npu.py + + +def install(): + import jittor.compiler as compiler + global has_acl, cc_flags + acl_compiler_home = os.path.dirname(__file__) + cc_files = sorted(glob.glob(acl_compiler_home + "/**/*.cc", + recursive=True)) + cc_files2 = [] + for name in cc_files: + if "acl_op_exec" in name: + compiler.extra_core_files.append(name) + else: + cc_files2.append(name) + cc_files = cc_files2 + ascend_toolkit_home = os.getenv('ASCEND_TOOLKIT_HOME') + + #print(ascend_toolkit_home) + #print(acl_compiler_home) + cc_flags += f" -DHAS_CUDA -DIS_ACL \ + -I{ascend_toolkit_home}/include/ \ + -I{ascend_toolkit_home}/include/acl/ \ + -I{ascend_toolkit_home}/include/aclnn/ \ + -I{ascend_toolkit_home}/include/aclnnop/ \ + -I{acl_compiler_home} -lascendcl -lacl_op_compiler \ + -I{acl_compiler_home}/aclnn \ + -L{ascend_toolkit_home}/lib64/" + + cc_flags += " -llibascendcl " + cc_flags += " -llibnnopbase " + cc_flags += " -llibopapi " + + #pdb.set_trace() + ctypes.CDLL("libascendcl.so", dlopen_flags) + f''' + -ltikc_runtime + -I/usr/local/Ascend/driver/include/ \ + -L{ascend_toolkit_home}/compiler/lib64/ \ + -L{ascend_toolkit_home}/runtime/lib64/ \ + ''' + jittor_utils.LOG.i("ACL detected") + + global mod + mod = jittor_utils.compile_module( + ''' +#include "common.h" +namespace jittor { +// @pyjt(process) +string process_acl(const string& src, const string& name, const map& kargs); +// @pyjt(init_acl_ops) +void init_acl_ops(); +}''', compiler.cc_flags + " " + " ".join(cc_files) + cc_flags) + jittor_utils.process_jittor_source("acl", mod.process) + + has_acl = 1 + os.environ["use_mkl"] = "0" + compiler.setup_fake_cuda_lib = True + + +def install_extern(): + return False + + +def check(): + import jittor.compiler as compiler + global has_acl, cc_flags + if tikcc_path: + try: + install() + except Exception as e: + jittor_utils.LOG.w(f"load ACL failed, exception: {e}") + has_acl = 0 + compiler.has_acl = has_acl + compiler.tikcc_path = tikcc_path + if not has_acl: return False + compiler.cc_flags += cc_flags + compiler.nvcc_path = tikcc_path + compiler.nvcc_flags = compiler.cc_flags.replace("-std=c++14", "") + return True + + +def post_process(): + if has_acl: + from jittor import pool + pool.pool_use_code_op = False + import jittor as jt + jt.flags.use_cuda_host_allocator = 1 + jt.flags.use_parallel_op_compiler = 0 + jt.flags.amp_reg |= 32 + 4 # 32 keep float16, 4 keep reduce type + mod.init_acl_ops() + + +def acl_cmd(name: str, + inputs: list, + output_dtypes: list, + output_shapes: list, + attr_code: str = ""): + input_code = '' + for i in range(len(inputs)): + input_code += f"op.add(in{i}, true);\n" + + output_code = '' + for i in range(len(output_dtypes)): + output_code += f"op.add(out{i}, false);\n" + + # read the tmp_file.cpp to the cuda_header + with open( + "/home/ma-user/work/zy/jittor/python/jittor/extern/acl/tmp_file.cpp", + "r") as f: + cuda_header = f.read() + import jittor as jt + return jt.code(output_shapes, + output_dtypes, + inputs, + cuda_header=cuda_header, + cuda_src=f""" + // aclop + AclOpRunner op("{name}"); + {input_code} + {output_code} + {attr_code} + op.run();""") + + +def change_function(): + import jittor as jt + from jittor import Function + + class TriuACL(Function): + + def __init__(self): + super(TriuACL, self).__init__() + + def execute(self, input, k): + self.input = input + + attr_code = f""" + op.jt_name = "triu"; + TriuAttr *attr = new TriuAttr(); + attr->diagonal = {k}; + op.op_attr.reset(attr); + """ + + result = acl_cmd("Triu", [input], + output_dtypes=[input.dtype], + output_shapes=[input.shape], + attr_code=attr_code)[0] + return result + + def grad(self, grad_output): + return grad_output + + class ConvACL(Function): + + def execute(self, + x, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1): + self.input = x + self.weight = weight + self.bias = bias + padding = _pair(padding) + stride = _pair(stride) + dilation = _pair(dilation) + out_channels = weight.shape[0] + if groups <= 0: + raise ValueError("groups must be a positive integer") + self.padding = padding + self.stride = stride + self.dilation = dilation + self.groups = groups + attr_code = f""" + op.jt_name = "conv2d"; + ConvAttr *attr = new ConvAttr(); + attr->convStrides = {{ {stride[0]}, {stride[1]} }}; + attr->convPads = {{ {padding[0]}, {padding[1]} }}; + attr->convDilations = {{ {dilation[0]}, {dilation[1]} }}; + attr->group = {groups}; + attr->convOutPads = {{ 1,1}}; + op.op_attr.reset(attr); + """ + input_height, input_width = x.shape[-2:] + kernel_height, kernel_width = weight.shape[-2:] + + output_height = (input_height + 2 * padding[0] - dilation[0] * + (kernel_height - 1) - 1) // stride[0] + 1 + output_width = (input_width + 2 * padding[1] - dilation[1] * + (kernel_width - 1) - 1) // stride[1] + 1 + + output_shape = (x.shape[0], out_channels, output_height, + output_width) + + inputs = [x, weight] + if bias is not None: + inputs.append(bias) + result = acl_cmd( + "Conv2d", + inputs, + output_dtypes=[x.dtype], + output_shapes=[output_shape], + attr_code=attr_code, + )[0] + return result + + def grad(self, grad_output): + x = self.input + weight = self.weight + bias = self.bias + inputs = [grad_output, x, weight] + + if bias is not None: + inputs.append(bias) + output_shapes = [x.shape, weight.shape] + output_dtypes = [x.dtype, weight.dtype] + if bias is not None: + output_shapes.append(bias.shape) + output_dtypes.append(bias.dtype) + padding = self.padding + stride = self.stride + dilation = self.dilation + groups = self.groups + attr_code = f""" + op.jt_name = "conv2dbackward"; + ConvAttr *attr = new ConvAttr(); + attr->convStrides = {{ {stride[0]}, {stride[1]} }}; + attr->convPads = {{ {padding[0]}, {padding[1]} }}; + attr->convDilations = {{ {dilation[0]}, {dilation[1]} }}; + attr->group = {groups}; + attr->convOutPads = {{ 1,1}}; + op.op_attr.reset(attr); + """ + results = acl_cmd("Conv2dBackward", + inputs, + output_dtypes=output_dtypes, + output_shapes=output_shapes, + attr_code=attr_code) + + return results + + class Conv2D(jt.nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True): + if in_channels <= 0: + raise ValueError( + f"in_channels must be greater than zero, got {in_channels}" + ) + if out_channels <= 0: + raise ValueError( + f"out_channels must be greater than zero, got {out_channels}" + ) + if groups <= 0: + raise ValueError( + f"groups must must be greater than zero, got {groups}") + assert in_channels % groups == 0, 'in_channels must be divisible by groups' + assert out_channels % groups == 0, 'out_channels must be divisible by groups' + if isinstance(kernel_size, tuple): + for size in kernel_size: + if size <= 0: + raise ValueError( + f"kernel_size must be greater than zero, got {kernel_size}" + ) + else: + if kernel_size <= 0: + raise ValueError( + f"kernel_size must be greater than zero, got {kernel_size}" + ) + if isinstance(stride, tuple): + for size in stride: + if size <= 0: + raise ValueError( + f"stride must be greater than zero, got {stride}") + else: + if stride <= 0: + raise ValueError( + f"stride must be greater than zero, got {stride}") + if isinstance(padding, tuple): + for size in padding: + if size < 0: + raise ValueError( + f"padding must be nonnegative, got {padding}") + else: + if padding < 0: + raise ValueError( + f"padding must be nonnegative, got {padding}") + if isinstance(dilation, tuple): + for size in dilation: + if size <= 0: + raise ValueError( + f"dilation must be greater than zero, got {dilation}" + ) + else: + if dilation <= 0: + raise ValueError( + f"dilation must be greater than zero, got {dilation}") + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size if isinstance( + kernel_size, tuple) else (kernel_size, kernel_size) + self.stride = stride if isinstance(stride, tuple) else (stride, + stride) + self.padding = padding if isinstance(padding, tuple) else (padding, + padding) + self.dilation = dilation if isinstance( + dilation, tuple) else (dilation, dilation) + self.groups = groups + self.is_depthwise_conv = self.groups == self.out_channels and self.groups == self.in_channels + if self.is_depthwise_conv and jt.flags.use_cuda and jt.compiler.is_cuda: + self.depthwise_conv = jt.nn.DepthwiseConv( + stride, padding, dilation) + Kh, Kw = self.kernel_size + + # self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out") + self.weight = jt.init.invariant_uniform( + [out_channels, in_channels // groups, Kh, Kw], dtype="float") + if bias: + fan = 1 + for i in self.weight.shape[1:]: + fan *= i + bound = 1 / math.sqrt(fan) + self.bias = jt.init.uniform([out_channels], + dtype="float", + low=-bound, + high=bound) + else: + self.bias = None + + def execute(self, x): + ret = jt.nn.conv2d(x, self.weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) + return ret + + def warp(origin_func, new_func): + + def warpper(*args, **kwargs): + if jt.flags.use_acl: + return new_func(*args, **kwargs) + return origin_func(*args, **kwargs) + + return warpper + + jt.triu = warp(jt.triu, TriuACL()) + jt.triu_ = warp(jt.triu, TriuACL()) + jt.Var.triu = lambda x: warp(jt.Var.triu, TriuACL())(x) + jt.Var.triu_ = lambda x: warp(jt.Var.triu_, TriuACL())(x) + jt.nn.conv2d = warp(jt.nn.conv2d, ConvACL()) + jt.nn.Conv2d = warp(jt.nn.Conv2d, Conv2D) + jt.nn.Conv = warp(jt.nn.Conv, Conv2D) diff --git a/python/jittor/extern/acl/acl_error_code.cc b/python/jittor/extern/acl/acl_error_code.cc new file mode 100644 index 00000000..5fd45dbf --- /dev/null +++ b/python/jittor/extern/acl/acl_error_code.cc @@ -0,0 +1,232 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** + +#include "common.h" +using std::string; +using std::unordered_map; + +typedef int aclError; + +static inline unordered_map gen_map(string s) +{ + unordered_map smap; + for (int i = 0; i < s.size(); i++) + { + if (s[i] == ';') + { + int j = s.rfind(" ", i); + int code = std::stoi(s.substr(j + 1, i - j - 1)); + int k = s.rfind(" ", j - 1); + int l = s.rfind(" ACL_", k - 1); + smap[code] = s.substr(l + 1, k - l - 1); + } + } + return smap; +} + +string acl_error_to_string(aclError error) +{ + + static unordered_map acl_error_map = gen_map(R"( +// from acl_base.h +static const int ACL_ERROR_INVALID_PARAM = 100000; +static const int ACL_ERROR_UNINITIALIZE = 100001; +static const int ACL_ERROR_REPEAT_INITIALIZE = 100002; +static const int ACL_ERROR_INVALID_FILE = 100003; +static const int ACL_ERROR_WRITE_FILE = 100004; +static const int ACL_ERROR_INVALID_FILE_SIZE = 100005; +static const int ACL_ERROR_PARSE_FILE = 100006; +static const int ACL_ERROR_FILE_MISSING_ATTR = 100007; +static const int ACL_ERROR_FILE_ATTR_INVALID = 100008; +static const int ACL_ERROR_INVALID_DUMP_CONFIG = 100009; +static const int ACL_ERROR_INVALID_PROFILING_CONFIG = 100010; +static const int ACL_ERROR_INVALID_MODEL_ID = 100011; +static const int ACL_ERROR_DESERIALIZE_MODEL = 100012; +static const int ACL_ERROR_PARSE_MODEL = 100013; +static const int ACL_ERROR_READ_MODEL_FAILURE = 100014; +static const int ACL_ERROR_MODEL_SIZE_INVALID = 100015; +static const int ACL_ERROR_MODEL_MISSING_ATTR = 100016; +static const int ACL_ERROR_MODEL_INPUT_NOT_MATCH = 100017; +static const int ACL_ERROR_MODEL_OUTPUT_NOT_MATCH = 100018; +static const int ACL_ERROR_MODEL_NOT_DYNAMIC = 100019; +static const int ACL_ERROR_OP_TYPE_NOT_MATCH = 100020; +static const int ACL_ERROR_OP_INPUT_NOT_MATCH = 100021; +static const int ACL_ERROR_OP_OUTPUT_NOT_MATCH = 100022; +static const int ACL_ERROR_OP_ATTR_NOT_MATCH = 100023; +static const int ACL_ERROR_OP_NOT_FOUND = 100024; +static const int ACL_ERROR_OP_LOAD_FAILED = 100025; +static const int ACL_ERROR_UNSUPPORTED_DATA_TYPE = 100026; +static const int ACL_ERROR_FORMAT_NOT_MATCH = 100027; +static const int ACL_ERROR_BIN_SELECTOR_NOT_REGISTERED = 100028; +static const int ACL_ERROR_KERNEL_NOT_FOUND = 100029; +static const int ACL_ERROR_BIN_SELECTOR_ALREADY_REGISTERED = 100030; +static const int ACL_ERROR_KERNEL_ALREADY_REGISTERED = 100031; +static const int ACL_ERROR_INVALID_QUEUE_ID = 100032; +static const int ACL_ERROR_REPEAT_SUBSCRIBE = 100033; +static const int ACL_ERROR_STREAM_NOT_SUBSCRIBE = 100034; +static const int ACL_ERROR_THREAD_NOT_SUBSCRIBE = 100035; +static const int ACL_ERROR_WAIT_CALLBACK_TIMEOUT = 100036; +static const int ACL_ERROR_REPEAT_FINALIZE = 100037; +static const int ACL_ERROR_NOT_STATIC_AIPP = 100038; +static const int ACL_ERROR_COMPILING_STUB_MODE = 100039; +static const int ACL_ERROR_GROUP_NOT_SET = 100040; +static const int ACL_ERROR_GROUP_NOT_CREATE = 100041; +static const int ACL_ERROR_PROF_ALREADY_RUN = 100042; +static const int ACL_ERROR_PROF_NOT_RUN = 100043; +static const int ACL_ERROR_DUMP_ALREADY_RUN = 100044; +static const int ACL_ERROR_DUMP_NOT_RUN = 100045; +static const int ACL_ERROR_PROF_REPEAT_SUBSCRIBE = 148046; +static const int ACL_ERROR_PROF_API_CONFLICT = 148047; +static const int ACL_ERROR_INVALID_MAX_OPQUEUE_NUM_CONFIG = 148048; +static const int ACL_ERROR_INVALID_OPP_PATH = 148049; +static const int ACL_ERROR_OP_UNSUPPORTED_DYNAMIC = 148050; +static const int ACL_ERROR_RELATIVE_RESOURCE_NOT_CLEARED = 148051; + +static const int ACL_ERROR_BAD_ALLOC = 200000; +static const int ACL_ERROR_API_NOT_SUPPORT = 200001; +static const int ACL_ERROR_INVALID_DEVICE = 200002; +static const int ACL_ERROR_MEMORY_ADDRESS_UNALIGNED = 200003; +static const int ACL_ERROR_RESOURCE_NOT_MATCH = 200004; +static const int ACL_ERROR_INVALID_RESOURCE_HANDLE = 200005; +static const int ACL_ERROR_FEATURE_UNSUPPORTED = 200006; +static const int ACL_ERROR_PROF_MODULES_UNSUPPORTED = 200007; + +static const int ACL_ERROR_STORAGE_OVER_LIMIT = 300000; + +static const int ACL_ERROR_INTERNAL_ERROR = 500000; +static const int ACL_ERROR_FAILURE = 500001; +static const int ACL_ERROR_GE_FAILURE = 500002; +static const int ACL_ERROR_RT_FAILURE = 500003; +static const int ACL_ERROR_DRV_FAILURE = 500004; +static const int ACL_ERROR_PROFILING_FAILURE = 500005; + +// from ge_error_codes.h +static const uint32_t ACL_ERROR_GE_PARAM_INVALID = 145000U; +static const uint32_t ACL_ERROR_GE_EXEC_NOT_INIT = 145001U; +static const uint32_t ACL_ERROR_GE_EXEC_MODEL_PATH_INVALID = 145002U; +static const uint32_t ACL_ERROR_GE_EXEC_MODEL_ID_INVALID = 145003U; +static const uint32_t ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID = 145006U; +static const uint32_t ACL_ERROR_GE_EXEC_MODEL_ADDR_INVALID = 145007U; +static const uint32_t ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID = 145008U; +static const uint32_t ACL_ERROR_GE_EXEC_LOAD_MODEL_REPEATED = 145009U; +static const uint32_t ACL_ERROR_GE_DYNAMIC_INPUT_ADDR_INVALID = 145011U; +static const uint32_t ACL_ERROR_GE_DYNAMIC_INPUT_LENGTH_INVALID = 145012U; +static const uint32_t ACL_ERROR_GE_DYNAMIC_BATCH_SIZE_INVALID = 145013U; +static const uint32_t ACL_ERROR_GE_AIPP_BATCH_EMPTY = 145014U; +static const uint32_t ACL_ERROR_GE_AIPP_NOT_EXIST = 145015U; +static const uint32_t ACL_ERROR_GE_AIPP_MODE_INVALID = 145016U; +static const uint32_t ACL_ERROR_GE_OP_TASK_TYPE_INVALID = 145017U; +static const uint32_t ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID = 145018U; +static const uint32_t ACL_ERROR_GE_PLGMGR_PATH_INVALID = 145019U; +static const uint32_t ACL_ERROR_GE_FORMAT_INVALID = 145020U; +static const uint32_t ACL_ERROR_GE_SHAPE_INVALID = 145021U; +static const uint32_t ACL_ERROR_GE_DATATYPE_INVALID = 145022U; +static const uint32_t ACL_ERROR_GE_MEMORY_ALLOCATION = 245000U; +static const uint32_t ACL_ERROR_GE_MEMORY_OPERATE_FAILED = 245001U; +static const uint32_t ACL_ERROR_GE_INTERNAL_ERROR = 545000U; +static const uint32_t ACL_ERROR_GE_LOAD_MODEL = 545001U; +static const uint32_t ACL_ERROR_GE_EXEC_LOAD_MODEL_PARTITION_FAILED = 545002U; +static const uint32_t ACL_ERROR_GE_EXEC_LOAD_WEIGHT_PARTITION_FAILED = 545003U; +static const uint32_t ACL_ERROR_GE_EXEC_LOAD_TASK_PARTITION_FAILED = 545004U; +static const uint32_t ACL_ERROR_GE_EXEC_LOAD_KERNEL_PARTITION_FAILED = 545005U; +static const uint32_t ACL_ERROR_GE_EXEC_RELEASE_MODEL_DATA = 545006U; +static const uint32_t ACL_ERROR_GE_COMMAND_HANDLE = 545007U; +static const uint32_t ACL_ERROR_GE_GET_TENSOR_INFO = 545008U; +static const uint32_t ACL_ERROR_GE_UNLOAD_MODEL = 545009U; + + +static const int32_t ACL_ERROR_RT_PARAM_INVALID = 107000; // param invalid +static const int32_t ACL_ERROR_RT_INVALID_DEVICEID = 107001; // invalid device id +static const int32_t ACL_ERROR_RT_CONTEXT_NULL = 107002; // current context null +static const int32_t ACL_ERROR_RT_STREAM_CONTEXT = 107003; // stream not in current context +static const int32_t ACL_ERROR_RT_MODEL_CONTEXT = 107004; // model not in current context +static const int32_t ACL_ERROR_RT_STREAM_MODEL = 107005; // stream not in model +static const int32_t ACL_ERROR_RT_EVENT_TIMESTAMP_INVALID = 107006; // event timestamp invalid +static const int32_t ACL_ERROR_RT_EVENT_TIMESTAMP_REVERSAL = 107007; // event timestamp reversal +static const int32_t ACL_ERROR_RT_ADDR_UNALIGNED = 107008; // memory address unaligned +static const int32_t ACL_ERROR_RT_FILE_OPEN = 107009; // open file failed +static const int32_t ACL_ERROR_RT_FILE_WRITE = 107010; // write file failed +static const int32_t ACL_ERROR_RT_STREAM_SUBSCRIBE = 107011; // error subscribe stream +static const int32_t ACL_ERROR_RT_THREAD_SUBSCRIBE = 107012; // error subscribe thread +static const int32_t ACL_ERROR_RT_GROUP_NOT_SET = 107013; // group not set +static const int32_t ACL_ERROR_RT_GROUP_NOT_CREATE = 107014; // group not create +static const int32_t ACL_ERROR_RT_STREAM_NO_CB_REG = 107015; // callback not register to stream +static const int32_t ACL_ERROR_RT_INVALID_MEMORY_TYPE = 107016; // invalid memory type +static const int32_t ACL_ERROR_RT_INVALID_HANDLE = 107017; // invalid handle +static const int32_t ACL_ERROR_RT_INVALID_MALLOC_TYPE = 107018; // invalid malloc type +static const int32_t ACL_ERROR_RT_WAIT_TIMEOUT = 107019; // wait timeout + +static const int32_t ACL_ERROR_RT_FEATURE_NOT_SUPPORT = 207000; // feature not support +static const int32_t ACL_ERROR_RT_MEMORY_ALLOCATION = 207001; // memory allocation error +static const int32_t ACL_ERROR_RT_MEMORY_FREE = 207002; // memory free error +static const int32_t ACL_ERROR_RT_AICORE_OVER_FLOW = 207003; // aicore over flow +static const int32_t ACL_ERROR_RT_NO_DEVICE = 207004; // no device +static const int32_t ACL_ERROR_RT_RESOURCE_ALLOC_FAIL = 207005; // resource alloc fail +static const int32_t ACL_ERROR_RT_NO_PERMISSION = 207006; // no permission +static const int32_t ACL_ERROR_RT_NO_EVENT_RESOURCE = 207007; // no event resource +static const int32_t ACL_ERROR_RT_NO_STREAM_RESOURCE = 207008; // no stream resource +static const int32_t ACL_ERROR_RT_NO_NOTIFY_RESOURCE = 207009; // no notify resource +static const int32_t ACL_ERROR_RT_NO_MODEL_RESOURCE = 207010; // no model resource +static const int32_t ACL_ERROR_RT_NO_CDQ_RESOURCE = 207011; // no cdq resource +static const int32_t ACL_ERROR_RT_OVER_LIMIT = 207012; // over limit +static const int32_t ACL_ERROR_RT_QUEUE_EMPTY = 207013; // queue is empty +static const int32_t ACL_ERROR_RT_QUEUE_FULL = 207014; // queue is full +static const int32_t ACL_ERROR_RT_REPEATED_INIT = 207015; // repeated init +static const int32_t ACL_ERROR_RT_AIVEC_OVER_FLOW = 207016; // aivec over flow + +static const int32_t ACL_ERROR_RT_INTERNAL_ERROR = 507000; // runtime internal error +static const int32_t ACL_ERROR_RT_TS_ERROR = 507001; // ts internel error +static const int32_t ACL_ERROR_RT_STREAM_TASK_FULL = 507002; // task full in stream +static const int32_t ACL_ERROR_RT_STREAM_TASK_EMPTY = 507003; // task empty in stream +static const int32_t ACL_ERROR_RT_STREAM_NOT_COMPLETE = 507004; // stream not complete +static const int32_t ACL_ERROR_RT_END_OF_SEQUENCE = 507005; // end of sequence +static const int32_t ACL_ERROR_RT_EVENT_NOT_COMPLETE = 507006; // event not complete +static const int32_t ACL_ERROR_RT_CONTEXT_RELEASE_ERROR = 507007; // context release error +static const int32_t ACL_ERROR_RT_SOC_VERSION = 507008; // soc version error +static const int32_t ACL_ERROR_RT_TASK_TYPE_NOT_SUPPORT = 507009; // task type not support +static const int32_t ACL_ERROR_RT_LOST_HEARTBEAT = 507010; // ts lost heartbeat +static const int32_t ACL_ERROR_RT_MODEL_EXECUTE = 507011; // model execute failed +static const int32_t ACL_ERROR_RT_REPORT_TIMEOUT = 507012; // report timeout +static const int32_t ACL_ERROR_RT_SYS_DMA = 507013; // sys dma error +static const int32_t ACL_ERROR_RT_AICORE_TIMEOUT = 507014; // aicore timeout +static const int32_t ACL_ERROR_RT_AICORE_EXCEPTION = 507015; // aicore exception +static const int32_t ACL_ERROR_RT_AICORE_TRAP_EXCEPTION = 507016; // aicore trap exception +static const int32_t ACL_ERROR_RT_AICPU_TIMEOUT = 507017; // aicpu timeout +static const int32_t ACL_ERROR_RT_AICPU_EXCEPTION = 507018; // aicpu exception +static const int32_t ACL_ERROR_RT_AICPU_DATADUMP_RSP_ERR = 507019; // aicpu datadump response error +static const int32_t ACL_ERROR_RT_AICPU_MODEL_RSP_ERR = 507020; // aicpu model operate response error +static const int32_t ACL_ERROR_RT_PROFILING_ERROR = 507021; // profiling error +static const int32_t ACL_ERROR_RT_IPC_ERROR = 507022; // ipc error +static const int32_t ACL_ERROR_RT_MODEL_ABORT_NORMAL = 507023; // model abort normal +static const int32_t ACL_ERROR_RT_KERNEL_UNREGISTERING = 507024; // kernel unregistering +static const int32_t ACL_ERROR_RT_RINGBUFFER_NOT_INIT = 507025; // ringbuffer not init +static const int32_t ACL_ERROR_RT_RINGBUFFER_NO_DATA = 507026; // ringbuffer no data +static const int32_t ACL_ERROR_RT_KERNEL_LOOKUP = 507027; // kernel lookup error +static const int32_t ACL_ERROR_RT_KERNEL_DUPLICATE = 507028; // kernel register duplicate +static const int32_t ACL_ERROR_RT_DEBUG_REGISTER_FAIL = 507029; // debug register failed +static const int32_t ACL_ERROR_RT_DEBUG_UNREGISTER_FAIL = 507030; // debug unregister failed +static const int32_t ACL_ERROR_RT_LABEL_CONTEXT = 507031; // label not in current context +static const int32_t ACL_ERROR_RT_PROGRAM_USE_OUT = 507032; // program register num use out +static const int32_t ACL_ERROR_RT_DEV_SETUP_ERROR = 507033; // device setup error +static const int32_t ACL_ERROR_RT_VECTOR_CORE_TIMEOUT = 507034; // vector core timeout +static const int32_t ACL_ERROR_RT_VECTOR_CORE_EXCEPTION = 507035; // vector core exception +static const int32_t ACL_ERROR_RT_VECTOR_CORE_TRAP_EXCEPTION = 507036; // vector core trap exception +static const int32_t ACL_ERROR_RT_CDQ_BATCH_ABNORMAL = 507037; // cdq alloc batch abnormal +static const int32_t ACL_ERROR_RT_DIE_MODE_CHANGE_ERROR = 507038; // can not change die mode +static const int32_t ACL_ERROR_RT_DIE_SET_ERROR = 507039; // single die mode can not set die +static const int32_t ACL_ERROR_RT_INVALID_DIEID = 507040; // invalid die id +static const int32_t ACL_ERROR_RT_DIE_MODE_NOT_SET = 507041; // die mode not set + +static const int32_t ACL_ERROR_RT_DRV_INTERNAL_ERROR = 507899; // drv internal error +static const int32_t ACL_ERROR_RT_AICPU_INTERNAL_ERROR = 507900; // aicpu internal error +static const int32_t ACL_ERROR_RT_SOCKET_CLOSE = 507901; // hdc disconnect + +)"); + if (acl_error_map.count(error)) + return acl_error_map[error]; + return "unknown " + std::to_string((int)error); +} \ No newline at end of file diff --git a/python/jittor/extern/acl/acl_jittor.cc b/python/jittor/extern/acl/acl_jittor.cc new file mode 100644 index 00000000..f14b0d53 --- /dev/null +++ b/python/jittor/extern/acl/acl_jittor.cc @@ -0,0 +1,302 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "utils/str_utils.h" +#include +#include + +namespace jittor +{ + + uint64_t acl_jittor_tid; + int acl_jittor_thread_running = 0; + aclrtStream aclstream; + +#define CHECK_ACL(x) ASSERTop(x, ==, 0) + + static void *acl_jittor_process_callback(void *) + { + acl_jittor_thread_running = 1; + int deviceId = 0; + + while (acl_jittor_thread_running) + { + // LOGir << "acl_jittor_process_callback"; + auto ret = aclrtProcessReport(1000); + if (ret) + { + if (acl_jittor_thread_running && ret != ACL_ERROR_RT_REPORT_TIMEOUT && ret != ACL_ERROR_RT_THREAD_SUBSCRIBE) + LOGir << "aclrtProcessReport:" << ret << acl_error_to_string(ret); + break; + } + } + acl_jittor_thread_running = 0; + return (void *)0; + } + + struct acl_jittor_initer + { + int32_t deviceId; + acl_jittor_initer() + { + CHECK_ACL(aclInit(nullptr)); + uint device_count = 0; + deviceId = 0; + // 获取可用的Device数量 + CHECK_ACL(aclrtGetDeviceCount(&device_count)); + LOGi << "Found ACL device number:" << device_count; + CHECK_ACL(aclrtSetDevice(deviceId)); + CHECK_ACL(aclrtCreateStream(&aclstream)); + // pthread_create(&acl_jittor_tid, nullptr, acl_jittor_process_callback, 0); + } + + ~acl_jittor_initer() + { + acl_jittor_thread_running = 0; + // CHECK_ACL(aclrtUnSubscribeReport(acl_jittor_tid, 0)); + aclrtDestroyStream(aclstream); + aclrtResetDevice(deviceId); + CHECK_ACL(aclFinalize()); + } + + } _acl_jittor_initer; + + string process_acl(const string &src, const string &name, const map &kargs) + { + if (endswith(name, "_jittor.cc")) + return src; + // static vector dont_compile = {"fp16_emu.cc"}; + // for (auto& s : dont_compile) + // if (endswith(name, s)) + // return " "; + static unordered_set cuda_headers = { + "cuda_runtime", "cudnn", "driver_types", + "cuda_fp16", "cuda_runtime_api", "fp16_emu", + "cudnn_rnn_descriptor", "cublas_v2", "cublas_wrapper", + "curand", "curand_wrapper", "cufft", "cufftXt", + "CudaUtils", "cutt", "cudnn_wrapper", "cuda_bf16"}; + static unordered_set fake_class = { + "cudnnHandle_t", "cudnnConvolutionBwdFilterAlgo_t", + "cudnnConvolutionBwdDataAlgo_t", "cudnnConvolutionFwdAlgo_t", + "cufftHandle"}; + try + { + auto tokens = token_split(src); + int edit = 0; + for (int i = 0; i < tokens.size(); i++) + { + auto &token = tokens[i]; + if (cuda_headers.count(token)) + token = "acl_jittor", edit++; + else if (fake_class.count(token)) + token = "int", edit++; + else if (token == "CUDA") + token = "ACL", edit++; + else if (startswith(token, "cuda")) + { + if (token.size() >= 5 && token[4] >= 'A' && token[4] <= 'Z') + { + if (token == "cudaGetDeviceCount") + { + token_replace(tokens, i, "($1);", "((uint*)$1);"); + } + else if (token == "cudaLaunchHostFunc") + { + // ACL_CALLBACK_BLOCK for 310 + token_replace(tokens, i, "LaunchHostFunc($1,$2,$3)", + "LaunchCallback($2,$3,ACL_CALLBACK_NO_BLOCK,$1)"); + } + else if (token == "cudaMemcpy") + token_replace(tokens, i, "cudaMemcpy($1,$2,$3,", + "aclrtMemcpy($1,$3,$2,$3,"); + else if (token == "cudaMemcpyAsync") + token_replace(tokens, i, "cudaMemcpyAsync($1,$2,$3,", + "aclrtMemcpyAsync($1,$3,$2,$3,"); + else if (token == "cudaMemcpyDeviceToHost") + token = "ACL_MEMCPY_DEVICE_TO_HOST"; + else if (token == "cudaMemcpyDefault") + token = "ACL_MEMCPY_HOST_TO_DEVICE"; + else if (token == "cudaMemcpyHostToDevice") + token = "ACL_MEMCPY_HOST_TO_DEVICE"; + else if (token == "cudaMemcpyDeviceToDevice") + token = "ACL_MEMCPY_DEVICE_TO_DEVICE"; + else if (token == "cudaMallocManaged" || token == "cudaMalloc") + { + // unified address not supported + token = "aclrtMalloc"; + token_replace(tokens, i, "($1,$2)", + "($1,$2,ACL_MEM_MALLOC_HUGE_FIRST)"); + } + else if (token == "cudaMemGetInfo") + token_replace(tokens, i, "cudaMemGetInfo($1,$2)", + "aclrtGetMemInfo(ACL_DDR_MEM,$1,$2)"); + else if (token == "cudaGetLastError") + token_replace(tokens, i, "cudaGetLastError()", "0"); + else if (token == "cudaStreamCreateWithFlags") + token_replace(tokens, i - 1, + "(cudaStreamCreateWithFlags($1,$2));", + "(aclrtCreateStream($1)); checkAclErrors(aclrtSubscribeReport(acl_jittor_tid,*$1));"); + else if (token == "cudaEventCreate") + token_replace(tokens, i, + "cudaEventCreate($1,$2)", + "aclrtCreateEvent($1)"); + else if (token == "cudaDeviceSynchronize") + token = "aclrtSynchronizeDevice"; + else if (token == "cudaStreamDestroy") + token_replace(tokens, i, "cudaStreamDestroy($1)", + "(aclrtUnSubscribeReport(acl_jittor_tid,$1), aclrtDestroyStream($1))"); + else if (token == "cudaEventDestroy") + token = "aclrtDestroyEvent"; + else if (token == "cudaEventRecord") + token = "aclrtRecordEvent"; + else if (token == "cudaStreamWaitEvent") + token_replace(tokens, i, + "cudaStreamWaitEvent($1,$2,$3)", + "aclrtStreamWaitEvent($1,$2)"); + + if (token.size() && token[0] == 'c') + token = "aclrt" + token.substr(4); + if (endswith(token, "_t")) + token = token.substr(0, token.size() - 2); + edit++; + } + } + else if (token == "_cudaGetErrorEnum") + { + token_replace(tokens, i, "_cudaGetErrorEnum($1)", "(acl_error_to_string($1))"); + edit++; + } + else if (token == "checkCudaErrors") + token = "checkAclErrors"; + else if (token == "JPU") + { + edit++; + string new_code; + if (tokens[i + 2] == "op_compiler") + token_replace(tokens, i, + "JPU(op_compiler($1,$2,$3))", + "acl_jittor_op_compiler($1,$2,$3)"); + else if (tokens[i + 2] == "header") + new_code = "#include \"acl_jittor.h\""; + if (new_code.size()) + token_replace(tokens, i, "JPU($1)", new_code); + } + else if (token == "use_cuda_managed_allocator" && tokens[i + 1][0] == ',') + { + tokens[i + 2] = "0"; // disable unified address + } + } + if (!edit) + return src; + string new_src = join(tokens, ""); + // if (name == "executor.cc") { + // new_src = string("#include \n#include \n#include \n")+ + // "namespace jittor { void acl_op_exec(Op*); }\n" + + // replace(new_src, "op->do_run_after_prepare(jkl);", + // R"({ + // acl_op_exec(op); + // })"); + // } + if (name == "profiler.cc") + { + new_src = token_replace_all(new_src, ".cc", ".tikcc"); + } + // LOGir << name << (name == "pass_manager.cc"); + if (name == "pass_manager.cc") + { + LOGir << "replace" << name; + new_src = token_replace_all(new_src, "run_pass();", "WTF"); + } + // ???????? + return new_src; + } + catch (const std::exception &e) + { + LOGe << "process acl error:" << e.what(); + LOGe << "name:" << name; + throw; + } + } + + void acl_jittor_op_compiler(string &filename, string &src, bool is_acl, string &extra_flags) + { + if (!is_acl) + return; + string new_src = process_acl(src, "", {}); + new_src = replace(new_src, R"(#include "misc/cuda_atomic.h")", ""); + new_src = replace(new_src, R"(#include "misc/cuda_limits.h")", ""); + new_src = replace(new_src, "__global__", "__ai_device_entry__"); + new_src = token_replace_all(new_src, "__launch_bounds__($1)", ""); + new_src = token_replace_all(new_src, "int thread_num = $1;", "int thread_num = 1;"); + new_src = token_replace_all(new_src, "tn0=std::max(tn0, $1);", ""); + new_src = token_replace_all(new_src, "<<<$1>>>", "<<<1,0>>>"); + new_src = token_replace_all(new_src, "int thread_id = $1;", "int thread_id = 1;"); + // for inc error + new_src = token_replace_all(new_src, "for ($1+=$2)", "for ($1++)"); + // bit op error + new_src = token_replace_all(new_src, "int tnum$1;", ""); + new_src = token_replace_all(new_src, "int p1$1;", ""); + new_src = token_replace_all(new_src, "int p2$1;", ""); + new_src = token_replace_all(new_src, "int tn$1=$2;", "int tn$1=0;"); + new_src = token_replace_all(new_src, "int tid$1=$2;", "int tid$1=0;"); + src = new_src; + + new_src = token_replace_all(new_src, "atomicAdd(&$1,$2);", "$1=$1+$2;"); + // new_src = token_replace_all(new_src, "bool", "int8"); + new_src = token_replace_all(new_src, "::numeric_min()", "-1e30"); + new_src = token_replace_all(new_src, "::numeric_max()", "1e30"); + // TODO: support max + unordered_map opmap = { + // {"::max","tikcc::scalar_max"}, + {"::sqrtf", "tikcc::scalar_sqrt"}}; + auto ss = split(new_src, ";"); + for (auto &s : ss) + { + if (s.find("?") != string::npos) + { + s = token_replace_all(s + ";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;"); + } + if (s.find("::max") != string::npos) + { + if (s.find("auto") == string::npos) + { + s = token_replace_all(s + ";", " $1=$4::max($2,$3);", " $1=$2;if ($2 < $3) $1=$3;"); + } + else + { + s = token_replace_all(s + ";", "auto $1=$4::max($2,$3);", "auto $1=$2;if ($2 < $3) $1=$3;"); + } + } + for (auto &kv : opmap) + { + if (s.find(kv.first) != string::npos) + { + if (s.find("auto") == string::npos) + { + // $1 = op($2) --> op($1, $2) + s = token_replace_all(s + ";", " $1= " + kv.first + "($2);", kv.second + "($1, $2);"); + } + else + { + // auto $1 = op($2) --> float32 $1; op($1, $2); + s = token_replace_all(s + ";", "auto $1= " + kv.first + "($2);", "float32 $1; " + kv.second + "($1, $2);"); + } + } + } + // s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;"); + // s = token_replace_all(s+";", "auto $1=$2?$3:$4;", "auto $1=$3;if (!($2)) $1=$4;"); + // if (s.find("::max") != string::npos) { + // s = token_replace_all(s+";", " $1= ::max($2);", "tikcc::scalar_max($1, $2);"); + // } + } + new_src = join(ss, ";"); + src = new_src; + } + +} diff --git a/python/jittor/extern/acl/acl_jittor.h b/python/jittor/extern/acl/acl_jittor.h new file mode 100644 index 00000000..18978f4e --- /dev/null +++ b/python/jittor/extern/acl/acl_jittor.h @@ -0,0 +1,241 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include + +std::string acl_error_to_string(aclError error); + +namespace jittor +{ + + EXTERN_LIB uint64_t acl_jittor_tid; + EXTERN_LIB aclrtStream aclstream; + + void acl_jittor_op_compiler(string &filename, string &src, bool is_acl, string &extra_flags); + + struct AclOpFunctions + { + // for Unary + std::function getWorkspaceSizeFuncUnary; + // for Cast + std::function getWorkspaceSizeFuncCast; + // for Bianry + std::function getWorkspaceSizeFuncBinary; + // for Add and Sub + std::function getWorkspaceSizeFuncAdd; + // for Expand and permute + std::function getWorkspaceSizeFuncExpand; + // for bmm and matmul + std::function getWorkspaceSizeFuncMatmul; + // for conv + std::function getWorkspaceSizeFuncConv; + // for reducesum, mean + std::function getWorkspaceSizeFuncReduceSum; + // for amax and amin + std::function getWorkspaceSizeFuncAmax; + // for conv backward + std::function getWorkspaceSizeFuncConvBackward; + // for proddim + std::function getWorkspaceSizeFuncProdDim; + // for select + std::function getWorkspaceSizeFuncSelect; + // for random_uniform and random_normal + std::function getWorkspaceSizeFuncRandom; + + std::function executeFunc; + + // 添加一个默认构造函数 + AclOpFunctions() = default; + + // for Unary + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncUnary(gwsf), executeFunc(execf) {} + + // for Cast + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncCast(gwsf), executeFunc(execf) {} + + // for Binary + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncBinary(gwsf), executeFunc(execf) {} + // for Add and Sub + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncAdd(gwsf), executeFunc(execf) {} + + // for Expand + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncExpand(gwsf), executeFunc(execf) {} + + // for Matmul + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncMatmul(gwsf), executeFunc(execf) {} + + // for conv + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncConv(gwsf), executeFunc(execf) {} + + // for reducesum, mean + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncReduceSum(gwsf), executeFunc(execf) {} + + // for amax amin + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncAmax(gwsf), executeFunc(execf) {} + + // for conv backward + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncConvBackward(gwsf), executeFunc(execf) {} + + // for proddim + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncProdDim(gwsf), executeFunc(execf) {} + + // for select + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncSelect(gwsf), executeFunc(execf) {} + + // for random_normal + AclOpFunctions(std::function gwsf, + std::function execf) + : getWorkspaceSizeFuncRandom(gwsf), executeFunc(execf) {} + }; + + static std::unordered_map aclOpFuncMap = { + {"Abs", AclOpFunctions(aclnnAbsGetWorkspaceSize, aclnnAbs)}, + {"Exp", AclOpFunctions(aclnnExpGetWorkspaceSize, aclnnExp)}, + {"Log", AclOpFunctions(aclnnLogGetWorkspaceSize, aclnnLog)}, + {"Sqrt", AclOpFunctions(aclnnSqrtGetWorkspaceSize, aclnnSqrt)}, + {"Ceil", AclOpFunctions(aclnnCeilGetWorkspaceSize, aclnnCeil)}, + {"Floor", AclOpFunctions(aclnnFloorGetWorkspaceSize, aclnnFloor)}, + {"Round", AclOpFunctions(aclnnRoundGetWorkspaceSize, aclnnRound)}, + {"Sin", AclOpFunctions(aclnnSinGetWorkspaceSize, aclnnSin)}, + {"Cos", AclOpFunctions(aclnnCosGetWorkspaceSize, aclnnCos)}, + {"Tan", AclOpFunctions(aclnnTanGetWorkspaceSize, aclnnTan)}, + {"Asin", AclOpFunctions(aclnnAsinGetWorkspaceSize, aclnnAsin)}, + {"Acos", AclOpFunctions(aclnnAcosGetWorkspaceSize, aclnnAcos)}, + {"Atan", AclOpFunctions(aclnnAtanGetWorkspaceSize, aclnnAtan)}, + {"Sinh", AclOpFunctions(aclnnSinhGetWorkspaceSize, aclnnSinh)}, + {"Cosh", AclOpFunctions(aclnnCoshGetWorkspaceSize, aclnnCosh)}, + {"Tanh", AclOpFunctions(aclnnTanhGetWorkspaceSize, aclnnTanh)}, + {"Asinh", AclOpFunctions(aclnnAsinhGetWorkspaceSize, aclnnAsinh)}, + {"Acosh", AclOpFunctions(aclnnAcoshGetWorkspaceSize, aclnnAcosh)}, + {"Atanh", AclOpFunctions(aclnnAtanhGetWorkspaceSize, aclnnAtanh)}, + {"Sigmoid", AclOpFunctions(aclnnSigmoidGetWorkspaceSize, aclnnSigmoid)}, + {"Erf", AclOpFunctions(aclnnErfGetWorkspaceSize, aclnnErf)}, + {"Erfinv", AclOpFunctions(aclnnErfinvGetWorkspaceSize, aclnnErfinv)}, + {"LogicalNot", AclOpFunctions(aclnnLogicalNotGetWorkspaceSize, aclnnLogicalNot)}, + {"BitwiseNot", AclOpFunctions(aclnnBitwiseNotGetWorkspaceSize, aclnnBitwiseNot)}, + {"Neg", AclOpFunctions(aclnnNegGetWorkspaceSize, aclnnNeg)}, + {"Cast", AclOpFunctions(aclnnCastGetWorkspaceSize, aclnnCast)}, + {"Maximum", AclOpFunctions(aclnnMaximumGetWorkspaceSize, aclnnMaximum)}, + {"Minimum", AclOpFunctions(aclnnMinimumGetWorkspaceSize, aclnnMinimum)}, + {"Add", AclOpFunctions(aclnnAddGetWorkspaceSize, aclnnAdd)}, + {"Sub", AclOpFunctions(aclnnSubGetWorkspaceSize, aclnnSub)}, + {"Mul", AclOpFunctions(aclnnMulGetWorkspaceSize, aclnnMul)}, + {"RealDiv", AclOpFunctions(aclnnDivGetWorkspaceSize, aclnnDiv)}, + {"FloorDiv", AclOpFunctions(aclnnFloorDivideGetWorkspaceSize, aclnnFloorDivide)}, + {"LessEqual", AclOpFunctions(aclnnLeTensorGetWorkspaceSize, aclnnLeTensor)}, + {"Less", AclOpFunctions(aclnnLtTensorGetWorkspaceSize, aclnnLtTensor)}, + {"GreaterEqual", AclOpFunctions(aclnnGeTensorGetWorkspaceSize, aclnnGeTensor)}, + {"Greater", AclOpFunctions(aclnnGtTensorGetWorkspaceSize, aclnnGtTensor)}, + {"Equal", AclOpFunctions(aclnnEqTensorGetWorkspaceSize, aclnnEqTensor)}, + {"NotEqual", AclOpFunctions(aclnnNeTensorGetWorkspaceSize, aclnnNeTensor)}, + {"LogicalAnd", AclOpFunctions(aclnnLogicalAndGetWorkspaceSize, aclnnLogicalAnd)}, + {"LogicalOr", AclOpFunctions(aclnnLogicalOrGetWorkspaceSize, aclnnLogicalOr)}, + {"LogicalXor", AclOpFunctions(aclnnLogicalXorGetWorkspaceSize, aclnnLogicalXor)}, + {"BitwiseAnd", AclOpFunctions(aclnnBitwiseAndTensorGetWorkspaceSize, aclnnBitwiseAndTensor)}, + {"BitwiseOr", AclOpFunctions(aclnnBitwiseOrTensorGetWorkspaceSize, aclnnBitwiseOrTensor)}, + {"BitwiseXor", AclOpFunctions(aclnnBitwiseXorTensorGetWorkspaceSize, aclnnBitwiseXorTensor)}, + {"Pow", AclOpFunctions(aclnnPowTensorTensorGetWorkspaceSize, aclnnPowTensorTensor)}, + {"Expand", AclOpFunctions(aclnnExpandGetWorkspaceSize, aclnnExpand)}, + {"MatMul", AclOpFunctions(aclnnMatmulGetWorkspaceSize, aclnnMatmul)}, + {"BatchMatMul", AclOpFunctions(aclnnBatchMatMulGetWorkspaceSize, aclnnBatchMatMul)}, + {"Conv2D", AclOpFunctions(aclnnConvolutionGetWorkspaceSize, aclnnConvolution)}, + {"ReduceMax", AclOpFunctions(aclnnAmaxGetWorkspaceSize, aclnnAmax)}, + {"ReduceMin", AclOpFunctions(aclnnAminGetWorkspaceSize, aclnnAmin)}, + {"ReduceSum", AclOpFunctions(aclnnReduceSumGetWorkspaceSize, aclnnReduceSum)}, + {"Triu", AclOpFunctions(aclnnTriuGetWorkspaceSize, aclnnTriu)}, + {"Conv2d", AclOpFunctions(aclnnConvolutionGetWorkspaceSize, aclnnConvolution)}, + {"Conv2dBackward", AclOpFunctions(aclnnConvolutionBackwardGetWorkspaceSize, aclnnConvolutionBackward)}, + {"ReduceMean", AclOpFunctions(aclnnMeanGetWorkspaceSize, aclnnMean)}, + // {"ReduceProd", AclOpFunctions(aclnnProdDimGetWorkspaceSize, aclnnProdDim)}, + {"Select", AclOpFunctions(aclnnSWhereGetWorkspaceSize, aclnnSWhere)}, + {"RandomUniform", AclOpFunctions(aclnnInplaceRandomGetWorkspaceSize, aclnnInplaceRandom)}, + {"RandomNormal", AclOpFunctions(aclnnInplaceNormalGetWorkspaceSize, aclnnInplaceNormal)}, + {"Transpose", AclOpFunctions(aclnnPermuteGetWorkspaceSize, aclnnPermute)}, + }; + + struct AclOpAttr + { + virtual ~AclOpAttr() {} + }; + + struct ConvAttr : AclOpAttr + { + vector convStrides; + vector convPads; + vector convOutPads; + vector convDilations; + bool convWithBias; + bool is_transposed; + int64_t group; + + // 析构函数 + ~ConvAttr() + { + convStrides.clear(); + convPads.clear(); + convOutPads.clear(); + convDilations.clear(); + } + }; + + struct ReduceAttr : AclOpAttr + { + vector axes; + // for proddim + int64_t prod_dim; + bool keepdims; + + ~ReduceAttr() + { + axes.clear(); + } + }; + + struct RandomAttr : AclOpAttr + { + int64_t seed, offset; + + ~RandomAttr() + { + } + }; + + struct TriuAttr : AclOpAttr + { + int64_t diagonal; + + ~TriuAttr() + { + } + }; + +} \ No newline at end of file diff --git a/python/jittor/extern/acl/acl_op_exec.cc b/python/jittor/extern/acl/acl_op_exec.cc new file mode 100644 index 00000000..712ab7d3 --- /dev/null +++ b/python/jittor/extern/acl/acl_op_exec.cc @@ -0,0 +1,862 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#include +#include +#include +#include "common.h" +#include "op.h" +#include "acl_jittor.h" +#include "ops/random_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/transpose_op.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "fused_op.h" +#include "ops/unary_op.h" +#include "ops/ternary_op.h" +#include "executor.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "op_compiler.h" +#include "ops/op_register.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "aclnn/aclnn.h" + +namespace jittor +{ + int CreateAclTensor(const std::vector &shape, void *deviceAddr, int64_t size, + aclDataType dataType, aclTensor **tensor, bool use_nchw = false) + { + // 计算连续tensor的strides + std::vector strides(shape.size(), 1); + for (int64_t i = shape.size() - 2; i >= 0; i--) + { + strides[i] = shape[i + 1] * strides[i + 1]; + } + if (shape.size() == 0) + strides = {}; + // 调用aclCreateTensor接口创建aclTensor + if (use_nchw) + *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_NCHW, + shape.data(), shape.size(), deviceAddr); + else + *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, + shape.data(), shape.size(), deviceAddr); + return 0; + } + + struct AclOpRunner + { + string name; + string jt_name; + vector in_; + vector out_; + std::unique_ptr op_attr; + + AclOpRunner(const string &name) : name(name) + { + } + + ~AclOpRunner() + { + } + + aclDataType get_dtype(NanoString s) + { + if (s == ns_float32) + return ACL_FLOAT; + if (s == ns_float16) + return ACL_FLOAT16; + if (s == ns_int64) + return ACL_INT64; + if (s == ns_int32) + return ACL_INT32; + if (s == ns_int8) + return ACL_INT8; + if (s == ns_int16) + return ACL_INT16; + if (s == ns_uint8) + return ACL_UINT8; + if (s == ns_uint16) + return ACL_UINT16; + if (s == ns_uint32) + return ACL_UINT32; + if (s == ns_bool) + return ACL_BOOL; + LOGf << "Not supported dtype: " << s; + return ACL_FLOAT; + } + + void add(Var *v, bool is_input) + { + + if (is_input) + { + in_.push_back(v); + } + else + { + out_.push_back(v); + } + return; + } + + template + std::vector createVector(int64_t size) + { + return std::vector(size, 0); + } + + void run() + { + // LOGir << name << " " << jt_name; + auto it = aclOpFuncMap.find(name); + if (it == aclOpFuncMap.end()) + { + LOGir << "Not supported op: " << name; + throw std::runtime_error("Unsupported operation type."); + } + + // 0. 算子的输入、输出、需要的attr定义 + std::vector> inputShapes; + std::vector> outputShapes; + + // for reduce + std::vector axes; + aclIntArray *dim = nullptr; + bool keepdims; + + bool use_nchw = false; + + auto input_num = in_.size(); + auto output_num = out_.size(); + + for (int input_idx = 0; input_idx < input_num; input_idx++) + { + std::vector shape; + for (int j = 0; j < in_[input_idx]->shape.size(); j++) + { + shape.push_back(in_[input_idx]->shape[j]); + } + inputShapes.push_back(shape); + } + for (int output_idx = 0; output_idx < output_num; output_idx++) + { + std::vector shape; + for (int j = 0; j < out_[output_idx]->shape.size(); j++) + { + shape.push_back(out_[output_idx]->shape[j]); + } + outputShapes.push_back(shape); + } + + // 1. 创建aclTensor和aclScalar,不同算子可能不一样,需要根据具体API的接口定义修改 + std::vector inputTensors; + std::vector outputTensors; + + // for add and sub + aclScalar *alpha = nullptr; + + // for expand + aclIntArray *size = nullptr; + + // for add and sub + float alphaValue = 1.0f; + + // for conv + aclIntArray *strides = nullptr; + aclIntArray *pads = nullptr; + aclIntArray *outPads = nullptr; + aclIntArray *dilations = nullptr; + int ret = -1; + + if (name == string("Add") || name == string("Sub")) + { + alpha = aclCreateScalar(&alphaValue, get_dtype(in_[0]->dtype())); + CHECK_RET(alpha != nullptr, return); + } + + if (jt_name == "conv" || jt_name == "conv2d" || jt_name == "conv2dbackward") + use_nchw = true; + + for (int idx = 0; idx < input_num; idx++) + { + inputTensors.push_back(nullptr); + auto ret = CreateAclTensor(inputShapes[idx], in_[idx]->mem_ptr, in_[idx]->size, get_dtype(in_[idx]->dtype()), &inputTensors[idx], use_nchw); + CHECK_RET(ret == ACL_SUCCESS, return); + } + + if (jt_name == "reduce" || jt_name == "transpose") + { + auto attr = dynamic_cast(op_attr.get()); + dim = aclCreateIntArray(attr->axes.data(), attr->axes.size()); + keepdims = attr->keepdims; + if (name == string("ReduceMax") || name == string("ReduceMin") || name == string("ReduceMean") || name == string("ReduceProd")) + { + if (attr->axes.size() == in_[0]->shape.size()) + outputShapes[0] = {}; + } + } + for (int idx = 0; idx < output_num; idx++) + { + outputTensors.push_back(nullptr); + auto ret = CreateAclTensor(outputShapes[idx], out_[idx]->mem_ptr, out_[idx]->size, get_dtype(out_[idx]->dtype()), &outputTensors[idx], use_nchw); + CHECK_RET(ret == ACL_SUCCESS, return); + } + + // 2. 调用CANN算子库aclnnxxxGetWorkspaceSize的接口,两段式接口的第一个 + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + if (name == string("Add") || name == string("Sub")) + ret = it->second.getWorkspaceSizeFuncAdd(inputTensors[0], inputTensors[1], alpha, outputTensors[0], &workspaceSize, &executor); + else if (name == string("Expand")) + { + size = aclCreateIntArray(&outputShapes[0][0], outputShapes[0].size()); + ret = it->second.getWorkspaceSizeFuncExpand(inputTensors[0], size, outputTensors[0], &workspaceSize, &executor); + } + else if (name == string("Cast")) + ret = it->second.getWorkspaceSizeFuncCast(inputTensors[0], get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor); + else if (jt_name == "unary") + ret = it->second.getWorkspaceSizeFuncUnary(inputTensors[0], outputTensors[0], &workspaceSize, &executor); + else if (jt_name == "binary") + ret = it->second.getWorkspaceSizeFuncBinary(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor); + else if (jt_name == "bmm" || jt_name == "matmul") + ret = it->second.getWorkspaceSizeFuncMatmul(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor); + else if (name == string("ReduceSum") || name == string("ReduceMean")) + { + ret = it->second.getWorkspaceSizeFuncReduceSum(inputTensors[0], dim, keepdims, get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor); + } + else if (name == string("ReduceMax") || name == string("ReduceMin")) + { + ret = it->second.getWorkspaceSizeFuncAmax(inputTensors[0], dim, keepdims, outputTensors[0], &workspaceSize, &executor); + } + // else if (name == string("ReduceProd")) + // { + // ret = it->second.getWorkspaceSizeFuncReduceProd(inputTensors[0], dim, false, outputTensors[0], &workspaceSize, &executor); + // } + else if (name == string("RandomUniform") || name == string("RandomNormal")) + { + auto attr = dynamic_cast(op_attr.get()); + ret = it->second.getWorkspaceSizeFuncRandom(outputTensors[0], int64_t(0), int64_t(1), attr->seed, attr->offset, &workspaceSize, &executor); + } + else if (name == string("Select")) + { + ret = it->second.getWorkspaceSizeFuncSelect(inputTensors[0], inputTensors[1], inputTensors[2], outputTensors[0], &workspaceSize, &executor); + } + else if (name == string("Triu")) + { + auto attr = dynamic_cast(op_attr.get()); + ret = it->second.getWorkspaceSizeFuncCast(inputTensors[0], aclDataType(attr->diagonal), outputTensors[0], &workspaceSize, &executor); + } + else if (name == string("Transpose")) + { + ret = it->second.getWorkspaceSizeFuncExpand(inputTensors[0], dim, outputTensors[0], &workspaceSize, &executor); + } + // else if (name == string("Conv2d")) + // { + // auto attr = dynamic_cast(op_attr.get()); + // strides = aclCreateIntArray(attr->convStrides.data(), 2); + // pads = aclCreateIntArray(attr->convPads.data(), 2); + // outPads = aclCreateIntArray(attr->convOutPads.data(), 2); + // dilations = aclCreateIntArray(attr->convDilations.data(), 2); + + // ret = it->second.getWorkspaceSizeFuncConv(inputTensors[0], inputTensors[1], nullptr, strides, pads, dilations, false, outPads, attr->group, outputTensors[0], 0, &workspaceSize, &executor); + // } + // else if (name == string("Conv2dBackward")) + // { + // auto attr = dynamic_cast(op_attr.get()); + // strides = aclCreateIntArray(attr->convStrides.data(), 2); + // pads = aclCreateIntArray(attr->convPads.data(), 2); + // outPads = aclCreateIntArray(attr->convOutPads.data(), 2); + // dilations = aclCreateIntArray(attr->convDilations.data(), 2); + // bool outputMask[3] = {true, true, false}; + // LOGir << attr->group; + // aclBoolArray *outMask = aclCreateBoolArray(outputMask, 3); + // ret = it->second.getWorkspaceSizeFuncConvBackward(inputTensors[0], inputTensors[1], inputTensors[2], nullptr, strides, pads, dilations, false, outPads, attr->group, outMask, 0, outputTensors[0], outputTensors[1], nullptr, &workspaceSize, &executor); + // } + else + LOGf << "not supported op " << jt_name; + + // for debug + if (ret != ACL_SUCCESS) + { + auto tmp_err_msg = aclGetRecentErrMsg(); + LOGir << name << ", " << tmp_err_msg; + } + + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxxGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return); + + // 4. 根据第一段接口计算出的workspaceSize申请device内存 + void *workspaceAddr = nullptr; + if (workspaceSize > 0) + { + ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: allocate workspace failed. ERROR: %d\n", name.c_str(), ret); return); + } + + // 5. 调用aclnnxx第二段接口 + ret = it->second.executeFunc(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxx failed. ERROR: %d\n", name.c_str(), ret); return); + + // 6. (固定写法)同步等待任务执行结束 + ret = aclrtSynchronizeStream(aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclrtSynchronizeStream failed. ERROR: %d\n", name.c_str(), ret); return); + + // 7. 释放aclTensor和aclScalar,需要根据具体API的接口定义修改 + // destroy tensor + for (int idx = 0; idx < input_num; idx++) + { + aclDestroyTensor(inputTensors[idx]); + } + for (int idx = 0; idx < output_num; idx++) + { + aclDestroyTensor(outputTensors[idx]); + } + // destroy scalar + aclDestroyScalar(alpha); + + // destroy IntArray + aclDestroyIntArray(size); + aclDestroyIntArray(dim); + aclDestroyIntArray(strides); + aclDestroyIntArray(pads); + aclDestroyIntArray(outPads); + aclDestroyIntArray(dilations); + + // 8. 释放device资源 + if (workspaceSize > 0) + { + aclrtFree(workspaceAddr); + } + return; + } + }; + + void free_var_mem(Var *v); + + unordered_map opname_map = { + // unary op + {ns_cast, "Cast"}, + {ns_negative, "Neg"}, + {ns_abs, "Abs"}, + {ns_exp, "Exp"}, + {ns_log, "Log"}, + {ns_sqrt, "Sqrt"}, + {ns_ceil, "Ceil"}, + {ns_floor, "Floor"}, + {ns_round, "Round"}, + // m(round_int) + // m(floor_int) + // m(ceil_int) + {ns_sin, "Sin"}, + {ns_cos, "Cos"}, + {ns_tan, "Tan"}, + {ns_asin, "Asin"}, + {ns_acos, "Acos"}, + {ns_atan, "Atan"}, + {ns_sinh, "Sinh"}, + {ns_cosh, "Cosh"}, + {ns_tanh, "Tanh"}, + {ns_asinh, "Asinh"}, + {ns_acosh, "Acosh"}, + {ns_atanh, "Atanh"}, + {ns_sigmoid, "Sigmoid"}, + {ns_erf, "Erf"}, + {ns_erfinv, "Erfinv"}, + {ns_logical_not, "LogicalNot"}, + {ns_bitwise_not, "BitwiseNot"}, + // binary op + {ns_pow, "Pow"}, + {ns_maximum, "Maximum"}, + {ns_minimum, "Minimum"}, + {ns_add, "Add"}, + {ns_subtract, "Sub"}, + {ns_multiply, "Mul"}, + {ns_divide, "RealDiv"}, + {ns_floor_divide, "FloorDiv"}, + {ns_mod, "Mod"}, + {ns_less, "Less"}, + {ns_less_equal, "LessEqual"}, + {ns_greater, "Greater"}, + {ns_greater_equal, "GreaterEqual"}, + {ns_equal, "Equal"}, + {ns_not_equal, "NotEqual"}, + {ns_left_shift, "LeftShift"}, + {ns_right_shift, "RightShift"}, + {ns_logical_and, "LogicalAnd"}, + {ns_logical_or, "LogicalOr"}, + {ns_logical_xor, "LogicalXor"}, + {ns_bitwise_and, "BitwiseAnd"}, + {ns_bitwise_or, "BitwiseOr"}, + {ns_bitwise_xor, "BitwiseXor"}, + + }; + + void fallback_cpu(Op *op) + { + LOGy << "!!! fallback_cpu " << op; + use_cuda = 0; + for (auto v : op->inputs()) + { + if (v->mem_ptr && v->allocator->is_cuda()) + { + migrate_to_cpu(v, exe.allocator); + } + } + for (auto v : op->outputs()) + { + if (v->mem_ptr && v->allocator->is_cuda()) + { + migrate_to_cpu(v, exe.allocator); + } + } + op->flags.set(NodeFlags::_cpu); + op->flags.set(NodeFlags::_cuda, 0); + if (op->name() == string("fused")) + { + auto fop = (FusedOp *)op; + for (auto op : fop->ops) + { + op->flags.set(NodeFlags::_cpu); + op->flags.set(NodeFlags::_cuda, 0); + } + } + op->do_run(); + use_cuda = 1; + } + + /* + check compile + if compiled: exec + else: compile + check is fused + check is relay + else + compile func = try exec + if failed: fallback_cpu + else + try compile + if failed: fallback_cpu + */ + + extern jit_op_entry_t (*do_compile_hook)(Op *); + jit_op_entry_t do_compile_inner(Op *op); + + void try_exec_and_fallback_cpu(Op *op) + { + LOGv << "try_exec_and_fallback_cpu " << op; + auto fop = (FusedOp *)op; + + vector new_alloced; + int fallback = 0; + try + { + for (Op *op : fop->ops) + { + for (auto out : op->outputs()) + { + if (out->mem_ptr) + continue; + out->alloc(exe.temp_allocator); + new_alloced.push_back(out); + } + if (op->name() == string("unary")) + { + auto uop = (UnaryOp *)op; + AclOpRunner op("..."); + op.add(uop->x, true); + op.add(uop->y, false); + auto iter = opname_map.find(uop->ns); + ASSERT(iter != opname_map.end()) << "op " << uop->ns << " not found"; + op.name = iter->second; + op.jt_name = uop->name(); + op.run(); + } + else if (op->name() == string("binary")) + { + auto bop = (BinaryOp *)op; + AclOpRunner op("..."); + op.add(bop->x, true); + op.add(bop->y, true); + op.add(bop->z, false); + auto iter = opname_map.find(bop->ns); + ASSERT(iter != opname_map.end()) << "op " << bop->ns << " not found"; + op.name = iter->second; + op.jt_name = bop->name(); + + if (bop->x->dtype() == ns_bool and bop->y->dtype() == ns_bool) + { + // BitwiseOr, BitwiseAnd, BitwiseXor -> LogicalOr, LogicalAnd, LogicalXor + if (bop->ns == ns_bitwise_or) + { + op.name = "LogicalOr"; + } + else if (bop->ns == ns_bitwise_and) + { + op.name = "LogicalAnd"; + } + else if (bop->ns == ns_bitwise_xor) + { + op.name = "LogicalXor"; + } + } + op.run(); + } + else if (op->name() == string("ternary")) + { + auto top = (TernaryOp *)op; + AclOpRunner op("Select"); + op.add(top->cond, true); + op.add(top->x, true); + op.add(top->y, true); + op.add(top->z, false); + op.run(); + } + else if (op->name() == string("array")) + { + auto aop = (ArrayOp *)op; + aclrtMemcpy(aop->output->mem_ptr, aop->output->size, aop->ptr(), aop->output->size, ACL_MEMCPY_HOST_TO_DEVICE); + } + else if (op->name() == string("reduce")) + { + auto rop = (ReduceOp *)op; + AclOpRunner op(""); + if (rop->ns == ns_add) + op.name = "ReduceSum"; + else if (rop->ns == ns_multiply) + // TODO unsupported the multi dim + op.name = "ReduceProd"; + else if (rop->ns == ns_maximum) + op.name = "ReduceMax"; + else if (rop->ns == ns_minimum) + op.name = "ReduceMin"; + else if (rop->ns == ns_mean) + op.name = "ReduceMean"; + else + LOGf << "op " << rop->ns << " not supported"; + op.jt_name = "reduce"; + op.add(rop->x, true); + + ReduceAttr *attr = new ReduceAttr(); + for (int i = 0; i < rop->x->shape.size(); i++) + if (rop->reduce_mask & (1 << i)) + attr->axes.push_back(i); + if (rop->x->shape.size() == rop->y->shape.size()) + attr->keepdims = true; + else + attr->keepdims = false; + + op.op_attr.reset(attr); + op.add(rop->y, false); + op.run(); + } + else if (op->name() == string("broadcast_to")) + { + auto bop = (BroadcastToOp *)op; + AclOpRunner op("Expand"); + op.jt_name = "expand"; + + NanoVector xshape, xshape_bk = bop->x->shape; + NanoVector zshape = bop->z->shape; + for (int i = 0; i < zshape.size(); i++) + { + if (bop->bcast_mask & (1 << i)) + { + xshape.push_back(1); + } + else + { + xshape.push_back(zshape[i]); + } + } + bop->x->shape = xshape; + op.add(bop->x, true); + // bop->x->shape = xshape_bk; + op.add(bop->z, false); + op.run(); + bop->x->shape = xshape_bk; + } + else if (op->name() == string("fuse_transpose")) + { + // replace fuse_transpose with transpose + auto top = (TransposeOp *)op; + AclOpRunner op("Transpose"); + op.add(top->x, true); + op.add(top->y, false); + op.jt_name = "transpose"; + + ReduceAttr *attr = new ReduceAttr(); + for (int i = 0; i < top->axes.size(); i++) + attr->axes.push_back(top->axes[i]); + op.op_attr.reset(attr); + + op.run(); + } + else + { + LOGf << "op " << op->name() << " not supported"; + } + } + } + catch (std::exception &e) + { + fallback = 1; + LOGir << "fallback cpu" << e.what(); + } + for (auto v : new_alloced) + { + free_var_mem(v); + } + if (fallback) + { + fallback_cpu(op); + } + } + + extern int current_seed; + extern int64 current_offset; + + static unordered_map> acl_ops = { + {"curand_random", [¤t_seed, ¤t_offset](Op *op) + { + auto _op = (RandomOp *)op; + AclOpRunner runner(_op->type == ns_uniform ? "RandomUniform" : "RandomNormal"); + auto out = op->output(0); + RandomAttr *attr = new RandomAttr(); + attr->seed = current_seed; + attr->offset = current_offset; + runner.jt_name = "random"; + runner.op_attr.reset(attr); + + runner.add(out, false); + runner.run(); + current_offset += out->numel(); + }}, + {"cublas_matmul", [&](Op *op) + { + struct MatmulOp : Op + { + Var *a, *b, *c; + bool trans_a, trans_b; + }; + auto _op = (MatmulOp *)op; + AclOpRunner runner("MatMul"); + runner.jt_name = "matmul"; + runner.add(_op->a, true); + runner.add(_op->b, true); + runner.add(_op->c, false); + runner.run(); + }}, + {"cublas_batched_matmul", [&](Op *op) + { + struct BatchedMatmulOp : Op + { + Var *a, *b, *c; + bool adj_x1, adj_x2; + }; + auto _op = (BatchedMatmulOp *)op; + AclOpRunner runner("BatchMatMul"); + runner.jt_name = "bmm"; + runner.add(_op->a, true); + runner.add(_op->b, true); + runner.add(_op->c, false); + runner.run(); + }}, + // {"cudnn_conv", [](Op *op) + // { + // struct ConvOp : Op + // { + // Var *x, *w, *y; + // int strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; + // string xformat, wformat, yformat; + // void run_acl() + // { + // AclOpRunner runner("Conv2D"); + // runner.jt_name = "conv"; + // runner.add(x, true); + // runner.add(w, true); + // runner.add(y, false); + // ConvAttr *attr = new ConvAttr(); + + // attr->convStrides = {strideh, stridew, 1, 1}; + // attr->convPads = {paddingh, paddingh, paddingw, paddingw}; + // attr->convOutPads = {1, 1, 1, 1}; + // attr->convDilations = {dilationh, dilationw, 1, 1}; + // attr->group = groups; + // runner.op_attr.reset(attr); + + // runner.run(); + // } + // }; + // auto _op = (ConvOp *)op; + // _op->run_acl(); + // }}, + // {"cudnn_conv_backward_x", [](Op *op) + // { + // struct ConvBackwardXOp : Op + // { + // Var *w, *dy, *dx; + // int xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; + // string xformat, wformat, yformat; + // void run_acl() + // { + // /* + // AclOpRunner runner("Conv2DBackpropInput"); + // runner.add_input_host_nv32(dx->shape); // 10,3,50,50 + // // runner.add_input_host_nv32(dy->shape); // 10,3,50,50 + // runner.add(w, true, ACL_FORMAT_NCHW); // 4,3,3,3 + // aclSetTensorDescName(runner.input_desc.back(), "filter"); + // runner.add(dy, true, ACL_FORMAT_NCHW); // 10,4,48,48 + // aclSetTensorDescName(runner.input_desc.back(), "out_backprop"); + // runner.add(dx, false, ACL_FORMAT_NCHW); // 10,3,50,50 + // aclSetTensorDescName(runner.input_desc.back(), "y"); + // runner.set_attr("strides", vector{1,1,strideh,stridew}); + // runner.set_attr("pads", vector{paddingh,paddingh,paddingw,paddingw}); + // runner.set_attr("dilations", vector{1,1,dilationh,dilationw}); + // runner.set_attr("groups", groups); + // runner.set_attr("data_format", "NCHW"); + // // runner.set_attr("dataFormat", "NCHW"); + // // runner.set_attr("data_format", "NCHW"); + // ASSERT(xformat=="abcd" && yformat=="abcd" && wformat=="oihw"); + // runner.run();*/ + // } + // }; + // auto _op = (ConvBackwardXOp *)op; + // _op->run_acl(); + // }}, + // {"cudnn_conv_backward_w", [](Op *op) + // { + // struct ConvBackwardWOp : Op + // { + // Var *x, *dy, *dw; + // int kh, kw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; + // string xformat, wformat, yformat; + // void run_acl() + // { + // /* + // AclOpRunner runner("Conv2DBackpropFilter"); + // runner.add(x, true, ACL_FORMAT_NCHW); + // runner.add_input_host_nv32(dw->shape); + // runner.add(dy, true, ACL_FORMAT_NCHW); + // runner.add(dw, false, ACL_FORMAT_NCHW); + // runner.set_attr("strides", vector{1, 1, strideh, stridew}); + // runner.set_attr("pads", vector{paddingh, paddingh, paddingw, paddingw}); + // runner.set_attr("dilations", vector{1, 1, dilationh, dilationw}); + // runner.set_attr("groups", groups); + // runner.set_attr("data_format", "NCHW"); + // // runner.set_attr("dataFormat", "NCHW"); + // // runner.set_attr("data_format", "NCHW"); + // // runner.set_attr("data_origin_format", "NCHW"); + // ASSERT(xformat == "abcd" && yformat == "abcd" && wformat == "oihw"); + // runner.run(); + // */ + // } + // }; + // auto _op = (ConvBackwardWOp *)op; + // _op->run_acl(); + // }}, + // {"cub_arg_reduce", } + }; + + static void exec_mapped_acl_ops(Op *op) + { + auto iter = acl_ops.find(op->name()); + if (iter != acl_ops.end()) + { + LOGv << "exec acl op " << op->name() << op; + iter->second(op); + } + else + { + LOGf << "op " << op->name() << " not supported"; + } + } + + static jit_op_entry_t acl_do_compile(Op *op) + { + LOGv << "compile" << op; + OpCompiler oc(op); + string *src = &oc.src; + for (auto op_type : op_types) + op_type->post_pass(&oc); + string src_after_passes; + // if is fused op + if (oc.op) + { + TunerManager tm(&oc); + src_after_passes = tm.tune(); + src = &src_after_passes; + } + op->compile_optimize(*src); + if (!op->flags.get(NodeFlags::_cuda)) + { + LOGv << "compile cpu"; + return oc.compile(op->get_jit_key(get_jk()), *src); + } + if (op->name() == string("fused")) + { + FusedOp *fop = (FusedOp *)op; + // if is a relayed op + if (fop->context->vrm.relay_groups.size()) + { + LOGv << "relay fused op"; + return oc.compile(op->get_jit_key(get_jk()), *src); + } + else + { + return &try_exec_and_fallback_cpu; + } + } + else if (op->name() == string("code")) + { + CodeOp *cop = (CodeOp *)op; + if (cop->cuda_src.find("acl") != string::npos) + { + LOGv << "compile acl op"; + return oc.compile(op->get_jit_key(get_jk()), *src); + } + else + { + return &exec_mapped_acl_ops; + } + } + else + { + LOGv << "compile finish" << op; + return &exec_mapped_acl_ops; + } + return do_compile_inner(op); + } + + // from op_register.cc + extern unordered_map op_info_map; + + void init_acl_ops() + { + do_compile_hook = acl_do_compile; + vector to_erase; + for (auto &kv : op_info_map) + { + if (startswith(kv.first, "cu") && acl_ops.count(kv.first) == 0) + { + to_erase.push_back(kv.first); + } + } + for (auto &k : to_erase) + { + LOGv << "op not supported: " << k << ", erase it."; + op_info_map.erase(k); + } + } + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/acl/aclnn/aclnn.cc b/python/jittor/extern/acl/aclnn/aclnn.cc new file mode 100644 index 00000000..c1e089e7 --- /dev/null +++ b/python/jittor/extern/acl/aclnn/aclnn.cc @@ -0,0 +1,58 @@ +#include +#include +#include "aclnn.h" + +int64_t GetShapeSize(const std::vector& shape) { + int64_t shapeSize = 1; + for (auto i : shape) { + shapeSize *= i; + } + return shapeSize; +} + +void PrintOutResult(std::vector &shape, void** deviceAddr) { + auto size = GetShapeSize(shape); + std::vector resultData(size, 0); + auto ret = aclrtMemcpy(resultData.data(), resultData.size() * sizeof(resultData[0]), + *deviceAddr, size * sizeof(resultData[0]), ACL_MEMCPY_DEVICE_TO_HOST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", ret); return); + for (int64_t i = 0; i < size; i++) { + LOG_PRINT("mean result[%ld] is: %d\n", i, resultData[i]); + } +} + +int Init(int32_t deviceId) { + // 固定写法,AscendCL初始化 + auto ret = aclInit(nullptr); + CHECK_RET(ret == ACL_SUCCESS or ret == 100002, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret); + ret = aclrtSetDevice(deviceId); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret); + //ret = aclrtCreateStream(stream); + //CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret); + return 0; +} + +/* +template +int CreateAclTensor(const std::vector& hostData, const std::vector& shape, void** deviceAddr, + aclDataType dataType, aclTensor** tensor) { + auto size = GetShapeSize(shape) * sizeof(T); + // 调用aclrtMalloc申请device侧内存 + auto ret = aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMalloc failed. ERROR: %d\n", ret); return ret); + // 调用aclrtMemcpy将host侧数据拷贝到device侧内存上 + ret = aclrtMemcpy(*deviceAddr, size, hostData.data(), size, ACL_MEMCPY_HOST_TO_DEVICE); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy failed. ERROR: %d\n", ret); return ret); + + // 计算连续tensor的strides + std::vector strides(shape.size(), 1); + for (int64_t i = shape.size() - 2; i >= 0; i--) { + strides[i] = shape[i + 1] * strides[i + 1]; + } + + // 调用aclCreateTensor接口创建aclTensor + *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, + shape.data(), shape.size(), *deviceAddr); + return 0; +}*/ + diff --git a/python/jittor/extern/acl/aclnn/aclnn.h b/python/jittor/extern/acl/aclnn/aclnn.h new file mode 100644 index 00000000..5f5315de --- /dev/null +++ b/python/jittor/extern/acl/aclnn/aclnn.h @@ -0,0 +1,94 @@ +#include +#include +#include "acl.h" +// unary +#include "aclnnop/aclnn_abs.h" +#include "aclnnop/aclnn_neg.h" +#include "aclnnop/aclnn_exp.h" +#include "aclnnop/aclnn_log.h" +#include "aclnnop/aclnn_sqrt.h" +#include "aclnnop/aclnn_ceil.h" +#include "aclnnop/aclnn_floor.h" +#include "aclnnop/aclnn_round.h" +#include "aclnnop/aclnn_sin.h" +#include "aclnnop/aclnn_cos.h" +#include "aclnnop/aclnn_tan.h" +#include "aclnnop/aclnn_asin.h" +#include "aclnnop/aclnn_acos.h" +#include "aclnnop/aclnn_atan.h" +#include "aclnnop/aclnn_sinh.h" +#include "aclnnop/aclnn_cosh.h" +#include "aclnnop/aclnn_tanh.h" +#include "aclnnop/aclnn_asinh.h" +#include "aclnnop/aclnn_acosh.h" +#include "aclnnop/aclnn_atanh.h" +#include "aclnnop/aclnn_sigmoid.h" +#include "aclnnop/aclnn_erf.h" +#include "aclnnop/aclnn_erfinv.h" +#include "aclnnop/aclnn_logical_not.h" +#include "aclnnop/aclnn_bitwise_not.h" +#include "aclnnop/aclnn_cast.h" +// binary +#include "aclnnop/aclnn_maximum.h" +#include "aclnnop/aclnn_minimum.h" +#include "aclnnop/aclnn_add.h" +#include "aclnnop/aclnn_sub.h" +#include "aclnnop/aclnn_mul.h" +#include "aclnnop/aclnn_div.h" +#include "aclnnop/aclnn_floor_divide.h" +#include "aclnnop/aclnn_le_tensor.h" +#include "aclnnop/aclnn_lt_tensor.h" +#include "aclnnop/aclnn_ge_tensor.h" +#include "aclnnop/aclnn_gt_tensor.h" +#include "aclnnop/aclnn_eq_tensor.h" +#include "aclnnop/aclnn_ne_tensor.h" +#include "aclnnop/aclnn_logical_and.h" +#include "aclnnop/aclnn_logical_or.h" +#include "aclnnop/aclnn_logical_xor.h" +#include "aclnnop/aclnn_bitwise_and_tensor.h" +#include "aclnnop/aclnn_bitwise_or_tensor.h" +#include "aclnnop/aclnn_bitwise_xor_tensor.h" +#include "aclnnop/aclnn_pow_tensor_tensor.h" +#include "aclnnop/aclnn_expand.h" +#include "aclnnop/aclnn_matmul.h" +#include "aclnnop/aclnn_batch_matmul.h" +#include "aclnnop/aclnn_convolution.h" +#include "aclnnop/aclnn_convolution_backward.h" +#include "aclnnop/aclnn_reduce_sum.h" +#include "aclnnop/aclnn_amax.h" +#include "aclnnop/aclnn_amin.h" +#include "aclnnop/aclnn_mean.h" +#include "aclnnop/aclnn_prod.h" +#include "aclnnop/aclnn_triu.h" +#include "aclnnop/aclnn_s_where.h" +#include "aclnnop/aclnn_random.h" +#include "aclnnop/aclnn_normal.h" +#include "aclnnop/aclnn_permute.h" + + +#define CHECK_RET(cond, return_expr) \ + do \ + { \ + if (!(cond)) \ + { \ + return_expr; \ + } \ + } while (0) + +#define LOG_PRINT(message, ...) \ + do \ + { \ + printf(message, ##__VA_ARGS__); \ + } while (0) + +int64_t GetShapeSize(const std::vector &shape); + +void PrintOutResult(std::vector &shape, void **deviceAddr); + +int Init(int32_t deviceId); + +/* +template +int CreateAclTensor(const std::vector& hostData, const std::vector& shape, void** deviceAddr, + aclDataType dataType, aclTensor** tensor); +*/ diff --git a/python/jittor/extern/acl/tmp_file.cpp b/python/jittor/extern/acl/tmp_file.cpp new file mode 100644 index 00000000..6a80d235 --- /dev/null +++ b/python/jittor/extern/acl/tmp_file.cpp @@ -0,0 +1,306 @@ +#include +#include +#include +#include + +namespace jittor +{ + int CreateAclTensor(const std::vector &shape, void *deviceAddr, int64_t size, + aclDataType dataType, aclTensor **tensor, bool use_nchw = false) + { + // 计算连续tensor的strides + std::vector strides(shape.size(), 1); + for (int64_t i = shape.size() - 2; i >= 0; i--) + { + strides[i] = shape[i + 1] * strides[i + 1]; + } + if (shape.size() == 0) + strides = {}; + // 调用aclCreateTensor接口创建aclTensor + if (use_nchw) + *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_NCHW, + shape.data(), shape.size(), deviceAddr); + else + *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, + shape.data(), shape.size(), deviceAddr); + return 0; + } + + struct AclOpRunner + { + string name; + string jt_name; + vector in_; + vector out_; + std::unique_ptr op_attr; + + AclOpRunner(const string &name) : name(name) + { + } + + ~AclOpRunner() + { + } + + aclDataType get_dtype(NanoString s) + { + if (s == ns_float32) + return ACL_FLOAT; + if (s == ns_float16) + return ACL_FLOAT16; + if (s == ns_int64) + return ACL_INT64; + if (s == ns_int32) + return ACL_INT32; + if (s == ns_int8) + return ACL_INT8; + if (s == ns_int16) + return ACL_INT16; + if (s == ns_uint8) + return ACL_UINT8; + if (s == ns_uint16) + return ACL_UINT16; + if (s == ns_uint32) + return ACL_UINT32; + if (s == ns_bool) + return ACL_BOOL; + LOGf << "Not supported dtype: " << s; + return ACL_FLOAT; + } + + void add(Var *v, bool is_input) + { + + if (is_input) + { + in_.push_back(v); + } + else + { + out_.push_back(v); + } + return; + } + + template + std::vector createVector(int64_t size) + { + return std::vector(size, 0); + } + + void run() + { + // LOGir << name << " " << jt_name; + auto it = aclOpFuncMap.find(name); + if (it == aclOpFuncMap.end()) + { + LOGir << "Not supported op: " << name; + throw std::runtime_error("Unsupported operation type."); + } + + // 0. 算子的输入、输出、需要的attr定义 + std::vector> inputShapes; + std::vector> outputShapes; + + // for reduce + std::vector axes; + aclIntArray *dim = nullptr; + + bool use_nchw = false; + + auto input_num = in_.size(); + + auto output_num = out_.size(); + + for (int input_idx = 0; input_idx < input_num; input_idx++) + { + std::vector shape; + for (int j = 0; j < in_[input_idx]->shape.size(); j++) + { + shape.push_back(in_[input_idx]->shape[j]); + } + inputShapes.push_back(shape); + } + for (int output_idx = 0; output_idx < output_num; output_idx++) + { + std::vector shape; + for (int j = 0; j < out_[output_idx]->shape.size(); j++) + { + shape.push_back(out_[output_idx]->shape[j]); + } + outputShapes.push_back(shape); + } + + // 1. 创建aclTensor和aclScalar,不同算子可能不一样,需要根据具体API的接口定义修改 + std::vector inputTensors; + std::vector outputTensors; + + // for add and sub + aclScalar *alpha = nullptr; + + // for expand + aclIntArray *size = nullptr; + + // for add and sub + float alphaValue = 1.0f; + + // for conv + aclIntArray *strides = nullptr; + aclIntArray *pads = nullptr; + aclIntArray *outPads = nullptr; + aclIntArray *dilations = nullptr; + int ret = -1; + + if (name == string("Add") || name == string("Sub")) + { + alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT); + CHECK_RET(alpha != nullptr, return); + } + + if (jt_name == "conv" || jt_name == "conv2d" || jt_name == "conv2dbackward") + use_nchw = true; + + for (int idx = 0; idx < input_num; idx++) + { + inputTensors.push_back(nullptr); + auto ret = CreateAclTensor(inputShapes[idx], in_[idx]->mem_ptr, in_[idx]->size, get_dtype(in_[idx]->dtype()), &inputTensors[idx], use_nchw); + CHECK_RET(ret == ACL_SUCCESS, return); + } + + if (jt_name == "reduce") + { + auto attr = dynamic_cast(op_attr.get()); + dim = aclCreateIntArray(attr->axes.data(), attr->axes.size()); + + if (name == string("ReduceMax") || name == string("ReduceMin") || name == string("ReduceMean") || name == string("ReduceProd")) + { + if (attr->axes.size() == in_[0]->shape.size()) + outputShapes[0] = {}; + } + } + for (int idx = 0; idx < output_num; idx++) + { + outputTensors.push_back(nullptr); + auto ret = CreateAclTensor(outputShapes[idx], out_[idx]->mem_ptr, out_[idx]->size, get_dtype(out_[idx]->dtype()), &outputTensors[idx], use_nchw); + CHECK_RET(ret == ACL_SUCCESS, return); + } + + // 2. 调用CANN算子库aclnnxxxGetWorkspaceSize的接口,两段式接口的第一个 + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + + if (name == string("Add") || name == string("Sub")) + ret = it->second.getWorkspaceSizeFuncAdd(inputTensors[0], inputTensors[1], alpha, outputTensors[0], &workspaceSize, &executor); + else if (name == string("Expand")) + { + size = aclCreateIntArray(&outputShapes[0][0], outputShapes[0].size()); + ret = it->second.getWorkspaceSizeFuncExpand(inputTensors[0], size, outputTensors[0], &workspaceSize, &executor); + } + else if (name == string("Cast")) + ret = it->second.getWorkspaceSizeFuncCast(inputTensors[0], get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor); + else if (jt_name == "unary") + ret = it->second.getWorkspaceSizeFuncUnary(inputTensors[0], outputTensors[0], &workspaceSize, &executor); + else if (jt_name == "binary") + ret = it->second.getWorkspaceSizeFuncBinary(inputTensors[0], inputTensors[1], outputTensors[0], &workspaceSize, &executor); + else if (jt_name == "bmm" || jt_name == "matmul") + ret = it->second.getWorkspaceSizeFuncMatmul(inputTensors[0], inputTensors[1], outputTensors[0], 1, &workspaceSize, &executor); + else if (name == string("ReduceSum") || name == string("ReduceMean")) + { + ret = it->second.getWorkspaceSizeFuncReduceSum(inputTensors[0], dim, false, get_dtype(out_[0]->dtype()), outputTensors[0], &workspaceSize, &executor); + } + else if (name == string("ReduceMax") || name == string("ReduceMin")) + { + ret = it->second.getWorkspaceSizeFuncAmax(inputTensors[0], dim, false, outputTensors[0], &workspaceSize, &executor); + } + // else if (name == string("ReduceProd")) + // { + // ret = it->second.getWorkspaceSizeFuncReduceProd(inputTensors[0], dim, false, outputTensors[0], &workspaceSize, &executor); + // } + else if (name == string("Select")) + { + ret = it->second.getWorkspaceSizeFuncSelect(inputTensors[0], inputTensors[1], inputTensors[2], outputTensors[0], &workspaceSize, &executor); + } + else if (name == string("Triu")) + { + auto attr = dynamic_cast(op_attr.get()); + ret = it->second.getWorkspaceSizeFuncCast(inputTensors[0], aclDataType(attr->diagonal), outputTensors[0], &workspaceSize, &executor); + } + else if (name == string("Conv2d")) + { + auto attr = dynamic_cast(op_attr.get()); + strides = aclCreateIntArray(attr->convStrides.data(), 2); + pads = aclCreateIntArray(attr->convPads.data(), 2); + outPads = aclCreateIntArray(attr->convOutPads.data(), 2); + dilations = aclCreateIntArray(attr->convDilations.data(), 2); + + ret = it->second.getWorkspaceSizeFuncConv(inputTensors[0], inputTensors[1], nullptr, strides, pads, dilations, false, outPads, attr->group, outputTensors[0], 0, &workspaceSize, &executor); + } + else if (name == string("Conv2dBackward")) + { + auto attr = dynamic_cast(op_attr.get()); + strides = aclCreateIntArray(attr->convStrides.data(), 2); + pads = aclCreateIntArray(attr->convPads.data(), 2); + outPads = aclCreateIntArray(attr->convOutPads.data(), 2); + dilations = aclCreateIntArray(attr->convDilations.data(), 2); + bool outputMask[3] = {true, true, false}; + aclBoolArray *outMask = aclCreateBoolArray(outputMask, 3); + ret = it->second.getWorkspaceSizeFuncConvBackward(inputTensors[0], inputTensors[1], inputTensors[2], nullptr, strides, pads, dilations, false, outPads, attr->group, outMask, 0, outputTensors[0], outputTensors[1], nullptr, &workspaceSize, &executor); + } + else + LOGf << "not supported op " << jt_name; + + // for debug + if (ret != ACL_SUCCESS) + { + auto tmp_err_msg = aclGetRecentErrMsg(); + LOGir << tmp_err_msg; + } + + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxxGetWorkspaceSize failed. ERROR: %d\n", name.c_str(), ret); return); + + // 4. 根据第一段接口计算出的workspaceSize申请device内存 + void *workspaceAddr = nullptr; + if (workspaceSize > 0) + { + ret = aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: allocate workspace failed. ERROR: %d\n", name.c_str(), ret); return); + } + + // 5. 调用aclnnxx第二段接口 + ret = it->second.executeFunc(workspaceAddr, workspaceSize, executor, aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclnnxxx failed. ERROR: %d\n", name.c_str(), ret); return); + + // 6. (固定写法)同步等待任务执行结束 + ret = aclrtSynchronizeStream(aclstream); + CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("%s: aclrtSynchronizeStream failed. ERROR: %d\n", name.c_str(), ret); return); + + // 7. 释放aclTensor和aclScalar,需要根据具体API的接口定义修改 + // destroy tensor + for (int idx = 0; idx < input_num; idx++) + { + aclDestroyTensor(inputTensors[idx]); + } + for (int idx = 0; idx < output_num; idx++) + { + aclDestroyTensor(outputTensors[idx]); + } + // destroy scalar + aclDestroyScalar(alpha); + + // destroy IntArray + aclDestroyIntArray(size); + aclDestroyIntArray(dim); + aclDestroyIntArray(strides); + aclDestroyIntArray(pads); + aclDestroyIntArray(outPads); + aclDestroyIntArray(dilations); + + // 8. 释放device资源 + if (workspaceSize > 0) + { + aclrtFree(workspaceAddr); + } + return; + } + }; +} \ No newline at end of file diff --git a/python/jittor/extern/corex/corex_compiler.py b/python/jittor/extern/corex/corex_compiler.py new file mode 100644 index 00000000..e28d1128 --- /dev/null +++ b/python/jittor/extern/corex/corex_compiler.py @@ -0,0 +1,98 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import os +from jittor_utils import env_or_try_find +import jittor_utils +import ctypes +import glob +import jittor.compiler as compiler + +has_corex = 0 +cc_flags = "" +compiler.has_corex = has_corex + +def install(): + import jittor.compiler as compiler + global has_corex, cc_flags + acl_compiler_home = os.path.dirname(__file__) + cc_files = sorted(glob.glob(acl_compiler_home+"/**/*.cc", recursive=True)) + jittor_utils.LOG.i("COREX detected") + + mod = jittor_utils.compile_module(''' +#include "common.h" +#include "utils/str_utils.h" + +namespace jittor { +// @pyjt(process) +string process_acl(const string& src, const string& name, const map& kargs) { + auto new_src = src; + new_src = replace(new_src, "helper_cuda.h", "../inc/helper_cuda.h"); + if (name == "string_view_map.h") + new_src = replace(new_src, "using std::string_view;", "using string_view = string;"); + if (name == "nan_checker.cu") + new_src = replace(new_src, "__trap()", "assert(0)"); + if (name == "jit_compiler.cc") { + // remove asm tuner + new_src = token_replace_all(new_src, "cmd = python_path$1;", ""); + new_src = token_replace_all(new_src, "JPU(op_compiler($1));", + R"(JPU(op_compiler($1)); + *extra_flags2 = replace(*extra_flags2, "--extended-lambda", ""); + *extra_flags2 = replace(*extra_flags2, "--expt-relaxed-constexpr", ""); + )"); + new_src = token_replace_all(new_src, + "if (is_cuda_op && $1 != string::npos)", + "if (is_cuda_op)"); + } + if (name == "where_op.cc") { + // default where kernel cannot handle 64 warp size, use cub_where instead + new_src = token_replace_all(new_src, "if (cub_where$1) {", "if (cub_where) {"); + } + if (name == "loop_var_analyze_pass.cc") { + new_src = token_replace_all(new_src, "DEFINE_FLAG($1, para_opt_level,$2,$3);", + "DEFINE_FLAG($1, para_opt_level, 4,$3);"); + } + return new_src; +} +}''', compiler.cc_flags + " " + " ".join(cc_files) + cc_flags) + jittor_utils.process_jittor_source("corex", mod.process) + # def nvcc_flags_to_corex(nvcc_flags): + # nvcc_flags = nvcc_flags.replace("--cudart=shared", "") + # nvcc_flags = nvcc_flags.replace("--cudart=shared", "") + + has_corex = 1 + compiler.has_corex = has_corex + corex_home = "/usr/local/corex" + compiler.nvcc_path = corex_home + "/bin/clang++" + compiler.cc_path = compiler.nvcc_path + compiler.cc_flags = compiler.cc_flags.replace("-fopenmp", "") + # compiler.nvcc_flags = cc_flags_to_corex(compiler.cc_flags) + compiler.nvcc_flags = compiler.cc_flags + " -x cu -Ofast -DNO_ATOMIC64 -Wno-c++11-narrowing " + compiler.convert_nvcc_flags = lambda x:x + compiler.is_cuda = 0 + os.environ["use_cutt"] = "0" + compiler.cc_type = "clang" + + +def install_extern(): + return False + + +def check(): + global has_corex, cc_flags + if os.path.isdir("/usr/local/corex"): + try: + install() + except Exception as e: + jittor_utils.LOG.w(f"load COREX failed, exception: {e}") + has_corex = 0 + if not has_corex: return False + return True + +def post_process(): + if not has_corex: return + import jittor.compiler as compiler + compiler.flags.cc_flags = compiler.flags.cc_flags.replace("-fopenmp", "") \ No newline at end of file diff --git a/python/jittor/extern/cuda/cub/inc/cub_test.h b/python/jittor/extern/cuda/cub/inc/cub_test.h new file mode 100644 index 00000000..e8924a77 --- /dev/null +++ b/python/jittor/extern/cuda/cub/inc/cub_test.h @@ -0,0 +1,253 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/****************************************************************************** + * Simple example of DeviceRadixSort::SortPairs(). + * + * Sorts an array of float keys paired with a corresponding array of int values. + * + * To compile using the command line: + * nvcc -arch=sm_XX example_device_radix_sort.cu -I../.. -lcudart -O3 + * + ******************************************************************************/ + +// Ensure printing of CUDA runtime errors to console +#define CUB_STDERR + +#include +#include +#include + +#include +#include + +#include + +using namespace cub; + + +//--------------------------------------------------------------------- +// Globals, constants and typedefs +//--------------------------------------------------------------------- + +bool g_verbose = false; // Whether to display input/output to console + + +//--------------------------------------------------------------------- +// Test generation +//--------------------------------------------------------------------- + +/** + * Simple key-value pairing for floating point types. Distinguishes + * between positive and negative zero. + */ +struct Pair +{ + float key; + int value; + + bool operator<(const Pair &b) const + { + if (key < b.key) + return true; + + if (key > b.key) + return false; + + // Return true if key is negative zero and b.key is positive zero + unsigned int key_bits = *reinterpret_cast(const_cast(&key)); + unsigned int b_key_bits = *reinterpret_cast(const_cast(&b.key)); + unsigned int HIGH_BIT = 1u << 31; + + return ((key_bits & HIGH_BIT) != 0) && ((b_key_bits & HIGH_BIT) == 0); + } +}; + + +/** + * Initialize key-value sorting problem. + */ +void Initialize( + float *h_keys, + int *h_values, + float *h_reference_keys, + int *h_reference_values, + int num_items) +{ + Pair *h_pairs = new Pair[num_items]; + + for (int i = 0; i < num_items; ++i) + { + RandomBits(h_keys[i]); + h_values[i] = i; + h_pairs[i].key = h_keys[i]; + h_pairs[i].value = h_values[i]; + } + + if (g_verbose) + { + printf("Input keys:\n"); + DisplayResults(h_keys, num_items); + printf("\n\n"); + + printf("Input values:\n"); + DisplayResults(h_values, num_items); + printf("\n\n"); + } + + std::stable_sort(h_pairs, h_pairs + num_items); + + for (int i = 0; i < num_items; ++i) + { + h_reference_keys[i] = h_pairs[i].key; + h_reference_values[i] = h_pairs[i].value; + } + + if (g_verbose) + { + printf("std Output keys:\n"); + DisplayResults(h_reference_keys, num_items); + printf("\n\n"); + + printf("std Output values:\n"); + DisplayResults(h_reference_values, num_items); + printf("\n\n"); + } + delete[] h_pairs; +} + + +//--------------------------------------------------------------------- +// Main +//--------------------------------------------------------------------- + +/** + * Main + */ +int cub_test_entry(int argc, char** argv) +{ + CachingDeviceAllocator g_allocator(true); // Caching allocator for device memory + + int num_items = 150; + + // Initialize command line + CommandLineArgs args(argc, argv); + g_verbose = args.CheckCmdLineFlag("v"); + args.GetCmdLineArgument("n", num_items); + + // Print usage + if (args.CheckCmdLineFlag("help")) + { + printf("%s " + "[--n= " + "[--device=] " + "[--v] " + "\n", argv[0]); + exit(0); + } + + // Initialize device + CubDebugExit(args.DeviceInit()); + + printf("cub::DeviceRadixSort::SortPairs() %d items (%d-byte keys %d-byte values)\n", + num_items, int(sizeof(float)), int(sizeof(int))); + fflush(stdout); + + // Allocate host arrays + float *h_keys = new float[num_items]; + float *h_reference_keys = new float[num_items]; + int *h_values = new int[num_items]; + int *h_reference_values = new int[num_items]; + + // Initialize problem and solution on host + Initialize(h_keys, h_values, h_reference_keys, h_reference_values, num_items); + + // Allocate device arrays + DoubleBuffer d_keys; + DoubleBuffer d_values; + CubDebugExit(g_allocator.DeviceAllocate((void**)&d_keys.d_buffers[0], sizeof(float) * num_items)); + CubDebugExit(g_allocator.DeviceAllocate((void**)&d_keys.d_buffers[1], sizeof(float) * num_items)); + CubDebugExit(g_allocator.DeviceAllocate((void**)&d_values.d_buffers[0], sizeof(int) * num_items)); + CubDebugExit(g_allocator.DeviceAllocate((void**)&d_values.d_buffers[1], sizeof(int) * num_items)); + + // Allocate temporary storage + size_t temp_storage_bytes = 0; + void *d_temp_storage = NULL; + + CubDebugExit(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items)); + CubDebugExit(g_allocator.DeviceAllocate(&d_temp_storage, temp_storage_bytes)); + + // Initialize device arrays + CubDebugExit(cudaMemcpy(d_keys.d_buffers[d_keys.selector], h_keys, sizeof(float) * num_items, cudaMemcpyHostToDevice)); + CubDebugExit(cudaMemcpy(d_values.d_buffers[d_values.selector], h_values, sizeof(int) * num_items, cudaMemcpyHostToDevice)); + + // Run + CubDebugExit(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, d_keys, d_values, num_items)); + + // Check for correctness (and display results, if specified) + std::unique_ptr d_keys_ptr(new float[num_items]); + std::unique_ptr d_values_ptr(new int[num_items]); + std::unique_ptr origin(new float[num_items]); + cudaMemcpy(d_keys_ptr.get(), d_keys.Current(), sizeof(float) * num_items, cudaMemcpyDeviceToHost); + cudaMemcpy(d_values_ptr.get(), d_values.Current(), sizeof(int) * num_items, cudaMemcpyDeviceToHost); + bool ok = true; + for (int i=0; i +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "cub_arg_reduce_op.h" +#include +#include "executor.h" +#include "ops/arg_reduce_op.h" +#ifdef JIT_cuda +#include +#include +#endif + +namespace jittor { + +#ifndef JIT +CubArgReduceOp::CubArgReduceOp(Var* x, Var* offsets, NanoString op, bool keepdims) + : x(x), offsets(offsets), op(op), keepdims(keepdims) { + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_cuda, 1); + ASSERT(offsets->dtype()==ns_int32); + y = create_output(nullptr, ns_int32); + y_key = create_output(nullptr, x->dtype()); + flags.set(NodeFlags::_manual_set_vnbb); + y->flags.set(NodeFlags::_needed_by_backward); +} + +VarPtr CubArgReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) { + return ArgReduceOp::get_grad(out, dout, v, v_index, v->shape.size()-1, y); +} + +void CubArgReduceOp::infer_shape() { + int n = 1; + for (int i = 0; i < x->shape.size(); ++i) { + if (i < x->shape.size() - 1) { + n *= x->shape[i]; + } + } + ASSERT(offsets->shape.size() == 1); + ASSERT(offsets->shape[0] == n + 1); + NanoVector shape; + for (int i = 0; i < x->shape.size() - 1; ++i) { + shape.push_back(x->shape[i]); + } + if (keepdims) { + shape.push_back(1); + } + if (shape.size() == 0) + shape.push_back(1); + y->set_shape(shape); + y_key->set_shape(shape); +} + +void CubArgReduceOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype(); + jk << "«Toffsets:" << offsets->dtype(); + jk << "«FUNC:"; + if (op==ns_minimum) + jk << "ArgMin"; + else + jk << "ArgMax"; +} + +#else // JIT +#ifdef JIT_cuda + +static __global__ void split(cub::KeyValuePair* a, Tx* key, int* val, int n) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int tnum = blockDim.x * gridDim.x; + for (int i=tid; iptr(); + auto* __restrict__ offsetsp = offsets->ptr(); + + int num_segments = 1; + for (int i = 0; i < x->shape.size() - 1; ++i) { + num_segments *= x->shape[i]; + } + size_t allocation_dout; + cub::KeyValuePair *d_out = (cub::KeyValuePair *)exe.temp_allocator->alloc(sizeof(cub::KeyValuePair) * num_segments, allocation_dout); + + // Determine temporary device storage requirementse = NULL; + void *d_temp_storage = NULL; + size_t temp_storage_bytes = 0; + cub::DeviceSegmentedReduce::@FUNC@@(d_temp_storage, temp_storage_bytes, + xp, d_out, num_segments, offsetsp, offsetsp + 1); + // Allocate temporary storage + size_t allocation; + d_temp_storage = exe.temp_allocator->alloc(temp_storage_bytes, allocation); + // Run sorting operation + cub::DeviceSegmentedReduce::@FUNC@@(d_temp_storage, temp_storage_bytes, + xp, d_out, num_segments, offsetsp, offsetsp + 1); + + auto* __restrict__ yp = y->ptr(); + auto* __restrict__ y_keyp = y_key->ptr(); + split<<>>(d_out, y_keyp, yp, num_segments); + + exe.temp_allocator->free(d_temp_storage, temp_storage_bytes, allocation); + exe.temp_allocator->free(d_out, sizeof(cub::KeyValuePair) * num_segments, allocation_dout); +} +#endif // JIT_cuda +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/cub/ops/cub_arg_reduce_op.h b/python/jittor/extern/cuda/cub/ops/cub_arg_reduce_op.h new file mode 100644 index 00000000..8bef8691 --- /dev/null +++ b/python/jittor/extern/cuda/cub/ops/cub_arg_reduce_op.h @@ -0,0 +1,29 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + + +namespace jittor { + +struct CubArgReduceOp : Op { + Var* x, * offsets, * y, * y_key; + NanoString op; + bool keepdims; + // @attrs(multiple_outputs) + CubArgReduceOp(Var* x, Var* offsets, NanoString op, bool keepdims); + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; + + const char* name() const override { return "cub_arg_reduce"; } + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/cub/ops/cub_argsort_op.cc b/python/jittor/extern/cuda/cub/ops/cub_argsort_op.cc new file mode 100644 index 00000000..d0503528 --- /dev/null +++ b/python/jittor/extern/cuda/cub/ops/cub_argsort_op.cc @@ -0,0 +1,100 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "cub_argsort_op.h" +#include +#include "executor.h" +#include "ops/argsort_op.h" +#ifdef JIT_cuda +#include +#endif + +namespace jittor { + +#ifndef JIT +CubArgsortOp::CubArgsortOp(Var* x, Var* indexes, Var* offsets, bool descending, NanoString dtype) + : x(x), indexes(indexes), offsets(offsets), descending(descending) { + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_cuda, 1); + ASSERT(offsets->dtype()==ns_int32); + y = create_output(nullptr, dtype); + y_key = create_output(nullptr, x->dtype()); + flags.set(NodeFlags::_manual_set_vnbb); + y->flags.set(NodeFlags::_needed_by_backward); +} + +VarPtr CubArgsortOp::grad(Var* out, Var* dout, Var* v, int v_index) { + return ArgsortOp::get_grad(out, dout, v, v_index, v->shape.size()-1, y); +} + +void CubArgsortOp::infer_shape() { + ASSERT(x->shape.size() == indexes->shape.size()); + int n = 1; + for (int i = 0; i < x->shape.size(); ++i) { + ASSERT(x->shape[i] == indexes->shape[i]); + if (i < x->shape.size() - 1) { + n *= x->shape[i]; + } + } + ASSERT(offsets->shape.size() == 1); + ASSERT(offsets->shape[0] == n + 1); + y->set_shape(x->shape); + y_key->set_shape(x->shape); +} + +void CubArgsortOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype(); + jk << "«Tindexes:" << indexes->dtype(); + jk << "«Toffsets:" << offsets->dtype(); + jk << "«Ty:" << y->dtype(); + jk << "«FUNC:"; + if (descending) + jk << "SortPairsDescending"; + else + jk << "SortPairs"; +} + +#else // JIT +#ifdef JIT_cuda +void CubArgsortOp::jit_run() { + auto* __restrict__ xp = x->ptr(); + auto* __restrict__ indexesp = indexes->ptr(); + auto* __restrict__ offsetsp = offsets->ptr(); + + int num_items = 1, num_segments = 1; + for (int i = 0; i < x->shape.size(); ++i) { + num_items *= x->shape[i]; + if (i < x->shape.size() - 1) { + num_segments *= x->shape[i]; + } + } + auto* __restrict__ yp = y->ptr(); + auto* __restrict__ y_keyp = y_key->ptr(); + + // Determine temporary device storage requirementse = NULL; + void *d_temp_storage = NULL; + size_t temp_storage_bytes = 0; + cub::DeviceSegmentedRadixSort::@FUNC@@(d_temp_storage, temp_storage_bytes, + xp, y_keyp, indexesp, yp, + num_items, num_segments, offsetsp, offsetsp + 1); + // Allocate temporary storage + size_t allocation; + d_temp_storage = exe.temp_allocator->alloc(temp_storage_bytes, allocation); + // Run sorting operation + cub::DeviceSegmentedRadixSort::@FUNC@@(d_temp_storage, temp_storage_bytes, + xp, y_keyp, indexesp, yp, + num_items, num_segments, offsetsp, offsetsp + 1); + exe.temp_allocator->free(d_temp_storage, temp_storage_bytes, allocation); +} +#endif // JIT_cuda +#endif // JIT + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/cub/ops/cub_argsort_op.h b/python/jittor/extern/cuda/cub/ops/cub_argsort_op.h new file mode 100644 index 00000000..47659101 --- /dev/null +++ b/python/jittor/extern/cuda/cub/ops/cub_argsort_op.h @@ -0,0 +1,28 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + + +namespace jittor { + +struct CubArgsortOp : Op { + Var* x, * indexes, * offsets, * y, * y_key; + bool descending; + // @attrs(multiple_outputs) + CubArgsortOp(Var* x, Var* indexes, Var* offsets, bool descending=false, NanoString dtype=ns_int32); + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; + + const char* name() const override { return "cub_argsort"; } + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/cub/ops/cub_cumsum_op.cc b/python/jittor/extern/cuda/cub/ops/cub_cumsum_op.cc new file mode 100644 index 00000000..ca8d0687 --- /dev/null +++ b/python/jittor/extern/cuda/cub/ops/cub_cumsum_op.cc @@ -0,0 +1,132 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "cub_cumsum_op.h" +#include +#include "executor.h" +#include "ops/op_register.h" +#ifdef JIT_cuda +#include +#include +#include +#endif + +namespace jittor { + +#ifndef JIT + +static auto make_cub_cumsum = get_op_info("cub_cumsum") + .get_constructor(); + +CubCumsumOp::CubCumsumOp(Var* x, bool reverse) : x(x),reverse(reverse) { + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_cuda, 1); + y = create_output(nullptr, x->dtype()); +} + +void CubCumsumOp::infer_shape() { + ASSERT(x->shape.size() == 1 || x->shape.size() == 2); //TODO:support batch_cumsum + y->set_shape(x->shape); +} + +void CubCumsumOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype(); + jk << "«Ty:" << y->dtype(); + jk << "«reverse:" << reverse; +} + +VarPtr CubCumsumOp::grad(Var* out, Var* dout, Var* v, int v_index) { + return make_cub_cumsum(dout, !reverse); + // return ArgsortOp::get_grad(out, dout, v, v_index, v->shape.size()-1, y); +} + +#else // JIT +#ifdef JIT_cuda + +#define ITEMS_PER_THREAD 4 +#define BLOCK_THREADS 1024 + +__global__ void BlockScanKernel(Tx* __restrict__ xp, Ty* __restrict__ yp, int batch_num, int num_items) { + typedef cub::BlockScan BlockScanT; + __shared__ typename BlockScanT::TempStorage temp_storage; + + int batch_id = blockIdx.x; + int offset = threadIdx.x * ITEMS_PER_THREAD; + __shared__ Tx prefix_sum[1]; + prefix_sum[0] = 0; + + for (int block_offset = offset; block_offset < num_items; block_offset += BLOCK_THREADS * ITEMS_PER_THREAD) { + int items = ITEMS_PER_THREAD; + if (block_offset + ITEMS_PER_THREAD > num_items) { + items = num_items - block_offset; + } + Tx thread_data[ITEMS_PER_THREAD] = {0}; + #pragma unroll + for (int i = 0; i < ITEMS_PER_THREAD; ++i) { + if (iptr(); + auto* __restrict__ yp = y->ptr(); + if (x->shape.size() == 1){ + int num_items = x->shape[0]; + + // Determine temporary device storage requirements for inclusive prefix sum + void *d_temp_storage = NULL; + size_t temp_storage_bytes = 0, temp_storage_allocation; + cub::DeviceScan::InclusiveSum(NULL, temp_storage_bytes, xp, yp, num_items); + d_temp_storage = exe.temp_allocator->alloc(temp_storage_bytes, temp_storage_allocation); + // Allocate temporary storage for inclusive prefix sum + // cudaMalloc(&d_temp_storage, temp_storage_bytes); + // Run inclusive prefix sum + if (reverse) { + auto xp_ = thrust::make_reverse_iterator(xp + num_items); + auto yp_ = thrust::make_reverse_iterator(yp + num_items); + cub::DeviceScan::InclusiveSum(d_temp_storage, temp_storage_bytes, xp_, yp_, num_items); + } else { + cub::DeviceScan::InclusiveSum(d_temp_storage, temp_storage_bytes, xp, yp, num_items); + } + // yp <-- [8, 14, 21, 26, 29, 29, 38] + exe.temp_allocator->free(d_temp_storage, temp_storage_bytes, temp_storage_allocation); + } else { + int batch_num = x->shape[0]; + int num_items = x->shape[1]; + BlockScanKernel<<>>(xp, yp, batch_num, num_items); + } +} +#endif // JIT_cuda +#endif // JIT + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/cub/ops/cub_cumsum_op.h b/python/jittor/extern/cuda/cub/ops/cub_cumsum_op.h new file mode 100644 index 00000000..aa17a6d5 --- /dev/null +++ b/python/jittor/extern/cuda/cub/ops/cub_cumsum_op.h @@ -0,0 +1,28 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + + +namespace jittor { + +struct CubCumsumOp : Op { + Var* x, * y; + bool reverse; + + CubCumsumOp(Var* x, bool reverse=false); + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + + void infer_shape() override; + const char* name() const override { return "cub_cumsum"; } + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/cub/ops/cub_test_op.cc b/python/jittor/extern/cuda/cub/ops/cub_test_op.cc new file mode 100644 index 00000000..42395b5e --- /dev/null +++ b/python/jittor/extern/cuda/cub/ops/cub_test_op.cc @@ -0,0 +1,44 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include + +#include "var.h" +#include "cub_test_op.h" +#include "utils/str_utils.h" + +#ifdef JIT +#include "cub_test.h" +#endif + +namespace jittor { + +#ifndef JIT +CubTestOp::CubTestOp(string cmd) : cmd(cmd) { + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_cuda, 1); + output = create_output(1, ns_float32); +} + +void CubTestOp::jit_prepare(JK& jk) { + jk << "«T:float32"; +} + +#else // JIT +#ifdef JIT_cuda +void CubTestOp::jit_run() { + auto args = split(cmd, " "); + if (!cmd.size()) args.clear(); + vector v(args.size()); + for (uint i=0; iptr()[0] = 123; +} +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/cub/ops/cub_test_op.h b/python/jittor/extern/cuda/cub/ops/cub_test_op.h new file mode 100644 index 00000000..be8aa97d --- /dev/null +++ b/python/jittor/extern/cuda/cub/ops/cub_test_op.h @@ -0,0 +1,22 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CubTestOp : Op { + Var* output; + string cmd; + + CubTestOp(string cmd); + + const char* name() const override { return "cub_test"; } + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/cub/ops/cub_where_op.cc b/python/jittor/extern/cuda/cub/ops/cub_where_op.cc new file mode 100644 index 00000000..a6c538bc --- /dev/null +++ b/python/jittor/extern/cuda/cub/ops/cub_where_op.cc @@ -0,0 +1,116 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Xiangli Li <1905692338@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "cub_where_op.h" +#ifdef JIT_cuda +#include "executor.h" +#include +#include "helper_cuda.h" +#include +#include +#include +#include +#include +#endif + +namespace jittor { + +#ifndef JIT +CubWhereOp::CubWhereOp(Var* cond, NanoString dtype) : cond(cond) { + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + auto ndim = cond->shape.size(); + outs.reset(new Var*[ndim]); + for (uint i=0; ishape.size(); + auto num = -cond->num; + for (uint i=0; iset_shape({num}); +} + +void CubWhereOp::jit_prepare(JK& jk) { + jk << "«Ti:" << cond->dtype(); + jk << "«To:" << outs[0]->dtype(); + jk << "«NDIM=" << JK::hex1(cond->shape.size()); +} + +#else // JIT +#ifdef JIT_cuda + +template +struct NonZeroOp +{ + __host__ __device__ __forceinline__ bool operator()(const T& a) const { + return (a!=T(0)); + } +}; + +__global__ static void where_kernel( + int n, + To* input + @for(i, 0, NDIM, 1, ,index_t shape_@i, To* out_@i) +) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int tnum = gridDim.x * blockDim.x; + for (index_t i=tid; inum; + size_t temp_storage_bytes=0; + size_t num_nonzeros_allocation; + auto num_nonzeros = exe.temp_allocator->alloc(sizeof(To), num_nonzeros_allocation); + + size_t temp_storage_allocation; + void* temp_storage; + + To* out_temp = outs[0]->ptr(); + + cub::CountingInputIterator counting_itr(0); + cub::TransformInputIterator, Ti*> itr(cond->ptr(), NonZeroOp()); + temp_storage_bytes = 0; + checkCudaErrors(cub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, counting_itr, itr, out_temp, (To*)num_nonzeros, N)); + temp_storage = exe.temp_allocator->alloc(temp_storage_bytes, temp_storage_allocation); + checkCudaErrors(cub::DeviceSelect::Flagged(temp_storage, temp_storage_bytes, counting_itr, itr,out_temp, (To*)num_nonzeros, N)); + exe.temp_allocator->free(temp_storage, temp_storage_bytes, temp_storage_allocation); + + To num_nonzeros_h; + cudaMemcpy(&num_nonzeros_h, num_nonzeros, sizeof(To), cudaMemcpyDeviceToHost); + @for(i, 0, NDIM, outs[@i]->set_shape({num_nonzeros_h});) + + if (num_nonzeros_h > 0 && NDIM > 1) { + int thread_num = std::min(1024, num_nonzeros_h); + int block_num = std::max(1, num_nonzeros_h/1024); + where_kernel<<>>( + num_nonzeros_h, + out_temp + @for(i, 0, NDIM, 1, , cond->shape[@i], outs[@i]->ptr()) + ); + } + exe.temp_allocator->free(num_nonzeros, sizeof(int), num_nonzeros_allocation); + +} +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/cub/ops/cub_where_op.h b/python/jittor/extern/cuda/cub/ops/cub_where_op.h new file mode 100644 index 00000000..0497ffb1 --- /dev/null +++ b/python/jittor/extern/cuda/cub/ops/cub_where_op.h @@ -0,0 +1,41 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Xiangli Li <1905692338@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + + +namespace jittor { + +struct CubWhereOp : Op { + Var* cond; + unique_ptr outs; + /** + Where Operator generate index of true condition. + + * [in] cond: condition for index generation + + * [in] dtype: type of return indexes + + * [out] out: return an array of indexes, same length with number of dims of cond + + Example:: + + jt.where([[0,0,1],[1,0,0]]) + # return ( [0,2], [1,0] ) + */ + // @attrs(multiple_outputs) + + CubWhereOp(Var* cond, NanoString dtype=ns_int32); + void infer_shape() override; + const char* name() const override { return "cub_where"; } + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h b/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h new file mode 100644 index 00000000..7667e495 --- /dev/null +++ b/python/jittor/extern/cuda/cublas/inc/cublas_wrapper.h @@ -0,0 +1,35 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include + +#include "utils/log.h" +#include "helper_cuda.h" +#include "fp16_emu.h" +#include "common.h" +#include "misc/nano_string.h" + +namespace jittor { + +EXTERN_LIB cublasHandle_t cublas_handle; + +static inline cudaDataType get_dtype(NanoString dtype) { + if (dtype == ns_float32) return CUDA_R_32F; + if (dtype == ns_float64) return CUDA_R_64F; + if (dtype == ns_float16) return CUDA_R_16F; + #ifndef IS_ROCM + if (dtype == ns_bfloat16) return CUDA_R_16BF; + #endif + LOGf << "not support type" << dtype; + return CUDA_R_32F; +} + +} // jittor diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_acc_matmul_op.cc b/python/jittor/extern/cuda/cublas/ops/cublas_acc_matmul_op.cc new file mode 100644 index 00000000..7ccfa00e --- /dev/null +++ b/python/jittor/extern/cuda/cublas/ops/cublas_acc_matmul_op.cc @@ -0,0 +1,138 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guowei Yang <471184555@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** + +#include "var.h" +#include "cublas_acc_matmul_op.h" +#include "cublas_wrapper.h" + +using namespace std; + +namespace jittor { + +extern int use_tensorcore; + +#ifndef JIT + +CublasAccMatmulOp::CublasAccMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b, int stride_a, int stride_b, int offset_a, int offset_b) + : a(a), b(b), trans_a(trans_a), trans_b(trans_b),stride_a(stride_a),stride_b(stride_b),offset_a(offset_a),offset_b(offset_b) { + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_manual_set_vnbb); + a->flags.set(NodeFlags::_needed_by_backward); + b->flags.set(NodeFlags::_needed_by_backward); + // TODO: support int8 * int8 + ASSERT(a->dtype().is_float() && b->dtype().is_float()) << "type of two inputs should be the same"; + // TODO: support diffrent input type + ASSERT(a->dtype().dsize() == b->dtype().dsize()) << "type of two inputs should be the same"; + c = create_output(nullptr, a->dtype()); +} + +void CublasAccMatmulOp::infer_shape() { + ASSERTop(a->shape.size(),==,2); + ASSERTop(b->shape.size(),==,2); + int n = a->shape[0], m = a->shape[1]; + int m_ = b->shape[0], k = b->shape[1]; + if (trans_a) { + swap(n, m); + } + if (trans_b) { + swap(m_, k); + } + ASSERTop(m,==,m_); + if(stride_a != -1) + n = stride_a; + if(stride_b != -1) + k = stride_b; + c->set_shape({n, k}); +} + +void CublasAccMatmulOp::jit_prepare(JK& jk) { + jk << "«T:" << a->dtype(); + jk << "«Trans_a:" << (trans_a ? 'T' : 'N'); + jk << "«Trans_b:" << (trans_b ? 'T' : 'N'); + jk << "«op:" << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D')); +} + +#else // JIT +#pragma clang diagnostic ignored "-Wtautological-compare" + +void CublasAccMatmulOp::jit_run() { + cublasHandle_t& handle_ = cublas_handle; + const T alpha = 1.0f; + const T beta = 0.0f; + + const auto& as = a->shape; + const auto& bs = b->shape; + auto n = as[0]; + auto m = as[1]; + auto k = bs[1]; + if ('@Trans_a'=='T') { + n = as[1]; + m = as[0]; + } + if ('@Trans_b'=='T') { + k = bs[0]; + } + bool has_fp16_or_bf16 = a->dtype() == ns_float16 + || b->dtype() == ns_float16 || c->dtype() == ns_float16 + || a->dtype() == ns_bfloat16 + || b->dtype() == ns_bfloat16 || c->dtype() == ns_bfloat16; + + // a: [n,m], b: [m,k], c: [n,k] + #if CUDART_VERSION >= 11000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; + cublasComputeType_t computeType = CUBLAS_COMPUTE_32F; + if (use_tensorcore>=3) { + computeType = CUBLAS_COMPUTE_32F_FAST_16F; + } else if (use_tensorcore==2) { + computeType = CUBLAS_COMPUTE_32F_FAST_16BF; + } else if (use_tensorcore==1) { + computeType = CUBLAS_COMPUTE_32F_FAST_TF32; + } + if (has_fp16_or_bf16) { + computeType = CUBLAS_COMPUTE_16F; + } + #else + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; + cudaDataType_t computeType = get_dtype(c->dtype()); + if (use_tensorcore) { + algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + } + if (has_fp16_or_bf16) { + computeType = CUDA_R_16F; + algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + } + #endif + int ldb, lda; + ldb = '@Trans_b' == 'N' ? k : m; + lda = '@Trans_a' == 'N' ? m : n; + if(stride_b != -1) + k = stride_b; + // if(stride_a != -1) + // n = stride_a; + checkCudaErrors(cublasGemmEx(handle_, + CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, + k, n, m, &alpha, + b->ptr() + offset_b,get_dtype(b->dtype()), ldb, + a->ptr() + offset_a,get_dtype(a->dtype()), lda, &beta, + c->ptr(),get_dtype(c->dtype()), k, + computeType, algo)); + // checkCudaErrors(cublas@op@@gemm(handle_, + // CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, + // k, n, m, &alpha, + // b->ptr(), '@Trans_b' == 'N' ? k : m, + // a->ptr(), '@Trans_a' == 'N' ? m : n, &beta, + // c->ptr(), k)); + + +} +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_acc_matmul_op.h b/python/jittor/extern/cuda/cublas/ops/cublas_acc_matmul_op.h new file mode 100644 index 00000000..b24acdb7 --- /dev/null +++ b/python/jittor/extern/cuda/cublas/ops/cublas_acc_matmul_op.h @@ -0,0 +1,27 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CublasAccMatmulOp : Op { + Var* a, * b, * c; + bool trans_a, trans_b; + int stride_a, stride_b; + int offset_a, offset_b; + CublasAccMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b, int stride_a=-1, int stride_b=-1, int offset_a=0, int offset_b=0); + + const char* name() const override { return "cublas_acc_matmul"; } + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc b/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc new file mode 100644 index 00000000..a69006b4 --- /dev/null +++ b/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.cc @@ -0,0 +1,180 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Meng-Hao Guo +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** + + +// cublas_batched_matmul_op.cc +#include "var.h" + +#include "cublas_batched_matmul_op.h" +#include "cublas_wrapper.h" + +using namespace std; + +namespace jittor { + +extern int use_tensorcore; + +#ifndef JIT + +static auto make_cublas_batched_matmul = get_op_info("cublas_batched_matmul") + .get_constructor(); + +CublasBatchedMatmulOp::CublasBatchedMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b) + : a(a), b(b), trans_a(trans_a), trans_b(trans_b) { + // TODO: support int8 * int8 + ASSERT(a->dtype().is_float() && b->dtype().is_float()) << "type of two inputs should be the same"; + // TODO: support diffrent input type + ASSERT(a->dtype().dsize() == b->dtype().dsize()) << "type of two inputs should be the same"; + c = create_output(nullptr, a->dtype()); + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_manual_set_vnbb); + a->flags.set(NodeFlags::_needed_by_backward); + b->flags.set(NodeFlags::_needed_by_backward); +} + + +VarPtr CublasBatchedMatmulOp::grad(Var* out, Var* dout, Var* v, int v_index) { + // a [b,n,m] b [b,m,k], c[b,n,k] + // c = a*b + if (v_index == 0) { + if (trans_a) + return make_cublas_batched_matmul(b, dout, trans_b, 1); + else + // da = dc*b^T + return make_cublas_batched_matmul(dout, b, 0, trans_b^1); + } else { + if (trans_b) + return make_cublas_batched_matmul(dout, a, 1, trans_a); + else + // db = a^T*dc + return make_cublas_batched_matmul(a, dout, trans_a^1, 0); + } +} + +void CublasBatchedMatmulOp::infer_shape(){ + auto adim = a->shape.size(); + auto bdim = b->shape.size(); + ASSERTop(adim,>=,3); + ASSERTop(bdim,>=,3); + ASSERTop(adim,==,bdim); + + auto n = a->shape[adim-2], m = a->shape[adim-1]; + auto m_ = b->shape[adim-2], k = b->shape[adim-1]; + + NanoVector c_shape; + + for (int i=0; ishape[i],==,b->shape[i]); + c_shape.push_back(a->shape[i]); + } + if (trans_a) { + swap(n, m); + } + if (trans_b) { + swap(m_, k); + } + ASSERTop(m,==,m_); + c_shape.push_back(n); + c_shape.push_back(k); + + c->set_shape(c_shape); +} + +void CublasBatchedMatmulOp::jit_prepare(JK& jk) { + jk << "«T:" << a->dtype(); + jk << "«Trans_a:" << (trans_a ? 'T' : 'N'); + jk << "«Trans_b:" << (trans_b ? 'T' : 'N'); + jk << "«op:" << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D')); +} + +#else // JIT +#ifdef JIT_cuda +#pragma clang diagnostic ignored "-Wtautological-compare" +void CublasBatchedMatmulOp::jit_run() { + cublasHandle_t& handle_ = cublas_handle; + const T alpha = 1.0f; + const T beta = 0.0f; + const float alpha_f = 1.0f; + const float beta_f = 0.0f; + void* alpha_p = (void*)&alpha_f; + void* beta_p = (void*)&beta_f; + + const auto& as = a->shape; + const auto& bs = b->shape; + auto adim = as.size(); + auto batch_size = as[0]; + for (int i=1; idtype() == ns_float16 + || b->dtype() == ns_float16 || c->dtype() == ns_float16 + || a->dtype() == ns_bfloat16 + || b->dtype() == ns_bfloat16 || c->dtype() == ns_bfloat16; + // a: [b,n,m], b: [b,m,k], c: [b,n,k] + #if CUDART_VERSION >= 11000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; + cublasComputeType_t computeType = CUBLAS_COMPUTE_32F; + if (use_tensorcore>=3) { + computeType = CUBLAS_COMPUTE_32F_FAST_16F; + } else if (use_tensorcore==2) { + computeType = CUBLAS_COMPUTE_32F_FAST_16BF; + } else if (use_tensorcore==1) { + computeType = CUBLAS_COMPUTE_32F_FAST_TF32; + } + if (has_fp16_or_bf16) { + computeType = use_tensorcore ? CUBLAS_COMPUTE_16F : CUBLAS_COMPUTE_32F; + algo = use_tensorcore ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP; + } + if (computeType == CUBLAS_COMPUTE_16F) { + alpha_p = (void*)α + beta_p = (void*)β + } + #else + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; + cudaDataType_t computeType = CUDA_R_32F; + if (use_tensorcore) { + algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + } + if (has_fp16_or_bf16) { + computeType = CUDA_R_16F; + algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + } + #endif + checkCudaErrors(cublasGemmStridedBatchedEx(handle_, + CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, + k, n, m, alpha_p, + b->ptr(),get_dtype(b->dtype()), '@Trans_b' == 'N' ? k : m, k * m, + a->ptr(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, n * m, beta_p, + c->ptr(),get_dtype(c->dtype()), k, k * n, + batch_size,computeType,algo)); + // checkCudaErrors(cublas@op@@gemmStridedBatched(handle_, + // CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, + // k, n, m, &alpha, + // b->ptr(), '@Trans_b' == 'N' ? k : m, k * m, + // a->ptr(), '@Trans_a' == 'N' ? m : n, n * m, &beta, + // c->ptr(), k, k * n, + // batch_size)); +} +#endif +#endif // JIT + +} // jittor + + diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.h b/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.h new file mode 100644 index 00000000..1be0fda2 --- /dev/null +++ b/python/jittor/extern/cuda/cublas/ops/cublas_batched_matmul_op.h @@ -0,0 +1,31 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Meng-Hao Guo +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** + + +// cublas_batched_matmul_op.h +#pragma once +#include "op.h" +#include "ops/op_register.h" +#include "var.h" + +namespace jittor { + +struct CublasBatchedMatmulOp : Op { + Var* a, * b, * c; + bool trans_a, trans_b; + CublasBatchedMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b); + + const char* name() const override { return "cublas_batched_matmul"; } + void infer_shape() override; + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + DECLARE_jit_run; +}; + +} // jittor diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc b/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc new file mode 100644 index 00000000..4a38742c --- /dev/null +++ b/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc @@ -0,0 +1,135 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guowei Yang <471184555@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** + +#include "var.h" +#include "cublas_matmul_op.h" +#include "cublas_wrapper.h" + +using namespace std; + +namespace jittor { + +extern int use_tensorcore; + +#ifndef JIT + +CublasMatmulOp::CublasMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b) + : a(a), b(b), trans_a(trans_a), trans_b(trans_b) { + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_manual_set_vnbb); + a->flags.set(NodeFlags::_needed_by_backward); + b->flags.set(NodeFlags::_needed_by_backward); + // TODO: support int8 * int8 + ASSERT(a->dtype().is_float() && b->dtype().is_float()) << "type of two inputs should be the same"; + // TODO: support diffrent input type + ASSERT(a->dtype().dsize() == b->dtype().dsize()) << "type of two inputs should be the same"; + c = create_output(nullptr, a->dtype()); +} + +void CublasMatmulOp::infer_shape() { + ASSERTop(a->shape.size(),==,2); + ASSERTop(b->shape.size(),==,2); + int n = a->shape[0], m = a->shape[1]; + int m_ = b->shape[0], k = b->shape[1]; + if (trans_a) { + swap(n, m); + } + if (trans_b) { + swap(m_, k); + } + ASSERTop(m,==,m_); + c->set_shape({n, k}); +} + +void CublasMatmulOp::jit_prepare(JK& jk) { + jk << "«T:" << a->dtype(); + jk << "«Trans_a:" << (trans_a ? 'T' : 'N'); + jk << "«Trans_b:" << (trans_b ? 'T' : 'N'); + jk << "«op:" << (a->dtype().dsize() == 2? 'H' : (a->dtype().dsize() == 4 ? 'S' : 'D')); +} + +#else // JIT +#pragma clang diagnostic ignored "-Wtautological-compare" + +void CublasMatmulOp::jit_run() { + cublasHandle_t& handle_ = cublas_handle; + const T alpha = 1.0f; + const T beta = 0.0f; + const float alpha_f = 1.0f; + const float beta_f = 0.0f; + void* alpha_p = (void*)&alpha_f; + void* beta_p = (void*)&beta_f; + + const auto& as = a->shape; + const auto& bs = b->shape; + auto n = as[0]; + auto m = as[1]; + auto k = bs[1]; + if ('@Trans_a'=='T') { + n = as[1]; + m = as[0]; + } + if ('@Trans_b'=='T') { + k = bs[0]; + } + bool has_fp16_or_bf16 = a->dtype() == ns_float16 + || b->dtype() == ns_float16 || c->dtype() == ns_float16 + || a->dtype() == ns_bfloat16 + || b->dtype() == ns_bfloat16 || c->dtype() == ns_bfloat16; + // a: [n,m], b: [m,k], c: [n,k] + #if CUDART_VERSION >= 11000 + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; + cublasComputeType_t computeType = CUBLAS_COMPUTE_32F; + if (use_tensorcore>=3) { + computeType = CUBLAS_COMPUTE_32F_FAST_16F; + } else if (use_tensorcore==2) { + computeType = CUBLAS_COMPUTE_32F_FAST_16BF; + } else if (use_tensorcore==1) { + computeType = CUBLAS_COMPUTE_32F_FAST_TF32; + } + if (has_fp16_or_bf16) { + computeType = use_tensorcore ? CUBLAS_COMPUTE_16F : CUBLAS_COMPUTE_32F; + algo = use_tensorcore ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP; + } + if (computeType == CUBLAS_COMPUTE_16F) { + alpha_p = (void*)α + beta_p = (void*)β + } + #else + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; + cudaDataType_t computeType = get_dtype(c->dtype()); + if (use_tensorcore) { + algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + } + if (has_fp16_or_bf16) { + computeType = CUDA_R_16F; + algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + } + #endif + checkCudaErrors(cublasGemmEx(handle_, + CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, + k, n, m, alpha_p, + b->ptr(),get_dtype(b->dtype()), '@Trans_b' == 'N' ? k : m, + a->ptr(),get_dtype(a->dtype()), '@Trans_a' == 'N' ? m : n, beta_p, + c->ptr(),get_dtype(c->dtype()), k, + computeType, algo)); + // checkCudaErrors(cublas@op@@gemm(handle_, + // CUBLAS_OP_@Trans_b, CUBLAS_OP_@Trans_a, + // k, n, m, &alpha, + // b->ptr(), '@Trans_b' == 'N' ? k : m, + // a->ptr(), '@Trans_a' == 'N' ? m : n, &beta, + // c->ptr(), k)); + + +} +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.h b/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.h new file mode 100644 index 00000000..7f1d6f50 --- /dev/null +++ b/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.h @@ -0,0 +1,25 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CublasMatmulOp : Op { + Var* a, * b, * c; + bool trans_a, trans_b; + CublasMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b); + + const char* name() const override { return "cublas_matmul"; } + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_test_op.cc b/python/jittor/extern/cuda/cublas/ops/cublas_test_op.cc new file mode 100644 index 00000000..970e2920 --- /dev/null +++ b/python/jittor/extern/cuda/cublas/ops/cublas_test_op.cc @@ -0,0 +1,34 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include + +#include "var.h" +#include "cublas_test_op.h" + +int cublas_test_entry(int); + +namespace jittor { + +#ifndef JIT +CublasTestOp::CublasTestOp(int size_mult) : size_mult(size_mult) { + output = create_output(1, ns_float32); +} + +void CublasTestOp::jit_prepare(JK& jk) { + jk << "«T:float32"; +} + +#else // JIT +#ifdef JIT_cpu +void CublasTestOp::jit_run() { + ASSERT(cublas_test_entry(size_mult)==0); + output->ptr()[0] = 123; +} +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_test_op.h b/python/jittor/extern/cuda/cublas/ops/cublas_test_op.h new file mode 100644 index 00000000..68767efb --- /dev/null +++ b/python/jittor/extern/cuda/cublas/ops/cublas_test_op.h @@ -0,0 +1,22 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CublasTestOp : Op { + Var* output; + int size_mult; + + CublasTestOp(int size_mult); + + const char* name() const override { return "cublas_test"; } + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/cublas/src/cublas_matmul_test.cc b/python/jittor/extern/cuda/cublas/src/cublas_matmul_test.cc new file mode 100644 index 00000000..8ee441e8 --- /dev/null +++ b/python/jittor/extern/cuda/cublas/src/cublas_matmul_test.cc @@ -0,0 +1,353 @@ +//////////////////////////////////////////////////////////////////////////// +// +// Copyright 1993-2015 NVIDIA Corporation. All rights reserved. +// +// Please refer to the NVIDIA end user license agreement (EULA) associated +// with this source code for terms and conditions that govern your use of +// this software. Any use, reproduction, disclosure, or distribution of +// this software and related documentation outside the terms of the EULA +// is strictly prohibited. +// +//////////////////////////////////////////////////////////////////////////// + +// +// Matrix multiplication: C = A * B. +// Host code. +// +// This sample implements matrix multiplication as described in Chapter 3 +// of the programming guide and uses the CUBLAS library to demonstrate +// the best performance. + +// SOME PRECAUTIONS: +// IF WE WANT TO CALCULATE ROW-MAJOR MATRIX MULTIPLY C = A * B, +// WE JUST NEED CALL CUBLAS API IN A REVERSE ORDER: cublasSegemm(B, A)! +// The reason is explained as follows: + +// CUBLAS library uses column-major storage, but C/C++ use row-major storage. +// When passing the matrix pointer to CUBLAS, the memory layout alters from +// row-major to column-major, which is equivalent to an implicit transpose. + +// In the case of row-major C/C++ matrix A, B, and a simple matrix multiplication +// C = A * B, we can't use the input order like cublasSgemm(A, B) because of +// implicit transpose. The actual result of cublasSegemm(A, B) is A(T) * B(T). +// If col(A(T)) != row(B(T)), equal to row(A) != col(B), A(T) and B(T) are not +// multipliable. Moreover, even if A(T) and B(T) are multipliable, the result C +// is a column-based cublas matrix, which means C(T) in C/C++, we need extra +// transpose code to convert it to a row-based C/C++ matrix. + +// To solve the problem, let's consider our desired result C, a row-major matrix. +// In cublas format, it is C(T) actually (because of the implicit transpose). +// C = A * B, so C(T) = (A * B) (T) = B(T) * A(T). Cublas matrice B(T) and A(T) +// happen to be C/C++ matrice B and A (still because of the implicit transpose)! +// We don't need extra transpose code, we only need alter the input order! +// +// CUBLAS provides high-performance matrix multiplication. +// See also: +// V. Volkov and J. Demmel, "Benchmarking GPUs to tune dense linear algebra," +// in Proc. 2008 ACM/IEEE Conf. on Supercomputing (SC '08), +// Piscataway, NJ: IEEE Press, 2008, pp. Art. 31:1-11. +// + +// Utilities and system includes +#include +#include // helper for shared functions common to CUDA Samples + +// CUDA runtime +#include +#include + +// CUDA and CUBLAS functions +#include +#include "utils/log.h" +#include "helper_cuda.h" + +#ifndef min +#define min(a,b) ((a < b) ? a : b) +#endif +#ifndef max +#define max(a,b) ((a > b) ? a : b) +#endif + +typedef struct _matrixSize // Optional Command-line multiplier for matrix sizes +{ + unsigned int uiWA, uiHA, uiWB, uiHB, uiWC, uiHC; +} sMatrixSize; + +//////////////////////////////////////////////////////////////////////////////// +//! Compute reference data set matrix multiply on CPU +//! C = A * B +//! @param C reference data, computed but preallocated +//! @param A matrix A as provided to device +//! @param B matrix B as provided to device +//! @param hA height of matrix A +//! @param wB width of matrix B +//////////////////////////////////////////////////////////////////////////////// +void +matrixMulCPU(float *C, const float *A, const float *B, unsigned int hA, unsigned int wA, unsigned int wB) +{ + for (unsigned int i = 0; i < hA; ++i) + for (unsigned int j = 0; j < wB; ++j) + { + double sum = 0; + + for (unsigned int k = 0; k < wA; ++k) + { + double a = A[i * wA + k]; + double b = B[k * wB + j]; + sum += a * b; + } + + C[i * wB + j] = (float)sum; + } +} + +// Allocates a matrix with random float entries. +void randomInit(float *data, int size) +{ + for (int i = 0; i < size; ++i) + data[i] = rand() / (float)RAND_MAX; +} + +void printDiff(float *data1, float *data2, int width, int height, int iListLength, float fListTol) +{ + printf("Listing first %d Differences > %.6f...\n", iListLength, fListTol); + int i,j,k; + int error_count=0; + + for (j = 0; j < height; j++) + { + if (error_count < iListLength) + { + printf("\n Row %d:\n", j); + } + + for (i = 0; i < width; i++) + { + k = j * width + i; + float fDiff = fabs(data1[k] - data2[k]); + + if (fDiff > fListTol) + { + if (error_count < iListLength) + { + printf(" Loc(%d,%d)\tCPU=%.5f\tGPU=%.5f\tDiff=%.6f\n", i, j, data1[k], data2[k], fDiff); + } + + error_count++; + } + } + } + + printf(" \n Total Errors = %d\n", error_count); +} + +void initializeCUDA(int &devID, int &iSizeMultiple, sMatrixSize &matrix_size) +{ + // By default, we use device 0, otherwise we override the device ID based on what is provided at the command line + cudaError_t error; + devID = 0; + + iSizeMultiple = min(iSizeMultiple, 100); + iSizeMultiple = max(iSizeMultiple, 1); + + cudaDeviceProp deviceProp; + + error = cudaGetDeviceProperties(&deviceProp, devID); + + if (error != cudaSuccess) + { + printf("cudaGetDeviceProperties returned error code %d, line(%d)\n", error, __LINE__); + exit(EXIT_FAILURE); + } + + printf("GPU Device %d: \"%s\" with compute capability %d.%d\n\n", devID, deviceProp.name, deviceProp.major, deviceProp.minor); + + int block_size = 32; + + matrix_size.uiWA = 3 * block_size * iSizeMultiple; + matrix_size.uiHA = 4 * block_size * iSizeMultiple; + matrix_size.uiWB = 2 * block_size * iSizeMultiple; + matrix_size.uiHB = 3 * block_size * iSizeMultiple; + matrix_size.uiWC = 2 * block_size * iSizeMultiple; + matrix_size.uiHC = 4 * block_size * iSizeMultiple; + + printf("MatrixA(%u,%u), MatrixB(%u,%u), MatrixC(%u,%u)\n", + matrix_size.uiHA, matrix_size.uiWA, + matrix_size.uiHB, matrix_size.uiWB, + matrix_size.uiHC, matrix_size.uiWC); + + if( matrix_size.uiWA != matrix_size.uiHB || + matrix_size.uiHA != matrix_size.uiHC || + matrix_size.uiWB != matrix_size.uiWC) + { + printf("ERROR: Matrix sizes do not match!\n"); + exit(-1); + } +} + +//////////////////////////////////////////////////////////////////////////////// +//! Run a simple test matrix multiply using CUBLAS +//////////////////////////////////////////////////////////////////////////////// +int matrixMultiply(int devID, sMatrixSize &matrix_size) +{ + cudaDeviceProp deviceProp; + + checkCudaErrors(cudaGetDeviceProperties(&deviceProp, devID)); + + int block_size = 32; + + // set seed for rand() + srand(2006); + + // allocate host memory for matrices A and B + unsigned int size_A = matrix_size.uiWA * matrix_size.uiHA; + unsigned int mem_size_A = sizeof(float) * size_A; + float *h_A = (float *)malloc(mem_size_A); + unsigned int size_B = matrix_size.uiWB * matrix_size.uiHB; + unsigned int mem_size_B = sizeof(float) * size_B; + float *h_B = (float *)malloc(mem_size_B); + + // set seed for rand() + srand(2006); + + // initialize host memory + randomInit(h_A, size_A); + randomInit(h_B, size_B); + + // allocate device memory + float *d_A, *d_B, *d_C; + unsigned int size_C = matrix_size.uiWC * matrix_size.uiHC; + unsigned int mem_size_C = sizeof(float) * size_C; + + // allocate host memory for the result + float *h_C = (float *) malloc(mem_size_C); + float *h_CUBLAS = (float *) malloc(mem_size_C); + + checkCudaErrors(cudaMalloc((void **) &d_A, mem_size_A)); + checkCudaErrors(cudaMalloc((void **) &d_B, mem_size_B)); + checkCudaErrors(cudaMemcpy(d_A, h_A, mem_size_A, cudaMemcpyHostToDevice)); + checkCudaErrors(cudaMemcpy(d_B, h_B, mem_size_B, cudaMemcpyHostToDevice)); + checkCudaErrors(cudaMalloc((void **) &d_C, mem_size_C)); + + // setup execution parameters + dim3 threads(block_size, block_size); + dim3 grid(matrix_size.uiWC / threads.x, matrix_size.uiHC / threads.y); + + // create and start timer + printf("Computing result using CUBLAS..."); + + // execute the kernel + int nIter = 30; + + // CUBLAS version 2.0 + { + const float alpha = 1.0f; + const float beta = 0.0f; + cublasHandle_t handle; + cudaEvent_t start, stop; + + checkCudaErrors(cublasCreate(&handle)); + + //Perform warmup operation with cublas + checkCudaErrors(cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, matrix_size.uiWB, matrix_size.uiHA, matrix_size.uiWA, &alpha, d_B, matrix_size.uiWB, d_A, matrix_size.uiWA, &beta, d_C, matrix_size.uiWB)); + + // Allocate CUDA events that we'll use for timing + checkCudaErrors(cudaEventCreate(&start)); + checkCudaErrors(cudaEventCreate(&stop)); + + // Record the start event + checkCudaErrors(cudaEventRecord(start, NULL)); + + for (int j = 0; j < nIter; j++) + { + //note cublas is column primary! + //need to transpose the order + checkCudaErrors(cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, matrix_size.uiWB, matrix_size.uiHA, matrix_size.uiWA, &alpha, d_B, matrix_size.uiWB, d_A, matrix_size.uiWA, &beta, d_C, matrix_size.uiWB)); + + } + + printf("done.\n"); + + // Record the stop event + checkCudaErrors(cudaEventRecord(stop, NULL)); + + // Wait for the stop event to complete + checkCudaErrors(cudaEventSynchronize(stop)); + + float msecTotal = 0.0f; + checkCudaErrors(cudaEventElapsedTime(&msecTotal, start, stop)); + + // Compute and print the performance + float msecPerMatrixMul = msecTotal / nIter; + double flopsPerMatrixMul = 2.0 * (double)matrix_size.uiHC * (double)matrix_size.uiWC * (double)matrix_size.uiHB; + double gigaFlops = (flopsPerMatrixMul * 1.0e-9f) / (msecPerMatrixMul / 1000.0f); + printf( + "Performance= %.2f GFlop/s, Time= %.3f msec, Size= %.0f Ops\n", + gigaFlops, + msecPerMatrixMul, + flopsPerMatrixMul); + + // copy result from device to host + checkCudaErrors(cudaMemcpy(h_CUBLAS, d_C, mem_size_C, cudaMemcpyDeviceToHost)); + + // Destroy the handle + checkCudaErrors(cublasDestroy(handle)); + } + + // compute reference solution + printf("Computing result using host CPU..."); + float *reference = (float *)malloc(mem_size_C); + bool resCUBLAS = true; + // only compare with cpu when size smaller than 1000 + if (matrix_size.uiHA < 1000) { + matrixMulCPU(reference, h_A, h_B, matrix_size.uiHA, matrix_size.uiWA, matrix_size.uiWB); + printf("done.\n"); + + // check result (CUBLAS) + resCUBLAS = sdkCompareL2fe(reference, h_CUBLAS, size_C, 1.0e-6f); + + if (resCUBLAS != true) + { + printDiff(reference, h_CUBLAS, matrix_size.uiWC, matrix_size.uiHC, 100, 1.0e-5f); + } + + printf("Comparing CUBLAS Matrix Multiply with CPU results: %s\n", (true == resCUBLAS) ? "PASS" : "FAIL"); + + printf("\nNOTE: The CUDA Samples are not meant for performance measurements. Results may vary when GPU Boost is enabled.\n"); + + } + // clean up memory + free(h_A); + free(h_B); + free(h_C); + free(reference); + checkCudaErrors(cudaFree(d_A)); + checkCudaErrors(cudaFree(d_B)); + checkCudaErrors(cudaFree(d_C)); + + if (resCUBLAS == true) + { + return EXIT_SUCCESS; // return value = 1 + } + else + { + return EXIT_FAILURE; // return value = 0 + } +} + +//////////////////////////////////////////////////////////////////////////////// +// Program main +//////////////////////////////////////////////////////////////////////////////// +int cublas_test_entry(int sizeMult) +{ + printf("[Matrix Multiply CUBLAS] - Starting...\n"); + + int devID = 0; + sMatrixSize matrix_size; + + initializeCUDA(devID, sizeMult, matrix_size); + + int matrix_result = matrixMultiply(devID, matrix_size); + + return matrix_result; +} \ No newline at end of file diff --git a/python/jittor/extern/cuda/cublas/src/cublas_wrapper.cc b/python/jittor/extern/cuda/cublas/src/cublas_wrapper.cc new file mode 100644 index 00000000..3ed347ef --- /dev/null +++ b/python/jittor/extern/cuda/cublas/src/cublas_wrapper.cc @@ -0,0 +1,34 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "cublas_wrapper.h" +#include "misc/cuda_flags.h" + +namespace jittor { + +cublasHandle_t cublas_handle; + +struct cublas_initer { + +inline cublas_initer() { + if (!get_device_count()) return; + checkCudaErrors(cublasCreate(&cublas_handle)); + LOGv << "cublasCreate finished" << (void*)cublas_handle; +} + +inline ~cublas_initer() { + if (!get_device_count()) return; + LOGv << "cublasDestroy:" << (void*)cublas_handle; + checkCudaErrors(cublasDestroy(cublas_handle)); + LOGv << "cublasDestroy finished"; +} + +} init; + +} // jittor diff --git a/python/jittor/extern/cuda/cublas/src/helper_cublas.cc b/python/jittor/extern/cuda/cublas/src/helper_cublas.cc new file mode 100644 index 00000000..9b6fcfe3 --- /dev/null +++ b/python/jittor/extern/cuda/cublas/src/helper_cublas.cc @@ -0,0 +1,56 @@ +/** + * Copyright 1993-2017 NVIDIA Corporation. All rights reserved. + * + * Please refer to the NVIDIA end user license agreement (EULA) associated + * with this source code for terms and conditions that govern your use of + * this software. Any use, reproduction, disclosure, or distribution of + * this software and related documentation outside the terms of the EULA + * is strictly prohibited. + * + */ + +//////////////////////////////////////////////////////////////////////////////// +// These are CUDA Helper functions for initialization and error checking + +#include +#include +#include "helper_cuda.h" + +#ifdef CUBLAS_API_H_ +// cuBLAS API errors +const char *_cudaGetErrorEnum(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; + } + + return ""; +} +#endif diff --git a/python/jittor/extern/cuda/cudnn/inc/cudnn_rnn_descriptor.h b/python/jittor/extern/cuda/cudnn/inc/cudnn_rnn_descriptor.h new file mode 100644 index 00000000..ca6de5c5 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/inc/cudnn_rnn_descriptor.h @@ -0,0 +1,154 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Zheng-Ning Liu +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" +#include "cudnn_wrapper.h" +#include "executor.h" +#include "init.h" + + +namespace jittor { + +static inline cudnnRNNMode_t rnn_string_to_rnn_mode(string mode) { + if (mode == "relu") + return CUDNN_RNN_RELU; + if (mode == "tanh") + return CUDNN_RNN_TANH; + if (mode == "lstm") + return CUDNN_LSTM; + ASSERT(mode == "gru") << "rnn mode must be relu, tanh, lstm, or gru, but got " << mode; + return CUDNN_GRU; +} + +static inline int rnn_string_to_num_linear_layers(string mode) { + if (mode == "relu") + return 2; + if (mode == "tanh") + return 2; + if (mode == "lstm") + return 8; + ASSERT(mode == "gru") << "mode must be relu, tanh, lstm, or gru, but got " << mode; + return 6; +} + +/** A wrapper for CUDNN dropout descriptor + */ +struct DropoutDescriptor { + cudnnDropoutDescriptor_t desc; + size_t stateSize, stateAllocation; + float dropout; + void *stateSpace; + + DropoutDescriptor(cudnnHandle_t handle, float dropout) + : dropout(dropout), stateSpace(nullptr) { + checkCudaErrors(cudnnCreateDropoutDescriptor(&desc)); + if (dropout > 0) { + checkCudaErrors(cudnnDropoutGetStatesSize(handle, &stateSize)); + stateSpace = exe.temp_allocator->alloc(stateSize, stateAllocation); + checkCudaErrors(cudnnSetDropoutDescriptor( + desc, + cudnn_handle, + dropout, + stateSpace, + stateSize, + get_seed() + )); + } else { + checkCudaErrors(cudnnSetDropoutDescriptor( + desc, handle, 0, nullptr, 0, 0 + )); + } + } + ~DropoutDescriptor() { + checkCudaErrors(cudnnDestroyDropoutDescriptor(desc)); + if (stateSpace) + exe.temp_allocator->free(stateSpace, stateSize, stateAllocation); + } +}; + +/** A wrapper for CUDNN RNN descriptor + */ +struct RnnDescriptor { + cudnnHandle_t handle; + cudnnRNNDescriptor_t desc; + DropoutDescriptor dropoutDesc; + + RnnDescriptor(cudnnHandle_t handle, string mode, int hidden_size, int num_layers, + float dropout, bool bidirectional) : handle(handle), dropoutDesc(handle, dropout) { + checkCudaErrors(cudnnCreateRNNDescriptor(&desc)); + checkCudaErrors(cudnnSetRNNDescriptor_v6( + handle, + desc, + hidden_size, + num_layers, + dropoutDesc.desc, + CUDNN_LINEAR_INPUT, + bidirectional ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, + rnn_string_to_rnn_mode(mode), + CUDNN_RNN_ALGO_STANDARD, + CUDNN_DATA_FLOAT + )); + } + + ~RnnDescriptor() { + checkCudaErrors(cudnnDestroyRNNDescriptor(desc)); + } + + size_t weight_space_size(const cudnnTensorDescriptor_t &xDesc) { + size_t size; + checkCudaErrors(cudnnGetRNNParamsSize( + handle, desc, xDesc, &size, CUDNN_DATA_FLOAT + )); + return size; + } + + size_t work_space_size(const cudnnTensorDescriptor_t *xDesc, int seq_length) { + size_t size; + checkCudaErrors(cudnnGetRNNWorkspaceSize( + handle, desc, seq_length, xDesc, &size + )); + return size; + } + + size_t reserve_space_size(const cudnnTensorDescriptor_t *xDesc, int seq_length) { + size_t size; + checkCudaErrors(cudnnGetRNNTrainingReserveSize( + handle, desc, seq_length, xDesc, &size + )); + return size; + } +}; + +/** + */ +struct RnnWeightDescriptor { + cudnnFilterDescriptor_t desc; + size_t size; + RnnWeightDescriptor(size_t size) : size(size) { + int dimW[3] = {(int) (size / sizeof(float)), 1, 1}; + checkCudaErrors(cudnnCreateFilterDescriptor(&desc)); + checkCudaErrors(cudnnSetFilterNdDescriptor(desc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, 3, dimW)); + } + ~RnnWeightDescriptor() { + cudnnDestroyFilterDescriptor(desc); + } +}; + +/** + Returns offsets of RNN linear parameters in a flatten array. + + Returns + ======= + list: [total size, param #1 offset, param #2 offset, ...] + + TODO: support cudnn rnn-v8; support proj_size + */ +// @pyjt(cudnn_rnn_weight_offset) +vector cudnn_rnn_weight_offset(string mode, int input_size, int hidden_size, int num_layers, int proj_size, bool bias, bool bidirectional); + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h b/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h new file mode 100644 index 00000000..73fb3233 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/inc/cudnn_wrapper.h @@ -0,0 +1,39 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include +#include +#ifndef IS_ROCM +#include +#endif +#include "utils/log.h" +#include "helper_cuda.h" +#include "fp16_emu.h" +#include "common.h" + +namespace jittor { + +EXTERN_LIB cudnnHandle_t cudnn_handle; +EXTERN_LIB int max_cache_size; +EXTERN_LIB float max_workspace_ratio; + +// @pyjt(set_algorithm_cache_size) +void set_algorithm_cache_size(int size); + +// @pyjt(set_max_workspace_ratio) +void set_max_workspace_ratio(float64 ratio); + + +template __inline__ cudnnDataType_t getDataType(); +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_HALF; } +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_FLOAT; } +#ifndef IS_ROCM +template <> __inline__ cudnnDataType_t getDataType<__nv_bfloat16>() { return CUDNN_DATA_BFLOAT16; } +#endif + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_w_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_w_op.cc new file mode 100644 index 00000000..a49338e0 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_w_op.cc @@ -0,0 +1,297 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang +// Guowei Yang <471184555@qq.com> +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "mem/allocator.h" +#include "var.h" +#include "cudnn_conv3d_backward_w_op.h" +#include "cudnn_wrapper.h" +#include "executor.h" +#include "ops/op_register.h" +#include "mem/mem_info.h" + +using namespace std; + +namespace jittor { + +extern int use_tensorcore; + +#pragma GCC diagnostic ignored "-Wunused-variable" + +#ifndef JIT + +CudnnConv3dBackwardWOp::CudnnConv3dBackwardWOp(Var* x, Var* dy, int kd, int kh, int kw, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups, string xformat) + : x(x), dy(dy), kd(kd), kh(kh), kw(kw), strided(strided), strideh(strideh), stridew(stridew), paddingd(paddingd), paddingh(paddingh), paddingw(paddingw), dilationd(dilationd), dilationh(dilationh), dilationw(dilationw), groups(groups), + xformat(move(xformat)) { + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_manual_set_vnbb); + x->flags.set(NodeFlags::_needed_by_backward); + dy->flags.set(NodeFlags::_needed_by_backward); + dw = create_output(nullptr, dtype_infer(dy->ns, x->ns)); +} + +void CudnnConv3dBackwardWOp::infer_shape() { + ASSERTop(x->shape.size(),==,5); + ASSERTop(dy->shape.size(),==,5); + int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw; + + if (xformat == "ncdhw") { + x->shape.unpack(xn, xc, xd, xh, xw); + dy->shape.unpack(yn, yc, yd, yh, yw); + } else { + x->shape.unpack(xn, xd, xh, xw, xc); + dy->shape.unpack(yn, yd, yh, yw, yc); + } + wco = yc, wci = xc / groups; + wh = kh; + ww = kw; + wd = kd; + dw->set_shape(NanoVector(wco, wci, wd, wh, ww)); +} + +void CudnnConv3dBackwardWOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype(); + jk << "«Ty:" << dy->dtype(); + jk << "«Tw:" << dw->dtype(); +} + +static auto make_conv3d = get_op_info("cudnn_conv3d") + .get_constructor(); +static auto make_backwardx = get_op_info("cudnn_conv3d_backward_x") + .get_constructor(); + + +VarPtr CudnnConv3dBackwardWOp::grad(Var* out, Var* dout, Var* v, int v_index) { + int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw; + + if (xformat == "ncdhw") { + x->shape.unpack(xn, xc, xd, xh, xw); + dy->shape.unpack(yn, yc, yd, yh, yw); + } else { + x->shape.unpack(xn, xd, xh, xw, xc); + dy->shape.unpack(yn, yd, yh, yw, yc); + } + + if (v_index == 0) { + return make_backwardx(dout, dy, xd, xh, xw, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat); + } else { + return make_conv3d(x, dout, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat); + } +} + +// unordered_map bwdw_algo_cache; + +#else // JIT +#ifdef JIT_cuda + +#pragma clang diagnostic ignored "-Wtautological-compare" + +EXTERN_LIB unordered_map bwdw_algo_cache; + +template __inline__ cudnnDataType_t getDataType(); +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_HALF; } +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_FLOAT; } + +void CudnnConv3dBackwardWOp::jit_run() { + auto w = dw; + auto y = dy; + cudnnHandle_t& handle_ = cudnn_handle; + + cudnnTensorDescriptor_t cudnnIdesc; + cudnnFilterDescriptor_t cudnnFdesc; + cudnnTensorDescriptor_t cudnnOdesc; + cudnnConvolutionDescriptor_t cudnnConvDesc; + + checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnIdesc )); + checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc )); + checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc )); + checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc )); + + int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw; + int sx[] = {0,0,0,0,1}; + for (int i=3; i>=0; i--) sx[i] = sx[i+1] * x->shape[i+1]; + int strideX[5]; + if (xformat == "ncdhw") { + x->shape.unpack(xn, xc, xd, xh, xw); + int tmp[5] = {sx[0],sx[1],sx[2],sx[3],sx[4]}; + memcpy(strideX, tmp, sizeof(tmp)); + } else { + x->shape.unpack(xn, xd, xh, xw, xc); + int tmp[5] = {sx[0],sx[2],sx[3],sx[4],sx[1]}; + memcpy(strideX, tmp, sizeof(tmp)); + } + int dimX[] = {xn, xc, xd, xh, xw}; + // dimX: ncdhw + checkCudaErrors(cudnnSetTensorNdDescriptor( + cudnnIdesc, getDataType(), + 5, dimX, strideX + )); + + auto ws = w->shape; + int dimW[] = {(int)ws[0],(int)ws[1],(int)ws[2],(int)ws[3],(int)ws[4]}; + // cudnn only support this two format + // https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnSetFilterNdDescriptor + #define filterFormat_oihw CUDNN_TENSOR_NCHW + #define filterFormat_ohwi CUDNN_TENSOR_NHWC + + // dimW: KCRS(oihw) + checkCudaErrors(cudnnSetFilterNdDescriptor( + cudnnFdesc, getDataType(), + // filterFormat_@WFORMAT, 5, dimW + filterFormat_oihw, 5, dimW + )); + + int padA[] = {paddingd, paddingh, paddingw}; + int convstrideA[] = {strided, strideh, stridew}; + int dilationA[] = {dilationd, dilationh, dilationw}; + // difference between + // CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION + // is the kernel rc order + // currently, No perf difference is observed between + // this two mode + checkCudaErrors(cudnnSetConvolutionNdDescriptor( + cudnnConvDesc, 3, + padA, convstrideA, dilationA, + CUDNN_CROSS_CORRELATION, getDataType() + )); + // MIOpen requires groups to be set after descriptor initialization + checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups )); + + // using tensor core + if(use_tensorcore){ + checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) ); + } + + + int sy[] = {0,0,0,0,1}; + for (int i=3; i>=0; i--) sy[i] = sy[i+1] * y->shape[i+1]; + int strideY[5]; + if (xformat == "ncdhw") { + y->shape.unpack(yn, yc, yd, yh, yw); + int tmp[5] = {sy[0],sy[1],sy[2],sy[3],sy[4]}; + memcpy(strideY, tmp, sizeof(tmp)); + } else { + y->shape.unpack(yn, yd, yh, yw, yc); + int tmp[5] = {sy[0],sy[2],sy[3],sy[4],sy[1]}; + memcpy(strideY, tmp, sizeof(tmp)); + } + int dimY[] = {yn, yc, yd, yh, yw}; + + checkCudaErrors(cudnnSetTensorNdDescriptor( + cudnnOdesc, getDataType(), + 5, dimY, strideY + )); + + cudnnConvolutionBwdFilterAlgo_t algos[] = { + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, + }; + int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; + int perf_count; + STACK_ALLOC(cudnnConvolutionBwdFilterAlgoPerf_t,perf_results,num_algos); + cudnnConvolutionBwdFilterAlgo_t algo; + bool benchmark=true; + + JK& jk = get_jk(); + jk.clear(); + jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ","; + jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ","; + jk << paddingd << paddingh << paddingw << "," << strided << strideh <second; + else { + if (bwdw_algo_cache.size()>=max_cache_size) benchmark = false; + if (benchmark) { + size_t max_ws_size = 0; + for (int i = 0; i < num_algos; i++) { + size_t sz; + cudnnStatus_t ret = cudnnGetConvolutionBackwardFilterWorkspaceSize(handle_, cudnnIdesc, cudnnOdesc, cudnnConvDesc, cudnnFdesc, algos[i], &sz); + // continue if use too much workspace + if (sz > mem_info.total_cuda_ram * max_workspace_ratio) continue; + if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz; + } + size_t allocation; + void* ws = exe.temp_allocator->alloc(max_ws_size, allocation); + checkCudaErrors(cudnnFindConvolutionBackwardFilterAlgorithmEx( + handle_, + cudnnIdesc, x->ptr(), + cudnnOdesc, y->ptr(), + cudnnConvDesc, + cudnnFdesc, w->ptr(), + num_algos, + &perf_count, + perf_results, + ws, + max_ws_size)); + exe.temp_allocator->free(ws, max_ws_size, allocation); + } else { + checkCudaErrors(cudnnGetConvolutionBackwardFilterAlgorithm_v7( + handle_, + cudnnIdesc, + cudnnOdesc, + cudnnConvDesc, + cudnnFdesc, + num_algos, + &perf_count, + perf_results)); + } + int best_algo_idx=-1; + for (int i = 0; i < perf_count; i++) + if (perf_results[i].status == CUDNN_STATUS_SUCCESS){ + best_algo_idx=i; + break; + } + ASSERT(best_algo_idx!=-1); + algo=perf_results[best_algo_idx].algo; + if (benchmark) { + bwdw_algo_cache[jk.to_string()] = algo; + if (bwdw_algo_cache.size()==max_cache_size) + LOGw << "backward w algorithm cache is full"; + } + } + + // TODO: warp work space + void *workSpace = 0; + size_t workSpaceSize; + checkCudaErrors (cudnnGetConvolutionBackwardFilterWorkspaceSize( + handle_, cudnnIdesc, cudnnOdesc, cudnnConvDesc, + cudnnFdesc, algo, &workSpaceSize)); + size_t allocation; + if (workSpaceSize > 0) { + workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation); + } + float alpha=1, beta=0; + checkCudaErrors(cudnnConvolutionBackwardFilter( + handle_, + (void*)(&alpha), + cudnnIdesc, x->ptr(), + cudnnOdesc, y->ptr(), + cudnnConvDesc, + algo, + workSpace, workSpaceSize, + (void*)(&beta), + cudnnFdesc, w->ptr()) + ); + if (workSpace) + exe.temp_allocator->free(workSpace, workSpaceSize, allocation); + + checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc )); + checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc )); + checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnOdesc )); + checkCudaErrors(cudnnDestroyConvolutionDescriptor( cudnnConvDesc )); +} +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_w_op.h b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_w_op.h new file mode 100644 index 00000000..e442e902 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_w_op.h @@ -0,0 +1,28 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang +// Guowei Yang <471184555@qq.com> +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CudnnConv3dBackwardWOp : Op { + Var* x, * dy, * dw; + int kd, kh, kw, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups; + string xformat; + + CudnnConv3dBackwardWOp(Var* x, Var* y, int kd, int kh, int kw, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups=1, string xformat="ncdhw"); + + const char* name() const override { return "cudnn_conv3d_backward_w"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_x_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_x_op.cc new file mode 100644 index 00000000..76b958a7 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_x_op.cc @@ -0,0 +1,287 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang +// Guowei Yang <471184555@qq.com> +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "mem/allocator.h" +#include "var.h" +#include "cudnn_conv3d_backward_x_op.h" +#include "cudnn_wrapper.h" +#include "executor.h" +#include "ops/op_register.h" +#include "mem/mem_info.h" + +using namespace std; + +namespace jittor { + +extern int use_tensorcore; + +#pragma GCC diagnostic ignored "-Wunused-variable" + +#ifndef JIT + +CudnnConv3dBackwardXOp::CudnnConv3dBackwardXOp(Var* w, Var* dy, int depth, int height, int width, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups, string xformat) + : w(w), dy(dy), xd(depth), xh(height), xw(width), strided(strided), strideh(strideh), stridew(stridew), paddingd(paddingd), paddingh(paddingh), paddingw(paddingw), dilationd(dilationd), dilationh(dilationh), dilationw(dilationw), groups(groups), + xformat(move(xformat)) { + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_manual_set_vnbb); + w->flags.set(NodeFlags::_needed_by_backward); + dy->flags.set(NodeFlags::_needed_by_backward); + dx = create_output(nullptr, dtype_infer(dy->ns, w->ns)); +} + +void CudnnConv3dBackwardXOp::infer_shape() { + ASSERTop(w->shape.size(),==,5); + ASSERTop(dy->shape.size(),==,5); + int xn, xc, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw; + w->shape.unpack(wco, wci, wd, wh, ww); + if (xformat == "ncdhw") + dy->shape.unpack(yn, yc, yd, yh, yw); + else + dy->shape.unpack(yn, yd, yh, yw, yc); + xn = yn, xc = wci * groups; + if (xformat == "ncdhw") + dx->set_shape(NanoVector(xn, xc, xd, xh, xw)); + else + dx->set_shape(NanoVector(xn, xd, xh, xw, xc)); +} + +void CudnnConv3dBackwardXOp::jit_prepare(JK& jk) { + jk << "«Tx:" << dx->dtype(); + jk << "«Ty:" << dy->dtype(); + jk << "«Tw:" << w->dtype(); +} + + +static auto make_conv3d = get_op_info("cudnn_conv3d") + .get_constructor(); +static auto make_backwardw = get_op_info("cudnn_conv3d_backward_w") + .get_constructor(); + + +VarPtr CudnnConv3dBackwardXOp::grad(Var* out, Var* dout, Var* v, int v_index) { + int xn, xc, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw; + w->shape.unpack(wco, wci, wd, wh, ww); + + if (v_index == 0) { + return make_backwardw(dout, dy, wd, wh, ww, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat); + } else { + return make_conv3d(dout, w, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat); + } +} +// unordered_map bwdx_algo_cache; + +#else // JIT +#ifdef JIT_cuda + +#pragma clang diagnostic ignored "-Wtautological-compare" + +EXTERN_LIB unordered_map bwdx_algo_cache; + +template __inline__ cudnnDataType_t getDataType(); +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_HALF; } +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_FLOAT; } + +void CudnnConv3dBackwardXOp::jit_run() { + auto x = dx; + auto y = dy; + cudnnHandle_t& handle_ = cudnn_handle; + + cudnnTensorDescriptor_t cudnnIdesc; + cudnnFilterDescriptor_t cudnnFdesc; + cudnnTensorDescriptor_t cudnnOdesc; + cudnnConvolutionDescriptor_t cudnnConvDesc; + + checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnIdesc )); + checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc )); + checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc )); + checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc )); + + int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw; + int sx[] = {0,0,0,0,1}; + for (int i=3; i>=0; i--) sx[i] = sx[i+1] * x->shape[i+1]; + int strideX[5]; + if (xformat == "ncdhw") { + x->shape.unpack(xn, xc, xd, xh, xw); + int tmp[5] = {sx[0],sx[1],sx[2],sx[3],sx[4]}; + memcpy(strideX, tmp, sizeof(tmp)); + } else { + x->shape.unpack(xn, xd, xh, xw, xc); + int tmp[5] = {sx[0],sx[2],sx[3],sx[4],sx[1]}; + memcpy(strideX, tmp, sizeof(tmp)); + } + int dimX[] = {xn, xc, xd, xh, xw}; + // dimX: ncdhw + checkCudaErrors(cudnnSetTensorNdDescriptor( + cudnnIdesc, getDataType(), + 5, dimX, strideX + )); + + auto ws = w->shape; + int dimW[] = {(int)ws[0],(int)ws[1],(int)ws[2],(int)ws[3],(int)ws[4]}; + // cudnn only support this two format + // https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnSetFilterNdDescriptor + #define filterFormat_oihw CUDNN_TENSOR_NCHW + #define filterFormat_ohwi CUDNN_TENSOR_NHWC + + // dimW: KCRS(oihw) + checkCudaErrors(cudnnSetFilterNdDescriptor( + cudnnFdesc, getDataType(), + // filterFormat_@WFORMAT, 5, dimW + filterFormat_oihw, 5, dimW + )); + + int padA[] = {paddingd, paddingh, paddingw}; + int convstrideA[] = {strided, strideh, stridew}; + int dilationA[] = {dilationd, dilationh, dilationw}; + // difference between + // CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION + // is the kernel rc order + // currently, No perf difference is observed between + // this two mode + checkCudaErrors(cudnnSetConvolutionNdDescriptor( + cudnnConvDesc, 3, + padA, convstrideA, dilationA, + CUDNN_CROSS_CORRELATION, getDataType() + )); + // MIOpen requires groups to be set after descriptor initialization + checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups )); + + // using tensor core + if(use_tensorcore){ + checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) ); + } + + + int sy[] = {0,0,0,0,1}; + for (int i=3; i>=0; i--) sy[i] = sy[i+1] * y->shape[i+1]; + int strideY[5]; + if (xformat == "ncdhw") { + y->shape.unpack(yn, yc, yd, yh, yw); + int tmp[5] = {sy[0],sy[1],sy[2],sy[3],sy[4]}; + memcpy(strideY, tmp, sizeof(tmp)); + } else { + y->shape.unpack(yn, yd, yh, yw, yc); + int tmp[5] = {sy[0],sy[2],sy[3],sy[4],sy[1]}; + memcpy(strideY, tmp, sizeof(tmp)); + } + int dimY[] = {yn, yc, yd, yh, yw}; + + checkCudaErrors(cudnnSetTensorNdDescriptor( + cudnnOdesc, getDataType(), + 5, dimY, strideY + )); + + cudnnConvolutionBwdDataAlgo_t algos[] = { + CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED + }; + int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; + int perf_count; + STACK_ALLOC(cudnnConvolutionBwdDataAlgoPerf_t,perf_results,num_algos); + cudnnConvolutionBwdDataAlgo_t algo; + bool benchmark=true; + + JK& jk = get_jk(); + jk.clear(); + jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ","; + jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ","; + jk << paddingd << paddingh << paddingw << "," << strided << strideh <second; + else { + if (bwdx_algo_cache.size()>=max_cache_size) benchmark = false; + if (benchmark) { + size_t max_ws_size = 0; + for (int i = 0; i < num_algos; i++) { + size_t sz; + cudnnStatus_t ret = cudnnGetConvolutionBackwardDataWorkspaceSize(handle_, cudnnFdesc, cudnnOdesc, cudnnConvDesc, cudnnIdesc, algos[i], &sz); + // continue if use too much workspace + if (sz > mem_info.total_cuda_ram * max_workspace_ratio) continue; + if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz; + } + size_t allocation; + void* ws = exe.temp_allocator->alloc(max_ws_size, allocation); + checkCudaErrors(cudnnFindConvolutionBackwardDataAlgorithmEx( + handle_, + cudnnFdesc, w->ptr(), + cudnnOdesc, y->ptr(), + cudnnConvDesc, + cudnnIdesc, x->ptr(), + num_algos, + &perf_count, + perf_results, + ws, + max_ws_size)); + exe.temp_allocator->free(ws, max_ws_size, allocation); + } else { + checkCudaErrors(cudnnGetConvolutionBackwardDataAlgorithm_v7( + handle_, + cudnnFdesc, + cudnnOdesc, + cudnnConvDesc, + cudnnIdesc, + num_algos, + &perf_count, + perf_results)); + } + int best_algo_idx=-1; + for (int i = 0; i < perf_count; i++) + if (perf_results[i].status == CUDNN_STATUS_SUCCESS){ + best_algo_idx=i; + break; + } + ASSERT(best_algo_idx!=-1); + algo=perf_results[best_algo_idx].algo; + if (benchmark) { + bwdx_algo_cache[jk.to_string()] = algo; + if (bwdx_algo_cache.size()==max_cache_size) + LOGw << "backward x algorithm cache is full"; + } + } + + // TODO: warp work space + void *workSpace = 0; + size_t workSpaceSize; + checkCudaErrors (cudnnGetConvolutionBackwardDataWorkspaceSize( + handle_, cudnnFdesc, cudnnOdesc, cudnnConvDesc, + cudnnIdesc, algo, &workSpaceSize)); + size_t allocation; + if (workSpaceSize > 0) { + workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation); + } + float alpha=1, beta=0; + checkCudaErrors(cudnnConvolutionBackwardData( + handle_, + (void*)(&alpha), + cudnnFdesc, w->ptr(), + cudnnOdesc, y->ptr(), + cudnnConvDesc, + algo, + workSpace, workSpaceSize, + (void*)(&beta), + cudnnIdesc, x->ptr()) + ); + if (workSpace) + exe.temp_allocator->free(workSpace, workSpaceSize, allocation); + + checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc )); + checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc )); + checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnOdesc )); + checkCudaErrors(cudnnDestroyConvolutionDescriptor( cudnnConvDesc )); +} +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_x_op.h b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_x_op.h new file mode 100644 index 00000000..c60b4f66 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_backward_x_op.h @@ -0,0 +1,28 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang +// Guowei Yang <471184555@qq.com> +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CudnnConv3dBackwardXOp : Op { + Var* w, * dy, * dx; + int xd, xh, xw, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups; + string xformat; + + CudnnConv3dBackwardXOp(Var* w, Var* y, int depth, int height, int width, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups=1, string xformat="ncdhw"); + + const char* name() const override { return "cudnn_conv3d_backward_x"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_op.cc new file mode 100644 index 00000000..52e9bbae --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_op.cc @@ -0,0 +1,292 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "cudnn_conv3d_op.h" +#include "cudnn_wrapper.h" +#include "executor.h" +#include "ops/op_register.h" +#include "mem/mem_info.h" + +using namespace std; + +namespace jittor { + +extern int use_tensorcore; + +#pragma GCC diagnostic ignored "-Wunused-variable" + +#ifndef JIT + +CudnnConv3dOp::CudnnConv3dOp(Var* x, Var* w, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd, int dilationh, int dilationw, int groups, string xformat) + : x(x), w(w), strided(strided), strideh(strideh), stridew(stridew), paddingd(paddingd), paddingh(paddingh), paddingw(paddingw), dilationd(dilationd), dilationh(dilationh), dilationw(dilationw), groups(groups), + xformat(move(xformat)) { + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_manual_set_vnbb); + x->flags.set(NodeFlags::_needed_by_backward); + w->flags.set(NodeFlags::_needed_by_backward); + y = create_output(nullptr, dtype_infer(x->ns, w->ns)); +} + +void CudnnConv3dOp::infer_shape() { + ASSERTop(x->shape.size(),==,5); + ASSERTop(w->shape.size(),==,5); + int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw; + if (xformat == "ncdhw") + x->shape.unpack(xn, xc, xd, xh, xw); + else + x->shape.unpack(xn, xd, xh, xw, xc); + w->shape.unpack(wco, wci, wd, wh, ww); + ASSERTop(wci * groups,==,xc); + yn = xn, yc = wco; + yd = (xd+paddingd*2-wd*dilationd+dilationd-1)/strided+1; + yh = (xh+paddingh*2-wh*dilationh+dilationh-1)/strideh+1; + yw = (xw+paddingw*2-ww*dilationw+dilationw-1)/stridew+1; + if (xformat == "ncdhw") + y->set_shape(NanoVector(yn, yc, yd, yh, yw)); + else + y->set_shape(NanoVector(yn, yd, yh, yw, yc)); +} + +void CudnnConv3dOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype(); + jk << "«Ty:" << y->dtype(); + jk << "«Tw:" << w->dtype(); +} + +static auto make_backwardx = get_op_info("cudnn_conv3d_backward_x") + .get_constructor(); +static auto make_backwardw = get_op_info("cudnn_conv3d_backward_w") + .get_constructor(); + +VarPtr CudnnConv3dOp::grad(Var* out, Var* dout, Var* v, int v_index) { + int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw; + if (xformat == "ncdhw") + x->shape.unpack(xn, xc, xd, xh, xw); + else + x->shape.unpack(xn, xd, xh, xw, xc); + w->shape.unpack(wco, wci, wd, wh, ww); + if (v_index == 0) { + return make_backwardx(w, dout, xd, xh, xw, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat); + } else { + return make_backwardw(x, dout, wd, wh, ww, strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups, xformat); + } +} + +// unordered_map fwd_algo_cache; + +#else // JIT +#ifdef JIT_cuda + +#pragma clang diagnostic ignored "-Wtautological-compare" + +EXTERN_LIB unordered_map fwd_algo_cache; + +template __inline__ cudnnDataType_t getDataType(); +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_HALF; } +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_FLOAT; } + +void CudnnConv3dOp::jit_run() { + cudnnHandle_t& handle_ = cudnn_handle; + + cudnnTensorDescriptor_t cudnnIdesc; + cudnnFilterDescriptor_t cudnnFdesc; + cudnnTensorDescriptor_t cudnnOdesc; + cudnnConvolutionDescriptor_t cudnnConvDesc; + + checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnIdesc )); + checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc )); + checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc )); + checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc )); + + int xn, xc, xd, xh, xw, wd, wh, ww, wci, wco, yn, yc, yd, yh, yw; + int sx[] = {0,0,0,0,1}; + for (int i=3; i>=0; i--) sx[i] = sx[i+1] * x->shape[i+1]; + int strideX[5]; + if (xformat == "ncdhw") { + x->shape.unpack(xn, xc, xd, xh, xw); + int tmp[5] = {sx[0],sx[1],sx[2],sx[3],sx[4]}; + memcpy(strideX, tmp, sizeof(tmp)); + } else { + x->shape.unpack(xn, xd, xh, xw, xc); + int tmp[5] = {sx[0],sx[2],sx[3],sx[4],sx[1]}; + memcpy(strideX, tmp, sizeof(tmp)); + } + int dimX[] = {xn, xc, xd, xh, xw}; + // dimX: ncdhw + checkCudaErrors(cudnnSetTensorNdDescriptor( + cudnnIdesc, getDataType(), + 5, dimX, strideX + )); + + auto ws = w->shape; + int dimW[] = {(int)ws[0],(int)ws[1],(int)ws[2],(int)ws[3],(int)ws[4]}; + // cudnn only support this two format + // https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnSetFilterNdDescriptor + #define filterFormat_oihw CUDNN_TENSOR_NCHW + #define filterFormat_ohwi CUDNN_TENSOR_NHWC + + // dimW: KCRS(oihw) + checkCudaErrors(cudnnSetFilterNdDescriptor( + cudnnFdesc, getDataType(), + // filterFormat_@WFORMAT, 5, dimW + filterFormat_oihw, 5, dimW + )); + + int padA[] = {paddingd, paddingh, paddingw}; + int convstrideA[] = {strided, strideh, stridew}; + int dilationA[] = {dilationd, dilationh, dilationw}; + // difference between + // CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION + // is the kernel rc order + // currently, No perf difference is observed between + // this two mode + checkCudaErrors(cudnnSetConvolutionNdDescriptor( + cudnnConvDesc, 3, + padA, convstrideA, dilationA, + CUDNN_CROSS_CORRELATION, getDataType() + )); + // MIOpen requires groups to be set after descriptor initialization + checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups )); + + // using tensor core + if(use_tensorcore){ + checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) ); + } + + + int sy[] = {0,0,0,0,1}; + for (int i=3; i>=0; i--) sy[i] = sy[i+1] * y->shape[i+1]; + int strideY[5]; + if (xformat == "ncdhw") { + y->shape.unpack(yn, yc, yd, yh, yw); + int tmp[5] = {sy[0],sy[1],sy[2],sy[3],sy[4]}; + memcpy(strideY, tmp, sizeof(tmp)); + } else { + y->shape.unpack(yn, yd, yh, yw, yc); + int tmp[5] = {sy[0],sy[2],sy[3],sy[4],sy[1]}; + memcpy(strideY, tmp, sizeof(tmp)); + } + int dimY[] = {yn, yc, yd, yh, yw}; + + checkCudaErrors(cudnnSetTensorNdDescriptor( + cudnnOdesc, getDataType(), + 5, dimY, strideY + )); + + cudnnConvolutionFwdAlgo_t algos[] = { + CUDNN_CONVOLUTION_FWD_ALGO_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_FFT, + CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, + }; + int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; + int perf_count; + STACK_ALLOC(cudnnConvolutionFwdAlgoPerf_t,perf_results,num_algos); + cudnnConvolutionFwdAlgo_t algo; + bool benchmark=true; + + JK& jk = get_jk(); + jk.clear(); + jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << "," << dimX[4] << ","; + jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << "," << dimW[4] << ","; + jk << paddingd << paddingh << paddingw << "," << strided << strideh <second; + else { + if (fwd_algo_cache.size()>=max_cache_size) benchmark = false; + if (benchmark) { + size_t max_ws_size = 0; + for (int i = 0; i < num_algos; i++) { + size_t sz; + cudnnStatus_t ret = cudnnGetConvolutionForwardWorkspaceSize( + handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc, + cudnnOdesc, algos[i], &sz); + // continue if use too much workspace + if (sz > mem_info.total_cuda_ram * max_workspace_ratio) continue; + if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz; + } + size_t allocation; + void* ws = exe.temp_allocator->alloc(max_ws_size, allocation); + checkCudaErrors(cudnnFindConvolutionForwardAlgorithmEx( + handle_, + cudnnIdesc, x->ptr(), + cudnnFdesc, w->ptr(), + cudnnConvDesc, + cudnnOdesc, y->ptr(), + num_algos, + &perf_count, + perf_results, + ws, + max_ws_size)); + exe.temp_allocator->free(ws, max_ws_size, allocation); + } else { + checkCudaErrors(cudnnGetConvolutionForwardAlgorithm_v7( + handle_, + cudnnIdesc, + cudnnFdesc, + cudnnConvDesc, + cudnnOdesc, + num_algos, + &perf_count, + perf_results)); + } + int best_algo_idx=-1; + for (int i = 0; i < perf_count; i++) + if (perf_results[i].status == CUDNN_STATUS_SUCCESS){ + best_algo_idx=i; + break; + } + ASSERT(best_algo_idx!=-1); + algo=perf_results[best_algo_idx].algo; + if (benchmark) { + fwd_algo_cache[jk.to_string()] = algo; + if (fwd_algo_cache.size()==max_cache_size) + LOGw << "forward_ algorithm cache is full"; + } + } + + // TODO: warp work space + void *workSpace = 0; + size_t workSpaceSize; + checkCudaErrors (cudnnGetConvolutionForwardWorkspaceSize( + handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc, + cudnnOdesc, algo, &workSpaceSize) ); + size_t allocation; + if (workSpaceSize > 0) { + workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation); + } + float alpha=1, beta=0; + checkCudaErrors(cudnnConvolutionForward( + handle_, + (void*)(&alpha), + cudnnIdesc, x->ptr(), + cudnnFdesc, w->ptr(), + cudnnConvDesc, + algo, + workSpace, workSpaceSize, + (void*)(&beta), + cudnnOdesc, y->ptr()) + ); + if (workSpace) + exe.temp_allocator->free(workSpace, workSpaceSize, allocation); + + checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc )); + checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc )); + checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnOdesc )); + checkCudaErrors(cudnnDestroyConvolutionDescriptor( cudnnConvDesc )); +} +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_op.h b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_op.h new file mode 100644 index 00000000..f6b40038 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv3d_op.h @@ -0,0 +1,24 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CudnnConv3dOp : Op { + Var* x, * w, * y; + int strided, strideh, stridew, paddingd, paddingh, paddingw, dilationd, dilationh, dilationw, groups; + string xformat; + CudnnConv3dOp(Var* x, Var* w, int strided, int strideh, int stridew, int paddingd, int paddingh, int paddingw, int dilationd=1, int dilationh=1, int dilationw=1, int groups=1, string xformat="ncdhw"); + + const char* name() const override { return "cudnn_conv3d"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc new file mode 100644 index 00000000..4545584b --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.cc @@ -0,0 +1,314 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang +// Guowei Yang <471184555@qq.com> +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "mem/allocator.h" +#include "var.h" +#include "cudnn_conv_backward_w_op.h" +#include "cudnn_wrapper.h" +#include "executor.h" +#include "ops/op_register.h" +#include "mem/mem_info.h" + +using namespace std; + +namespace jittor { + +extern int use_tensorcore; + +static inline int findc(const string& format, const char& c) { + if (c==format[0]) return 0; + if (c==format[1]) return 1; + if (c==format[2]) return 2; + ASSERT(c==format[3]) << "Not a valid format" << format << c; + return 3; +} + +#ifndef JIT +static inline void get_shape(Var* x, const char* f, const string& format, int& a, int& b, int &c, int& d) { + auto& shape = x->shape; + a = shape[findc(format, f[0])]; + b = shape[findc(format, f[1])]; + c = shape[findc(format, f[2])]; + d = shape[findc(format, f[3])]; +} + +static inline void set_shape(Var* x, const char* f, const string& format, int a, int b, int c, int d) { + int64 shape[4]; + shape[findc(format, f[0])] = a; + shape[findc(format, f[1])] = b; + shape[findc(format, f[2])] = c; + shape[findc(format, f[3])] = d; + x->set_shape(NanoVector( + shape[0], shape[1], shape[2], shape[3])); +} + +CudnnConvBackwardWOp::CudnnConvBackwardWOp(Var* x, Var* dy, int kh, int kw, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups, string xformat, string wformat, string yformat) + : x(x), dy(dy), kh(kh), kw(kw), strideh(strideh), stridew(stridew), paddingh(paddingh), paddingw(paddingw), dilationh(dilationh), dilationw(dilationw), groups(groups), + xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) { + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_manual_set_vnbb); + x->flags.set(NodeFlags::_needed_by_backward); + dy->flags.set(NodeFlags::_needed_by_backward); + dw = create_output(nullptr, dtype_infer(dy->ns, x->ns)); +} + +void CudnnConvBackwardWOp::infer_shape() { + ASSERTop(x->shape.size(),==,4); + ASSERTop(dy->shape.size(),==,4); + int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw; + get_shape(x, "abcd", xformat, xn, xc, xh, xw); + get_shape(dy, "abcd", yformat, yn, yc, yh, yw); + wco = yc, wci = xc / groups; + wh = kh; + ww = kw; + set_shape(dw, "oihw", wformat, wco, wci, wh, ww); +} + +void CudnnConvBackwardWOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype(); + jk << "«Ty:" << dy->dtype(); + jk << "«Tw:" << dw->dtype(); + jk << "«XFORMAT:" << xformat; + jk << "«WFORMAT:" << wformat; + jk << "«YFORMAT:" << yformat; +} + +static auto make_conv = get_op_info("cudnn_conv") + .get_constructor(); +static auto make_backwardx = get_op_info("cudnn_conv_backward_x") + .get_constructor(); + +VarPtr CudnnConvBackwardWOp::grad(Var* out, Var* dout, Var* v, int v_index) { + int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw; + + if (xformat == "nchw") { + x->shape.unpack(xn, xc, xh, xw); + dy->shape.unpack(yn, yc, yh, yw); + } else { + x->shape.unpack(xn, xh, xw, xc); + dy->shape.unpack(yn, yh, yw, yc); + } + + if (v_index == 0) { + return make_backwardx(dout, dy, xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat); + } else { + return make_conv(x, dout, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat); + } +} + +unordered_map bwdw_algo_cache; + +#else // JIT +#ifdef JIT_cuda + +#pragma clang diagnostic ignored "-Wtautological-compare" + +EXTERN_LIB unordered_map bwdw_algo_cache; + +void CudnnConvBackwardWOp::jit_run() { + auto w = dw; + auto y = dy; + cudnnHandle_t& handle_ = cudnn_handle; + + cudnnTensorDescriptor_t cudnnIdesc; + cudnnFilterDescriptor_t cudnnFdesc; + cudnnTensorDescriptor_t cudnnOdesc; + cudnnConvolutionDescriptor_t cudnnConvDesc; + + checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnIdesc )); + checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc )); + checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc )); + checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc )); + + int dimX[] = { + (int)x->shape[findc("@XFORMAT", 'a')], // n + (int)x->shape[findc("@XFORMAT", 'b')], // c + (int)x->shape[findc("@XFORMAT", 'c')], // h + (int)x->shape[findc("@XFORMAT", 'd')], // w + }; + int _strideX[] = {0,0,0,1}; + for (int i=2; i>=0; i--) _strideX[i] = _strideX[i+1] * x->shape[i+1]; + int strideX[] = { + _strideX[findc("@XFORMAT", 'a')], // n + _strideX[findc("@XFORMAT", 'b')], // c + _strideX[findc("@XFORMAT", 'c')], // h + _strideX[findc("@XFORMAT", 'd')], // w + }; + // dimX: nchw + checkCudaErrors(cudnnSetTensorNdDescriptor( + cudnnIdesc, getDataType(), + 4, dimX, strideX + )); + + auto ws = w->shape; + int dimW[] = {(int)ws[0],(int)ws[1],(int)ws[2],(int)ws[3]}; + // cudnn only support this two format + // https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnSetFilterNdDescriptor + #define filterFormat_oihw CUDNN_TENSOR_NCHW + #define filterFormat_ohwi CUDNN_TENSOR_NHWC + + // dimW: KCRS(oihw) + checkCudaErrors(cudnnSetFilterNdDescriptor( + cudnnFdesc, getDataType(), + filterFormat_@WFORMAT, 4, dimW + )); + + int padA[] = {paddingh, paddingw}; + int convstrideA[] = {strideh, stridew}; + int dilationA[] = {dilationh, dilationw}; + // difference between + // CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION + // is the kernel rc order + // currently, No perf difference is observed between + // this two mode + checkCudaErrors(cudnnSetConvolutionNdDescriptor( + cudnnConvDesc, /*convDim=*/2, + padA, convstrideA, dilationA, + CUDNN_CROSS_CORRELATION, getDataType() + )); + // MIOpen requires groups to be set after descriptor initialization + checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups )); + + // using tensor core + if(use_tensorcore){ + // CUDNN_TENSOR_OP_MATH + // The use of Tensor Core operations is permitted but will not actively perform datatype down conversion on tensors in order to utilize Tensor Cores. + // CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION + // The use of Tensor Core operations is permitted and will actively perform datatype down conversion on tensors in order to utilize Tensor Cores. + + checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) ); + } + + int dimY[] = { + (int)y->shape[findc("@YFORMAT", 'a')], // n + (int)y->shape[findc("@YFORMAT", 'b')], // c + (int)y->shape[findc("@YFORMAT", 'c')], // h + (int)y->shape[findc("@YFORMAT", 'd')], // w + }; + int _strideY[] = {0,0,0,1}; + for (int i=2; i>=0; i--) _strideY[i] = _strideY[i+1] * y->shape[i+1]; + int strideY[] = { + _strideY[findc("@YFORMAT", 'a')], // n + _strideY[findc("@YFORMAT", 'b')], // c + _strideY[findc("@YFORMAT", 'c')], // h + _strideY[findc("@YFORMAT", 'd')], // w + }; + checkCudaErrors(cudnnSetTensorNdDescriptor( + cudnnOdesc, getDataType(), + 4, dimY, strideY + )); + + cudnnConvolutionBwdFilterAlgo_t algos[] = { + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, + }; + int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; + int perf_count; + STACK_ALLOC(cudnnConvolutionBwdFilterAlgoPerf_t,perf_results,num_algos); + cudnnConvolutionBwdFilterAlgo_t algo; + bool benchmark=true; + + JK& jk = get_jk(); + jk.clear(); + jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ","; + jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ","; + jk << paddingh << paddingw << "," <second; + else { + if (bwdw_algo_cache.size()>=max_cache_size) benchmark = false; + if (benchmark) { + size_t max_ws_size = 0; + for (int i = 0; i < num_algos; i++) { + size_t sz; + cudnnStatus_t ret = cudnnGetConvolutionBackwardFilterWorkspaceSize(handle_, cudnnIdesc, cudnnOdesc, cudnnConvDesc, cudnnFdesc, algos[i], &sz); + // continue if use too much workspace + if (sz > mem_info.total_cuda_ram * max_workspace_ratio) continue; + if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz; + } + size_t allocation; + void* ws = exe.temp_allocator->alloc(max_ws_size, allocation); + checkCudaErrors(cudnnFindConvolutionBackwardFilterAlgorithmEx( + handle_, + cudnnIdesc, x->ptr(), + cudnnOdesc, y->ptr(), + cudnnConvDesc, + cudnnFdesc, w->ptr(), + num_algos, + &perf_count, + perf_results, + ws, + max_ws_size)); + exe.temp_allocator->free(ws, max_ws_size, allocation); + } else { + checkCudaErrors(cudnnGetConvolutionBackwardFilterAlgorithm_v7( + handle_, + cudnnIdesc, + cudnnOdesc, + cudnnConvDesc, + cudnnFdesc, + num_algos, + &perf_count, + perf_results)); + } + int best_algo_idx=-1; + for (int i = 0; i < perf_count; i++) + if (perf_results[i].status == CUDNN_STATUS_SUCCESS){ + best_algo_idx=i; + break; + } + ASSERT(best_algo_idx!=-1); + algo=perf_results[best_algo_idx].algo; + if (benchmark) { + bwdw_algo_cache[jk.to_string()] = algo; + if (bwdw_algo_cache.size()==max_cache_size) + LOGw << "backward w algorithm cache is full"; + } + } + + // TODO: warp work space + void *workSpace = 0; + size_t workSpaceSize; + checkCudaErrors (cudnnGetConvolutionBackwardFilterWorkspaceSize( + handle_, cudnnIdesc, cudnnOdesc, cudnnConvDesc, + cudnnFdesc, algo, &workSpaceSize)); + size_t allocation; + if (workSpaceSize > 0) { + workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation); + } + float alpha=1, beta=0; + checkCudaErrors(cudnnConvolutionBackwardFilter( + handle_, + (void*)(&alpha), + cudnnIdesc, x->ptr(), + cudnnOdesc, y->ptr(), + cudnnConvDesc, + algo, + workSpace, workSpaceSize, + (void*)(&beta), + cudnnFdesc, w->ptr()) + ); + if (workSpace) + exe.temp_allocator->free(workSpace, workSpaceSize, allocation); + + checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc )); + checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc )); + checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnOdesc )); + checkCudaErrors(cudnnDestroyConvolutionDescriptor( cudnnConvDesc )); +} +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.h b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.h new file mode 100644 index 00000000..bd173db5 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_w_op.h @@ -0,0 +1,28 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang +// Guowei Yang <471184555@qq.com> +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CudnnConvBackwardWOp : Op { + Var* x, * dy, * dw; + int kh, kw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; + string xformat, wformat, yformat; + + CudnnConvBackwardWOp(Var* x, Var* y, int kh, int kw, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd"); + + const char* name() const override { return "cudnn_conv_backward_w"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc new file mode 100644 index 00000000..59e9fcf8 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc @@ -0,0 +1,301 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang +// Guowei Yang <471184555@qq.com> +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "mem/allocator.h" +#include "var.h" +#include "cudnn_conv_backward_x_op.h" +#include "cudnn_wrapper.h" +#include "executor.h" +#include "ops/op_register.h" +#include "mem/mem_info.h" + +using namespace std; + +static inline int findc(const char* format, const char& c) { + if (c==format[0]) return 0; + if (c==format[1]) return 1; + if (c==format[2]) return 2; + ASSERT(c==format[3]) << "Not a valid format" << format << c; + return 3; +} + +namespace jittor { + +extern int use_tensorcore; + +#ifndef JIT + +static inline void get_shape(Var* x, const char* f, const string& format, int& a, int& b, int &c, int& d) { + auto& shape = x->shape; + a = shape[findc(format.c_str(), f[0])]; + b = shape[findc(format.c_str(), f[1])]; + c = shape[findc(format.c_str(), f[2])]; + d = shape[findc(format.c_str(), f[3])]; +} + +static inline void set_shape(Var* x, const char* f, const string& format, int a, int b, int c, int d) { + int64 shape[4]; + shape[findc(format.c_str(), f[0])] = a; + shape[findc(format.c_str(), f[1])] = b; + shape[findc(format.c_str(), f[2])] = c; + shape[findc(format.c_str(), f[3])] = d; + x->set_shape(NanoVector( + shape[0], shape[1], shape[2], shape[3])); +} + +CudnnConvBackwardXOp::CudnnConvBackwardXOp(Var* w, Var* dy, int height, int width, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups, string xformat, string wformat, string yformat) + : w(w), dy(dy), xh(height), xw(width), strideh(strideh), stridew(stridew), paddingh(paddingh), paddingw(paddingw), dilationh(dilationh), dilationw(dilationw), groups(groups), + xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) { + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_manual_set_vnbb); + w->flags.set(NodeFlags::_needed_by_backward); + dy->flags.set(NodeFlags::_needed_by_backward); + dx = create_output(nullptr, dtype_infer(dy->ns, w->ns)); +} + +void CudnnConvBackwardXOp::infer_shape() { + ASSERTop(w->shape.size(),==,4); + ASSERTop(dy->shape.size(),==,4); + int xn, xc, wh, ww, wci, wco, yn, yc, yh, yw; + get_shape(w, "oihw", wformat, wco, wci, wh, ww); + get_shape(dy, "abcd", yformat, yn, yc, yh, yw); + xn = yn, xc = wci * groups; + set_shape(dx, "abcd", xformat, xn, xc, xh, xw); +} + +void CudnnConvBackwardXOp::jit_prepare(JK& jk) { + jk << "«Tx:" << dx->dtype(); + jk << "«Ty:" << dy->dtype(); + jk << "«Tw:" << w->dtype(); + jk << "«XFORMAT:" << xformat; + jk << "«WFORMAT:" << wformat; + jk << "«YFORMAT:" << yformat; +} + +static auto make_conv = get_op_info("cudnn_conv") + .get_constructor(); +static auto make_backwardw = get_op_info("cudnn_conv_backward_w") + .get_constructor(); + +VarPtr CudnnConvBackwardXOp::grad(Var* out, Var* dout, Var* v, int v_index) { + int xn, xc, wh, ww, wci, wco, yn, yc, yd, yh, yw; + w->shape.unpack(wco, wci, wh, ww); + + if (v_index == 0) { + return make_backwardw(dout, dy, wh, ww, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat); + } else { + return make_conv(dout, w, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat); + } +} +unordered_map bwdx_algo_cache; + +#else // JIT +#ifdef JIT_cuda + +#pragma clang diagnostic ignored "-Wtautological-compare" + +EXTERN_LIB unordered_map bwdx_algo_cache; + +void CudnnConvBackwardXOp::jit_run() { + auto x = dx; + auto y = dy; + cudnnHandle_t& handle_ = cudnn_handle; + + cudnnTensorDescriptor_t cudnnIdesc; + cudnnFilterDescriptor_t cudnnFdesc; + cudnnTensorDescriptor_t cudnnOdesc; + cudnnConvolutionDescriptor_t cudnnConvDesc; + + checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnIdesc )); + checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc )); + checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc )); + checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc )); + + int dimX[] = { + (int)x->shape[findc("@XFORMAT", 'a')], // n + (int)x->shape[findc("@XFORMAT", 'b')], // c + (int)x->shape[findc("@XFORMAT", 'c')], // h + (int)x->shape[findc("@XFORMAT", 'd')], // w + }; + int _strideX[] = {0,0,0,1}; + for (int i=2; i>=0; i--) _strideX[i] = _strideX[i+1] * x->shape[i+1]; + int strideX[] = { + _strideX[findc("@XFORMAT", 'a')], // n + _strideX[findc("@XFORMAT", 'b')], // c + _strideX[findc("@XFORMAT", 'c')], // h + _strideX[findc("@XFORMAT", 'd')], // w + }; + // dimX: nchw + checkCudaErrors(cudnnSetTensorNdDescriptor( + cudnnIdesc, getDataType(), + 4, dimX, strideX + )); + + auto ws = w->shape; + int dimW[] = {(int)ws[0],(int)ws[1],(int)ws[2],(int)ws[3]}; + // cudnn only support this two format + // https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnSetFilterNdDescriptor + #define filterFormat_oihw CUDNN_TENSOR_NCHW + #define filterFormat_iohw CUDNN_TENSOR_NCHW + #define filterFormat_ohwi CUDNN_TENSOR_NHWC + + // dimW: KCRS(oihw) + checkCudaErrors(cudnnSetFilterNdDescriptor( + cudnnFdesc, getDataType(), + filterFormat_@WFORMAT, 4, dimW + )); + + int padA[] = {paddingh, paddingw}; + int convstrideA[] = {strideh, stridew}; + int dilationA[] = {dilationh, dilationw}; + // difference between + // CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION + // is the kernel rc order + // currently, No perf difference is observed between + // this two mode + checkCudaErrors(cudnnSetConvolutionNdDescriptor( + cudnnConvDesc, 2, + padA, convstrideA, dilationA, + CUDNN_CROSS_CORRELATION, getDataType() + )); + // MIOpen requires groups to be set after descriptor initialization + checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups )); + + // using tensor core + if(use_tensorcore){ + checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) ); + } + + int dimY[] = { + (int)y->shape[findc("@YFORMAT", 'a')], // n + (int)y->shape[findc("@YFORMAT", 'b')], // c + (int)y->shape[findc("@YFORMAT", 'c')], // h + (int)y->shape[findc("@YFORMAT", 'd')], // w + }; + int _strideY[] = {0,0,0,1}; + for (int i=2; i>=0; i--) _strideY[i] = _strideY[i+1] * y->shape[i+1]; + int strideY[] = { + _strideY[findc("@YFORMAT", 'a')], // n + _strideY[findc("@YFORMAT", 'b')], // c + _strideY[findc("@YFORMAT", 'c')], // h + _strideY[findc("@YFORMAT", 'd')], // w + }; + checkCudaErrors(cudnnSetTensorNdDescriptor( + cudnnOdesc, getDataType(), + 4, dimY, strideY + )); + + cudnnConvolutionBwdDataAlgo_t algos[] = { + CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED + }; + int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; + int perf_count; + STACK_ALLOC(cudnnConvolutionBwdDataAlgoPerf_t,perf_results,num_algos); + cudnnConvolutionBwdDataAlgo_t algo; + bool benchmark=true; + + JK& jk = get_jk(); + jk.clear(); + jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ","; + jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ","; + jk << paddingh << paddingw << "," <second; + else { + if (bwdx_algo_cache.size()>=max_cache_size) benchmark = false; + if (benchmark) { + size_t max_ws_size = 0; + for (int i = 0; i < num_algos; i++) { + size_t sz; + cudnnStatus_t ret = cudnnGetConvolutionBackwardDataWorkspaceSize(handle_, cudnnFdesc, cudnnOdesc, cudnnConvDesc, cudnnIdesc, algos[i], &sz); + // continue if use too much workspace + if (sz > mem_info.total_cuda_ram * max_workspace_ratio) continue; + if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz; + } + size_t allocation; + void* ws = exe.temp_allocator->alloc(max_ws_size, allocation); + checkCudaErrors(cudnnFindConvolutionBackwardDataAlgorithmEx( + handle_, + cudnnFdesc, w->ptr(), + cudnnOdesc, y->ptr(), + cudnnConvDesc, + cudnnIdesc, x->ptr(), + num_algos, + &perf_count, + perf_results, + ws, + max_ws_size)); + exe.temp_allocator->free(ws, max_ws_size, allocation); + } else { + checkCudaErrors(cudnnGetConvolutionBackwardDataAlgorithm_v7( + handle_, + cudnnFdesc, + cudnnOdesc, + cudnnConvDesc, + cudnnIdesc, + num_algos, + &perf_count, + perf_results)); + } + int best_algo_idx=-1; + for (int i = 0; i < perf_count; i++) + if (perf_results[i].status == CUDNN_STATUS_SUCCESS){ + best_algo_idx=i; + break; + } + ASSERT(best_algo_idx!=-1); + algo=perf_results[best_algo_idx].algo; + if (benchmark) { + bwdx_algo_cache[jk.to_string()] = algo; + if (bwdx_algo_cache.size()==max_cache_size) + LOGw << "backward x algorithm cache is full"; + } + } + + // TODO: warp work space + void *workSpace = 0; + size_t workSpaceSize; + checkCudaErrors (cudnnGetConvolutionBackwardDataWorkspaceSize( + handle_, cudnnFdesc, cudnnOdesc, cudnnConvDesc, + cudnnIdesc, algo, &workSpaceSize)); + size_t allocation; + if (workSpaceSize > 0) { + workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation); + } + float alpha=1, beta=0; + checkCudaErrors(cudnnConvolutionBackwardData( + handle_, + (void*)(&alpha), + cudnnFdesc, w->ptr(), + cudnnOdesc, y->ptr(), + cudnnConvDesc, + algo, + workSpace, workSpaceSize, + (void*)(&beta), + cudnnIdesc, x->ptr()) + ); + if (workSpace) + exe.temp_allocator->free(workSpace, workSpaceSize, allocation); + + checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc )); + checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc )); + checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnOdesc )); + checkCudaErrors(cudnnDestroyConvolutionDescriptor( cudnnConvDesc )); +} +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.h b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.h new file mode 100644 index 00000000..0cd36dd3 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.h @@ -0,0 +1,28 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang +// Guowei Yang <471184555@qq.com> +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CudnnConvBackwardXOp : Op { + Var* w, * dy, * dx; + int xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; + string xformat, wformat, yformat; + + CudnnConvBackwardXOp(Var* w, Var* y, int height, int width, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd"); + + const char* name() const override { return "cudnn_conv_backward_x"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_op.cc new file mode 100644 index 00000000..87cfefb9 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_op.cc @@ -0,0 +1,315 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "cudnn_conv_op.h" +#include "cudnn_wrapper.h" +#include "executor.h" +#include "ops/op_register.h" +#include "mem/mem_info.h" + +using namespace std; + +namespace jittor { + +extern int use_tensorcore; + +static inline int findc(const char* format, const char& c) { + if (c==format[0]) return 0; + if (c==format[1]) return 1; + if (c==format[2]) return 2; + ASSERT(c==format[3]) << "Not a valid format" << format << c; + return 3; +} + +#ifndef JIT + +static inline void get_shape(Var* x, const char* f, const string& format, int& a, int& b, int &c, int& d) { + auto& shape = x->shape; + a = shape[findc(format.c_str(), f[0])]; + b = shape[findc(format.c_str(), f[1])]; + c = shape[findc(format.c_str(), f[2])]; + d = shape[findc(format.c_str(), f[3])]; +} + +static inline void set_shape(Var* x, const char* f, const string& format, int a, int b, int c, int d) { + int64 shape[4]; + shape[findc(format.c_str(), f[0])] = a; + shape[findc(format.c_str(), f[1])] = b; + shape[findc(format.c_str(), f[2])] = c; + shape[findc(format.c_str(), f[3])] = d; + x->set_shape(NanoVector( + shape[0], shape[1], shape[2], shape[3])); +} + +CudnnConvOp::CudnnConvOp(Var* x, Var* w, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups, string xformat, string wformat, string yformat) + : x(x), w(w), strideh(strideh), stridew(stridew), paddingh(paddingh), paddingw(paddingw), dilationh(dilationh), dilationw(dilationw), groups(groups), + xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) { + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_manual_set_vnbb); + x->flags.set(NodeFlags::_needed_by_backward); + w->flags.set(NodeFlags::_needed_by_backward); + y = create_output(nullptr, dtype_infer(x->ns, w->ns)); + if (!this->yformat.size()) + this->yformat = this->xformat; +} + +void CudnnConvOp::infer_shape() { + ASSERTop(x->shape.size(),==,4); + ASSERTop(w->shape.size(),==,4); + int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw; + get_shape(x, "abcd", xformat, xn, xc, xh, xw); + get_shape(w, "oihw", wformat, wco, wci, wh, ww); + ASSERTop(wci * groups,==,xc); + yn = xn, yc = wco; + yh = (xh+paddingh*2-wh*dilationh+dilationh-1)/strideh+1; + yw = (xw+paddingw*2-ww*dilationw+dilationw-1)/stridew+1; + set_shape(y, "abcd", yformat, yn, yc, yh, yw); +} + +void CudnnConvOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype(); + jk << "«Ty:" << y->dtype(); + jk << "«Tw:" << w->dtype(); + jk << "«XFORMAT:" << xformat; + jk << "«WFORMAT:" << wformat; + jk << "«YFORMAT:" << yformat; +} +static auto make_backwardx = get_op_info("cudnn_conv_backward_x") + .get_constructor(); +static auto make_backwardw = get_op_info("cudnn_conv_backward_w") + .get_constructor(); +VarPtr CudnnConvOp::grad(Var* out, Var* dout, Var* v, int v_index) { + int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw; + if (xformat == "ncdhw") + x->shape.unpack(xn, xc, xh, xw); + else + x->shape.unpack(xn, xh, xw, xc); + w->shape.unpack(wco, wci, wh, ww); + if (v_index == 0) { + return make_backwardx(w, dout, xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat); + } else { + return make_backwardw(x, dout, wh, ww, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups, xformat, wformat, yformat); + } +} + +unordered_map fwd_algo_cache; + +#else // JIT +#ifdef JIT_cuda + +#pragma clang diagnostic ignored "-Wtautological-compare" + +EXTERN_LIB unordered_map fwd_algo_cache; + +void CudnnConvOp::jit_run() { + cudnnHandle_t& handle_ = cudnn_handle; + + cudnnTensorDescriptor_t cudnnIdesc; + cudnnFilterDescriptor_t cudnnFdesc; + cudnnTensorDescriptor_t cudnnOdesc; + cudnnConvolutionDescriptor_t cudnnConvDesc; + + checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnIdesc )); + checkCudaErrors(cudnnCreateFilterDescriptor( &cudnnFdesc )); + checkCudaErrors(cudnnCreateTensorDescriptor( &cudnnOdesc )); + checkCudaErrors(cudnnCreateConvolutionDescriptor( &cudnnConvDesc )); + + + int dimX[] = { + (int)x->shape[findc("@XFORMAT", 'a')], // n + (int)x->shape[findc("@XFORMAT", 'b')], // c + (int)x->shape[findc("@XFORMAT", 'c')], // h + (int)x->shape[findc("@XFORMAT", 'd')], // w + }; + int _strideX[] = {0,0,0,1}; + for (int i=2; i>=0; i--) _strideX[i] = _strideX[i+1] * x->shape[i+1]; + int strideX[] = { + _strideX[findc("@XFORMAT", 'a')], // n + _strideX[findc("@XFORMAT", 'b')], // c + _strideX[findc("@XFORMAT", 'c')], // h + _strideX[findc("@XFORMAT", 'd')], // w + }; + // dimX: nchw + checkCudaErrors(cudnnSetTensorNdDescriptor( + cudnnIdesc, getDataType(), + 4, dimX, strideX + )); + + auto ws = w->shape; + int dimW[] = {(int)ws[0],(int)ws[1],(int)ws[2],(int)ws[3]}; + // cudnn only support this two format + // https://docs.nvidia.com/deeplearning/sdk/cudnn-api/index.html#cudnnSetFilterNdDescriptor + #define filterFormat_oihw CUDNN_TENSOR_NCHW + #define filterFormat_ohwi CUDNN_TENSOR_NHWC + + // dimW: KCRS(oihw) + checkCudaErrors(cudnnSetFilterNdDescriptor( + cudnnFdesc, getDataType(), + filterFormat_@WFORMAT, 4, dimW + )); + + int padA[] = {paddingh, paddingw}; + int convstrideA[] = {strideh, stridew}; + int dilationA[] = {dilationh, dilationw}; + // difference between + // CUDNN_CONVOLUTION and CUDNN_CROSS_CORRELATION + // is the kernel rc order + // currently, No perf difference is observed between + // this two mode + checkCudaErrors(cudnnSetConvolutionNdDescriptor( + cudnnConvDesc, 2, + padA, convstrideA, dilationA, + CUDNN_CROSS_CORRELATION, getDataType() + )); + // MIOpen requires groups to be set after descriptor initialization + checkCudaErrors(cudnnSetConvolutionGroupCount( cudnnConvDesc, groups )); + + // using tensor core + if(use_tensorcore){ + checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) ); + } + bool has_fp16_or_bf16 = x->dtype() == ns_float16 + || y->dtype() == ns_float16 || w->dtype() == ns_float16 + || x->dtype() == ns_bfloat16 + || y->dtype() == ns_bfloat16 || w->dtype() == ns_bfloat16; + + if (has_fp16_or_bf16) { + checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) ); + } + + int dimY[] = { + (int)y->shape[findc("@YFORMAT", 'a')], // n + (int)y->shape[findc("@YFORMAT", 'b')], // c + (int)y->shape[findc("@YFORMAT", 'c')], // h + (int)y->shape[findc("@YFORMAT", 'd')], // w + }; + int _strideY[] = {0,0,0,1}; + for (int i=2; i>=0; i--) _strideY[i] = _strideY[i+1] * y->shape[i+1]; + int strideY[] = { + _strideY[findc("@YFORMAT", 'a')], // n + _strideY[findc("@YFORMAT", 'b')], // c + _strideY[findc("@YFORMAT", 'c')], // h + _strideY[findc("@YFORMAT", 'd')], // w + }; + checkCudaErrors(cudnnSetTensorNdDescriptor( + cudnnOdesc, getDataType(), + 4, dimY, strideY + )); + + cudnnConvolutionFwdAlgo_t algos[] = { + CUDNN_CONVOLUTION_FWD_ALGO_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_FFT, + CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, + }; + int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; + int perf_count; + STACK_ALLOC(cudnnConvolutionFwdAlgoPerf_t,perf_results,num_algos); + cudnnConvolutionFwdAlgo_t algo; + bool benchmark=true; + + JK& jk = get_jk(); + jk.clear(); + jk << dimX[0] << "," << dimX[1] << "," << dimX[2] << "," << dimX[3] << ","; + jk << dimW[0] << "," << dimW[1] << "," << dimW[2] << "," << dimW[3] << ","; + jk << paddingh << paddingw << "," <second; + else { + if (fwd_algo_cache.size()>=max_cache_size) benchmark = false; + if (benchmark) { + size_t max_ws_size = 0; + for (int i = 0; i < num_algos; i++) { + size_t sz; + cudnnStatus_t ret = cudnnGetConvolutionForwardWorkspaceSize( + handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc, + cudnnOdesc, algos[i], &sz); + // continue if use too much workspace + if (sz > mem_info.total_cuda_ram * max_workspace_ratio) continue; + if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz; + } + size_t allocation; + void* ws = exe.temp_allocator->alloc(max_ws_size, allocation); + checkCudaErrors(cudnnFindConvolutionForwardAlgorithmEx( + handle_, + cudnnIdesc, x->ptr(), + cudnnFdesc, w->ptr(), + cudnnConvDesc, + cudnnOdesc, y->ptr(), + num_algos, + &perf_count, + perf_results, + ws, + max_ws_size)); + exe.temp_allocator->free(ws, max_ws_size, allocation); + } else { + checkCudaErrors(cudnnGetConvolutionForwardAlgorithm_v7( + handle_, + cudnnIdesc, + cudnnFdesc, + cudnnConvDesc, + cudnnOdesc, + num_algos, + &perf_count, + perf_results)); + } + int best_algo_idx=-1; + for (int i = 0; i < perf_count; i++) + if (perf_results[i].status == CUDNN_STATUS_SUCCESS){ + best_algo_idx=i; + break; + } + ASSERT(best_algo_idx!=-1); + algo=perf_results[best_algo_idx].algo; + if (benchmark) { + fwd_algo_cache[jk.to_string()] = algo; + if (fwd_algo_cache.size()==max_cache_size) + LOGw << "forward_ algorithm cache is full"; + } + } + + // TODO: warp work space + void *workSpace = 0; + size_t workSpaceSize; + checkCudaErrors (cudnnGetConvolutionForwardWorkspaceSize( + handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc, + cudnnOdesc, algo, &workSpaceSize) ); + size_t allocation; + if (workSpaceSize > 0) { + workSpace = exe.temp_allocator->alloc(workSpaceSize, allocation); + } + float alpha=1, beta=0; + checkCudaErrors(cudnnConvolutionForward( + handle_, + (void*)(&alpha), + cudnnIdesc, x->ptr(), + cudnnFdesc, w->ptr(), + cudnnConvDesc, + algo, + workSpace, workSpaceSize, + (void*)(&beta), + cudnnOdesc, y->ptr()) + ); + if (workSpace) + exe.temp_allocator->free(workSpace, workSpaceSize, allocation); + + checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnIdesc )); + checkCudaErrors(cudnnDestroyFilterDescriptor( cudnnFdesc )); + checkCudaErrors(cudnnDestroyTensorDescriptor( cudnnOdesc )); + checkCudaErrors(cudnnDestroyConvolutionDescriptor( cudnnConvDesc )); +} +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_op.h b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_op.h new file mode 100644 index 00000000..c7ecdd54 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_conv_op.h @@ -0,0 +1,25 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CudnnConvOp : Op { + Var* x, * w, * y; + int strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; + string xformat, wformat, yformat; + /* CudnnConvOp: xformat abcd represents nchw */ + CudnnConvOp(Var* x, Var* w, int strideh, int stridew, int paddingh, int paddingw, int dilationh=1, int dilationw=1, int groups=1, string xformat="abcd", string wformat="oihw", string yformat=""); + + const char* name() const override { return "cudnn_conv"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.cc new file mode 100644 index 00000000..eedd261e --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.cc @@ -0,0 +1,194 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Zheng-Ning Liu +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "cudnn_rnn_descriptor.h" +#include "cudnn_rnn_backward_x_op.h" +#include "cudnn_wrapper.h" +#include "executor.h" +#include "ops/op_register.h" + +namespace jittor { + +#pragma GCC diagnostic ignored "-Wunused-variable" + +#ifndef JIT + +CudnnRnnBackwardXOp::CudnnRnnBackwardXOp(Var *x, Var* hx, Var* cx, Var* y, Var* dy, Var* dhy, Var* dcy, Var* w, Var* reservation, + string mode, int input_size, int hidden_size, int num_layers, int proj_size, + double dropout, bool bias, bool bidirectional) + : x(x), hx(hx), cx(cx), y(y), dy(dy), dhy(dhy), dcy(dcy), w(w), reservation(reservation), + mode(mode), input_size(input_size), hidden_size(hidden_size), num_layers(num_layers), + proj_size(proj_size), dropout(dropout), bias(bias), bidirectional(bidirectional) { + + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + + ASSERTop(mode,==,"lstm"); + ASSERTop(proj_size,==,0); + init_rnn(); +} + +CudnnRnnBackwardXOp::CudnnRnnBackwardXOp(Var* x, Var* hx, Var* y, Var* dy, Var* dhy, Var* w, Var* reservation, + string mode, int input_size, int hidden_size, int num_layers, int proj_size, + double dropout, bool bias, bool bidirectional) + : x(x), hx(hx), cx(nullptr), y(y), dy(dy), dhy(dhy), dcy(nullptr), w(w), reservation(reservation), + mode(mode), input_size(input_size), hidden_size(hidden_size), num_layers(num_layers), + proj_size(proj_size), dropout(dropout), bias(bias), bidirectional(bidirectional) { + + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + + ASSERTop(mode,!=,"lstm"); + ASSERTop(proj_size,==,0); + init_rnn(); +} + +void CudnnRnnBackwardXOp::init_rnn() { + dx = create_output(nullptr, ns_float32); + dhx = create_output(nullptr, ns_float32); + + if (mode == "lstm") + dcx = create_output(nullptr, ns_float32); + else + dcx = nullptr; + + dw = create_output(nullptr, dtype_infer(x->ns, y->ns)); + + seq_length = y->shape[0]; + batch_size = y->shape[1]; +} + +void CudnnRnnBackwardXOp::infer_shape() { + dx->set_shape(NanoVector(seq_length, batch_size, input_size)); + + int num_directions = 1 + bidirectional; + if (proj_size > 0) + dhx->set_shape(NanoVector(num_layers * num_directions, batch_size, proj_size)); + else + dhx->set_shape(NanoVector(num_layers * num_directions, batch_size, hidden_size)); + + if (dcx) + dcx->set_shape(NanoVector(num_layers * num_directions, batch_size, hidden_size)); + + dw->set_shape(w->shape); +} + +void CudnnRnnBackwardXOp::jit_prepare(JK& jk) { + jk << "«Tx:" << hx->dtype(); + jk << "«Ty:" << y->dtype(); + jk << "«Tw:" << w->dtype(); +} + +#else // JIT +#ifdef JIT_cuda + +template __inline__ cudnnDataType_t getDataType(); +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_HALF; } +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_FLOAT; } + +void CudnnRnnBackwardXOp::jit_run() { + int num_directions = 1 + bidirectional; + + int in_dims[3] = {batch_size, input_size, 1}; + int out_dims[3] = {batch_size, hidden_size * num_directions, 1}; + int in_strides[3] = {in_dims[1] * in_dims[2], in_dims[2], 1}; + int out_strides[3] = {out_dims[1] * out_dims[2], out_dims[2], 1}; + int hidden_dims[3] = {num_layers * num_directions, batch_size, hidden_size}; + int hidden_strides[3] = {hidden_dims[1] * hidden_dims[2], hidden_dims[2], 1}; + + vector xDesc(seq_length), dxDesc(seq_length); + vector yDesc(seq_length), dyDesc(seq_length); + + for (int i = 0; i < seq_length; ++i) { + checkCudaErrors(cudnnCreateTensorDescriptor(&xDesc[i])); + checkCudaErrors(cudnnCreateTensorDescriptor(&dxDesc[i])); + checkCudaErrors(cudnnCreateTensorDescriptor(&yDesc[i])); + checkCudaErrors(cudnnCreateTensorDescriptor(&dyDesc[i])); + checkCudaErrors(cudnnSetTensorNdDescriptor(xDesc[i], getDataType(), 3, in_dims, in_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(dxDesc[i], getDataType(), 3, in_dims, in_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(yDesc[i], getDataType(), 3, out_dims, out_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(dyDesc[i], getDataType(), 3, out_dims, out_strides)); + } + + cudnnTensorDescriptor_t dhyDesc, dcyDesc; + cudnnTensorDescriptor_t hxDesc, cxDesc, dhxDesc, dcxDesc; + checkCudaErrors(cudnnCreateTensorDescriptor(&hxDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&cxDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&dhxDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&dcxDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&dhyDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&dcyDesc)); + checkCudaErrors(cudnnSetTensorNdDescriptor(hxDesc, getDataType(), 3, hidden_dims, hidden_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(cxDesc, getDataType(), 3, hidden_dims, hidden_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(dhxDesc, getDataType(), 3, hidden_dims, hidden_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(dcxDesc, getDataType(), 3, hidden_dims, hidden_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(dhyDesc, getDataType(), 3, hidden_dims, hidden_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(dcyDesc, getDataType(), 3, hidden_dims, hidden_strides)); + + RnnWeightDescriptor w_desc(w->size); + RnnDescriptor rnn_desc(cudnn_handle, mode, hidden_size, num_layers, dropout, bidirectional); + + void *work_space; + size_t work_space_size = rnn_desc.work_space_size(dxDesc.data(), seq_length); + size_t work_space_allocation; + if (work_space_size > 0) + work_space = exe.temp_allocator->alloc(work_space_size, work_space_allocation); + + size_t reserveSpaceSize = reservation->size; + + checkCudaErrors(cudnnRNNBackwardData( + cudnn_handle, rnn_desc.desc, + seq_length, + yDesc.data(), y->ptr(), + dyDesc.data(), dy->ptr(), + dhyDesc, dhy->ptr(), + dcyDesc, mode == "lstm" ? dcy->ptr(): nullptr, + w_desc.desc, w->ptr(), + hxDesc, hx->ptr(), + cxDesc, mode == "lstm" ? cx->ptr() : nullptr, + dxDesc.data(), dx->ptr(), + dhxDesc, dhx->ptr(), + dcxDesc, mode == "lstm" ? dcx->ptr() : nullptr, + work_space, work_space_size, + reservation->ptr(), reservation->size + )); + + checkCudaErrors(cudaMemset(dw->ptr(), 0, dw->size)); + + checkCudaErrors(cudnnRNNBackwardWeights( + cudnn_handle, rnn_desc.desc, + seq_length, + xDesc.data(), x->ptr(), + hxDesc, hx->ptr(), + yDesc.data(), y->ptr(), + work_space, work_space_size, + w_desc.desc, dw->ptr(), + reservation->ptr(), reservation->size + )); + + for (int i = 0; i < seq_length; ++i) { + checkCudaErrors(cudnnDestroyTensorDescriptor(xDesc[i])); + checkCudaErrors(cudnnDestroyTensorDescriptor(dxDesc[i])); + checkCudaErrors(cudnnDestroyTensorDescriptor(yDesc[i])); + checkCudaErrors(cudnnDestroyTensorDescriptor(dyDesc[i])); + } + + checkCudaErrors(cudnnDestroyTensorDescriptor(dhyDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(dcyDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(hxDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(cxDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(dhxDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(dcxDesc)); + + if (work_space) + exe.temp_allocator->free(work_space, work_space_size, work_space_allocation); +} + +#endif +#endif // JIT +} diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.h b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.h new file mode 100644 index 00000000..7cf5164a --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_backward_x_op.h @@ -0,0 +1,38 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Zheng-Ning Liu +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CudnnRnnBackwardXOp : Op { + Var* x, * hx, * cx; + Var* y, * dy, * dhy, * dcy; + Var* w; + Var* dx, * dhx, * dcx, * dw; + Var* reservation; + string mode; + int input_size, hidden_size, num_layers, proj_size, batch_size; + int seq_length; + float dropout; + bool bias, bidirectional; + + // @attrs(multiple_outputs) + CudnnRnnBackwardXOp(Var* x, Var* hx, Var* cx, Var* y, Var* dy, Var* dhy, Var* dcy, Var* w, Var* reservation, string mode, int input_size, int hidden_size, int num_layers, int proj_size, double dropout, bool bias, bool bidirectional); + + // @attrs(multiple_outputs) + CudnnRnnBackwardXOp(Var* x, Var* hx, Var* y, Var* dy, Var* dhy, Var* w, Var* reservation, string mode, int input_size, int hidden_size, int num_layers, int proj_size, double dropout, bool bias, bool bidirectional); + + void init_rnn(); + + const char* name() const override { return "cudnn_rnn_backward_x"; } + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.cc new file mode 100644 index 00000000..c23752f6 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.cc @@ -0,0 +1,234 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Zheng-Ning Liu +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "cudnn_rnn_descriptor.h" +#include "cudnn_rnn_op.h" +#include "cudnn_wrapper.h" +#include "executor.h" +#include "ops/op_register.h" + +using namespace std; + +namespace jittor { + +#pragma GCC diagnostic ignored "-Wunused-variable" + +#ifndef JIT + +CudnnRnnOp::CudnnRnnOp(Var* x, Var* hx, Var* cx, Var* w, + string mode, int input_size, int hidden_size, int num_layers, int proj_size, + double dropout, bool bias, bool bidirectional, bool is_train) + : x(x), hx(hx), cx(cx), w(w), mode(mode), input_size(input_size), hidden_size(hidden_size), + num_layers(num_layers), proj_size(proj_size), dropout(dropout), bias(bias), + bidirectional(bidirectional), is_train(is_train) { + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_grads, 1); + + ASSERTop(mode,==,"lstm"); + ASSERTop(proj_size,==,0); + init_rnn(); +} + +CudnnRnnOp::CudnnRnnOp(Var* x, Var* hx, Var* w, + string mode, int input_size, int hidden_size, int num_layers, int proj_size, + double dropout, bool bias, bool bidirectional, bool is_train) + : x(x), hx(hx), cx(nullptr), w(w), mode(mode), input_size(input_size), hidden_size(hidden_size), + num_layers(num_layers), proj_size(proj_size), dropout(dropout), bias(bias), + bidirectional(bidirectional), is_train(is_train) { + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_grads, 1); + + ASSERTop(mode,!=,"lstm"); + ASSERTop(proj_size,==,0); + init_rnn(); +} + +void CudnnRnnOp::init_rnn() { + y = create_output(nullptr, dtype_infer(x->ns, w->ns)); + hy = create_output(nullptr, dtype_infer(x->ns, w->ns)); + if (mode == "lstm") + cy = create_output(nullptr, dtype_infer(x->ns, w->ns)); + else + cy = nullptr; + + if (is_train) + reservation = create_output(nullptr, ns_float32); + else + reservation = nullptr; + + seq_length = x->shape[0]; + batch_size = x->shape[1]; +} + +void CudnnRnnOp::infer_shape() { + ASSERTop(x->shape.size(),==,3); + ASSERTop(x->shape[2],==,input_size); + + int num_directions = 1 + bidirectional; + + y->set_shape(NanoVector(seq_length, batch_size, hidden_size * num_directions)); + + if (proj_size > 0) + hy->set_shape(NanoVector(num_layers * num_directions, batch_size, proj_size)); + else + hy->set_shape(NanoVector(num_layers * num_directions, batch_size, hidden_size)); + + if (cy) + cy->set_shape(NanoVector(num_layers * num_directions, batch_size, hidden_size)); + + if (reservation) { + #ifdef IS_CUDA + int in_dims[3] = {batch_size, input_size, 1}; + int in_strides[3] = {in_dims[1] * in_dims[2], in_dims[2], 1}; + + vector xDesc(seq_length); + RnnDescriptor rnn_desc(cudnn_handle, mode, hidden_size, num_layers, dropout, bidirectional); + for (int i = 0; i < seq_length; ++i) { + checkCudaErrors(cudnnCreateTensorDescriptor(&xDesc[i])); + checkCudaErrors(cudnnSetTensorNdDescriptor(xDesc[i], CUDNN_DATA_FLOAT, 3, in_dims, in_strides)); + } + reservation->set_shape(rnn_desc.reserve_space_size(xDesc.data(), seq_length)); + #endif + } +} + +void CudnnRnnOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype(); + jk << "«Ty:" << y->dtype(); + jk << "«Tw:" << w->dtype(); +} + +static auto make_backwardx_with_cx = get_op_info("cudnn_rnn_backward_x") + .get_constructor, Var*, Var*, Var*, Var*, Var*, Var*, Var*, Var*, Var*, string, int, int, int, int, double, bool, bool>(); +static auto make_backwardx_without_cx = get_op_info("cudnn_rnn_backward_x") + .get_constructor, Var*, Var*, Var*, Var*, Var*, Var*, Var*, string, int, int, int, int, double, bool, bool>(); +static auto make_number = get_op_info("number") + .get_constructor(); + +void CudnnRnnOp::grads(Var** dout, VarPtr* dins) { + VarPtr dy = dout[0]; + VarPtr dhy = dout[1]; + VarPtr dcy = cx ? dout[2] : nullptr; + if (!dy.ptr) dy = make_number(0.0, y); + if (!dhy.ptr) dhy = make_number(0.0, hy); + if (!dcy.ptr && cx) dcy = make_number(0.0, cy); + + + vector dInput; + if (cx) + dInput = make_backwardx_with_cx(x, hx, cx, y, dy, dhy, dcy, w, reservation, mode, input_size, hidden_size, num_layers, proj_size, dropout, bias, bidirectional); + else + dInput = make_backwardx_without_cx(x, hx, y, dy, dhy, w, reservation, mode, input_size, hidden_size, num_layers, proj_size, dropout, bias, bidirectional); + + for (int i = 0; i < 3 + (cx != nullptr); ++i) + dins[i] = move(dInput[i]); +} + +#else // JIT +#ifdef JIT_cuda + +#pragma clang diagnostic ignored "-Wtautological-compare" + +template __inline__ cudnnDataType_t getDataType(); +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_HALF; } +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_FLOAT; } + +void CudnnRnnOp::jit_run() { + int num_directions = bidirectional + 1; + int num_linear_layers = rnn_string_to_num_linear_layers(mode); + + int in_dims[3] = {batch_size, input_size, 1}; + int out_dims[3] = {batch_size, hidden_size * num_directions, 1}; + int in_strides[3] = {in_dims[1] * in_dims[2], in_dims[2], 1}; + int out_strides[3] = {out_dims[1] * out_dims[2], out_dims[2], 1}; + int hidden_dims[3] = {num_layers * num_directions, batch_size, hidden_size}; + int hidden_strides[3] = {hidden_dims[1] * hidden_dims[2], hidden_dims[2], 1}; + + vector xDesc(seq_length); + vector yDesc(seq_length); + cudnnTensorDescriptor_t hxDesc, cxDesc; + cudnnTensorDescriptor_t hyDesc, cyDesc; + + for (int i = 0; i < seq_length; ++i) { + checkCudaErrors(cudnnCreateTensorDescriptor(&xDesc[i])); + checkCudaErrors(cudnnCreateTensorDescriptor(&yDesc[i])); + checkCudaErrors(cudnnSetTensorNdDescriptor(xDesc[i], getDataType(), 3, in_dims, in_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(yDesc[i], getDataType(), 3, out_dims, out_strides)); + } + + checkCudaErrors(cudnnCreateTensorDescriptor(&hxDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&cxDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&hyDesc)); + checkCudaErrors(cudnnCreateTensorDescriptor(&cyDesc)); + + checkCudaErrors(cudnnSetTensorNdDescriptor(hxDesc, getDataType(), 3, hidden_dims, hidden_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(cxDesc, getDataType(), 3, hidden_dims, hidden_strides)); + + checkCudaErrors(cudnnSetTensorNdDescriptor(hyDesc, getDataType(), 3, hidden_dims, hidden_strides)); + checkCudaErrors(cudnnSetTensorNdDescriptor(cyDesc, getDataType(), 3, hidden_dims, hidden_strides)); + + RnnDescriptor rnn_desc(cudnn_handle, mode, hidden_size, num_layers, dropout, bidirectional); + + void *work_space; + size_t work_space_size = rnn_desc.work_space_size(xDesc.data(), seq_length); + size_t work_space_allocation; + if (work_space_size > 0) + work_space = exe.temp_allocator->alloc(work_space_size, work_space_allocation); + + RnnWeightDescriptor w_desc(w->size); + + if (is_train) { + checkCudaErrors(cudnnRNNForwardTraining( + cudnn_handle, rnn_desc.desc, + seq_length, + xDesc.data(), x->ptr(), + hxDesc, hx->ptr(), + cxDesc, mode == "lstm" ? cx->ptr() : nullptr, + w_desc.desc, w->ptr(), + yDesc.data(), y->ptr(), + hyDesc, hy->ptr(), + cyDesc, mode == "lstm" ? cy->ptr() : nullptr, + work_space, work_space_size, + reservation->ptr(), reservation->size + )); + } else { + checkCudaErrors(cudnnRNNForwardInference( + cudnn_handle, rnn_desc.desc, + seq_length, + xDesc.data(), x->ptr(), + hxDesc, hx->ptr(), + cxDesc, mode == "lstm" ? cx->ptr() : nullptr, + w_desc.desc, w->ptr(), + yDesc.data(), y->ptr(), + hyDesc, hy->ptr(), + cyDesc, mode == "lstm" ? cy->ptr() : nullptr, + work_space, work_space_size + )); + } + + for (int i = 0; i < seq_length; i++) { + checkCudaErrors(cudnnDestroyTensorDescriptor(xDesc[i])); + checkCudaErrors(cudnnDestroyTensorDescriptor(yDesc[i])); + } + + checkCudaErrors(cudnnDestroyTensorDescriptor(hxDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(cxDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(hyDesc)); + checkCudaErrors(cudnnDestroyTensorDescriptor(cyDesc)); + + if (work_space) + exe.temp_allocator->free(work_space, work_space_size, work_space_allocation); +} + +#endif +#endif // JIT + +} // jittor + diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.h b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.h new file mode 100644 index 00000000..5f87618d --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_rnn_op.h @@ -0,0 +1,36 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Zheng-Ning Liu +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CudnnRnnOp : Op { + Var* x, * hx, * cx, * y, * hy, * cy; + Var* w; + Var* reservation; + string mode; + int input_size, hidden_size, num_layers, proj_size; + int seq_length, batch_size; + float dropout; + bool bias, bidirectional, is_train; + + // @attrs(multiple_outputs) + CudnnRnnOp(Var* x, Var* hx, Var* cx, Var* w, string mode, int input_size, int hidden_size, int num_layers, int proj_size, double dropout, bool batch_first, bool bias, bool bidirectional); + // @attrs(multiple_outputs) + CudnnRnnOp(Var* x, Var* hx, Var* w, string mode, int input_size, int hidden_size, int num_layers, int proj_size, double dropout, bool batch_first, bool bias, bool bidirectional); + + void init_rnn(); + + const char* name() const override { return "cudnn_rnn"; } + void grads(Var** douts, VarPtr* dins) override; + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_test_op.cc b/python/jittor/extern/cuda/cudnn/ops/cudnn_test_op.cc new file mode 100644 index 00000000..1402a524 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_test_op.cc @@ -0,0 +1,40 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include + +#include "var.h" +#include "cudnn_test_op.h" +#include "utils/str_utils.h" + +int cudnn_test_entry( int argc, char** argv ); + +namespace jittor { + +#ifndef JIT +CudnnTestOp::CudnnTestOp(string cmd) : cmd(move(cmd)) { + output = create_output(1, ns_float32); +} + +void CudnnTestOp::jit_prepare(JK& jk) { + jk << "«T:float32"; +} + +#else // JIT +#ifdef JIT_cpu +void CudnnTestOp::jit_run() { + auto args = split(cmd, " "); + if (!cmd.size()) args.clear(); + vector v(args.size()); + for (uint i=0; iptr()[0] = 123; +} +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/ops/cudnn_test_op.h b/python/jittor/extern/cuda/cudnn/ops/cudnn_test_op.h new file mode 100644 index 00000000..59c6fcb0 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/ops/cudnn_test_op.h @@ -0,0 +1,21 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CudnnTestOp : Op { + Var* output; + string cmd; + CudnnTestOp(string cmd); + + const char* name() const override { return "cudnn_test"; } + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/cudnn/src/cudnn_conv_test.cc b/python/jittor/extern/cuda/cudnn/src/cudnn_conv_test.cc new file mode 100644 index 00000000..e85622e8 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/src/cudnn_conv_test.cc @@ -0,0 +1,1001 @@ +// Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// +// This example demonstrates how to use CUDNN library calls cudnnConvolutionForward, +// cudnnConvolutionBackwardData, and cudnnConvolutionBackwardFilter with the option +// to enable Tensor Cores on Volta with cudnnSetConvolutionMathType. +// +// 1. Make sure cuda and cudnn are installed in the same directory. +// +// 2. Run make from the directory of the sample specifying the cuda installation path: +// make CUDA_PATH= +// +// 3. Use the following arguments to run sample with different convolution parameters: +// -c2048 -h7 -w7 -k512 -r1 -s1 -pad_h0 -pad_w0 -u1 -v1 +// -c512 -h28 -w28 -k128 -r1 -s1 -pad_h0 -pad_w0 -u1 -v1 +// -c512 -h28 -w28 -k1024 -r1 -s1 -pad_h0 -pad_w0 -u2 -v2 +// -c512 -h28 -w28 -k256 -r1 -s1 -pad_h0 -pad_w0 -u2 -v2 +// -c256 -h14 -w14 -k256 -r3 -s3 -pad_h1 -pad_w1 -u1 -v1 +// -c256 -h14 -w14 -k1024 -r1 -s1 -pad_h0 -pad_w0 -u1 -v1 +// -c1024 -h14 -w14 -k256 -r1 -s1 -pad_h0 -pad_w0 -u1 -v1 +// -c1024 -h14 -w14 -k2048 -r1 -s1 -pad_h0 -pad_w0 -u2 -v2 +// -c1024 -h14 -w14 -k512 -r1 -s1 -pad_h0 -pad_w0 -u2 -v2 +// -c512 -h7 -w7 -k512 -r3 -s3 -pad_h1 -pad_w1 -u1 -v1 +// -c512 -h7 -w7 -k2048 -r1 -s1 -pad_h0 -pad_w0 -u1 -v1 +// -c2048 -h7 -w7 -k512 -r1 -s1 -pad_h0 -pad_w0 -u1 -v1 +// +// 4. Use the following additional arguments to run the layer with different setup: +// -mathType1 : enable Tensor Cores on Volta. +// -dgrad : run cudnnConvolutionBackwardData() instead of cudnnConvolutionForward(). +// -wgrad : run cudnnConvolutionBackwardFilter() instead of cudnnConvolutionForward(). +// -n : mini batch size. (use -b with large n) +// -b : benchmark mode. Bypass the CPU correctness check. +// -filterFormat1 : Use tensor format CUDNN_TENSOR_NHWC instead of CUDNN_TENSOR_NCHW. +// + +#include +#include +#include +#include +#include +#include + +#include +#include "utils/log.h" +#include "helper_cuda.h" +#include "fp16_dev.h" +#include "fp16_emu.h" + +#define SWITCH_CHAR '-' +#define THRESHOLD 2.0e-2 + +#if defined(__linux__) +#include +#include +#include +#include +static double second (void) +{ + struct timeval tv; + gettimeofday(&tv, NULL); + return (double)tv.tv_sec + (double)tv.tv_usec / 1000000.0; +} + +template __inline__ cudnnDataType_t getDataType(); +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_HALF; } +template <> __inline__ cudnnDataType_t getDataType() { return CUDNN_DATA_FLOAT; } + +//Generate uniform numbers [0,1) +static void initImage(float* image, int imageSize) { + static unsigned seed = 123456789; + for (int index = 0; index < imageSize; index++) { + seed = ( 1103515245 * seed + 12345 ) & 0xffffffff; + image[index] = float(seed)*2.3283064e-10; //2^-32 + } +} + +static void initImage(half1* image, int imageSize) { + static unsigned seed = 123456789; + for (int index = 0; index < imageSize; index++) { + seed = ( 1103515245 * seed + 12345 ) & 0xffffffff; + image[index] = cpu_float2half_rn(float(seed)*2.3283064e-10); //2^-32 + } +} + +static void printPerf( double cudaTime, double cudaGflops, double cudaBandwithGb, + const char *cpuLib, double cpuTime, double cpuGflops, double cpuBandwithGb) +{ + printf( "^^^^ CUDA : elapsed = %g sec, ", cudaTime ); + if (cudaGflops > 0) printf( "Gflops = %.3f ", cudaGflops ); + if (cudaBandwithGb > 0) printf( "Bandwidth = %.3f ", cudaBandwithGb ); + printf( "\n"); + if (cpuLib) { + printf( "^^^^%s : elapsed = %g sec, ", cpuLib, cpuTime ); + if (cpuGflops > 0) printf( "Gflops = %.3f ", cpuGflops ); + if (cpuBandwithGb > 0) printf( "Bandwidth = %.3f, ", cpuBandwithGb ); + printf( "Speedup %.2f\n", cpuTime/cudaTime ); + + } +} + +static void generateStrides(const int* dimA, int* strideA, int nbDims, bool isNchw) { + if (isNchw) { + strideA[nbDims-1] = 1 ; + for(int d = nbDims-2 ; d >= 0 ; d--) { + strideA[d] = strideA[d+1] * dimA[d+1] ; + } + } else { + strideA[1] = 1; + strideA[nbDims-1] = strideA[1]*dimA[1]; + for(int d = nbDims-2 ; d >= 2 ; d--) { + strideA[d] = strideA[d+1] * dimA[d+1] ; + } + strideA[0] = strideA[2]*dimA[2]; + } +} + +// Convert a linear index +// i = d_1 s_1 ... s_n + d_2 s_2 ... s_n + d_n-1 s_n + d_n +// into a multidimensional index +// (d_1, d_2, ..., d_n) +void lin2dim(int id, int* ids, const int* dims, int length) { + int idrem = id ; + int prod = 1 ; // accumulates the product of the dimensions + for(int i = length-1; i >= 0; i--) { + ids[i] = (idrem / prod) % dims[i] ; + idrem = id - ids[i] * prod ; + prod *= dims[i] ; + } +} + +// Convert a multidimensional index +// (d_1, d_2, ..., d_n) +// into a linear index +// i = d_1 s_1 + ... + d_n s_n +static int dim2lin(const int* ids, const int* strides, int length) { + int res = 0 ; + for(int i = 0 ; i < length ; i++) { + res += ids[i] * strides[i]; + } + return res ; +} + +static float doFma(float fval, float ival, float tmp) { + return fval*ival+tmp; +} + +static float doFma(half1 fval, half1 ival, float tmp) { + return cpu_half2float(fval)*cpu_half2float(ival)+tmp; +} + +static void doEpilog(float *out, int idx, float alphaAcc, float beta) { + if( beta == 0.f ) { + out[idx] = alphaAcc; + } else { + out[idx] = alphaAcc + out[idx]*beta; + } +} + +static void doEpilog(half1 *out, int idx, float alphaAcc, float beta) { + if( beta == 0.f ) { + out[idx] = cpu_float2half_rn(alphaAcc); + } else { + out[idx] = cpu_float2half_rn(alphaAcc + cpu_half2float(out[idx])*beta); + } +} + +template +static void conv_cpu_ref ( + const T_ELEM* inputData, + const T_ELEM* filterData, + T_ELEM* outputData, + float alpha, + float beta, + bool isNchw, + const int* inDims, + const int* filDims, + const int* outDims, + const int* inStride, + const int* outStride, + const int* stride, + const int* pad, + const int* dilation, + int nbDims +) { + int imDims = nbDims - 2 ; + + int filStride[8] = {0} ; + generateStrides(filDims, filStride, nbDims, isNchw); + + bool isConv = true; //(CUDNN_CONVOLUTION == mode) ; + // Number of pixels in output + int nPixelsOut = 1 ; + for(int i = 2 ; i < nbDims ; i++) + nPixelsOut *= outDims[i] ; + // Number of pixels in filter + int nPixelsFil = 1 ; + for(int i = 2 ; i < nbDims ; i++) + nPixelsFil *= filDims[i] ; + // Used to store coordinates + int filIds[8] = {0} ; + int outIds[8] = {0} ; + int inIds [8] = {0} ; + int tmpIds[8] = {0} ; + // For each image in the output + for(int ni = 0 ; ni < outDims[0] ; ni++) { + // For each feature layer of the output + for(int ki = 0 ; ki < outDims[1] ; ki++) { + int outputOffset = ni * outStride[0] + ki * outStride[1] ; + // Loop over all entries of the result + for(int outId = 0 ; outId < nPixelsOut ; outId++) { + // Get output pixel ids + lin2dim(outId, outIds, outDims+2, imDims) ; // Skip n and k dimensions + // Now we get the coordinates in input space of the "top left" corner of the filter: multiply by stride and remove pad + for(int d = 0 ; d < imDims ; d++) { + inIds[d] = outIds[d] * stride[d] - pad[d] ; + } + // We then accumulate + float tmp = 0.f; + for(int ci = 0 ; ci < inDims[1] ; ci++) { + int inputOffset = ni * inStride[0] + ci * inStride[1] ; + int filterOffset = ki * filStride[0] + ci * filStride[1] ; + for(int filId = 0 ; filId < nPixelsFil ; filId ++) { + // Get the position of the pixel + lin2dim(filId, filIds, filDims+2, imDims) ; + // Compute the corresponding output pixel + // and check wether we are in the padding area on the fly too (not that for convolution, we flip the image patch (equivalent to flipping the filter patch)) + bool inside = true ; + for(int d = 0 ; d < imDims && inside ; d++) { + if (isConv) { + tmpIds[d] = inIds[d] + dilation[d] * (filDims[2+d]-1 - filIds[d]) ; + } else { + tmpIds[d] = inIds[d] + dilation[d] * filIds[d] ; + } + inside &= (tmpIds[d] >= 0 && tmpIds[d] < inDims[2+d]) ; // If we are in the padding area: stop and skip computations + } + if(inside) { + int actualTmpId = inputOffset + dim2lin(tmpIds, (inStride)+2, imDims) ; + //int actualFilId = filterOffset + filId ; + int actualFilId = filterOffset + dim2lin(filIds, (filStride)+2, imDims) ; + T_ELEM fval = filterData[actualFilId] ; + T_ELEM ival = inputData [actualTmpId] ; + tmp = doFma(fval, ival, tmp); + } + } + } + + // We put the result in the output + int actualOutId = outputOffset + dim2lin(outIds, (outStride)+2, imDims) ; + doEpilog(outputData, actualOutId, alpha*tmp, beta); + } + } + } +} + +template +static void dataGrad_cpu_ref ( + const T_ELEM *weight, + const T_ELEM *top_diff, + T_ELEM *output, + float alpha, + float beta, + bool isNchw, + const int* inDims, + const int* filDims, + const int* outDims, + const int* inStride, + const int* outStride, + const int* stride, + const int* pad, + const int* dilation, + int nbDims ) +{ + + // Sanity checks + // output is n x c x h x w + // diff is n x k x p x q + // filter is k x c x r x s + assert(inDims[0] == outDims[0]); // n + assert(inDims[1] == filDims[0]); // k + assert(outDims[1] == filDims[1]); // c + + int filStride[8] = {0} ; + generateStrides(filDims, filStride, nbDims, isNchw); + + bool isConv = true; //(CUDNN_CONVOLUTION == mode) ; + + // For every output pixel (n x c x h x w) + for(int ni = 0; ni < outDims[0]; ni++) { + for(int ci = 0; ci < outDims[1]; ci++) { + for(int hi = 0; hi < outDims[2]; hi++) { + for(int wi = 0; wi < outDims[3]; wi++) { + int outIdx = ni * outStride[0] + + ci * outStride[1] + + hi * outStride[2] + + wi * outStride[3]; + float val = 0.0; + + // For every diff channel (k) + for(int ki = 0; ki < inDims[1]; ki++) { // Sum over k channels + int offset_filter = ki * filStride[0] + ci * filStride[1]; + int offset_diff = ni * inStride[0] + ki * inStride[1]; + // For every pixel if filter (r x s) + for(int ri = 0; ri < filDims[2]; ri++) { + int p = hi + pad[0]; + if (isConv){ + p -= (filDims[2] - 1 - ri) * dilation[0]; + } else { + p -= ri * dilation[0]; + } + if ( p%stride[0] ) + continue; + p/=stride[0]; + + for(int si = 0; si < filDims[3]; si++) { + int q = wi + pad[1]; + // Fetch the value in filter and diff, product and accumulate + // So basically, for the convolution, we replace r by dim-1-r and s by dim-1-s to "flip" the filter + // We can then just reason in term of correlation + if (isConv){ + q -= (filDims[3] - 1 - si) * dilation[1]; + } else { + q -= si * dilation[1]; + } + //Skip if q or p isn't multiple of strides + if ( q%stride[1] ) + continue; + q/=stride[1]; + int inBounds = ( (p >= 0) && (p < inDims[2]) && (q >= 0) && (q < inDims[3]) ); + if (inBounds) { + int filterIdx = offset_filter + ri * filStride[2] + si * filStride[3]; + int diffIdx = offset_diff + p * inStride[2] + q * inStride[3]; + T_ELEM imTmp = top_diff[diffIdx]; + T_ELEM filTmp = weight[filterIdx]; + val = doFma(filTmp, imTmp, val); + } + } + } + } + doEpilog(output, outIdx, alpha*val, beta); + } + } + } + } +} + +template +static void weightGrad_cpu_ref(/*const TensorNdTestDesc_t *tensorInputDesc,*/ + const T_ELEM *image, + /*const TensorNdTestDesc_t *tensorDiffDesc,*/ + const T_ELEM *diffData, + /*const ConvNdTestDesc_t *convDesc,*/ + /*const TensorNdTestDesc_t *filterOutputDesc,*/ + float alpha, + float beta, + T_ELEM *output, + bool isNchw, + const int* inDims, + const int* filDims, + const int* diffDims, + const int* inStride, + const int* diffStride, + const int* stride, + const int* pad, + const int* dilation, + int nbDims ) +{ + // Some sanity checks + // image is n x c x h x w + // diff is n x k x p x q + // filter is k x c x r x s + assert(inDims[0] == diffDims[0]) ; + assert(inDims[1] == filDims[1]) ; + assert(diffDims[1] == filDims[0]) ; + + // Filter stride + int filterStride[4] ; + generateStrides(filDims, filterStride, nbDims, isNchw); + + bool isConv = true; //(CUDNN_CONVOLUTION == mode) ; + + // For every filter pixel (k x c x r x s) + for(int ci = 0; ci < inDims[1]; ci++) { // Loop over filter output pixels + for(int ri = 0; ri < filDims[2]; ri++) { // ^ + for(int si = 0; si < filDims[3]; si++) { // ^ + for(int ki = 0; ki < filDims[0]; ki++){ // ^ + int filIdx = ki * filterStride[0] + ci * filterStride[1] + ri * filterStride[2] + si * filterStride[3] ; + float val = 0.f ; + // For every image (n) + for(int ni = 0 ; ni < inDims[0]; ni++) { // Sum over the batch + int offset_image = ni * inStride[0] + ci * inStride[1] ; + int offset_diff = ni * diffStride[0] + ki * diffStride[1] ; + // For every pixel in diff (p x q) + for(int pi = 0; pi < diffDims[2] ; pi++ ) { // Sum over the pixels of diff + for(int qi = 0; qi < diffDims[3] ; qi++ ) { // ^ + // Fetch the value in image and diff, product and accumulate + int y = pi * stride[0] - pad[0] ; + int x = qi * stride[1] - pad[1] ; + // Convolution = Correlation with a flipped filter + // So basically, for the convolution, we replace r by dim-1-r and s by dim-1-s to "flip" the filter + // We can then just reason in term of correlation + if (isConv){ + y += (filDims[2] - 1 - ri) * dilation[0] ; + x += (filDims[3] - 1 - si) * dilation[1] ; + } else { + // The effect of dilation on the gradient is to start the "zone of influence" of a given pixel further into the image, so dilation + // only produces a shift in x and y + y += ri * dilation[0] ; + x += si * dilation[1] ; + } + // Image value + int inBounds = ((x >=0)&&(x < inDims[3])&&(y >=0)&&(y < inDims[2])); + if (inBounds) { + int imIdx = offset_image + y * inStride[2] + x * inStride[3] ; + // Diff value + int diffIdx = offset_diff + pi * diffStride[2] + qi * diffStride[3] ; + // Prod and accumulate + T_ELEM imTmp = image[imIdx] ; + T_ELEM diffTmp = diffData[diffIdx]; + val = doFma(diffTmp, imTmp, val); + } + } + } + } + doEpilog(output, filIdx, alpha*val, beta); + } + } + } + } +} + + +float getError(float dev, float ref) { + if (ref > 1.0 || ref < -1.0) + return (dev - ref)/ref; + else + return dev - ref; +} + +float getError(half1 dev, half1 ref) { + if (cpu_half2float(ref) > 1.0 || cpu_half2float(ref) < -1.0) + return (cpu_half2float(dev) - cpu_half2float(ref))/cpu_half2float(ref); + else + return cpu_half2float(dev) - cpu_half2float(ref); +} + +static inline int getFwdConvDilatedFilterDim(int filterDim, + int dilation) +{ + return ( (filterDim - 1) * dilation ) + 1 ; +} + +static inline int getFwdConvPaddedImageDim(int tensorDim, + int pad) +{ + return tensorDim + (2 * pad) ; +} + +static inline int getFwdConvOutputDim( int tensorDim, + int pad, + int filterDim, + int stride, + int dilation) +{ + int p = (getFwdConvPaddedImageDim(tensorDim, pad) - getFwdConvDilatedFilterDim(filterDim, dilation))/stride + 1; + return(p); +} + +template +int doConv( + cudnnHandle_t handle_, + T_ELEM* devPtrI, + T_ELEM* devPtrF, + T_ELEM* devPtrO, + T_ELEM* hostI, + T_ELEM* hostF, + T_ELEM* hostO, + cudnnTensorDescriptor_t cudnnIdesc, + cudnnFilterDescriptor_t cudnnFdesc, + cudnnTensorDescriptor_t cudnnOdesc, + cudnnConvolutionDescriptor_t cudnnConvDesc, + float alpha, + float beta, + cudnnTensorFormat_t filterFormat, + const int* dimA, + const int* filterdimA, + const int* outdimA, + const int* strideA, + const int* outstrideA, + const int* convstrideA, + const int* padA, + const int* dilationA, + const int benchmark) { + + int outsize = outstrideA[0]*outdimA[0]; + T_ELEM* hostOfromdev = (T_ELEM*)calloc (outsize, sizeof(hostO[0]) ); + + cudnnConvolutionFwdAlgo_t algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + + void *workSpace = 0; + size_t workSpaceSize; + + checkCudaErrors ( cudnnGetConvolutionForwardWorkspaceSize(handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc, + cudnnOdesc, algo, &workSpaceSize) ); + + if (workSpaceSize > 0) { + cudaMalloc(&workSpace, workSpaceSize); + } + double start = second(); + checkCudaErrors ( cudnnConvolutionForward (handle_, + (void*)(&alpha), + cudnnIdesc, devPtrI, + cudnnFdesc, devPtrF, + cudnnConvDesc, + algo, + workSpace, workSpaceSize, + (void*)(&beta), + cudnnOdesc, devPtrO) ); + checkCudaErrors( cudaDeviceSynchronize() ); + double stop = second(); + printPerf( stop - start, 0, 0, + 0, 0, 0, 0); + checkCudaErrors( cudaMemcpy(hostOfromdev, devPtrO, sizeof(hostO[0]) * outsize, cudaMemcpyDeviceToHost) ); + checkCudaErrors( cudaDeviceSynchronize() ); + if (workSpace) { + cudaFree(workSpace); + workSpace = 0; + } + int numErrors = 0; + if (!benchmark) { + conv_cpu_ref( hostI, hostF, hostO, alpha, beta, (filterFormat == CUDNN_TENSOR_NCHW), dimA, filterdimA, outdimA, strideA, outstrideA, convstrideA, padA, dilationA, 4); + for (int index = 0; index < outsize; index++) { // assuming out data is packed + float diff = getError(hostOfromdev[index], hostO[index]); + if (diff < 0) diff = -diff; + if(diff > THRESHOLD) { + numErrors++; + } + } + } + return numErrors; +} + +template +int doDgrad( + cudnnHandle_t handle_, + T_ELEM* devPtrI, + T_ELEM* devPtrF, + T_ELEM* devPtrO, + T_ELEM* hostI, + T_ELEM* hostF, + T_ELEM* hostO, + cudnnTensorDescriptor_t cudnnIdesc, + cudnnFilterDescriptor_t cudnnFdesc, + cudnnTensorDescriptor_t cudnnOdesc, + cudnnConvolutionDescriptor_t cudnnConvDesc, + float alpha, + float beta, + cudnnTensorFormat_t filterFormat, + const int* dimA, + const int* filterdimA, + const int* outdimA, + const int* strideA, + const int* outstrideA, + const int* convstrideA, + const int* padA, + const int* dilationA, + const int benchmark) { + + int insize = strideA[0]*dimA[0]; + T_ELEM* hostIfromdev = (T_ELEM*)calloc (insize, sizeof(hostI[0]) ); + cudnnConvolutionBwdDataAlgo_t algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + + void *workSpace = 0; + size_t workSpaceSize; + + checkCudaErrors ( cudnnGetConvolutionBackwardDataWorkspaceSize(handle_, cudnnFdesc, cudnnOdesc, cudnnConvDesc, + cudnnIdesc, algo, &workSpaceSize) ); + + if (workSpaceSize > 0) { + cudaMalloc(&workSpace, workSpaceSize); + } + double start = second(); + checkCudaErrors ( cudnnConvolutionBackwardData (handle_, + (void*)(&alpha), + cudnnFdesc, devPtrF, + cudnnOdesc, devPtrO, + cudnnConvDesc, + algo, + workSpace, workSpaceSize, + (void*)(&beta), + cudnnIdesc, devPtrI) ); + checkCudaErrors( cudaDeviceSynchronize() ); + double stop = second(); + printPerf( stop - start, 0, 0, + 0, 0, 0, 0); + checkCudaErrors( cudaMemcpy(hostIfromdev, devPtrI, sizeof(hostI[0]) * insize, cudaMemcpyDeviceToHost) ); + checkCudaErrors( cudaDeviceSynchronize() ); + if (workSpace) { + cudaFree(workSpace); + workSpace = 0; + } + int numErrors = 0; + if (!benchmark) { + dataGrad_cpu_ref(hostF, hostO, hostI, alpha, beta, (filterFormat == CUDNN_TENSOR_NCHW), outdimA, filterdimA, dimA, outstrideA, strideA, convstrideA, padA, dilationA, 4); + for (int index = 0; index < insize; index++) { // assuming in data is packed + float diff = getError(hostIfromdev[index], hostI[index]); + if (diff < 0) diff = -diff; + if(diff > THRESHOLD) { + numErrors++; + } + } + } + return numErrors; +} + +template +int doWgrad( + cudnnHandle_t handle_, + T_ELEM* devPtrI, + T_ELEM* devPtrF, + T_ELEM* devPtrO, + T_ELEM* hostI, + T_ELEM* hostF, + T_ELEM* hostO, + cudnnTensorDescriptor_t cudnnIdesc, + cudnnFilterDescriptor_t cudnnFdesc, + cudnnTensorDescriptor_t cudnnOdesc, + cudnnConvolutionDescriptor_t cudnnConvDesc, + float alpha, + float beta, + cudnnTensorFormat_t filterFormat, + const int* dimA, + const int* filterdimA, + const int* outdimA, + const int* strideA, + const int* outstrideA, + const int* convstrideA, + const int* padA, + const int* dilationA, + const int benchmark) { + + int filsize = filterdimA[0]*filterdimA[1]*filterdimA[2]*filterdimA[3]; + T_ELEM* hostFfromdev = (T_ELEM*)calloc (filsize, sizeof(hostF[0]) ); + cudnnConvolutionBwdFilterAlgo_t algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; + + void *workSpace = 0; + size_t workSpaceSize; + + checkCudaErrors ( cudnnGetConvolutionBackwardFilterWorkspaceSize(handle_, cudnnIdesc, cudnnOdesc, cudnnConvDesc, + cudnnFdesc, algo, &workSpaceSize) ); + + if (workSpaceSize > 0) { + cudaMalloc(&workSpace, workSpaceSize); + } + double start = second(); + checkCudaErrors ( cudnnConvolutionBackwardFilter (handle_, + (void*)(&alpha), + cudnnIdesc, devPtrI, + cudnnOdesc, devPtrO, + cudnnConvDesc, + algo, + workSpace, workSpaceSize, + (void*)(&beta), + cudnnFdesc, devPtrF) ); + checkCudaErrors( cudaDeviceSynchronize() ); + double stop = second(); + printPerf( stop - start, 0, 0, + 0, 0, 0, 0); + checkCudaErrors( cudaMemcpy(hostFfromdev, devPtrF, sizeof(hostF[0]) * filsize, cudaMemcpyDeviceToHost) ); + checkCudaErrors( cudaDeviceSynchronize() ); + if (workSpace) { + cudaFree(workSpace); + workSpace = 0; + } + int numErrors = 0; + if (!benchmark) { + weightGrad_cpu_ref(hostI, hostO, alpha, beta, hostF, (filterFormat == CUDNN_TENSOR_NCHW), dimA, filterdimA, outdimA, strideA, outstrideA, convstrideA, padA, dilationA, 4); + for (int index = 0; index < filsize; index++) { // assuming in data is packed + float diff = getError(hostFfromdev[index], hostF[index]); + if (diff < 0) diff = -diff; + if(diff > THRESHOLD) { + numErrors++; + } + } + } + return numErrors; +} + +template +int doTest(int algo, int* dimA, int* padA, int* convstrideA, int* filterdimA, cudnnTensorFormat_t filterFormat, int mathType, int benchmark) { + + cudnnHandle_t handle_; + T_ELEM* devPtrI; + T_ELEM* devPtrF; + T_ELEM* devPtrO; + T_ELEM* hostI; + T_ELEM* hostF; + T_ELEM* hostO; + + cudnnTensorDescriptor_t cudnnIdesc; + cudnnFilterDescriptor_t cudnnFdesc; + cudnnTensorDescriptor_t cudnnOdesc; + cudnnConvolutionDescriptor_t cudnnConvDesc; + + int convDim = 2; + + float alpha = 1.0f; + float beta = 0.0; + + checkCudaErrors(cudnnCreate(&handle_)); + + checkCudaErrors( cudnnCreateTensorDescriptor( &cudnnIdesc )); + checkCudaErrors( cudnnCreateFilterDescriptor( &cudnnFdesc )); + checkCudaErrors( cudnnCreateTensorDescriptor( &cudnnOdesc )); + checkCudaErrors( cudnnCreateConvolutionDescriptor( &cudnnConvDesc )); + + int dilationA[] = {1, 1}; + + int strideA[] = {8192, 1024, 32, 1}; + generateStrides(dimA, strideA, 4, (filterFormat == CUDNN_TENSOR_NCHW)); + int insize = strideA[0]*dimA[0]; + + int filtersize = filterdimA[0]*filterdimA[1]*filterdimA[2]*filterdimA[3]; + + int outdimA[] = {1, 8, 30, 30}; + outdimA[0] = dimA[0]; + outdimA[1] = filterdimA[0]; + for( int dim = 0; dim < 2; dim++) { + outdimA[dim+2] = getFwdConvOutputDim( dimA[dim+2], + padA[dim], + filterdimA[dim+2], + convstrideA[dim], + dilationA[dim]); + } + + int outstrideA[] = {7200, 900, 30, 1}; + generateStrides(outdimA, outstrideA, 4, (filterFormat == CUDNN_TENSOR_NCHW)); + int outsize = outstrideA[0]*outdimA[0]; + + cudaMalloc ((void**)&(devPtrI), (insize) * sizeof(devPtrI[0]) ); + cudaMalloc ((void**)&(devPtrF), (filtersize) * sizeof(devPtrF[0]) ); + cudaMalloc ((void**)&(devPtrO), (outsize) * sizeof(devPtrO[0]) ); + hostI = (T_ELEM*)calloc (insize, sizeof(hostI[0]) ); + hostF = (T_ELEM*)calloc (filtersize, sizeof(hostF[0]) ); + hostO = (T_ELEM*)calloc (outsize, sizeof(hostO[0]) ); + + initImage(hostI, insize); + initImage(hostF, filtersize); + initImage(hostO, outsize); + + checkCudaErrors( cudaMemcpy(devPtrI, hostI, sizeof(hostI[0]) * insize, cudaMemcpyHostToDevice)); + checkCudaErrors( cudaMemcpy(devPtrF, hostF, sizeof(hostF[0]) * filtersize, cudaMemcpyHostToDevice)); + checkCudaErrors( cudaMemcpy(devPtrO, hostO, sizeof(hostO[0]) * outsize, cudaMemcpyHostToDevice)); + checkCudaErrors( cudaDeviceSynchronize() ); + + checkCudaErrors( cudnnSetTensorNdDescriptor(cudnnIdesc, getDataType(), convDim+2, dimA, strideA) ); + + checkCudaErrors( cudnnSetFilterNdDescriptor(cudnnFdesc, getDataType(), filterFormat, convDim+2, filterdimA)); + + checkCudaErrors( cudnnSetConvolutionNdDescriptor(cudnnConvDesc, + convDim, + padA, + convstrideA, + dilationA, + CUDNN_CONVOLUTION, + CUDNN_DATA_FLOAT) ); + if (mathType == 1) { + checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) ); + } + + checkCudaErrors( cudnnSetTensorNdDescriptor(cudnnOdesc, getDataType(), convDim+2, outdimA, outstrideA) ); + + int numErrors = 0; + if (algo == 0) { + printf("Testing conv\n"); + numErrors = doConv( + handle_, + devPtrI, + devPtrF, + devPtrO, + hostI, + hostF, + hostO, + cudnnIdesc, + cudnnFdesc, + cudnnOdesc, + cudnnConvDesc, + alpha, + beta, + filterFormat, + dimA, + filterdimA, + outdimA, + strideA, + outstrideA, + convstrideA, + padA, + dilationA, + benchmark); + } else if (algo == 1) { + printf("Testing dgrad\n"); + numErrors = doDgrad( + handle_, + devPtrI, + devPtrF, + devPtrO, + hostI, + hostF, + hostO, + cudnnIdesc, + cudnnFdesc, + cudnnOdesc, + cudnnConvDesc, + alpha, + beta, + filterFormat, + dimA, + filterdimA, + outdimA, + strideA, + outstrideA, + convstrideA, + padA, + dilationA, + benchmark); + } else { + printf("Testing wgrad\n"); + numErrors = doWgrad( + handle_, + devPtrI, + devPtrF, + devPtrO, + hostI, + hostF, + hostO, + cudnnIdesc, + cudnnFdesc, + cudnnOdesc, + cudnnConvDesc, + alpha, + beta, + filterFormat, + dimA, + filterdimA, + outdimA, + strideA, + outstrideA, + convstrideA, + padA, + dilationA, + benchmark); + } + + if (!benchmark) { + if (numErrors == 0) { + printf("Test PASSED\n"); + } else { + printf("Test FAILED, num errors = %d\n", numErrors); + } + } + + if (devPtrI) cudaFree (devPtrI); + if (devPtrF) cudaFree (devPtrF); + if (devPtrO) cudaFree (devPtrO); + if (cudnnIdesc) cudnnDestroyTensorDescriptor(cudnnIdesc); + if (cudnnFdesc) cudnnDestroyFilterDescriptor(cudnnFdesc); + if (cudnnOdesc) cudnnDestroyTensorDescriptor(cudnnOdesc); + if (cudnnConvDesc) cudnnDestroyConvolutionDescriptor(cudnnConvDesc); + + return 0; +} + +int cudnn_test_entry( int argc, char** argv ) +{ + int algo = 0; + int mathType = 0; + int benchmark = 0; + + int dimA[] = {1, 8, 32, 32}; + + int padA[] = {0, 0}; + int convstrideA[] = {1, 1}; + + int filterdimA[] = {8, 8, 3, 3}; + + cudnnTensorFormat_t filterFormat = CUDNN_TENSOR_NCHW; + + int error = 0; + while (argc) { + if (*argv[0] == SWITCH_CHAR) { + switch (*(argv[0]+1)) { + case 'b': + benchmark = 1; + break; + case 'c': + dimA[1] = atol(argv[0]+2); + filterdimA[1] = dimA[1]; + break; + case 'd': + if ( strncmp( argv[0]+1, "dgrad" , strlen("dgrad")) == 0) { + algo = 1; + } + break; + case 'f': + if ( strncmp( argv[0]+1, "filterFormat" , strlen("filterFormat")) == 0) { + filterFormat = (cudnnTensorFormat_t)(atoi(argv[0]+ 1 + strlen("filterFormat"))); + } + break; + case 'h': + dimA[2] = atol(argv[0]+2); + break; + case 'k': + filterdimA[0] = atol(argv[0]+2); + break; + case 'm': + if ( strncmp( argv[0]+1, "mathType1" , strlen("mathType1")) == 0) { + mathType = 1; + } + break; + case 'n': + dimA[0] = atol(argv[0]+2); + break; + case 'p': + if ( strncmp( argv[0]+1, "pad_h" , strlen("pad_h")) == 0) { + padA[0] = (int)atol(argv[0]+ 1 + strlen("pad_h")); + } + else if ( strncmp( argv[0]+1, "pad_w" , strlen("pad_w")) == 0) { + padA[1] = (int)atol(argv[0]+ 1 + strlen("pad_w")); + } + break; + case 'r': + filterdimA[2] = atol(argv[0]+2); + break; + case 's': + filterdimA[3] = atol(argv[0]+2); + break; + case 'u': + convstrideA[0] = atol(argv[0]+2); + break; + case 'v': + convstrideA[1] = atol(argv[0]+2); + break; + case 'w': + if ( strncmp( argv[0]+1, "wgrad" , strlen("wgrad")) == 0) { + algo = 2; + } + else dimA[3] = atol(argv[0]+2); + break; + default: + error++; + break; + } + if (error) { + fprintf(stderr, "Unknown switch '%c%s'\n\n", SWITCH_CHAR, argv[0]+1); + return error; + } + } + else { + fprintf(stderr, "Invalid separator '%c' for option '%s'\n\n", *argv[0], argv[0] ); + return 1; + } + argc -= 1; + argv++; + } + + printf("Testing single precision\n"); + doTest(algo, dimA, padA, convstrideA, filterdimA, filterFormat, mathType, benchmark); + printf("Testing half precision (math in single precision)\n"); + doTest(algo, dimA, padA, convstrideA, filterdimA, filterFormat, mathType, benchmark); + + return 0; +} + +#else +int cudnn_test_entry( int argc, char** argv ) { + return 0; +} +#endif \ No newline at end of file diff --git a/python/jittor/extern/cuda/cudnn/src/cudnn_rnn_descriptor.cc b/python/jittor/extern/cuda/cudnn/src/cudnn_rnn_descriptor.cc new file mode 100644 index 00000000..c49eb319 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/src/cudnn_rnn_descriptor.cc @@ -0,0 +1,74 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Zheng-Ning Liu +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "cudnn_rnn_descriptor.h" + +namespace jittor { + +vector cudnn_rnn_weight_offset(string mode, int input_size, int hidden_size, int num_layers, int proj_size, bool bias, bool bidirectional) { + // A pseudo mini-batch for fetching weight space size. + int dimX[] = {1, input_size, 1}; + int strideX[] = {input_size, 1, 1}; + cudnnTensorDescriptor_t xDesc; + checkCudaErrors(cudnnCreateTensorDescriptor(&xDesc)); + checkCudaErrors(cudnnSetTensorNdDescriptor(xDesc, CUDNN_DATA_FLOAT, 3, dimX, strideX)); + + RnnDescriptor rnn_desc = RnnDescriptor(cudnn_handle, mode, hidden_size, num_layers, 0, bidirectional); + int weightSpaceSize = rnn_desc.weight_space_size(xDesc); + RnnWeightDescriptor w_desc(weightSpaceSize); + + vector weight_offsets; + weight_offsets.push_back(weightSpaceSize / sizeof(float)); + + int num_directions = bidirectional + 1; + int num_linear_layers = rnn_string_to_num_linear_layers(mode); + + for (int layer = 0; layer < num_layers * num_directions; layer++) { + for (int linLayerID = 0; linLayerID < num_linear_layers; linLayerID++) { + cudnnFilterDescriptor_t linLayerMatDesc; + cudnnFilterDescriptor_t linLayerBiasDesc; + float *linLayerMat = nullptr; + float *linLayerBias = nullptr; + + checkCudaErrors(cudnnCreateFilterDescriptor(&linLayerMatDesc)); + checkCudaErrors(cudnnCreateFilterDescriptor(&linLayerBiasDesc)); + + checkCudaErrors(cudnnGetRNNLinLayerMatrixParams( + cudnn_handle, rnn_desc.desc, + layer, + xDesc, + w_desc.desc, + nullptr, + linLayerID, + linLayerMatDesc, + (void **) &linLayerMat + )); + weight_offsets.push_back(linLayerMat - (float *) nullptr); + + if (bias) { + checkCudaErrors(cudnnGetRNNLinLayerBiasParams( + cudnn_handle, rnn_desc.desc, + layer, + xDesc, + w_desc.desc, + nullptr, + linLayerID, + linLayerBiasDesc, + (void **) &linLayerBias + )); + weight_offsets.push_back(linLayerBias - (float *) nullptr); + } + } + } + + checkCudaErrors(cudnnDestroyTensorDescriptor(xDesc)); + + return weight_offsets; +} + + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/src/cudnn_wrapper.cc b/python/jittor/extern/cuda/cudnn/src/cudnn_wrapper.cc new file mode 100644 index 00000000..2f9b27ea --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/src/cudnn_wrapper.cc @@ -0,0 +1,40 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "cudnn_wrapper.h" +#include "misc/cuda_flags.h" + +namespace jittor { + +cudnnHandle_t cudnn_handle; +int max_cache_size = 100; +float max_workspace_ratio = 0.25; + +void set_algorithm_cache_size(int size) { + max_cache_size = size; +} + +void set_max_workspace_ratio(float64 ratio) { + max_workspace_ratio = ratio; +} + +struct cudnn_initer { + +inline cudnn_initer() { + if (!get_device_count()) return; + checkCudaErrors(cudnnCreate(&cudnn_handle)); + LOGv << "cudnnCreate finished"; +} + +inline ~cudnn_initer() { + if (!get_device_count()) return; + checkCudaErrors(cudnnDestroy(cudnn_handle)); + LOGv << "cudnnDestroy finished"; +} + +} init; + +} // jittor diff --git a/python/jittor/extern/cuda/cudnn/src/helper_cudnn.cc b/python/jittor/extern/cuda/cudnn/src/helper_cudnn.cc new file mode 100644 index 00000000..6b3fa2e6 --- /dev/null +++ b/python/jittor/extern/cuda/cudnn/src/helper_cudnn.cc @@ -0,0 +1,7 @@ +#include +#include "utils/log.h" +#include "helper_cuda.h" + +const char *_cudaGetErrorEnum(cudnnStatus_t error) { + return cudnnGetErrorString(error); +} \ No newline at end of file diff --git a/python/jittor/extern/cuda/cufft/inc/cufft_utils.h b/python/jittor/extern/cuda/cufft/inc/cufft_utils.h new file mode 100644 index 00000000..c9b6f18f --- /dev/null +++ b/python/jittor/extern/cuda/cufft/inc/cufft_utils.h @@ -0,0 +1,102 @@ +/* + * Copyright 2020 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * This source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * These Licensed Deliverables contained herein is PROPRIETARY and + * CONFIDENTIAL to NVIDIA and is being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + +#pragma once + +// CUDA API error checking +#ifndef CUDA_RT_CALL +#define CUDA_RT_CALL( call ) \ + { \ + auto status = static_cast( call ); \ + if ( status != cudaSuccess ) \ + fprintf( stderr, \ + "ERROR: CUDA RT call \"%s\" in line %d of file %s failed " \ + "with " \ + "%s (%d).\n", \ + #call, \ + __LINE__, \ + __FILE__, \ + cudaGetErrorString( status ), \ + status ); \ + } +#endif // CUDA_RT_CALL + +// cufft API error chekcing +#ifndef CUFFT_CALL +#define CUFFT_CALL( call ) \ + { \ + auto status = static_cast( call ); \ + if ( status != CUFFT_SUCCESS ) \ + fprintf( stderr, \ + "ERROR: CUFFT call \"%s\" in line %d of file %s failed " \ + "with " \ + "code (%d).\n", \ + #call, \ + __LINE__, \ + __FILE__, \ + status ); \ + } +#endif // CUFFT_CALL + +// template <> struct traits { +// // scalar type +// typedef float T; + +// using input_host_type = std::complex; +// using input_device_type = cufftComplex; + +// using output_host_type = std::complex; +// using output_device_type = cufftComplex; + +// static constexpr cufftType_t transformType = CUDA_R_64F; + +// template inline static T rand(RNG &gen) { +// return make_cuFloatComplex((S)gen(), (S)gen()); +// } +// }; \ No newline at end of file diff --git a/python/jittor/extern/cuda/cufft/inc/cufft_wrapper.h b/python/jittor/extern/cuda/cufft/inc/cufft_wrapper.h new file mode 100644 index 00000000..afe17689 --- /dev/null +++ b/python/jittor/extern/cuda/cufft/inc/cufft_wrapper.h @@ -0,0 +1,24 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com>. +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include +#include "cufft_utils.h" + +#include "utils/log.h" +#include "helper_cuda.h" +#include "fp16_emu.h" +#include "common.h" + +namespace jittor { + +EXTERN_LIB unordered_map cufft_handle_cache; + +} // jittor diff --git a/python/jittor/extern/cuda/cufft/ops/cufft_fft_op.cc b/python/jittor/extern/cuda/cufft/ops/cufft_fft_op.cc new file mode 100644 index 00000000..c8577daa --- /dev/null +++ b/python/jittor/extern/cuda/cufft/ops/cufft_fft_op.cc @@ -0,0 +1,101 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com>. +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** + +#include "var.h" +#include "init.h" +#include +#include +#include "helper_cuda.h" +#include "cufft_fft_op.h" +#include "cufft_wrapper.h" + +#include +#include +#include +#include +#include +#include "cufft_utils.h" +#include "ops/op_register.h" + + +namespace jittor { + +#ifndef JIT +static auto make_cufft_fft = get_op_info("cufft_fft") + .get_constructor(); +CufftFftOp::CufftFftOp(Var* x, bool inverse) : x(x), inverse(inverse) { + flags.set(NodeFlags::_cuda, 1); + y = create_output(x->shape, x->dtype()); +} + +VarPtr CufftFftOp::grad(Var* out, Var* dout, Var* v, int v_index) { + return make_cufft_fft(dout, !inverse); +} + +void CufftFftOp::jit_prepare(JK& jk) { + if ((y->dtype() != "float32") && (y->dtype() != "float64")){ + printf("not supported fft dtype: %s\n", y->dtype().to_cstring()); + ASSERT(false); + } + jk << "«T:" << y->dtype(); + jk << "«I:" << inverse; + jk << "«TS:\"" << y->dtype()<<"\""; +} + +#else // JIT +#ifdef JIT_cpu +void CufftFftOp::jit_run() { +} +#else // JIT_cuda +void CufftFftOp::jit_run() { + auto* __restrict__ xp = x->mem_ptr; + auto* __restrict__ yp = y->mem_ptr; + + int batch_size = x->shape[0]; + int n1 = x->shape[1], n2 = x->shape[2]; + int fft_size = batch_size * n1 * n2; + std::array fft = {n1, n2}; + + auto op_type = CUFFT_C2C; + if (TS == "float32") { + op_type = CUFFT_C2C; + } else if (TS == "float64") { + op_type = CUFFT_Z2Z; + } + JK& jk = get_jk(); + jk.clear(); + jk << fft[0] << "," << fft[1] << "," << TS << "," << batch_size; + auto iter = cufft_handle_cache.find(jk.to_string()); + cufftHandle plan; + if (iter!=cufft_handle_cache.end()) plan = iter->second; + else { + CUFFT_CALL(cufftCreate(&plan)); + CUFFT_CALL(cufftPlanMany(&plan, 2, fft.data(), + nullptr, 1, fft[0] * fft[1], // *inembed, istride, idist + nullptr, 1, fft[0] * fft[1], // *onembed, ostride, odist + op_type, batch_size)); + CUFFT_CALL(cufftSetStream(plan, 0)); + cufft_handle_cache[jk.to_string()] = plan; + } + /* + * Note: + * Identical pointers to data and output arrays implies in-place transformation + */ + if (TS == "float32") { + CUFFT_CALL(cufftExecC2C(plan, (cufftComplex *)xp, (cufftComplex *)yp, I ? CUFFT_INVERSE : CUFFT_FORWARD)); + } else if (TS == "float64") { + CUFFT_CALL(cufftExecZ2Z(plan, (cufftDoubleComplex *)xp, (cufftDoubleComplex *)yp, I ? CUFFT_INVERSE : CUFFT_FORWARD)); + } + +} +#endif // JIT_cpu +#endif // JIT + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/cufft/ops/cufft_fft_op.h b/python/jittor/extern/cuda/cufft/ops/cufft_fft_op.h new file mode 100644 index 00000000..61b242ac --- /dev/null +++ b/python/jittor/extern/cuda/cufft/ops/cufft_fft_op.h @@ -0,0 +1,27 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com>. +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +//TODO: support FFT2D only now. +struct CufftFftOp : Op { + bool inverse; + Var* x, * y; + NanoString type; + CufftFftOp(Var* x, bool inverse=false); + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + + const char* name() const override { return "cufft_fft"; } + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/cufft/src/cufft_wrapper.cc b/python/jittor/extern/cuda/cufft/src/cufft_wrapper.cc new file mode 100644 index 00000000..003ba7eb --- /dev/null +++ b/python/jittor/extern/cuda/cufft/src/cufft_wrapper.cc @@ -0,0 +1,35 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com>. +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "cufft_wrapper.h" +#include "misc/cuda_flags.h" + +namespace jittor { + +unordered_map cufft_handle_cache; + +struct cufft_initer { + +inline cufft_initer() { + if (!get_device_count()) return; + LOGv << "cufftCreate finished"; +} + +inline ~cufft_initer() { + if (!get_device_count()) return; + for (auto it = cufft_handle_cache.begin(); it != cufft_handle_cache.end(); it++) { + CUFFT_CALL(cufftDestroy(it->second)); + } + cufft_handle_cache.clear(); + LOGv << "cufftDestroy finished"; +} + +} init; + +} // jittor diff --git a/python/jittor/extern/cuda/curand/inc/curand_wrapper.h b/python/jittor/extern/cuda/curand/inc/curand_wrapper.h new file mode 100644 index 00000000..7cff699c --- /dev/null +++ b/python/jittor/extern/cuda/curand/inc/curand_wrapper.h @@ -0,0 +1,22 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include + +#include "helper_cuda.h" +#include "fp16_emu.h" +#include "common.h" + +namespace jittor { + +EXTERN_LIB curandGenerator_t gen; + +} // jittor diff --git a/python/jittor/extern/cuda/curand/ops/curand_random_op.cc b/python/jittor/extern/cuda/curand/ops/curand_random_op.cc new file mode 100644 index 00000000..9edc6670 --- /dev/null +++ b/python/jittor/extern/cuda/curand/ops/curand_random_op.cc @@ -0,0 +1,54 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include + +#include "var.h" +#include "init.h" +#include +#include +#include "helper_cuda.h" +#include "curand_random_op.h" +#include "curand_wrapper.h" + +namespace jittor { + +#ifndef JIT +CurandRandomOp::CurandRandomOp(NanoVector shape, NanoString dtype, NanoString type) { + flags.set(NodeFlags::_cuda, 1); + output = create_output(shape, dtype); + this->type = type; + ASSERT(type == ns_normal || type == ns_uniform); +} + +void CurandRandomOp::jit_prepare(JK& jk) { + jk << "«T:" << output->dtype(); + jk << "«R:" << type; +} + +#else // JIT +#ifdef JIT_cpu +void CurandRandomOp::jit_run() { +} +#else // JIT_cuda +void CurandRandomOp::jit_run() { + @define(TT,@if(@strcmp(@T,float32)==0,,Double)) + + auto* __restrict__ x = output->ptr(); + index_t num = output->num; + // curand doesn't support even number, we add 1 when it is even + // because allocator will make odd chunks, so this wouldn't cause + // segmentation fault + num += num&1; + @if(@strcmp(@R,uniform)==0, + checkCudaErrors(curandGenerateUniform@TT (gen, x, num));, + checkCudaErrors(curandGenerateNormal@TT (gen, x, num, 0, 1)); + ) +} +#endif // JIT_cpu +#endif // JIT + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/curand/ops/curand_random_op.h b/python/jittor/extern/cuda/curand/ops/curand_random_op.h new file mode 100644 index 00000000..320bd51e --- /dev/null +++ b/python/jittor/extern/cuda/curand/ops/curand_random_op.h @@ -0,0 +1,24 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com>. +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CurandRandomOp : Op { + Var* output; + NanoString type; + CurandRandomOp(NanoVector shape, NanoString dtype=ns_float32, NanoString type=ns_uniform); + + const char* name() const override { return "curand_random"; } + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/curand/src/curand_wrapper.cc b/python/jittor/extern/cuda/curand/src/curand_wrapper.cc new file mode 100644 index 00000000..c30089b9 --- /dev/null +++ b/python/jittor/extern/cuda/curand/src/curand_wrapper.cc @@ -0,0 +1,37 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "curand_wrapper.h" +#include "init.h" +#include "misc/cuda_flags.h" + +namespace jittor { + +curandGenerator_t gen; + +struct curand_initer { + +inline curand_initer() { + if (!get_device_count()) return; + checkCudaErrors( curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT) ); + add_set_seed_callback([](int seed) { + checkCudaErrors( curandSetPseudoRandomGeneratorSeed(gen, seed) ); + }); + LOGv << "curandCreate finished"; +} + +inline ~curand_initer() { + if (!get_device_count()) return; + checkCudaErrors( curandDestroyGenerator(gen) ); + LOGv << "curandDestroy finished"; +} + +} init_; + +} // jittor diff --git a/python/jittor/extern/cuda/curand/src/helper_curand.cc b/python/jittor/extern/cuda/curand/src/helper_curand.cc new file mode 100644 index 00000000..679962a5 --- /dev/null +++ b/python/jittor/extern/cuda/curand/src/helper_curand.cc @@ -0,0 +1,62 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com>. +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +// These are CUDA Helper functions for initialization and error checking + +#include +#include +#include "utils/log.h" +#include "helper_cuda.h" +#include + +// cuRAND API errors +const char *_cudaGetErrorEnum(curandStatus_t error) { + switch (error) { + case CURAND_STATUS_SUCCESS: + return "CURAND_STATUS_SUCCESS"; + + case CURAND_STATUS_VERSION_MISMATCH: + return "CURAND_STATUS_VERSION_MISMATCH"; + + case CURAND_STATUS_NOT_INITIALIZED: + return "CURAND_STATUS_NOT_INITIALIZED"; + + case CURAND_STATUS_ALLOCATION_FAILED: + return "CURAND_STATUS_ALLOCATION_FAILED"; + + case CURAND_STATUS_TYPE_ERROR: + return "CURAND_STATUS_TYPE_ERROR"; + + case CURAND_STATUS_OUT_OF_RANGE: + return "CURAND_STATUS_OUT_OF_RANGE"; + + case CURAND_STATUS_LENGTH_NOT_MULTIPLE: + return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; + + case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: + return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; + + case CURAND_STATUS_LAUNCH_FAILURE: + return "CURAND_STATUS_LAUNCH_FAILURE"; + + case CURAND_STATUS_PREEXISTING_FAILURE: + return "CURAND_STATUS_PREEXISTING_FAILURE"; + + case CURAND_STATUS_INITIALIZATION_FAILED: + return "CURAND_STATUS_INITIALIZATION_FAILED"; + + case CURAND_STATUS_ARCH_MISMATCH: + return "CURAND_STATUS_ARCH_MISMATCH"; + + case CURAND_STATUS_INTERNAL_ERROR: + return "CURAND_STATUS_INTERNAL_ERROR"; + } + + return ""; +} \ No newline at end of file diff --git a/python/jittor/extern/cuda/cutt/ops/cutt_test_op.cc b/python/jittor/extern/cuda/cutt/ops/cutt_test_op.cc new file mode 100644 index 00000000..490d7e7c --- /dev/null +++ b/python/jittor/extern/cuda/cutt/ops/cutt_test_op.cc @@ -0,0 +1,42 @@ +// *************************************************************** +// Copyright (c) 2019 Dun Liang . All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "cutt_test_op.h" +#include "utils/str_utils.h" + +#ifdef JIT +#include "cutt.h" +#endif + +namespace jittor { + +#ifndef JIT +CuttTestOp::CuttTestOp(string cmd) : cmd(cmd) { + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_cuda, 1); + output = create_output(1, ns_float32); +} + +void CuttTestOp::jit_prepare(JK& jk) { + jk << "«T:float32"; +} + +#else // JIT +#ifdef JIT_cuda + +void CuttTestOp::jit_run() { + auto args = split(cmd, " "); + if (!cmd.size()) args.clear(); + vector v(args.size()); + for (uint i=0; iptr()[0] = 123; + +} +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/cutt/ops/cutt_test_op.h b/python/jittor/extern/cuda/cutt/ops/cutt_test_op.h new file mode 100644 index 00000000..3de115ad --- /dev/null +++ b/python/jittor/extern/cuda/cutt/ops/cutt_test_op.h @@ -0,0 +1,24 @@ +// *************************************************************** +// Copyright (c) 2019 +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CuttTestOp : Op { + Var* output; + string cmd; + + CuttTestOp(string cmd); + + const char* name() const override { return "cutt_test"; } + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/cutt/ops/cutt_transpose_op.cc b/python/jittor/extern/cuda/cutt/ops/cutt_transpose_op.cc new file mode 100644 index 00000000..aa19dfbe --- /dev/null +++ b/python/jittor/extern/cuda/cutt/ops/cutt_transpose_op.cc @@ -0,0 +1,124 @@ +// *************************************************************** +// Copyright (c) 2019 Dun Liang . All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "cutt_transpose_op.h" +#include "ops/op_register.h" +#include "cutt.h" +#include "cutt_wrapper.h" +#include "misc/stack_vector.h" +#include "helper_cuda.h" + +namespace jittor { + +#ifndef JIT +static auto make_transpose = get_op_info("cutt_transpose") + .get_constructor(); + +CuttTransposeOp::CuttTransposeOp(Var* x, NanoVector axes) : x(x), axes(axes) { + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_cuda, 1); + int i=0; + for (; idtype()); + flags.set(NodeFlags::_manual_set_vnbb); +} + +void CuttTransposeOp::infer_shape() { + auto xdim = x->shape.size(); + CHECK(xdim); + if (!axes.size()) { + for (int i=0; i<(int)xdim; i++) + axes.push_back(xdim-1-i); + } else { + CHECKop(axes.size(),==,xdim); + int64_t mask=0; + for (auto i : axes) mask |= 1<shape[axes[i]]); + y->set_shape(shape); +} + +VarPtr CuttTransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) { + NanoVector reverse; + reverse.reserve(axes.size(), axes.size()); + for (uint i=0; i cutt_plan_cache; + +#else // JIT + +EXTERN_LIB unordered_map cutt_plan_cache; + +void CuttTransposeOp::jit_run() { + // Return if x is empty + if (x->num == 0) + return; + + cudaGetLastError(); + auto* __restrict__ xp = x->mem_ptr; + auto* __restrict__ yp = y->mem_ptr; + StackVector x_shape; + StackVector new_shape, new_axes, trans, reverse; + int dim = x->shape.size(); + for (int i=0; ishape[i] != 1) + new_shape.push_back(x->shape[i]); + } + for (int i = 0; i < dim; ++i) { + if (x->shape[axes[i]] != 1) { + new_axes.push_back(trans[axes[i]]); + } + } + dim = new_shape.size(); + for (int i=0; inum==1) { + checkCudaErrors(cudaMemcpyAsync(yp, xp, x->size, cudaMemcpyDeviceToDevice, 0)); + return; + } + JK& jk = get_jk(); + jk.clear(); + jk << dim << ','; + for (int i=0; idtype().dsize() << '.'; + auto iter = cutt_plan_cache.find(jk.to_string()); + LOGvvv << "Run cutt_transpose with key:" << jk.to_string(); + + if (iter!=cutt_plan_cache.end()){ + cuttExecute(iter->second, xp, yp); + } else { + cuttHandle plan; + checkCudaErrors(cudaDeviceSynchronize()); + auto ret = cuttPlan(&plan, dim, x_shape.data(), reverse.data(), x->dtype().dsize(), 0); + CHECK(0==ret) << ret << jk.to_string() << x << y; + cutt_plan_cache[jk.to_string()] = plan; + cuttExecute(plan, xp, yp); + } +} +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/cutt/ops/cutt_transpose_op.h b/python/jittor/extern/cuda/cutt/ops/cutt_transpose_op.h new file mode 100644 index 00000000..95fdc656 --- /dev/null +++ b/python/jittor/extern/cuda/cutt/ops/cutt_transpose_op.h @@ -0,0 +1,25 @@ +// *************************************************************** +// Copyright (c) 2019 +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CuttTransposeOp : Op { + Var* x, * y; + NanoVector axes; + CuttTransposeOp(Var* x, NanoVector axes=NanoVector()); + + const char* name() const override { return "cutt_transpose"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/cutt/ops/cutt_wrapper.cc b/python/jittor/extern/cuda/cutt/ops/cutt_wrapper.cc new file mode 100644 index 00000000..5319d718 --- /dev/null +++ b/python/jittor/extern/cuda/cutt/ops/cutt_wrapper.cc @@ -0,0 +1,36 @@ +// *************************************************************** +// Copyright (c) 2019 +// Dun Liang +// Guowei Yang <471184555@qq.com> +// All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "cutt_wrapper.h" + + +namespace jittor { + +void jt_alloc(void** p, size_t len, size_t& allocation) { + *p = exe.allocator->alloc(len, allocation); +} + +void jt_free(void* p, size_t len, size_t& allocation) { + exe.allocator->free(p, len, allocation); +} + +struct cutt_initer { + +inline cutt_initer() { + custom_cuda_malloc = jt_alloc; + custom_cuda_free = jt_free; + LOGv << "cuttCreate finished"; +} + +inline ~cutt_initer() { + LOGv << "cuttDestroy finished"; +} + +} cutt_init; + +} // jittor diff --git a/python/jittor/extern/cuda/cutt/ops/cutt_wrapper.h b/python/jittor/extern/cuda/cutt/ops/cutt_wrapper.h new file mode 100644 index 00000000..99380cf2 --- /dev/null +++ b/python/jittor/extern/cuda/cutt/ops/cutt_wrapper.h @@ -0,0 +1,15 @@ +// *************************************************************** +// Copyright (c) 2019 +// Dun Liang +// Guowei Yang <471184555@qq.com> +// All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "executor.h" +#include "CudaUtils.h" + +void jt_alloc(void** p, size_t len, size_t& allocation); + +void jt_free(void* p, size_t len, size_t& allocation); \ No newline at end of file diff --git a/python/jittor/extern/cuda/inc/fp16_dev.h b/python/jittor/extern/cuda/inc/fp16_dev.h new file mode 100644 index 00000000..ca1d69f1 --- /dev/null +++ b/python/jittor/extern/cuda/inc/fp16_dev.h @@ -0,0 +1,21 @@ +/** +* Copyright 2014 NVIDIA Corporation. All rights reserved. +* +* Please refer to the NVIDIA end user license agreement (EULA) associated +* with this source code for terms and conditions that govern your use of +* this software. Any use, reproduction, disclosure, or distribution of +* this software and related documentation outside the terms of the EULA +* is strictly prohibited. +* +*/ + +#if !defined(_FP16_DEV_H_) +#define _FP16_DEV_H_ + +#include "fp16_emu.h" + +template +void gpu_float2half_rn(int size, const value_type *buffIn, half1 *buffOut); + +#endif // _FP16_DEV_H_ + diff --git a/python/jittor/extern/cuda/inc/fp16_emu.h b/python/jittor/extern/cuda/inc/fp16_emu.h new file mode 100644 index 00000000..a89bb419 --- /dev/null +++ b/python/jittor/extern/cuda/inc/fp16_emu.h @@ -0,0 +1,167 @@ +/* + * Copyright 1993-2014 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * This source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * These Licensed Deliverables contained herein is PROPRIETARY and + * CONFIDENTIAL to NVIDIA and is being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + +// Conversion from/to 16-bit floating point (half-precision). + +#if !defined(_FP16_EMU_H_) +#define _FP16_EMU_H_ + +#include +#include + +// Necessary to ensure visibility of CUDART_VERSION macro +#include + +// Definition of '__half_raw' was not provided before CUDA 9.0. +// '__half_raw' is our type where the unsigned 16-bit integer +// data member 'x' can be accessed in both CUDA 9.0 and 8.0. +#if CUDART_VERSION < 9000 +typedef __half __half_raw; +#endif + +// Internally, in CUDNN we use half1 struct as the FP16 type. +typedef __half half1; + +#define HLF_EPSILON 4.887581E-04 +#define HLF_MIN 6.103516E-05 +#define HLF_MAX 6.550400E+04 + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + +half1 cpu_float2half_rn(float f); + +float cpu_half2float(half1 h); + +static __inline__ __device__ __host__ half1 habs(half1 h) +{ + __half_raw hr = reinterpret_cast<__half_raw&>(h); + hr.x &= 0x7fffU; + return reinterpret_cast(hr); +} + +static __inline__ __device__ __host__ half1 hneg(half1 h) +{ + __half_raw hr = reinterpret_cast<__half_raw&>(h); + hr.x ^= 0x8000U; + return reinterpret_cast(hr); +} + +static __inline__ __device__ __host__ int ishnan(half1 h) +{ + // When input is NaN, exponent is all ones and mantissa is non-zero. + __half_raw hr = reinterpret_cast<__half_raw&>(h); + return (hr.x & 0x7c00U) == 0x7c00U && (hr.x & 0x03ffU) != 0; +} + +static __inline__ __device__ __host__ int ishinf(half1 h) +{ + // When input is +/- inf, exponent is all ones and mantissa is zero. + __half_raw hr = reinterpret_cast<__half_raw&>(h); + return (hr.x & 0x7c00U) == 0x7c00U && (hr.x & 0x03ffU) == 0; +} + +static __inline__ __device__ __host__ int ishequ(half1 x, half1 y) +{ + __half_raw xr = reinterpret_cast<__half_raw&>(x); + __half_raw yr = reinterpret_cast<__half_raw&>(y); + return ishnan(x) == 0 && ishnan(y) == 0 && xr.x == yr.x; +} + +// Returns 0.0000 in FP16 binary form +static __inline__ __device__ __host__ half1 hzero() +{ + __half_raw hr; + hr.x = 0x0000U; + return reinterpret_cast(hr); +} + +// Returns 1.0000 in FP16 binary form +static __inline__ __device__ __host__ half1 hone() +{ + __half_raw hr; + hr.x = 0x3c00U; + return reinterpret_cast(hr); +} + +// Returns quiet NaN, the most significant fraction bit #9 is set +static __inline__ __device__ __host__ half1 hnan() +{ + __half_raw hr; + hr.x = 0x7e00U; + return reinterpret_cast(hr); +} + +// Largest positive FP16 value, corresponds to 6.5504e+04 +static __inline__ __device__ __host__ half1 hmax() +{ + // Exponent all ones except LSB (0x1e), mantissa is all ones (0x3ff) + __half_raw hr; + hr.x = 0x7bffU; + return reinterpret_cast(hr); +} + +// Smallest positive (normalized) FP16 value, corresponds to 6.1035e-05 +static __inline__ __device__ __host__ half1 hmin() +{ + // Exponent is 0x01 (5 bits), mantissa is all zeros (10 bits) + __half_raw hr; + hr.x = 0x0400U; + return reinterpret_cast(hr); +} + + +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +#endif // _FP16_EMU_H_ + diff --git a/python/jittor/extern/cuda/inc/helper_cuda.h b/python/jittor/extern/cuda/inc/helper_cuda.h new file mode 100644 index 00000000..fc2e525f --- /dev/null +++ b/python/jittor/extern/cuda/inc/helper_cuda.h @@ -0,0 +1,444 @@ +/** + * Copyright 1993-2017 NVIDIA Corporation. All rights reserved. + * + * Please refer to the NVIDIA end user license agreement (EULA) associated + * with this source code for terms and conditions that govern your use of + * this software. Any use, reproduction, disclosure, or distribution of + * this software and related documentation outside the terms of the EULA + * is strictly prohibited. + * + */ + +//////////////////////////////////////////////////////////////////////////////// +// These are CUDA Helper functions for initialization and error checking + +#ifndef COMMON_HELPER_CUDA_H_ +#define COMMON_HELPER_CUDA_H_ + +#pragma once + +#include "utils/log.h" + +#include +#include +#include +#include +#include + +#ifdef IS_CUDA +#include +#endif + +#ifndef EXIT_WAIVED +#define EXIT_WAIVED 2 +#endif + +// Note, it is required that your SDK sample to include the proper header +// files, please refer the CUDA examples for examples of the needed CUDA +// headers, which may change depending on which CUDA functions are used. + +// CUDA Runtime error messages +#ifdef __DRIVER_TYPES_H__ +inline const char *_cudaGetErrorEnum(cudaError_t error) { + return cudaGetErrorName(error); +} +#endif + +// CUDA Driver API errors +#ifdef CUDA_DRIVER_API +inline const char *_cudaGetErrorEnum(CUresult error) { + const char *ret = NULL; + cuGetErrorName(error, &ret); + return ret ? ret : ""; +} +#endif + +#ifdef CUBLAS_API_H_ +// cuBLAS API errors +const char *_cudaGetErrorEnum(cublasStatus_t error); +#endif + +#ifdef CUDNN_H_ +// cudnn API errors +const char *_cudaGetErrorEnum(cudnnStatus_t error); +#endif + +#ifdef _CUFFT_H_ +// cuFFT API errors +const char *_cudaGetErrorEnum(cufftResult error); +#endif + +#ifdef CUSPARSEAPI +// cuSPARSE API errors +const char *_cudaGetErrorEnum(cusparseStatus_t error); +#endif + +#ifdef CUSOLVER_COMMON_H_ +// cuSOLVER API errors +const char *_cudaGetErrorEnum(cusolverStatus_t error); +#endif + +#ifdef CURAND_H_ +// cuRAND API errors +const char *_cudaGetErrorEnum(curandStatus_t error); +#endif + +#ifdef NCCL_H_ +// cuRAND API errors +const char *_cudaGetErrorEnum(ncclResult_t error); +#endif + +#ifdef NV_NPPIDEFS_H +// NPP API errors +const char *_cudaGetErrorEnum(NppStatus error); +#endif + +#ifdef __DRIVER_TYPES_H__ +#ifndef DEVICE_RESET +#define DEVICE_RESET cudaDeviceReset(); +#endif +#else +#ifndef DEVICE_RESET +#define DEVICE_RESET +#endif +#endif + +namespace jittor { +EXTERN_LIB bool peek_logged; +} + +template +void peek(T result, char const *const func, const char *const file, + int const line) { + if (result) { + // DEVICE_RESET + if (jittor::peek_logged) return; + jittor::peek_logged = 1; + LOGe << "Peek CUDA error at" << file >> ":" >> line << " code=" + >> static_cast(result) >> "(" << _cudaGetErrorEnum(result) << ")" + << func; + } +} + +template +void check(T result, char const *const func, const char *const file, + int const line) { + if (result) { + // DEVICE_RESET + LOGf << "CUDA error at" << file >> ":" >> line << " code=" + >> static_cast(result) >> "(" << _cudaGetErrorEnum(result) << ")" + << func; + } +} + +#define checkCudaErrors(val) check((val), #val, __FILE__, __LINE__) +#define peekCudaErrors(val) peek((val), #val, __FILE__, __LINE__) + +#ifdef __DRIVER_TYPES_H__ +// This will output the proper CUDA error strings in the event +// that a CUDA host call returns an error +#define checkCudaErrors(val) check((val), #val, __FILE__, __LINE__) +#define peekCudaErrors(val) peek((val), #val, __FILE__, __LINE__) + +// This will output the proper error string when calling cudaGetLastError +#define getLastCudaError(msg) __getLastCudaError(msg, __FILE__, __LINE__) + +inline void __getLastCudaError(const char *errorMessage, const char *file, + const int line) { + cudaError_t err = cudaGetLastError(); + + if (cudaSuccess != err) { + // DEVICE_RESET + LOGf << "CUDA error at" << file >> ":" >> line << " code=" + >> static_cast(err) >> "(" << _cudaGetErrorEnum(err) << ")" + << errorMessage; + } +} + +// This will only print the proper error string when calling cudaGetLastError +// but not exit program incase error detected. +#define printLastCudaError(msg) __printLastCudaError(msg, __FILE__, __LINE__) + +inline void __printLastCudaError(const char *errorMessage, const char *file, + const int line) { + cudaError_t err = cudaGetLastError(); + + if (cudaSuccess != err) { + // DEVICE_RESET + LOGf << "CUDA error at" << file >> ":" >> line << " code=" + >> static_cast(err) >> "(" << _cudaGetErrorEnum(err) << ")" + << errorMessage; + } +} +#endif + +#ifndef MAX +#define MAX(a, b) (a > b ? a : b) +#endif + +// Float To Int conversion +inline int ftoi(float value) { + return (value >= 0 ? static_cast(value + 0.5) + : static_cast(value - 0.5)); +} + +// Beginning of GPU Architecture definitions +inline int _ConvertSMVer2Cores(int major, int minor) { + // Defines for GPU Architecture types (using the SM version to determine + // the # of cores per SM + typedef struct { + int SM; // 0xMm (hexidecimal notation), M = SM Major version, + // and m = SM minor version + int Cores; + } sSMtoCores; + + sSMtoCores nGpuArchCoresPerSM[] = { + {0x30, 192}, + {0x32, 192}, + {0x35, 192}, + {0x37, 192}, + {0x50, 128}, + {0x52, 128}, + {0x53, 128}, + {0x60, 64}, + {0x61, 128}, + {0x62, 128}, + {0x70, 64}, + {0x72, 64}, + {0x75, 64}, + {-1, -1}}; + + int index = 0; + + while (nGpuArchCoresPerSM[index].SM != -1) { + if (nGpuArchCoresPerSM[index].SM == ((major << 4) + minor)) { + return nGpuArchCoresPerSM[index].Cores; + } + + index++; + } + + // If we don't find the values, we default use the previous one + // to run properly + printf( + "MapSMtoCores for SM %d.%d is undefined." + " Default to use %d Cores/SM\n", + major, minor, nGpuArchCoresPerSM[index - 1].Cores); + return nGpuArchCoresPerSM[index - 1].Cores; +} + // end of GPU Architecture definitions + +#ifdef __CUDA_RUNTIME_H__ +// General GPU Device CUDA Initialization +inline int gpuDeviceInit(int devID) { + int device_count; + checkCudaErrors(cudaGetDeviceCount(&device_count)); + + if (device_count == 0) { + fprintf(stderr, + "gpuDeviceInit() CUDA error: " + "no devices supporting CUDA.\n"); + exit(EXIT_FAILURE); + } + + if (devID < 0) { + devID = 0; + } + + if (devID > device_count - 1) { + fprintf(stderr, "\n"); + fprintf(stderr, ">> %d CUDA capable GPU device(s) detected. <<\n", + device_count); + fprintf(stderr, + ">> gpuDeviceInit (-device=%d) is not a valid" + " GPU device. <<\n", + devID); + fprintf(stderr, "\n"); + return -devID; + } + + cudaDeviceProp deviceProp; + checkCudaErrors(cudaGetDeviceProperties(&deviceProp, devID)); + + if (deviceProp.computeMode == cudaComputeModeProhibited) { + fprintf(stderr, + "Error: device is running in , no threads can use cudaSetDevice().\n"); + return -1; + } + + if (deviceProp.major < 1) { + fprintf(stderr, "gpuDeviceInit(): GPU device does not support CUDA.\n"); + exit(EXIT_FAILURE); + } + + checkCudaErrors(cudaSetDevice(devID)); + printf("gpuDeviceInit() CUDA Device [%d]: \"%s\n", devID, deviceProp.name); + + return devID; +} + +// This function returns the best GPU (with maximum GFLOPS) +inline int gpuGetMaxGflopsDeviceId() { + int current_device = 0, sm_per_multiproc = 0; + int max_perf_device = 0; + int device_count = 0; + int devices_prohibited = 0; + + uint64_t max_compute_perf = 0; + cudaDeviceProp deviceProp; + checkCudaErrors(cudaGetDeviceCount(&device_count)); + + if (device_count == 0) { + fprintf(stderr, + "gpuGetMaxGflopsDeviceId() CUDA error:" + " no devices supporting CUDA.\n"); + exit(EXIT_FAILURE); + } + + // Find the best CUDA capable GPU device + current_device = 0; + + while (current_device < device_count) { + cudaGetDeviceProperties(&deviceProp, current_device); + + // If this GPU is not running on Compute Mode prohibited, + // then we can add it to the list + if (deviceProp.computeMode != cudaComputeModeProhibited) { + if (deviceProp.major == 9999 && deviceProp.minor == 9999) { + sm_per_multiproc = 1; + } else { + sm_per_multiproc = + _ConvertSMVer2Cores(deviceProp.major, deviceProp.minor); + } + + uint64_t compute_perf = (uint64_t)deviceProp.multiProcessorCount * + sm_per_multiproc * deviceProp.clockRate; + + if (compute_perf > max_compute_perf) { + max_compute_perf = compute_perf; + max_perf_device = current_device; + } + } else { + devices_prohibited++; + } + + ++current_device; + } + + if (devices_prohibited == device_count) { + fprintf(stderr, + "gpuGetMaxGflopsDeviceId() CUDA error:" + " all devices have compute mode prohibited.\n"); + exit(EXIT_FAILURE); + } + + return max_perf_device; +} + +// Initialization code to find the best CUDA Device +inline int findCudaDevice(int argc, const char **argv) { + cudaDeviceProp deviceProp; + int devID = 0; + + // If the command-line has a device number specified, use it + if (checkCmdLineFlag(argc, argv, "device")) { + devID = getCmdLineArgumentInt(argc, argv, "device="); + + if (devID < 0) { + printf("Invalid command line parameter\n "); + exit(EXIT_FAILURE); + } else { + devID = gpuDeviceInit(devID); + + if (devID < 0) { + printf("exiting...\n"); + exit(EXIT_FAILURE); + } + } + } else { + // Otherwise pick the device with highest Gflops/s + devID = gpuGetMaxGflopsDeviceId(); + checkCudaErrors(cudaSetDevice(devID)); + checkCudaErrors(cudaGetDeviceProperties(&deviceProp, devID)); + printf("GPU Device %d: \"%s\" with compute capability %d.%d\n\n", devID, + deviceProp.name, deviceProp.major, deviceProp.minor); + } + + return devID; +} + +inline int findIntegratedGPU() { + int current_device = 0; + int device_count = 0; + int devices_prohibited = 0; + + cudaDeviceProp deviceProp; + checkCudaErrors(cudaGetDeviceCount(&device_count)); + + if (device_count == 0) { + fprintf(stderr, "CUDA error: no devices supporting CUDA.\n"); + exit(EXIT_FAILURE); + } + + // Find the integrated GPU which is compute capable + while (current_device < device_count) { + cudaGetDeviceProperties(&deviceProp, current_device); + + // If GPU is integrated and is not running on Compute Mode prohibited, + // then cuda can map to GLES resource + if (deviceProp.integrated && + (deviceProp.computeMode != cudaComputeModeProhibited)) { + checkCudaErrors(cudaSetDevice(current_device)); + checkCudaErrors(cudaGetDeviceProperties(&deviceProp, current_device)); + printf("GPU Device %d: \"%s\" with compute capability %d.%d\n\n", + current_device, deviceProp.name, deviceProp.major, + deviceProp.minor); + + return current_device; + } else { + devices_prohibited++; + } + + current_device++; + } + + if (devices_prohibited == device_count) { + fprintf(stderr, + "CUDA error:" + " No GLES-CUDA Interop capable GPU found.\n"); + exit(EXIT_FAILURE); + } + + return -1; +} + +// General check for CUDA GPU SM Capabilities +inline bool checkCudaCapabilities(int major_version, int minor_version) { + cudaDeviceProp deviceProp; + deviceProp.major = 0; + deviceProp.minor = 0; + int dev; + + checkCudaErrors(cudaGetDevice(&dev)); + checkCudaErrors(cudaGetDeviceProperties(&deviceProp, dev)); + + if ((deviceProp.major > major_version) || + (deviceProp.major == major_version && + deviceProp.minor >= minor_version)) { + printf(" Device %d: <%16s >, Compute SM %d.%d detected\n", dev, + deviceProp.name, deviceProp.major, deviceProp.minor); + return true; + } else { + printf( + " No GPU device was found that can support " + "CUDA compute capability %d.%d.\n", + major_version, minor_version); + return false; + } +} +#endif + + // end of CUDA Helper Functions + +#endif // COMMON_HELPER_CUDA_H_ diff --git a/python/jittor/extern/cuda/inc/helper_functions.h b/python/jittor/extern/cuda/inc/helper_functions.h new file mode 100644 index 00000000..f157ab56 --- /dev/null +++ b/python/jittor/extern/cuda/inc/helper_functions.h @@ -0,0 +1,42 @@ +/** + * Copyright 1993-2013 NVIDIA Corporation. All rights reserved. + * + * Please refer to the NVIDIA end user license agreement (EULA) associated + * with this source code for terms and conditions that govern your use of + * this software. Any use, reproduction, disclosure, or distribution of + * this software and related documentation outside the terms of the EULA + * is strictly prohibited. + * + */ + +// These are helper functions for the SDK samples (string parsing, +// timers, image helpers, etc) +#ifndef COMMON_HELPER_FUNCTIONS_H_ +#define COMMON_HELPER_FUNCTIONS_H_ + +#ifdef WIN32 +#pragma warning(disable : 4996) +#endif + +// includes, project +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +// includes, timer, string parsing, image helpers +#include // helper functions for image compare, dump, data comparisons +#include // helper functions for string parsing +#include // helper functions for timers + +#ifndef EXIT_WAIVED +#define EXIT_WAIVED 2 +#endif + +#endif // COMMON_HELPER_FUNCTIONS_H_ diff --git a/python/jittor/extern/cuda/inc/helper_image.h b/python/jittor/extern/cuda/inc/helper_image.h new file mode 100644 index 00000000..190402dc --- /dev/null +++ b/python/jittor/extern/cuda/inc/helper_image.h @@ -0,0 +1,984 @@ +/** + * Copyright 1993-2013 NVIDIA Corporation. All rights reserved. + * + * Please refer to the NVIDIA end user license agreement (EULA) associated + * with this source code for terms and conditions that govern your use of + * this software. Any use, reproduction, disclosure, or distribution of + * this software and related documentation outside the terms of the EULA + * is strictly prohibited. + * + */ + +// These are helper functions for the SDK samples (image,bitmap) +#ifndef COMMON_HELPER_IMAGE_H_ +#define COMMON_HELPER_IMAGE_H_ + +#include +#include +#include + +#include +#include +#include +#include +#include + +#ifndef MIN +#define MIN(a, b) ((a < b) ? a : b) +#endif +#ifndef MAX +#define MAX(a, b) ((a > b) ? a : b) +#endif + +#ifndef EXIT_WAIVED +#define EXIT_WAIVED 2 +#endif + +#include + +// namespace unnamed (internal) +namespace helper_image_internal { +//! size of PGM file header +const unsigned int PGMHeaderSize = 0x40; + +// types + +//! Data converter from unsigned char / unsigned byte to type T +template +struct ConverterFromUByte; + +//! Data converter from unsigned char / unsigned byte +template <> +struct ConverterFromUByte { + //! Conversion operator + //! @return converted value + //! @param val value to convert + float operator()(const unsigned char &val) { + return static_cast(val); + } +}; + +//! Data converter from unsigned char / unsigned byte to float +template <> +struct ConverterFromUByte { + //! Conversion operator + //! @return converted value + //! @param val value to convert + float operator()(const unsigned char &val) { + return static_cast(val) / 255.0f; + } +}; + +//! Data converter from unsigned char / unsigned byte to type T +template +struct ConverterToUByte; + +//! Data converter from unsigned char / unsigned byte to unsigned int +template <> +struct ConverterToUByte { + //! Conversion operator (essentially a passthru + //! @return converted value + //! @param val value to convert + unsigned char operator()(const unsigned char &val) { return val; } +}; + +//! Data converter from unsigned char / unsigned byte to unsigned int +template <> +struct ConverterToUByte { + //! Conversion operator + //! @return converted value + //! @param val value to convert + unsigned char operator()(const float &val) { + return static_cast(val * 255.0f); + } +}; +} // namespace helper_image_internal + +#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) +#ifndef FOPEN +#define FOPEN(fHandle, filename, mode) fopen_s(&fHandle, filename, mode) +#endif +#ifndef FOPEN_FAIL +#define FOPEN_FAIL(result) (result != 0) +#endif +#ifndef SSCANF +#define SSCANF sscanf_s +#endif +#else +#ifndef FOPEN +#define FOPEN(fHandle, filename, mode) (fHandle = fopen(filename, mode)) +#endif +#ifndef FOPEN_FAIL +#define FOPEN_FAIL(result) (result == NULL) +#endif +#ifndef SSCANF +#define SSCANF sscanf +#endif +#endif + +inline bool __loadPPM(const char *file, unsigned char **data, unsigned int *w, + unsigned int *h, unsigned int *channels) { + FILE *fp = NULL; + + if (FOPEN_FAIL(FOPEN(fp, file, "rb"))) { + std::cerr << "__LoadPPM() : Failed to open file: " << file << std::endl; + return false; + } + + // check header + char header[helper_image_internal::PGMHeaderSize]; + + if (fgets(header, helper_image_internal::PGMHeaderSize, fp) == NULL) { + std::cerr << "__LoadPPM() : reading PGM header returned NULL" << std::endl; + return false; + } + + if (strncmp(header, "P5", 2) == 0) { + *channels = 1; + } else if (strncmp(header, "P6", 2) == 0) { + *channels = 3; + } else { + std::cerr << "__LoadPPM() : File is not a PPM or PGM image" << std::endl; + *channels = 0; + return false; + } + + // parse header, read maxval, width and height + unsigned int width = 0; + unsigned int height = 0; + unsigned int maxval = 0; + unsigned int i = 0; + + while (i < 3) { + if (fgets(header, helper_image_internal::PGMHeaderSize, fp) == NULL) { + std::cerr << "__LoadPPM() : reading PGM header returned NULL" + << std::endl; + return false; + } + + if (header[0] == '#') { + continue; + } + + if (i == 0) { + i += SSCANF(header, "%u %u %u", &width, &height, &maxval); + } else if (i == 1) { + i += SSCANF(header, "%u %u", &height, &maxval); + } else if (i == 2) { + i += SSCANF(header, "%u", &maxval); + } + } + + // check if given handle for the data is initialized + if (NULL != *data) { + if (*w != width || *h != height) { + std::cerr << "__LoadPPM() : Invalid image dimensions." << std::endl; + } + } else { + *data = (unsigned char *)malloc(sizeof(unsigned char) * width * height * + *channels); + *w = width; + *h = height; + } + + // read and close file + if (fread(*data, sizeof(unsigned char), width * height * *channels, fp) == + 0) { + std::cerr << "__LoadPPM() read data returned error." << std::endl; + } + + fclose(fp); + + return true; +} + +template +inline bool sdkLoadPGM(const char *file, T **data, unsigned int *w, + unsigned int *h) { + unsigned char *idata = NULL; + unsigned int channels; + + if (true != __loadPPM(file, &idata, w, h, &channels)) { + return false; + } + + unsigned int size = *w * *h * channels; + + // initialize mem if necessary + // the correct size is checked / set in loadPGMc() + if (NULL == *data) { + *data = reinterpret_cast(malloc(sizeof(T) * size)); + } + + // copy and cast data + std::transform(idata, idata + size, *data, + helper_image_internal::ConverterFromUByte()); + + free(idata); + + return true; +} + +template +inline bool sdkLoadPPM4(const char *file, T **data, unsigned int *w, + unsigned int *h) { + unsigned char *idata = 0; + unsigned int channels; + + if (__loadPPM(file, &idata, w, h, &channels)) { + // pad 4th component + int size = *w * *h; + // keep the original pointer + unsigned char *idata_orig = idata; + *data = reinterpret_cast(malloc(sizeof(T) * size * 4)); + unsigned char *ptr = *data; + + for (int i = 0; i < size; i++) { + *ptr++ = *idata++; + *ptr++ = *idata++; + *ptr++ = *idata++; + *ptr++ = 0; + } + + free(idata_orig); + return true; + } else { + free(idata); + return false; + } +} + +inline bool __savePPM(const char *file, unsigned char *data, unsigned int w, + unsigned int h, unsigned int channels) { + assert(NULL != data); + assert(w > 0); + assert(h > 0); + + std::fstream fh(file, std::fstream::out | std::fstream::binary); + + if (fh.bad()) { + std::cerr << "__savePPM() : Opening file failed." << std::endl; + return false; + } + + if (channels == 1) { + fh << "P5\n"; + } else if (channels == 3) { + fh << "P6\n"; + } else { + std::cerr << "__savePPM() : Invalid number of channels." << std::endl; + return false; + } + + fh << w << "\n" << h << "\n" << 0xff << std::endl; + + for (unsigned int i = 0; (i < (w * h * channels)) && fh.good(); ++i) { + fh << data[i]; + } + + fh.flush(); + + if (fh.bad()) { + std::cerr << "__savePPM() : Writing data failed." << std::endl; + return false; + } + + fh.close(); + + return true; +} + +template +inline bool sdkSavePGM(const char *file, T *data, unsigned int w, + unsigned int h) { + unsigned int size = w * h; + unsigned char *idata = (unsigned char *)malloc(sizeof(unsigned char) * size); + + std::transform(data, data + size, idata, + helper_image_internal::ConverterToUByte()); + + // write file + bool result = __savePPM(file, idata, w, h, 1); + + // cleanup + free(idata); + + return result; +} + +inline bool sdkSavePPM4ub(const char *file, unsigned char *data, unsigned int w, + unsigned int h) { + // strip 4th component + int size = w * h; + unsigned char *ndata = + (unsigned char *)malloc(sizeof(unsigned char) * size * 3); + unsigned char *ptr = ndata; + + for (int i = 0; i < size; i++) { + *ptr++ = *data++; + *ptr++ = *data++; + *ptr++ = *data++; + data++; + } + + bool result = __savePPM(file, ndata, w, h, 3); + free(ndata); + return result; +} + +////////////////////////////////////////////////////////////////////////////// +//! Read file \filename and return the data +//! @return bool if reading the file succeeded, otherwise false +//! @param filename name of the source file +//! @param data uninitialized pointer, returned initialized and pointing to +//! the data read +//! @param len number of data elements in data, -1 on error +////////////////////////////////////////////////////////////////////////////// +template +inline bool sdkReadFile(const char *filename, T **data, unsigned int *len, + bool verbose) { + // check input arguments + assert(NULL != filename); + assert(NULL != len); + + // intermediate storage for the data read + std::vector data_read; + + // open file for reading + FILE *fh = NULL; + + // check if filestream is valid + if (FOPEN_FAIL(FOPEN(fh, filename, "r"))) { + printf("Unable to open input file: %s\n", filename); + return false; + } + + // read all data elements + T token; + + while (!feof(fh)) { + fscanf(fh, "%f", &token); + data_read.push_back(token); + } + + // the last element is read twice + data_read.pop_back(); + fclose(fh); + + // check if the given handle is already initialized + if (NULL != *data) { + if (*len != data_read.size()) { + std::cerr << "sdkReadFile() : Initialized memory given but " + << "size mismatch with signal read " + << "(data read / data init = " << (unsigned int)data_read.size() + << " / " << *len << ")" << std::endl; + + return false; + } + } else { + // allocate storage for the data read + *data = reinterpret_cast(malloc(sizeof(T) * data_read.size())); + // store signal size + *len = static_cast(data_read.size()); + } + + // copy data + memcpy(*data, &data_read.front(), sizeof(T) * data_read.size()); + + return true; +} + +////////////////////////////////////////////////////////////////////////////// +//! Read file \filename and return the data +//! @return bool if reading the file succeeded, otherwise false +//! @param filename name of the source file +//! @param data uninitialized pointer, returned initialized and pointing to +//! the data read +//! @param len number of data elements in data, -1 on error +////////////////////////////////////////////////////////////////////////////// +template +inline bool sdkReadFileBlocks(const char *filename, T **data, unsigned int *len, + unsigned int block_num, unsigned int block_size, + bool verbose) { + // check input arguments + assert(NULL != filename); + assert(NULL != len); + + // open file for reading + FILE *fh = fopen(filename, "rb"); + + if (fh == NULL && verbose) { + std::cerr << "sdkReadFile() : Opening file failed." << std::endl; + return false; + } + + // check if the given handle is already initialized + // allocate storage for the data read + data[block_num] = reinterpret_cast(malloc(block_size)); + + // read all data elements + fseek(fh, block_num * block_size, SEEK_SET); + *len = fread(data[block_num], sizeof(T), block_size / sizeof(T), fh); + + fclose(fh); + + return true; +} + +////////////////////////////////////////////////////////////////////////////// +//! Write a data file \filename +//! @return true if writing the file succeeded, otherwise false +//! @param filename name of the source file +//! @param data data to write +//! @param len number of data elements in data, -1 on error +//! @param epsilon epsilon for comparison +////////////////////////////////////////////////////////////////////////////// +template +inline bool sdkWriteFile(const char *filename, const T *data, unsigned int len, + const S epsilon, bool verbose, bool append = false) { + assert(NULL != filename); + assert(NULL != data); + + // open file for writing + // if (append) { + std::fstream fh(filename, std::fstream::out | std::fstream::ate); + + if (verbose) { + std::cerr << "sdkWriteFile() : Open file " << filename + << " for write/append." << std::endl; + } + + /* } else { + std::fstream fh(filename, std::fstream::out); + if (verbose) { + std::cerr << "sdkWriteFile() : Open file " << filename << " for + write." << std::endl; + } + } + */ + + // check if filestream is valid + if (!fh.good()) { + if (verbose) { + std::cerr << "sdkWriteFile() : Opening file failed." << std::endl; + } + + return false; + } + + // first write epsilon + fh << "# " << epsilon << "\n"; + + // write data + for (unsigned int i = 0; (i < len) && (fh.good()); ++i) { + fh << data[i] << ' '; + } + + // Check if writing succeeded + if (!fh.good()) { + if (verbose) { + std::cerr << "sdkWriteFile() : Writing file failed." << std::endl; + } + + return false; + } + + // file ends with nl + fh << std::endl; + + return true; +} + +////////////////////////////////////////////////////////////////////////////// +//! Compare two arrays of arbitrary type +//! @return true if \a reference and \a data are identical, otherwise false +//! @param reference timer_interface to the reference data / gold image +//! @param data handle to the computed data +//! @param len number of elements in reference and data +//! @param epsilon epsilon to use for the comparison +////////////////////////////////////////////////////////////////////////////// +template +inline bool compareData(const T *reference, const T *data, + const unsigned int len, const S epsilon, + const float threshold) { + assert(epsilon >= 0); + + bool result = true; + unsigned int error_count = 0; + + for (unsigned int i = 0; i < len; ++i) { + float diff = static_cast(reference[i]) - static_cast(data[i]); + bool comp = (diff <= epsilon) && (diff >= -epsilon); + result &= comp; + + error_count += !comp; + +#if 0 + + if (!comp) { + std::cerr << "ERROR, i = " << i << ",\t " + << reference[i] << " / " + << data[i] + << " (reference / data)\n"; + } + +#endif + } + + if (threshold == 0.0f) { + return (result) ? true : false; + } else { + if (error_count) { + printf("%4.2f(%%) of bytes mismatched (count=%d)\n", + static_cast(error_count) * 100 / static_cast(len), + error_count); + } + + return (len * threshold > error_count) ? true : false; + } +} + +#ifndef __MIN_EPSILON_ERROR +#define __MIN_EPSILON_ERROR 1e-3f +#endif + +////////////////////////////////////////////////////////////////////////////// +//! Compare two arrays of arbitrary type +//! @return true if \a reference and \a data are identical, otherwise false +//! @param reference handle to the reference data / gold image +//! @param data handle to the computed data +//! @param len number of elements in reference and data +//! @param epsilon epsilon to use for the comparison +//! @param epsilon threshold % of (# of bytes) for pass/fail +////////////////////////////////////////////////////////////////////////////// +template +inline bool compareDataAsFloatThreshold(const T *reference, const T *data, + const unsigned int len, const S epsilon, + const float threshold) { + assert(epsilon >= 0); + + // If we set epsilon to be 0, let's set a minimum threshold + float max_error = MAX((float)epsilon, __MIN_EPSILON_ERROR); + int error_count = 0; + bool result = true; + + for (unsigned int i = 0; i < len; ++i) { + float diff = + fabs(static_cast(reference[i]) - static_cast(data[i])); + bool comp = (diff < max_error); + result &= comp; + + if (!comp) { + error_count++; + } + } + + if (threshold == 0.0f) { + if (error_count) { + printf("total # of errors = %d\n", error_count); + } + + return (error_count == 0) ? true : false; + } else { + if (error_count) { + printf("%4.2f(%%) of bytes mismatched (count=%d)\n", + static_cast(error_count) * 100 / static_cast(len), + error_count); + } + + return ((len * threshold > error_count) ? true : false); + } +} + +inline void sdkDumpBin(void *data, unsigned int bytes, const char *filename) { + printf("sdkDumpBin: <%s>\n", filename); + FILE *fp; + FOPEN(fp, filename, "wb"); + fwrite(data, bytes, 1, fp); + fflush(fp); + fclose(fp); +} + +inline bool sdkCompareBin2BinUint(const char *src_file, const char *ref_file, + unsigned int nelements, const float epsilon, + const float threshold, char *exec_path) { + unsigned int *src_buffer, *ref_buffer; + FILE *src_fp = NULL, *ref_fp = NULL; + + uint64_t error_count = 0; + size_t fsize = 0; + + if (FOPEN_FAIL(FOPEN(src_fp, src_file, "rb"))) { + printf("compareBin2Bin unable to open src_file: %s\n", + src_file); + error_count++; + } + + char *ref_file_path = sdkFindFilePath(ref_file, exec_path); + + if (ref_file_path == NULL) { + printf("compareBin2Bin unable to find <%s> in <%s>\n", + ref_file, exec_path); + printf(">>> Check info.xml and [project//data] folder <%s> <<<\n", + ref_file); + printf("Aborting comparison!\n"); + printf(" FAILED\n"); + error_count++; + + if (src_fp) { + fclose(src_fp); + } + + if (ref_fp) { + fclose(ref_fp); + } + } else { + if (FOPEN_FAIL(FOPEN(ref_fp, ref_file_path, "rb"))) { + printf( + "compareBin2Bin " + " unable to open ref_file: %s\n", + ref_file_path); + error_count++; + } + + if (src_fp && ref_fp) { + src_buffer = (unsigned int *)malloc(nelements * sizeof(unsigned int)); + ref_buffer = (unsigned int *)malloc(nelements * sizeof(unsigned int)); + + fsize = fread(src_buffer, nelements, sizeof(unsigned int), src_fp); + fsize = fread(ref_buffer, nelements, sizeof(unsigned int), ref_fp); + + printf( + "> compareBin2Bin nelements=%d," + " epsilon=%4.2f, threshold=%4.2f\n", + nelements, epsilon, threshold); + printf(" src_file <%s>, size=%d bytes\n", src_file, + static_cast(fsize)); + printf(" ref_file <%s>, size=%d bytes\n", ref_file_path, + static_cast(fsize)); + + if (!compareData(ref_buffer, src_buffer, nelements, + epsilon, threshold)) { + error_count++; + } + + fclose(src_fp); + fclose(ref_fp); + + free(src_buffer); + free(ref_buffer); + } else { + if (src_fp) { + fclose(src_fp); + } + + if (ref_fp) { + fclose(ref_fp); + } + } + } + + if (error_count == 0) { + printf(" OK\n"); + } else { + printf(" FAILURE: %d errors...\n", (unsigned int)error_count); + } + + return (error_count == 0); // returns true if all pixels pass +} + +inline bool sdkCompareBin2BinFloat(const char *src_file, const char *ref_file, + unsigned int nelements, const float epsilon, + const float threshold, char *exec_path) { + float *src_buffer = NULL, *ref_buffer = NULL; + FILE *src_fp = NULL, *ref_fp = NULL; + size_t fsize = 0; + + uint64_t error_count = 0; + + if (FOPEN_FAIL(FOPEN(src_fp, src_file, "rb"))) { + printf("compareBin2Bin unable to open src_file: %s\n", src_file); + error_count = 1; + } + + char *ref_file_path = sdkFindFilePath(ref_file, exec_path); + + if (ref_file_path == NULL) { + printf("compareBin2Bin unable to find <%s> in <%s>\n", ref_file, + exec_path); + printf(">>> Check info.xml and [project//data] folder <%s> <<<\n", + exec_path); + printf("Aborting comparison!\n"); + printf(" FAILED\n"); + error_count++; + + if (src_fp) { + fclose(src_fp); + } + + if (ref_fp) { + fclose(ref_fp); + } + } else { + if (FOPEN_FAIL(FOPEN(ref_fp, ref_file_path, "rb"))) { + printf("compareBin2Bin unable to open ref_file: %s\n", + ref_file_path); + error_count = 1; + } + + if (src_fp && ref_fp) { + src_buffer = reinterpret_cast(malloc(nelements * sizeof(float))); + ref_buffer = reinterpret_cast(malloc(nelements * sizeof(float))); + + printf( + "> compareBin2Bin nelements=%d, epsilon=%4.2f," + " threshold=%4.2f\n", + nelements, epsilon, threshold); + fsize = fread(src_buffer, sizeof(float), nelements, src_fp); + printf(" src_file <%s>, size=%d bytes\n", src_file, + static_cast(fsize * sizeof(float))); + fsize = fread(ref_buffer, sizeof(float), nelements, ref_fp); + printf(" ref_file <%s>, size=%d bytes\n", ref_file_path, + static_cast(fsize * sizeof(float))); + + if (!compareDataAsFloatThreshold( + ref_buffer, src_buffer, nelements, epsilon, threshold)) { + error_count++; + } + + fclose(src_fp); + fclose(ref_fp); + + free(src_buffer); + free(ref_buffer); + } else { + if (src_fp) { + fclose(src_fp); + } + + if (ref_fp) { + fclose(ref_fp); + } + } + } + + if (error_count == 0) { + printf(" OK\n"); + } else { + printf(" FAILURE: %d errors...\n", (unsigned int)error_count); + } + + return (error_count == 0); // returns true if all pixels pass +} + +inline bool sdkCompareL2fe(const float *reference, const float *data, + const unsigned int len, const float epsilon) { + assert(epsilon >= 0); + + float error = 0; + float ref = 0; + + for (unsigned int i = 0; i < len; ++i) { + float diff = reference[i] - data[i]; + error += diff * diff; + ref += reference[i] * reference[i]; + } + + float normRef = sqrtf(ref); + + if (fabs(ref) < 1e-7) { +#ifdef _DEBUG + std::cerr << "ERROR, reference l2-norm is 0\n"; +#endif + return false; + } + + float normError = sqrtf(error); + error = normError / normRef; + bool result = error < epsilon; +#ifdef _DEBUG + + if (!result) { + std::cerr << "ERROR, l2-norm error " << error << " is greater than epsilon " + << epsilon << "\n"; + } + +#endif + + return result; +} + +inline bool sdkLoadPPMub(const char *file, unsigned char **data, + unsigned int *w, unsigned int *h) { + unsigned int channels; + return __loadPPM(file, data, w, h, &channels); +} + +inline bool sdkLoadPPM4ub(const char *file, unsigned char **data, + unsigned int *w, unsigned int *h) { + unsigned char *idata = 0; + unsigned int channels; + + if (__loadPPM(file, &idata, w, h, &channels)) { + // pad 4th component + int size = *w * *h; + // keep the original pointer + unsigned char *idata_orig = idata; + *data = (unsigned char *)malloc(sizeof(unsigned char) * size * 4); + unsigned char *ptr = *data; + + for (int i = 0; i < size; i++) { + *ptr++ = *idata++; + *ptr++ = *idata++; + *ptr++ = *idata++; + *ptr++ = 0; + } + + free(idata_orig); + return true; + } else { + free(idata); + return false; + } +} + +inline bool sdkComparePPM(const char *src_file, const char *ref_file, + const float epsilon, const float threshold, + bool verboseErrors) { + unsigned char *src_data, *ref_data; + uint64_t error_count = 0; + unsigned int ref_width, ref_height; + unsigned int src_width, src_height; + + if (src_file == NULL || ref_file == NULL) { + if (verboseErrors) { + std::cerr << "PPMvsPPM: src_file or ref_file is NULL." + " Aborting comparison\n"; + } + + return false; + } + + if (verboseErrors) { + std::cerr << "> Compare (a)rendered: <" << src_file << ">\n"; + std::cerr << "> (b)reference: <" << ref_file << ">\n"; + } + + if (sdkLoadPPM4ub(ref_file, &ref_data, &ref_width, &ref_height) != true) { + if (verboseErrors) { + std::cerr << "PPMvsPPM: unable to load ref image file: " << ref_file + << "\n"; + } + + return false; + } + + if (sdkLoadPPM4ub(src_file, &src_data, &src_width, &src_height) != true) { + std::cerr << "PPMvsPPM: unable to load src image file: " << src_file + << "\n"; + return false; + } + + if (src_height != ref_height || src_width != ref_width) { + if (verboseErrors) { + std::cerr << "PPMvsPPM: source and ref size mismatch (" << src_width + << "," << src_height << ")vs(" << ref_width << "," << ref_height + << ")\n"; + } + } + + if (verboseErrors) { + std::cerr << "PPMvsPPM: comparing images size (" << src_width << "," + << src_height << ") epsilon(" << epsilon << "), threshold(" + << threshold * 100 << "%)\n"; + } + + if (compareData(ref_data, src_data, src_width * src_height * 4, epsilon, + threshold) == false) { + error_count = 1; + } + + if (error_count == 0) { + if (verboseErrors) { + std::cerr << " OK\n\n"; + } + } else { + if (verboseErrors) { + std::cerr << " FAILURE! " << error_count << " errors...\n\n"; + } + } + + // returns true if all pixels pass + return (error_count == 0) ? true : false; +} + +inline bool sdkComparePGM(const char *src_file, const char *ref_file, + const float epsilon, const float threshold, + bool verboseErrors) { + unsigned char *src_data = 0, *ref_data = 0; + uint64_t error_count = 0; + unsigned int ref_width, ref_height; + unsigned int src_width, src_height; + + if (src_file == NULL || ref_file == NULL) { + if (verboseErrors) { + std::cerr << "PGMvsPGM: src_file or ref_file is NULL." + " Aborting comparison\n"; + } + + return false; + } + + if (verboseErrors) { + std::cerr << "> Compare (a)rendered: <" << src_file << ">\n"; + std::cerr << "> (b)reference: <" << ref_file << ">\n"; + } + + if (sdkLoadPPMub(ref_file, &ref_data, &ref_width, &ref_height) != true) { + if (verboseErrors) { + std::cerr << "PGMvsPGM: unable to load ref image file: " << ref_file + << "\n"; + } + + return false; + } + + if (sdkLoadPPMub(src_file, &src_data, &src_width, &src_height) != true) { + std::cerr << "PGMvsPGM: unable to load src image file: " << src_file + << "\n"; + return false; + } + + if (src_height != ref_height || src_width != ref_width) { + if (verboseErrors) { + std::cerr << "PGMvsPGM: source and ref size mismatch (" << src_width + << "," << src_height << ")vs(" << ref_width << "," << ref_height + << ")\n"; + } + } + + if (verboseErrors) + std::cerr << "PGMvsPGM: comparing images size (" << src_width << "," + << src_height << ") epsilon(" << epsilon << "), threshold(" + << threshold * 100 << "%)\n"; + + if (compareData(ref_data, src_data, src_width * src_height, epsilon, + threshold) == false) { + error_count = 1; + } + + if (error_count == 0) { + if (verboseErrors) { + std::cerr << " OK\n\n"; + } + } else { + if (verboseErrors) { + std::cerr << " FAILURE! " << error_count << " errors...\n\n"; + } + } + + // returns true if all pixels pass + return (error_count == 0) ? true : false; +} + +#endif // COMMON_HELPER_IMAGE_H_ diff --git a/python/jittor/extern/cuda/inc/helper_string.h b/python/jittor/extern/cuda/inc/helper_string.h new file mode 100644 index 00000000..77864b8f --- /dev/null +++ b/python/jittor/extern/cuda/inc/helper_string.h @@ -0,0 +1,683 @@ +/** + * Copyright 1993-2013 NVIDIA Corporation. All rights reserved. + * + * Please refer to the NVIDIA end user license agreement (EULA) associated + * with this source code for terms and conditions that govern your use of + * this software. Any use, reproduction, disclosure, or distribution of + * this software and related documentation outside the terms of the EULA + * is strictly prohibited. + * + */ + +// These are helper functions for the SDK samples (string parsing, timers, etc) +#ifndef COMMON_HELPER_STRING_H_ +#define COMMON_HELPER_STRING_H_ + +#include +#include +#include +#include + +#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) +#ifndef _CRT_SECURE_NO_DEPRECATE +#define _CRT_SECURE_NO_DEPRECATE +#endif +#ifndef STRCASECMP +#define STRCASECMP _stricmp +#endif +#ifndef STRNCASECMP +#define STRNCASECMP _strnicmp +#endif +#ifndef STRCPY +#define STRCPY(sFilePath, nLength, sPath) strcpy_s(sFilePath, nLength, sPath) +#endif + +#ifndef FOPEN +#define FOPEN(fHandle, filename, mode) fopen_s(&fHandle, filename, mode) +#endif +#ifndef FOPEN_FAIL +#define FOPEN_FAIL(result) (result != 0) +#endif +#ifndef SSCANF +#define SSCANF sscanf_s +#endif +#ifndef SPRINTF +#define SPRINTF sprintf_s +#endif +#else // Linux Includes +#include +#include + +#ifndef STRCASECMP +#define STRCASECMP strcasecmp +#endif +#ifndef STRNCASECMP +#define STRNCASECMP strncasecmp +#endif +#ifndef STRCPY +#define STRCPY(sFilePath, nLength, sPath) strcpy(sFilePath, sPath) +#endif + +#ifndef FOPEN +#define FOPEN(fHandle, filename, mode) (fHandle = fopen(filename, mode)) +#endif +#ifndef FOPEN_FAIL +#define FOPEN_FAIL(result) (result == NULL) +#endif +#ifndef SSCANF +#define SSCANF sscanf +#endif +#ifndef SPRINTF +#define SPRINTF sprintf +#endif +#endif + +#ifndef EXIT_WAIVED +#define EXIT_WAIVED 2 +#endif + +// CUDA Utility Helper Functions +inline int stringRemoveDelimiter(char delimiter, const char *string) { + int string_start = 0; + + while (string[string_start] == delimiter) { + string_start++; + } + + if (string_start >= static_cast(strlen(string) - 1)) { + return 0; + } + + return string_start; +} + +inline int getFileExtension(char *filename, char **extension) { + int string_length = static_cast(strlen(filename)); + + while (filename[string_length--] != '.') { + if (string_length == 0) break; + } + + if (string_length > 0) string_length += 2; + + if (string_length == 0) + *extension = NULL; + else + *extension = &filename[string_length]; + + return string_length; +} + +inline bool checkCmdLineFlag(const int argc, const char **argv, + const char *string_ref) { + bool bFound = false; + + if (argc >= 1) { + for (int i = 1; i < argc; i++) { + int string_start = stringRemoveDelimiter('-', argv[i]); + const char *string_argv = &argv[i][string_start]; + + const char *equal_pos = strchr(string_argv, '='); + int argv_length = static_cast( + equal_pos == 0 ? strlen(string_argv) : equal_pos - string_argv); + + int length = static_cast(strlen(string_ref)); + + if (length == argv_length && + !STRNCASECMP(string_argv, string_ref, length)) { + bFound = true; + continue; + } + } + } + + return bFound; +} + +// This function wraps the CUDA Driver API into a template function +template +inline bool getCmdLineArgumentValue(const int argc, const char **argv, + const char *string_ref, T *value) { + bool bFound = false; + + if (argc >= 1) { + for (int i = 1; i < argc; i++) { + int string_start = stringRemoveDelimiter('-', argv[i]); + const char *string_argv = &argv[i][string_start]; + int length = static_cast(strlen(string_ref)); + + if (!STRNCASECMP(string_argv, string_ref, length)) { + if (length + 1 <= static_cast(strlen(string_argv))) { + int auto_inc = (string_argv[length] == '=') ? 1 : 0; + *value = (T)atoi(&string_argv[length + auto_inc]); + } + + bFound = true; + i = argc; + } + } + } + + return bFound; +} + +inline int getCmdLineArgumentInt(const int argc, const char **argv, + const char *string_ref) { + bool bFound = false; + int value = -1; + + if (argc >= 1) { + for (int i = 1; i < argc; i++) { + int string_start = stringRemoveDelimiter('-', argv[i]); + const char *string_argv = &argv[i][string_start]; + int length = static_cast(strlen(string_ref)); + + if (!STRNCASECMP(string_argv, string_ref, length)) { + if (length + 1 <= static_cast(strlen(string_argv))) { + int auto_inc = (string_argv[length] == '=') ? 1 : 0; + value = atoi(&string_argv[length + auto_inc]); + } else { + value = 0; + } + + bFound = true; + continue; + } + } + } + + if (bFound) { + return value; + } else { + return 0; + } +} + +inline float getCmdLineArgumentFloat(const int argc, const char **argv, + const char *string_ref) { + bool bFound = false; + float value = -1; + + if (argc >= 1) { + for (int i = 1; i < argc; i++) { + int string_start = stringRemoveDelimiter('-', argv[i]); + const char *string_argv = &argv[i][string_start]; + int length = static_cast(strlen(string_ref)); + + if (!STRNCASECMP(string_argv, string_ref, length)) { + if (length + 1 <= static_cast(strlen(string_argv))) { + int auto_inc = (string_argv[length] == '=') ? 1 : 0; + value = static_cast(atof(&string_argv[length + auto_inc])); + } else { + value = 0.f; + } + + bFound = true; + continue; + } + } + } + + if (bFound) { + return value; + } else { + return 0; + } +} + +inline bool getCmdLineArgumentString(const int argc, const char **argv, + const char *string_ref, + char **string_retval) { + bool bFound = false; + + if (argc >= 1) { + for (int i = 1; i < argc; i++) { + int string_start = stringRemoveDelimiter('-', argv[i]); + char *string_argv = const_cast(&argv[i][string_start]); + int length = static_cast(strlen(string_ref)); + + if (!STRNCASECMP(string_argv, string_ref, length)) { + *string_retval = &string_argv[length + 1]; + bFound = true; + continue; + } + } + } + + if (!bFound) { + *string_retval = NULL; + } + + return bFound; +} + +////////////////////////////////////////////////////////////////////////////// +//! Find the path for a file assuming that +//! files are found in the searchPath. +//! +//! @return the path if succeeded, otherwise 0 +//! @param filename name of the file +//! @param executable_path optional absolute path of the executable +////////////////////////////////////////////////////////////////////////////// +inline char *sdkFindFilePath(const char *filename, + const char *executable_path) { + // defines a variable that is replaced with the name of the + // executable + + // Typical relative search paths to locate needed companion files (e.g. sample + // input data, or JIT source files) The origin for the relative search may be + // the .exe file, a .bat file launching an .exe, a browser .exe launching the + // .exe or .bat, etc + const char *searchPath[] = { + "./", // same dir + "./_data_files/", + "./common/", // "/common/" subdir + "./common/data/", // "/common/data/" subdir + "./data/", // "/data/" subdir + "./src/", // "/src/" subdir + "./src//data/", // "/src//data/" subdir + "./inc/", // "/inc/" subdir + "./0_Simple/", // "/0_Simple/" subdir + "./1_Utilities/", // "/1_Utilities/" subdir + "./2_Graphics/", // "/2_Graphics/" subdir + "./3_Imaging/", // "/3_Imaging/" subdir + "./4_Finance/", // "/4_Finance/" subdir + "./5_Simulations/", // "/5_Simulations/" subdir + "./6_Advanced/", // "/6_Advanced/" subdir + "./7_CUDALibraries/", // "/7_CUDALibraries/" subdir + "./8_Android/", // "/8_Android/" subdir + "./samples/", // "/samples/" subdir + + "./0_Simple//data/", // "/0_Simple//data/" + // subdir + "./1_Utilities//data/", // "/1_Utilities//data/" + // subdir + "./2_Graphics//data/", // "/2_Graphics//data/" + // subdir + "./3_Imaging//data/", // "/3_Imaging//data/" + // subdir + "./4_Finance//data/", // "/4_Finance//data/" + // subdir + "./5_Simulations//data/", // "/5_Simulations//data/" + // subdir + "./6_Advanced//data/", // "/6_Advanced//data/" + // subdir + "./7_CUDALibraries//", // "/7_CUDALibraries//" + // subdir + "./7_CUDALibraries//data/", // "/7_CUDALibraries//data/" + // subdir + + "../", // up 1 in tree + "../common/", // up 1 in tree, "/common/" subdir + "../common/data/", // up 1 in tree, "/common/data/" subdir + "../data/", // up 1 in tree, "/data/" subdir + "../src/", // up 1 in tree, "/src/" subdir + "../inc/", // up 1 in tree, "/inc/" subdir + + "../0_Simple//data/", // up 1 in tree, + // "/0_Simple//" + // subdir + "../1_Utilities//data/", // up 1 in tree, + // "/1_Utilities//" + // subdir + "../2_Graphics//data/", // up 1 in tree, + // "/2_Graphics//" + // subdir + "../3_Imaging//data/", // up 1 in tree, + // "/3_Imaging//" + // subdir + "../4_Finance//data/", // up 1 in tree, + // "/4_Finance//" + // subdir + "../5_Simulations//data/", // up 1 in tree, + // "/5_Simulations//" + // subdir + "../6_Advanced//data/", // up 1 in tree, + // "/6_Advanced//" + // subdir + "../7_CUDALibraries//data/", // up 1 in tree, + // "/7_CUDALibraries//" + // subdir + "../8_Android//data/", // up 1 in tree, + // "/8_Android//" + // subdir + "../samples//data/", // up 1 in tree, + // "/samples//" + // subdir + "../../", // up 2 in tree + "../../common/", // up 2 in tree, "/common/" subdir + "../../common/data/", // up 2 in tree, "/common/data/" subdir + "../../data/", // up 2 in tree, "/data/" subdir + "../../src/", // up 2 in tree, "/src/" subdir + "../../inc/", // up 2 in tree, "/inc/" subdir + "../../sandbox//data/", // up 2 in tree, + // "/sandbox//" + // subdir + "../../0_Simple//data/", // up 2 in tree, + // "/0_Simple//" + // subdir + "../../1_Utilities//data/", // up 2 in tree, + // "/1_Utilities//" + // subdir + "../../2_Graphics//data/", // up 2 in tree, + // "/2_Graphics//" + // subdir + "../../3_Imaging//data/", // up 2 in tree, + // "/3_Imaging//" + // subdir + "../../4_Finance//data/", // up 2 in tree, + // "/4_Finance//" + // subdir + "../../5_Simulations//data/", // up 2 in tree, + // "/5_Simulations//" + // subdir + "../../6_Advanced//data/", // up 2 in tree, + // "/6_Advanced//" + // subdir + "../../7_CUDALibraries//data/", // up 2 in tree, + // "/7_CUDALibraries//" + // subdir + "../../8_Android//data/", // up 2 in tree, + // "/8_Android//" + // subdir + "../../samples//data/", // up 2 in tree, + // "/samples//" + // subdir + "../../../", // up 3 in tree + "../../../src//", // up 3 in tree, + // "/src//" subdir + "../../../src//data/", // up 3 in tree, + // "/src//data/" + // subdir + "../../../src//src/", // up 3 in tree, + // "/src//src/" + // subdir + "../../../src//inc/", // up 3 in tree, + // "/src//inc/" + // subdir + "../../../sandbox//", // up 3 in tree, + // "/sandbox//" + // subdir + "../../../sandbox//data/", // up 3 in tree, + // "/sandbox//data/" + // subdir + "../../../sandbox//src/", // up 3 in tree, + // "/sandbox//src/" + // subdir + "../../../sandbox//inc/", // up 3 in tree, + // "/sandbox//inc/" + // subdir + "../../../0_Simple//data/", // up 3 in tree, + // "/0_Simple//" + // subdir + "../../../1_Utilities//data/", // up 3 in tree, + // "/1_Utilities//" + // subdir + "../../../2_Graphics//data/", // up 3 in tree, + // "/2_Graphics//" + // subdir + "../../../3_Imaging//data/", // up 3 in tree, + // "/3_Imaging//" + // subdir + "../../../4_Finance//data/", // up 3 in tree, + // "/4_Finance//" + // subdir + "../../../5_Simulations//data/", // up 3 in tree, + // "/5_Simulations//" + // subdir + "../../../6_Advanced//data/", // up 3 in tree, + // "/6_Advanced//" + // subdir + "../../../7_CUDALibraries//data/", // up 3 in tree, + // "/7_CUDALibraries//" + // subdir + "../../../8_Android//data/", // up 3 in tree, + // "/8_Android//" + // subdir + "../../../0_Simple//", // up 3 in tree, + // "/0_Simple//" + // subdir + "../../../1_Utilities//", // up 3 in tree, + // "/1_Utilities//" + // subdir + "../../../2_Graphics//", // up 3 in tree, + // "/2_Graphics//" + // subdir + "../../../3_Imaging//", // up 3 in tree, + // "/3_Imaging//" + // subdir + "../../../4_Finance//", // up 3 in tree, + // "/4_Finance//" + // subdir + "../../../5_Simulations//", // up 3 in tree, + // "/5_Simulations//" + // subdir + "../../../6_Advanced//", // up 3 in tree, + // "/6_Advanced//" + // subdir + "../../../7_CUDALibraries//", // up 3 in tree, + // "/7_CUDALibraries//" + // subdir + "../../../8_Android//", // up 3 in tree, + // "/8_Android//" + // subdir + "../../../samples//data/", // up 3 in tree, + // "/samples//" + // subdir + "../../../common/", // up 3 in tree, "../../../common/" subdir + "../../../common/data/", // up 3 in tree, "../../../common/data/" subdir + "../../../data/", // up 3 in tree, "../../../data/" subdir + "../../../../", // up 4 in tree + "../../../../src//", // up 4 in tree, + // "/src//" subdir + "../../../../src//data/", // up 4 in tree, + // "/src//data/" + // subdir + "../../../../src//src/", // up 4 in tree, + // "/src//src/" + // subdir + "../../../../src//inc/", // up 4 in tree, + // "/src//inc/" + // subdir + "../../../../sandbox//", // up 4 in tree, + // "/sandbox//" + // subdir + "../../../../sandbox//data/", // up 4 in tree, + // "/sandbox//data/" + // subdir + "../../../../sandbox//src/", // up 4 in tree, + // "/sandbox//src/" + // subdir + "../../../../sandbox//inc/", // up 4 in tree, + // "/sandbox//inc/" + // subdir + "../../../../0_Simple//data/", // up 4 in tree, + // "/0_Simple//" + // subdir + "../../../../1_Utilities//data/", // up 4 in tree, + // "/1_Utilities//" + // subdir + "../../../../2_Graphics//data/", // up 4 in tree, + // "/2_Graphics//" + // subdir + "../../../../3_Imaging//data/", // up 4 in tree, + // "/3_Imaging//" + // subdir + "../../../../4_Finance//data/", // up 4 in tree, + // "/4_Finance//" + // subdir + "../../../../5_Simulations//data/", // up 4 in tree, + // "/5_Simulations//" + // subdir + "../../../../6_Advanced//data/", // up 4 in tree, + // "/6_Advanced//" + // subdir + "../../../../7_CUDALibraries//data/", // up 4 in tree, + // "/7_CUDALibraries//" + // subdir + "../../../../8_Android//data/", // up 4 in tree, + // "/8_Android//" + // subdir + "../../../../0_Simple//", // up 4 in tree, + // "/0_Simple//" + // subdir + "../../../../1_Utilities//", // up 4 in tree, + // "/1_Utilities//" + // subdir + "../../../../2_Graphics//", // up 4 in tree, + // "/2_Graphics//" + // subdir + "../../../../3_Imaging//", // up 4 in tree, + // "/3_Imaging//" + // subdir + "../../../../4_Finance//", // up 4 in tree, + // "/4_Finance//" + // subdir + "../../../../5_Simulations//", // up 4 in tree, + // "/5_Simulations//" + // subdir + "../../../../6_Advanced//", // up 4 in tree, + // "/6_Advanced//" + // subdir + "../../../../7_CUDALibraries//", // up 4 in tree, + // "/7_CUDALibraries//" + // subdir + "../../../../8_Android//", // up 4 in tree, + // "/8_Android//" + // subdir + "../../../../samples//data/", // up 4 in tree, + // "/samples//" + // subdir + "../../../../common/", // up 4 in tree, "../../../common/" subdir + "../../../../common/data/", // up 4 in tree, "../../../common/data/" + // subdir + "../../../../data/", // up 4 in tree, "../../../data/" subdir + "../../../../../", // up 5 in tree + "../../../../../src//", // up 5 in tree, + // "/src//" + // subdir + "../../../../../src//data/", // up 5 in tree, + // "/src//data/" + // subdir + "../../../../../src//src/", // up 5 in tree, + // "/src//src/" + // subdir + "../../../../../src//inc/", // up 5 in tree, + // "/src//inc/" + // subdir + "../../../../../sandbox//", // up 5 in tree, + // "/sandbox//" + // subdir + "../../../../../sandbox//data/", // up 5 in tree, + // "/sandbox//data/" + // subdir + "../../../../../sandbox//src/", // up 5 in tree, + // "/sandbox//src/" + // subdir + "../../../../../sandbox//inc/", // up 5 in tree, + // "/sandbox//inc/" + // subdir + "../../../../../0_Simple//data/", // up 5 in tree, + // "/0_Simple//" + // subdir + "../../../../../1_Utilities//data/", // up 5 in tree, + // "/1_Utilities//" + // subdir + "../../../../../2_Graphics//data/", // up 5 in tree, + // "/2_Graphics//" + // subdir + "../../../../../3_Imaging//data/", // up 5 in tree, + // "/3_Imaging//" + // subdir + "../../../../../4_Finance//data/", // up 5 in tree, + // "/4_Finance//" + // subdir + "../../../../../5_Simulations//data/", // up 5 in tree, + // "/5_Simulations//" + // subdir + "../../../../../6_Advanced//data/", // up 5 in tree, + // "/6_Advanced//" + // subdir + "../../../../../7_CUDALibraries//data/", // up 5 in + // tree, + // "/7_CUDALibraries//" + // subdir + "../../../../../8_Android//data/", // up 5 in tree, + // "/8_Android//" + // subdir + "../../../../../samples//data/", // up 5 in tree, + // "/samples//" + // subdir + "../../../../../common/", // up 5 in tree, "../../../common/" subdir + "../../../../../common/data/", // up 5 in tree, "../../../common/data/" + // subdir + }; + + // Extract the executable name + std::string executable_name; + + if (executable_path != 0) { + executable_name = std::string(executable_path); + +#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) + // Windows path delimiter + size_t delimiter_pos = executable_name.find_last_of('\\'); + executable_name.erase(0, delimiter_pos + 1); + + if (executable_name.rfind(".exe") != std::string::npos) { + // we strip .exe, only if the .exe is found + executable_name.resize(executable_name.size() - 4); + } + +#else + // Linux & OSX path delimiter + size_t delimiter_pos = executable_name.find_last_of('/'); + executable_name.erase(0, delimiter_pos + 1); +#endif + } + + // Loop over all search paths and return the first hit + for (unsigned int i = 0; i < sizeof(searchPath) / sizeof(char *); ++i) { + std::string path(searchPath[i]); + size_t executable_name_pos = path.find(""); + + // If there is executable_name variable in the searchPath + // replace it with the value + if (executable_name_pos != std::string::npos) { + if (executable_path != 0) { + path.replace(executable_name_pos, strlen(""), + executable_name); + } else { + // Skip this path entry if no executable argument is given + continue; + } + } + +#ifdef _DEBUG + printf("sdkFindFilePath <%s> in %s\n", filename, path.c_str()); +#endif + + // Test if the file exists + path.append(filename); + FILE *fp; + FOPEN(fp, path.c_str(), "rb"); + + if (fp != NULL) { + fclose(fp); + // File found + // returning an allocated array here for backwards compatibility reasons + char *file_path = reinterpret_cast(malloc(path.length() + 1)); + STRCPY(file_path, path.length() + 1, path.c_str()); + return file_path; + } + + if (fp) { + fclose(fp); + } + } + + // File not found + return 0; +} + +#endif // COMMON_HELPER_STRING_H_ diff --git a/python/jittor/extern/cuda/inc/helper_timer.h b/python/jittor/extern/cuda/inc/helper_timer.h new file mode 100644 index 00000000..fc3ee767 --- /dev/null +++ b/python/jittor/extern/cuda/inc/helper_timer.h @@ -0,0 +1,448 @@ +/** + * Copyright 1993-2013 NVIDIA Corporation. All rights reserved. + * + * Please refer to the NVIDIA end user license agreement (EULA) associated + * with this source code for terms and conditions that govern your use of + * this software. Any use, reproduction, disclosure, or distribution of + * this software and related documentation outside the terms of the EULA + * is strictly prohibited. + * + */ + +// Helper Timing Functions +#ifndef COMMON_HELPER_TIMER_H_ +#define COMMON_HELPER_TIMER_H_ + +#ifndef EXIT_WAIVED +#define EXIT_WAIVED 2 +#endif + +// includes, system +#include + + +// Definition of the StopWatch Interface, this is used if we don't want to use +// the CUT functions But rather in a self contained class interface +class StopWatchInterface { + public: + StopWatchInterface() {} + virtual ~StopWatchInterface() {} + + public: + //! Start time measurement + virtual void start() = 0; + + //! Stop time measurement + virtual void stop() = 0; + + //! Reset time counters to zero + virtual void reset() = 0; + + //! Time in msec. after start. If the stop watch is still running (i.e. there + //! was no call to stop()) then the elapsed time is returned, otherwise the + //! time between the last start() and stop call is returned + virtual float getTime() = 0; + + //! Mean time to date based on the number of times the stopwatch has been + //! _stopped_ (ie finished sessions) and the current total time + virtual float getAverageTime() = 0; +}; + +////////////////////////////////////////////////////////////////// +// Begin Stopwatch timer class definitions for all OS platforms // +////////////////////////////////////////////////////////////////// +#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) +// includes, system +#define WINDOWS_LEAN_AND_MEAN +#include +#undef min +#undef max + +//! Windows specific implementation of StopWatch +class StopWatchWin : public StopWatchInterface { + public: + //! Constructor, default + StopWatchWin() + : start_time(), + end_time(), + diff_time(0.0f), + total_time(0.0f), + running(false), + clock_sessions(0), + freq(0), + freq_set(false) { + if (!freq_set) { + // helper variable + LARGE_INTEGER temp; + + // get the tick frequency from the OS + QueryPerformanceFrequency(reinterpret_cast(&temp)); + + // convert to type in which it is needed + freq = (static_cast(temp.QuadPart)) / 1000.0; + + // rememeber query + freq_set = true; + } + } + + // Destructor + ~StopWatchWin() {} + + public: + //! Start time measurement + inline void start(); + + //! Stop time measurement + inline void stop(); + + //! Reset time counters to zero + inline void reset(); + + //! Time in msec. after start. If the stop watch is still running (i.e. there + //! was no call to stop()) then the elapsed time is returned, otherwise the + //! time between the last start() and stop call is returned + inline float getTime(); + + //! Mean time to date based on the number of times the stopwatch has been + //! _stopped_ (ie finished sessions) and the current total time + inline float getAverageTime(); + + private: + // member variables + + //! Start of measurement + LARGE_INTEGER start_time; + //! End of measurement + LARGE_INTEGER end_time; + + //! Time difference between the last start and stop + float diff_time; + + //! TOTAL time difference between starts and stops + float total_time; + + //! flag if the stop watch is running + bool running; + + //! Number of times clock has been started + //! and stopped to allow averaging + int clock_sessions; + + //! tick frequency + double freq; + + //! flag if the frequency has been set + bool freq_set; +}; + +// functions, inlined + +//////////////////////////////////////////////////////////////////////////////// +//! Start time measurement +//////////////////////////////////////////////////////////////////////////////// +inline void StopWatchWin::start() { + QueryPerformanceCounter(reinterpret_cast(&start_time)); + running = true; +} + +//////////////////////////////////////////////////////////////////////////////// +//! Stop time measurement and increment add to the current diff_time summation +//! variable. Also increment the number of times this clock has been run. +//////////////////////////////////////////////////////////////////////////////// +inline void StopWatchWin::stop() { + QueryPerformanceCounter(reinterpret_cast(&end_time)); + diff_time = static_cast(((static_cast(end_time.QuadPart) - + static_cast(start_time.QuadPart)) / + freq)); + + total_time += diff_time; + clock_sessions++; + running = false; +} + +//////////////////////////////////////////////////////////////////////////////// +//! Reset the timer to 0. Does not change the timer running state but does +//! recapture this point in time as the current start time if it is running. +//////////////////////////////////////////////////////////////////////////////// +inline void StopWatchWin::reset() { + diff_time = 0; + total_time = 0; + clock_sessions = 0; + + if (running) { + QueryPerformanceCounter(reinterpret_cast(&start_time)); + } +} + +//////////////////////////////////////////////////////////////////////////////// +//! Time in msec. after start. If the stop watch is still running (i.e. there +//! was no call to stop()) then the elapsed time is returned added to the +//! current diff_time sum, otherwise the current summed time difference alone +//! is returned. +//////////////////////////////////////////////////////////////////////////////// +inline float StopWatchWin::getTime() { + // Return the TOTAL time to date + float retval = total_time; + + if (running) { + LARGE_INTEGER temp; + QueryPerformanceCounter(reinterpret_cast(&temp)); + retval += static_cast(((static_cast(temp.QuadPart) - + static_cast(start_time.QuadPart)) / + freq)); + } + + return retval; +} + +//////////////////////////////////////////////////////////////////////////////// +//! Time in msec. for a single run based on the total number of COMPLETED runs +//! and the total time. +//////////////////////////////////////////////////////////////////////////////// +inline float StopWatchWin::getAverageTime() { + return (clock_sessions > 0) ? (total_time / clock_sessions) : 0.0f; +} +#else +// Declarations for Stopwatch on Linux and Mac OSX +// includes, system +#include +#include + +//! Windows specific implementation of StopWatch +class StopWatchLinux : public StopWatchInterface { + public: + //! Constructor, default + StopWatchLinux() + : start_time(), + diff_time(0.0), + total_time(0.0), + running(false), + clock_sessions(0) {} + + // Destructor + virtual ~StopWatchLinux() {} + + public: + //! Start time measurement + inline void start(); + + //! Stop time measurement + inline void stop(); + + //! Reset time counters to zero + inline void reset(); + + //! Time in msec. after start. If the stop watch is still running (i.e. there + //! was no call to stop()) then the elapsed time is returned, otherwise the + //! time between the last start() and stop call is returned + inline float getTime(); + + //! Mean time to date based on the number of times the stopwatch has been + //! _stopped_ (ie finished sessions) and the current total time + inline float getAverageTime(); + + private: + // helper functions + + //! Get difference between start time and current time + inline float getDiffTime(); + + private: + // member variables + + //! Start of measurement + struct timeval start_time; + + //! Time difference between the last start and stop + float diff_time; + + //! TOTAL time difference between starts and stops + float total_time; + + //! flag if the stop watch is running + bool running; + + //! Number of times clock has been started + //! and stopped to allow averaging + int clock_sessions; +}; + +// functions, inlined + +//////////////////////////////////////////////////////////////////////////////// +//! Start time measurement +//////////////////////////////////////////////////////////////////////////////// +inline void StopWatchLinux::start() { + gettimeofday(&start_time, 0); + running = true; +} + +//////////////////////////////////////////////////////////////////////////////// +//! Stop time measurement and increment add to the current diff_time summation +//! variable. Also increment the number of times this clock has been run. +//////////////////////////////////////////////////////////////////////////////// +inline void StopWatchLinux::stop() { + diff_time = getDiffTime(); + total_time += diff_time; + running = false; + clock_sessions++; +} + +//////////////////////////////////////////////////////////////////////////////// +//! Reset the timer to 0. Does not change the timer running state but does +//! recapture this point in time as the current start time if it is running. +//////////////////////////////////////////////////////////////////////////////// +inline void StopWatchLinux::reset() { + diff_time = 0; + total_time = 0; + clock_sessions = 0; + + if (running) { + gettimeofday(&start_time, 0); + } +} + +//////////////////////////////////////////////////////////////////////////////// +//! Time in msec. after start. If the stop watch is still running (i.e. there +//! was no call to stop()) then the elapsed time is returned added to the +//! current diff_time sum, otherwise the current summed time difference alone +//! is returned. +//////////////////////////////////////////////////////////////////////////////// +inline float StopWatchLinux::getTime() { + // Return the TOTAL time to date + float retval = total_time; + + if (running) { + retval += getDiffTime(); + } + + return retval; +} + +//////////////////////////////////////////////////////////////////////////////// +//! Time in msec. for a single run based on the total number of COMPLETED runs +//! and the total time. +//////////////////////////////////////////////////////////////////////////////// +inline float StopWatchLinux::getAverageTime() { + return (clock_sessions > 0) ? (total_time / clock_sessions) : 0.0f; +} +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// +inline float StopWatchLinux::getDiffTime() { + struct timeval t_time; + gettimeofday(&t_time, 0); + + // time difference in milli-seconds + return static_cast(1000.0 * (t_time.tv_sec - start_time.tv_sec) + + (0.001 * (t_time.tv_usec - start_time.tv_usec))); +} +#endif // WIN32 + +//////////////////////////////////////////////////////////////////////////////// +//! Timer functionality exported + +//////////////////////////////////////////////////////////////////////////////// +//! Create a new timer +//! @return true if a time has been created, otherwise false +//! @param name of the new timer, 0 if the creation failed +//////////////////////////////////////////////////////////////////////////////// +inline bool sdkCreateTimer(StopWatchInterface **timer_interface) { +// printf("sdkCreateTimer called object %08x\n", (void *)*timer_interface); +#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) + *timer_interface = reinterpret_cast(new StopWatchWin()); +#else + *timer_interface = + reinterpret_cast(new StopWatchLinux()); +#endif + return (*timer_interface != NULL) ? true : false; +} + +//////////////////////////////////////////////////////////////////////////////// +//! Delete a timer +//! @return true if a time has been deleted, otherwise false +//! @param name of the timer to delete +//////////////////////////////////////////////////////////////////////////////// +inline bool sdkDeleteTimer(StopWatchInterface **timer_interface) { + // printf("sdkDeleteTimer called object %08x\n", (void *)*timer_interface); + if (*timer_interface) { + delete *timer_interface; + *timer_interface = NULL; + } + + return true; +} + +//////////////////////////////////////////////////////////////////////////////// +//! Start the time with name \a name +//! @param name name of the timer to start +//////////////////////////////////////////////////////////////////////////////// +inline bool sdkStartTimer(StopWatchInterface **timer_interface) { + // printf("sdkStartTimer called object %08x\n", (void *)*timer_interface); + if (*timer_interface) { + (*timer_interface)->start(); + } + + return true; +} + +//////////////////////////////////////////////////////////////////////////////// +//! Stop the time with name \a name. Does not reset. +//! @param name name of the timer to stop +//////////////////////////////////////////////////////////////////////////////// +inline bool sdkStopTimer(StopWatchInterface **timer_interface) { + // printf("sdkStopTimer called object %08x\n", (void *)*timer_interface); + if (*timer_interface) { + (*timer_interface)->stop(); + } + + return true; +} + +//////////////////////////////////////////////////////////////////////////////// +//! Resets the timer's counter. +//! @param name name of the timer to reset. +//////////////////////////////////////////////////////////////////////////////// +inline bool sdkResetTimer(StopWatchInterface **timer_interface) { + // printf("sdkResetTimer called object %08x\n", (void *)*timer_interface); + if (*timer_interface) { + (*timer_interface)->reset(); + } + + return true; +} + +//////////////////////////////////////////////////////////////////////////////// +//! Return the average time for timer execution as the total time +//! for the timer dividied by the number of completed (stopped) runs the timer +//! has made. +//! Excludes the current running time if the timer is currently running. +//! @param name name of the timer to return the time of +//////////////////////////////////////////////////////////////////////////////// +inline float sdkGetAverageTimerValue(StopWatchInterface **timer_interface) { + // printf("sdkGetAverageTimerValue called object %08x\n", (void + // *)*timer_interface); + if (*timer_interface) { + return (*timer_interface)->getAverageTime(); + } else { + return 0.0f; + } +} + +//////////////////////////////////////////////////////////////////////////////// +//! Total execution time for the timer over all runs since the last reset +//! or timer creation. +//! @param name name of the timer to obtain the value of. +//////////////////////////////////////////////////////////////////////////////// +inline float sdkGetTimerValue(StopWatchInterface **timer_interface) { + // printf("sdkGetTimerValue called object %08x\n", (void *)*timer_interface); + if (*timer_interface) { + return (*timer_interface)->getTime(); + } else { + return 0.0f; + } +} + +#endif // COMMON_HELPER_TIMER_H_ + diff --git a/python/jittor/extern/cuda/nccl/inc/nccl_wrapper.h b/python/jittor/extern/cuda/nccl/inc/nccl_wrapper.h new file mode 100644 index 00000000..d9e48b07 --- /dev/null +++ b/python/jittor/extern/cuda/nccl/inc/nccl_wrapper.h @@ -0,0 +1,24 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. +// All Rights Reserved. +// Maintainers: +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "mpi_wrapper.h" + +#include +#include +#include "utils/log.h" +#include "helper_cuda.h" + +namespace jittor { + +EXTERN_LIB ncclComm_t comm; +EXTERN_LIB ncclUniqueId id; +EXTERN_LIB int nccl_device_id; + +} // jittor diff --git a/python/jittor/extern/cuda/nccl/ops/nccl_all_gather_op.cc b/python/jittor/extern/cuda/nccl/ops/nccl_all_gather_op.cc new file mode 100644 index 00000000..00cb7c9e --- /dev/null +++ b/python/jittor/extern/cuda/nccl/ops/nccl_all_gather_op.cc @@ -0,0 +1,69 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Guoye Yang <498731903@qq.com>. +// Dun Liang . +// All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "nccl_all_gather_op.h" +#include "utils/str_utils.h" + +#include +#include +#include "helper_cuda.h" +#include "nccl_wrapper.h" +#include "ops/op_register.h" +namespace jittor { + +#ifndef JIT + +static auto nccl_all_gather = + get_op_info("nccl_all_gather").get_constructor(); + +NcclAllGatherOp::NcclAllGatherOp(Var* x) : x(x) { + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_cuda, 1); + y = create_output(nullptr, x->dtype()); +} + +void NcclAllGatherOp::infer_shape() { + NanoVector yshape; + yshape.push_back(mpi_world_size * x->shape[0]); + for (int i=1; ishape.size(); i++) + yshape.push_back(x->shape[i]); + y->set_shape(yshape); +} + +VarPtr NcclAllGatherOp::grad(Var* out, Var* dout, Var* v, int v_index) { + LOGf << "not implemented"; + return nullptr; +} + +void NcclAllGatherOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype(); +} + +#else // JIT +#ifdef JIT_cuda + +void NcclAllGatherOp::jit_run() { + @define(T_NCCL, + @if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, ncclFloat) + @if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, ncclInt) + @if(@strcmp(@Tx,float64)==0, ncclFloat64) + @if(@strcmp(@Tx,int64)==0, ncclInt64) + @if(@strcmp(@Tx,uint8)==0, ncclUint8) + @if(@strcmp(@Tx,float16)==0, ncclHalf) + @if(@strcmp(@Tx,bfloat16)==0, ncclBfloat16) + ) + auto* __restrict__ xp = x->ptr(); + auto* __restrict__ yp = y->ptr(); + checkCudaErrors(ncclAllGather(xp, yp, x->num, @T_NCCL, comm, 0)); +} + +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/nccl/ops/nccl_all_gather_op.h b/python/jittor/extern/cuda/nccl/ops/nccl_all_gather_op.h new file mode 100644 index 00000000..308d1078 --- /dev/null +++ b/python/jittor/extern/cuda/nccl/ops/nccl_all_gather_op.h @@ -0,0 +1,25 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Guoye Yang <498731903@qq.com>. +// Dun Liang . +// All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct NcclAllGatherOp : Op { + Var* x, * y; + + NcclAllGatherOp(Var* x); + void infer_shape() override; + + const char* name() const override { return "nccl_all_gather"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/nccl/ops/nccl_all_reduce_op.cc b/python/jittor/extern/cuda/nccl/ops/nccl_all_reduce_op.cc new file mode 100644 index 00000000..cb55653d --- /dev/null +++ b/python/jittor/extern/cuda/nccl/ops/nccl_all_reduce_op.cc @@ -0,0 +1,63 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Guoye Yang <498731903@qq.com>. +// Dun Liang . +// All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "nccl_all_reduce_op.h" +#include "utils/str_utils.h" + +#include +#include +#include "helper_cuda.h" +#include "nccl_wrapper.h" +#include "ops/op_register.h" +namespace jittor { + +#ifndef JIT + +static auto nccl_all_reduce = + get_op_info("nccl_all_reduce").get_constructor(); + +NcclAllReduceOp::NcclAllReduceOp(Var* x) : x(x) { + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_cuda, 1); + y = create_output(nullptr, x->dtype()); +} + +void NcclAllReduceOp::infer_shape() { + y->set_shape(x->shape); +} + +VarPtr NcclAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) { + return nccl_all_reduce(dout); +} + +void NcclAllReduceOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype(); +} + +#else // JIT +#ifdef JIT_cuda + +void NcclAllReduceOp::jit_run() { + @define(T_NCCL, + @if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, ncclFloat) + @if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, ncclInt) + @if(@strcmp(@Tx,float64)==0, ncclFloat64) + @if(@strcmp(@Tx,int64)==0, ncclInt64) + @if(@strcmp(@Tx,uint8)==0, ncclUint8) + @if(@strcmp(@Tx,float16)==0, ncclHalf) + ) + auto* __restrict__ xp = x->ptr(); + auto* __restrict__ yp = y->ptr(); + checkCudaErrors(ncclAllReduce(xp, yp, y->num, @T_NCCL, ncclSum, comm, 0)); +} + +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/nccl/ops/nccl_all_reduce_op.h b/python/jittor/extern/cuda/nccl/ops/nccl_all_reduce_op.h new file mode 100644 index 00000000..3bfb5dd4 --- /dev/null +++ b/python/jittor/extern/cuda/nccl/ops/nccl_all_reduce_op.h @@ -0,0 +1,25 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Guoye Yang <498731903@qq.com>. +// Dun Liang . +// All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct NcclAllReduceOp : Op { + Var* x, * y; + + NcclAllReduceOp(Var* x); + void infer_shape() override; + + const char* name() const override { return "nccl_all_reduce"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/nccl/ops/nccl_broadcast_op.cc b/python/jittor/extern/cuda/nccl/ops/nccl_broadcast_op.cc new file mode 100644 index 00000000..4cc28ded --- /dev/null +++ b/python/jittor/extern/cuda/nccl/ops/nccl_broadcast_op.cc @@ -0,0 +1,62 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Guoye Yang <498731903@qq.com>. +// Dun Liang . +// All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "nccl_broadcast_op.h" +#include "utils/str_utils.h" + +#include +#include +#include "helper_cuda.h" +#include "nccl_wrapper.h" +#include "ops/op_register.h" +namespace jittor { + +#ifndef JIT +NcclBroadcastOp::NcclBroadcastOp(Var* x, int root) : x(x), root(root) { + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_cuda, 1); + y = create_output(nullptr, x->dtype()); +} + +void NcclBroadcastOp::infer_shape() { + y->set_shape(x->shape); +} + +VarPtr NcclBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) { + static auto nccl_reduce = + get_op_info("nccl_reduce").get_constructor(); + return nccl_reduce(dout,root); +} + +void NcclBroadcastOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype(); +} + +#else // JIT +#ifdef JIT_cuda + +void NcclBroadcastOp::jit_run() { + @define(T_NCCL, + @if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, ncclFloat) + @if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, ncclInt) + @if(@strcmp(@Tx,float64)==0, ncclFloat64) + @if(@strcmp(@Tx,int64)==0, ncclInt64) + @if(@strcmp(@Tx,uint8)==0, ncclUint8) + @if(@strcmp(@Tx,float16)==0, ncclHalf) + @if(@strcmp(@Tx,bfloat16)==0, ncclBfloat16) + ) + auto* __restrict__ xp = x->ptr(); + auto* __restrict__ yp = y->ptr(); + checkCudaErrors(ncclBroadcast(xp, yp, y->num, @T_NCCL, root, comm, 0)); +} + +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/nccl/ops/nccl_broadcast_op.h b/python/jittor/extern/cuda/nccl/ops/nccl_broadcast_op.h new file mode 100644 index 00000000..f9aac5f4 --- /dev/null +++ b/python/jittor/extern/cuda/nccl/ops/nccl_broadcast_op.h @@ -0,0 +1,26 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Guoye Yang <498731903@qq.com>. +// Dun Liang . +// All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct NcclBroadcastOp : Op { + Var* x, * y; + int root; + + NcclBroadcastOp(Var* x, int root=0); + void infer_shape() override; + + const char* name() const override { return "nccl_broadcast"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/nccl/ops/nccl_reduce_op.cc b/python/jittor/extern/cuda/nccl/ops/nccl_reduce_op.cc new file mode 100644 index 00000000..da60f273 --- /dev/null +++ b/python/jittor/extern/cuda/nccl/ops/nccl_reduce_op.cc @@ -0,0 +1,64 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Guoye Yang <498731903@qq.com>. +// Dun Liang . +// All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "nccl_reduce_op.h" +#include "utils/str_utils.h" + +#include +#include +#include "helper_cuda.h" +#include "nccl_wrapper.h" +#include "ops/op_register.h" +namespace jittor { + +#ifndef JIT +NcclReduceOp::NcclReduceOp(Var* x, int root) : x(x), root(root) { + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_cuda, 1); + y = create_output(nullptr, x->dtype()); +} + +void NcclReduceOp::infer_shape() { + y->set_shape(x->shape); +} + +VarPtr NcclReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) { + static auto nccl_broadcast = + get_op_info("nccl_broadcast").get_constructor(); + return nccl_broadcast(dout,root); +} + +void NcclReduceOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype(); +} + +#else // JIT +#ifdef JIT_cuda + +void NcclReduceOp::jit_run() { + @define(T_NCCL, + @if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, ncclFloat) + @if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, ncclInt) + @if(@strcmp(@Tx,float64)==0, ncclFloat64) + @if(@strcmp(@Tx,int64)==0, ncclInt64) + @if(@strcmp(@Tx,uint8)==0, ncclUint8) + @if(@strcmp(@Tx,float16)==0, ncclHalf) + @if(@strcmp(@Tx,bfloat16)==0, ncclBfloat16) + ) + auto* __restrict__ xp = x->ptr(); + auto* __restrict__ yp = y->ptr(); + checkCudaErrors(ncclReduce(xp, yp, y->num, @T_NCCL, ncclSum, root, comm, 0)); + if (root != mpi_world_rank) + checkCudaErrors(cudaMemsetAsync(yp, 0, y->size)); +} + +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/nccl/ops/nccl_reduce_op.h b/python/jittor/extern/cuda/nccl/ops/nccl_reduce_op.h new file mode 100644 index 00000000..7a663a66 --- /dev/null +++ b/python/jittor/extern/cuda/nccl/ops/nccl_reduce_op.h @@ -0,0 +1,26 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Guoye Yang <498731903@qq.com>. +// Dun Liang . +// All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct NcclReduceOp : Op { + Var* x, * y; + int root; + + NcclReduceOp(Var* x, int root=0); + void infer_shape() override; + + const char* name() const override { return "nccl_reduce"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/nccl/ops/nccl_test_op.cc b/python/jittor/extern/cuda/nccl/ops/nccl_test_op.cc new file mode 100644 index 00000000..fbb49daa --- /dev/null +++ b/python/jittor/extern/cuda/nccl/ops/nccl_test_op.cc @@ -0,0 +1,127 @@ +// *************************************************************** +// Copyright (c) 2019 Dun Liang . All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "nccl_test_op.h" +#include "utils/str_utils.h" + +#include "nccl_wrapper.h" + + +namespace jittor { + +#ifndef JIT +NcclTestOp::NcclTestOp(string cmd) : cmd(cmd) { + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_cuda, 1); + output = create_output(1, ns_float32); +} + +void NcclTestOp::jit_prepare(JK& jk) { + jk << "«T:float32"; +} + +#else // JIT +#ifdef JIT_cuda + +static void test_with_mpi() { + int size = 32*1024*1024; + int myRank = mpi_world_rank; + int nRanks = mpi_world_size; + int localRank = mpi_local_rank; + + float *sendbuff, *recvbuff; + cudaStream_t s; + checkCudaErrors(cudaMalloc(&sendbuff, size * sizeof(float))); + checkCudaErrors(cudaMalloc(&recvbuff, size * sizeof(float))); + checkCudaErrors(cudaStreamCreate(&s)); + + //communicating using NCCL + checkCudaErrors(ncclAllReduce((const void*)sendbuff, (void*)recvbuff, size, ncclFloat, ncclSum, + comm, s)); + + //completing NCCL operation by synchronizing on the CUDA stream + checkCudaErrors(cudaStreamSynchronize(s)); + + //free device buffers + checkCudaErrors(cudaFree(sendbuff)); + checkCudaErrors(cudaFree(recvbuff)); + checkCudaErrors(cudaStreamDestroy(s)); + + LOGi << "MPI rank" << myRank << "Success"; +} + +void NcclTestOp::jit_run() { + output->ptr()[0] = 123; + if (cmd == "test_with_mpi") { + test_with_mpi(); + return; + } + + + //managing 4 devices + int nDev; + checkCudaErrors(cudaGetDeviceCount(&nDev)); + nDev = std::min(nDev, 2); + + ncclComm_t comms[nDev]; + int size = 32*1024*1024; + int devs[4] = { 0, 1, 2, 3 }; + + + //allocating and initializing device buffers + float** sendbuff = (float**)malloc(nDev * sizeof(float*)); + float** recvbuff = (float**)malloc(nDev * sizeof(float*)); + cudaStream_t* s = (cudaStream_t*)malloc(sizeof(cudaStream_t)*nDev); + + + for (int i = 0; i < nDev; ++i) { + checkCudaErrors(cudaSetDevice(i)); + checkCudaErrors(cudaMalloc(sendbuff + i, size * sizeof(float))); + checkCudaErrors(cudaMalloc(recvbuff + i, size * sizeof(float))); + checkCudaErrors(cudaMemset(sendbuff[i], 1, size * sizeof(float))); + checkCudaErrors(cudaMemset(recvbuff[i], 0, size * sizeof(float))); + checkCudaErrors(cudaStreamCreate(s+i)); + } + + + //initializing NCCL + checkCudaErrors(ncclCommInitAll(comms, nDev, devs)); + + + //calling NCCL communication API. Group API is required when using + //multiple devices per thread + checkCudaErrors(ncclGroupStart()); + for (int i = 0; i < nDev; ++i) + checkCudaErrors(ncclAllReduce((const void*)sendbuff[i], (void*)recvbuff[i], size, ncclFloat, ncclSum, + comms[i], s[i])); + checkCudaErrors(ncclGroupEnd()); + + + //synchronizing on CUDA streams to wait for completion of NCCL operation + for (int i = 0; i < nDev; ++i) { + checkCudaErrors(cudaSetDevice(i)); + checkCudaErrors(cudaStreamSynchronize(s[i])); + } + + + //free device buffers + for (int i = 0; i < nDev; ++i) { + checkCudaErrors(cudaSetDevice(i)); + checkCudaErrors(cudaFree(sendbuff[i])); + checkCudaErrors(cudaFree(recvbuff[i])); + } + + + //finalizing NCCL + for(int i = 0; i < nDev; ++i) + ncclCommDestroy(comms[i]); + checkCudaErrors(cudaSetDevice(0)); +} + +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/nccl/ops/nccl_test_op.h b/python/jittor/extern/cuda/nccl/ops/nccl_test_op.h new file mode 100644 index 00000000..26f19eac --- /dev/null +++ b/python/jittor/extern/cuda/nccl/ops/nccl_test_op.h @@ -0,0 +1,25 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. +// All Rights Reserved. +// Maintainers: +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct NcclTestOp : Op { + Var* output; + string cmd; + + NcclTestOp(string cmd); + + const char* name() const override { return "nccl_test"; } + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/nccl/src/nccl_wrapper.cc b/python/jittor/extern/cuda/nccl/src/nccl_wrapper.cc new file mode 100644 index 00000000..c8dfb1ba --- /dev/null +++ b/python/jittor/extern/cuda/nccl/src/nccl_wrapper.cc @@ -0,0 +1,65 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. +// All Rights Reserved. +// Maintainers: +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "misc/cuda_flags.h" +#include "nccl_wrapper.h" +#include "event_queue.h" + +const char *_cudaGetErrorEnum(ncclResult_t error) { + return ncclGetErrorString(error); +} + +namespace jittor { + +ncclComm_t comm; +ncclUniqueId id; +int nccl_device_id = 0; + + +struct nccl_initer { + +nccl_initer() { + int device_count = get_device_count(); + if (!device_count) return; + if (!inside_mpi) return; + nccl_device_id = mpi_local_rank; + if (mpi_local_rank >= device_count) { + LOGw << "mpi_local_rank(">>mpi_local_rank>>") is larger than device_count(" + >>device_count>>")"; + nccl_device_id = nccl_device_id % device_count; + } + LOGv << "NCCL init in device" << nccl_device_id << "local_rank" << mpi_local_rank; + checkCudaErrors(cudaSetDevice(nccl_device_id)); + event_queue.run_sync([]() { + checkCudaErrors(cudaSetDevice(nccl_device_id)); + }); + if (mpi_local_size > device_count) { + // NCCL not support multiple process on one GPU, + // failback use MPI + return; + } + use_device_mpi = true; + if (mpi_world_rank == 0) + checkCudaErrors(ncclGetUniqueId(&id)); + MPI_CHECK(MPI_Bcast((void *)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD)); + checkCudaErrors(ncclCommInitRank(&comm, mpi_world_size, id, mpi_world_rank)); +} + +~nccl_initer() { + if (!get_device_count()) return; + if (!inside_mpi) return; + if (!use_device_mpi) return; + checkCudaErrors(ncclCommDestroy(comm)); +} + +}; + +static nccl_initer nccl_init; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/cuda/src/fp16_emu.cc b/python/jittor/extern/cuda/src/fp16_emu.cc new file mode 100644 index 00000000..f74c1ae8 --- /dev/null +++ b/python/jittor/extern/cuda/src/fp16_emu.cc @@ -0,0 +1,152 @@ +/* + * Copyright 1993-2014 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * This source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * These Licensed Deliverables contained herein is PROPRIETARY and + * CONFIDENTIAL to NVIDIA and is being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + +#include "fp16_emu.h" + + +#ifdef __GNUC__ +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#pragma GCC diagnostic ignored "-Wunused-local-typedefs" +#endif + +#define STATIC_ASSERT(cond) do { typedef char compile_time_assert[(cond) ? 1 : -1]; } while (0) + +// Host functions for converting between FP32 and FP16 formats +// Paulius Micikevicius (pauliusm@nvidia.com) + +half1 cpu_float2half_rn(float f) +{ + unsigned x = *((int*)(void*)(&f)); + unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; + unsigned sign, exponent, mantissa; + + __half_raw hr; + + // Get rid of +NaN/-NaN case first. + if (u > 0x7f800000) { + hr.x = 0x7fffU; + return reinterpret_cast(hr); + } + + sign = ((x >> 16) & 0x8000); + + // Get rid of +Inf/-Inf, +0/-0. + if (u > 0x477fefff) { + hr.x = sign | 0x7c00U; + return reinterpret_cast(hr); + } + if (u < 0x33000001) { + hr.x = sign | 0x0000U; + return reinterpret_cast(hr); + } + + exponent = ((u >> 23) & 0xff); + mantissa = (u & 0x7fffff); + + if (exponent > 0x70) { + shift = 13; + exponent -= 0x70; + } else { + shift = 0x7e - exponent; + exponent = 0; + mantissa |= 0x800000; + } + lsb = (1 << shift); + lsb_s1 = (lsb >> 1); + lsb_m1 = (lsb - 1); + + // Round to nearest even. + remainder = (mantissa & lsb_m1); + mantissa >>= shift; + if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { + ++mantissa; + if (!(mantissa & 0x3ff)) { + ++exponent; + mantissa = 0; + } + } + + hr.x = (sign | (exponent << 10) | mantissa); + + return reinterpret_cast(hr); +} + + +float cpu_half2float(half1 h) +{ + STATIC_ASSERT(sizeof(int) == sizeof(float)); + + __half_raw hr = reinterpret_cast<__half_raw&>(h); + + unsigned sign = ((hr.x >> 15) & 1); + unsigned exponent = ((hr.x >> 10) & 0x1f); + unsigned mantissa = ((hr.x & 0x3ff) << 13); + + if (exponent == 0x1f) { /* NaN or Inf */ + mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); + exponent = 0xff; + } else if (!exponent) { /* Denorm or Zero */ + if (mantissa) { + unsigned int msb; + exponent = 0x71; + do { + msb = (mantissa & 0x400000); + mantissa <<= 1; /* normalize */ + --exponent; + } while (!msb); + mantissa &= 0x7fffff; /* 1.mantissa is implicit */ + } + } else { + exponent += 0x70; + } + + int temp = ((sign << 31) | (exponent << 23) | mantissa); + + return reinterpret_cast(temp); +} diff --git a/python/jittor/extern/cuda/src/helper_cuda.cc b/python/jittor/extern/cuda/src/helper_cuda.cc new file mode 100644 index 00000000..5d944ba2 --- /dev/null +++ b/python/jittor/extern/cuda/src/helper_cuda.cc @@ -0,0 +1,460 @@ +/** + * Copyright 1993-2017 NVIDIA Corporation. All rights reserved. + * + * Please refer to the NVIDIA end user license agreement (EULA) associated + * with this source code for terms and conditions that govern your use of + * this software. Any use, reproduction, disclosure, or distribution of + * this software and related documentation outside the terms of the EULA + * is strictly prohibited. + * + */ + +//////////////////////////////////////////////////////////////////////////////// +// These are CUDA Helper functions for initialization and error checking + +#include +#include "utils/log.h" +#include "helper_cuda.h" + +#ifdef _CUFFT_H_ +// cuFFT API errors +const char *_cudaGetErrorEnum(cufftResult error) { + switch (error) { + case CUFFT_SUCCESS: + return "CUFFT_SUCCESS"; + + case CUFFT_INVALID_PLAN: + return "CUFFT_INVALID_PLAN"; + + case CUFFT_ALLOC_FAILED: + return "CUFFT_ALLOC_FAILED"; + + case CUFFT_INVALID_TYPE: + return "CUFFT_INVALID_TYPE"; + + case CUFFT_INVALID_VALUE: + return "CUFFT_INVALID_VALUE"; + + case CUFFT_INTERNAL_ERROR: + return "CUFFT_INTERNAL_ERROR"; + + case CUFFT_EXEC_FAILED: + return "CUFFT_EXEC_FAILED"; + + case CUFFT_SETUP_FAILED: + return "CUFFT_SETUP_FAILED"; + + case CUFFT_INVALID_SIZE: + return "CUFFT_INVALID_SIZE"; + + case CUFFT_UNALIGNED_DATA: + return "CUFFT_UNALIGNED_DATA"; + + case CUFFT_INCOMPLETE_PARAMETER_LIST: + return "CUFFT_INCOMPLETE_PARAMETER_LIST"; + + case CUFFT_INVALID_DEVICE: + return "CUFFT_INVALID_DEVICE"; + + case CUFFT_PARSE_ERROR: + return "CUFFT_PARSE_ERROR"; + + case CUFFT_NO_WORKSPACE: + return "CUFFT_NO_WORKSPACE"; + + case CUFFT_NOT_IMPLEMENTED: + return "CUFFT_NOT_IMPLEMENTED"; + + case CUFFT_LICENSE_ERROR: + return "CUFFT_LICENSE_ERROR"; + + case CUFFT_NOT_SUPPORTED: + return "CUFFT_NOT_SUPPORTED"; + } + + return ""; +} +#endif + + +#ifdef CUSPARSEAPI +// cuSPARSE API errors +const char *_cudaGetErrorEnum(cusparseStatus_t error) { + switch (error) { + case CUSPARSE_STATUS_SUCCESS: + return "CUSPARSE_STATUS_SUCCESS"; + + case CUSPARSE_STATUS_NOT_INITIALIZED: + return "CUSPARSE_STATUS_NOT_INITIALIZED"; + + case CUSPARSE_STATUS_ALLOC_FAILED: + return "CUSPARSE_STATUS_ALLOC_FAILED"; + + case CUSPARSE_STATUS_INVALID_VALUE: + return "CUSPARSE_STATUS_INVALID_VALUE"; + + case CUSPARSE_STATUS_ARCH_MISMATCH: + return "CUSPARSE_STATUS_ARCH_MISMATCH"; + + case CUSPARSE_STATUS_MAPPING_ERROR: + return "CUSPARSE_STATUS_MAPPING_ERROR"; + + case CUSPARSE_STATUS_EXECUTION_FAILED: + return "CUSPARSE_STATUS_EXECUTION_FAILED"; + + case CUSPARSE_STATUS_INTERNAL_ERROR: + return "CUSPARSE_STATUS_INTERNAL_ERROR"; + + case CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED: + return "CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED"; + } + + return ""; +} +#endif + + +#ifdef CUSOLVER_COMMON_H_ +// cuSOLVER API errors +const char *_cudaGetErrorEnum(cusolverStatus_t error) { + switch (error) { + case CUSOLVER_STATUS_SUCCESS: + return "CUSOLVER_STATUS_SUCCESS"; + case CUSOLVER_STATUS_NOT_INITIALIZED: + return "CUSOLVER_STATUS_NOT_INITIALIZED"; + case CUSOLVER_STATUS_ALLOC_FAILED: + return "CUSOLVER_STATUS_ALLOC_FAILED"; + case CUSOLVER_STATUS_INVALID_VALUE: + return "CUSOLVER_STATUS_INVALID_VALUE"; + case CUSOLVER_STATUS_ARCH_MISMATCH: + return "CUSOLVER_STATUS_ARCH_MISMATCH"; + case CUSOLVER_STATUS_MAPPING_ERROR: + return "CUSOLVER_STATUS_MAPPING_ERROR"; + case CUSOLVER_STATUS_EXECUTION_FAILED: + return "CUSOLVER_STATUS_EXECUTION_FAILED"; + case CUSOLVER_STATUS_INTERNAL_ERROR: + return "CUSOLVER_STATUS_INTERNAL_ERROR"; + case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED: + return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED"; + case CUSOLVER_STATUS_NOT_SUPPORTED: + return "CUSOLVER_STATUS_NOT_SUPPORTED "; + case CUSOLVER_STATUS_ZERO_PIVOT: + return "CUSOLVER_STATUS_ZERO_PIVOT"; + case CUSOLVER_STATUS_INVALID_LICENSE: + return "CUSOLVER_STATUS_INVALID_LICENSE"; + } + + return ""; +} +#endif + + +#ifdef CURAND_H_ +// cuRAND API errors +const char *_cudaGetErrorEnum(curandStatus_t error) { + switch (error) { + case CURAND_STATUS_SUCCESS: + return "CURAND_STATUS_SUCCESS"; + + case CURAND_STATUS_VERSION_MISMATCH: + return "CURAND_STATUS_VERSION_MISMATCH"; + + case CURAND_STATUS_NOT_INITIALIZED: + return "CURAND_STATUS_NOT_INITIALIZED"; + + case CURAND_STATUS_ALLOCATION_FAILED: + return "CURAND_STATUS_ALLOCATION_FAILED"; + + case CURAND_STATUS_TYPE_ERROR: + return "CURAND_STATUS_TYPE_ERROR"; + + case CURAND_STATUS_OUT_OF_RANGE: + return "CURAND_STATUS_OUT_OF_RANGE"; + + case CURAND_STATUS_LENGTH_NOT_MULTIPLE: + return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; + + case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: + return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; + + case CURAND_STATUS_LAUNCH_FAILURE: + return "CURAND_STATUS_LAUNCH_FAILURE"; + + case CURAND_STATUS_PREEXISTING_FAILURE: + return "CURAND_STATUS_PREEXISTING_FAILURE"; + + case CURAND_STATUS_INITIALIZATION_FAILED: + return "CURAND_STATUS_INITIALIZATION_FAILED"; + + case CURAND_STATUS_ARCH_MISMATCH: + return "CURAND_STATUS_ARCH_MISMATCH"; + + case CURAND_STATUS_INTERNAL_ERROR: + return "CURAND_STATUS_INTERNAL_ERROR"; + } + + return ""; +} +#endif + + +#ifdef NV_NPPIDEFS_H +// NPP API errors +const char *_cudaGetErrorEnum(NppStatus error) { + switch (error) { + case NPP_NOT_SUPPORTED_MODE_ERROR: + return "NPP_NOT_SUPPORTED_MODE_ERROR"; + + case NPP_ROUND_MODE_NOT_SUPPORTED_ERROR: + return "NPP_ROUND_MODE_NOT_SUPPORTED_ERROR"; + + case NPP_RESIZE_NO_OPERATION_ERROR: + return "NPP_RESIZE_NO_OPERATION_ERROR"; + + case NPP_NOT_SUFFICIENT_COMPUTE_CAPABILITY: + return "NPP_NOT_SUFFICIENT_COMPUTE_CAPABILITY"; + +#if ((NPP_VERSION_MAJOR << 12) + (NPP_VERSION_MINOR << 4)) <= 0x5000 + + case NPP_BAD_ARG_ERROR: + return "NPP_BAD_ARGUMENT_ERROR"; + + case NPP_COEFF_ERROR: + return "NPP_COEFFICIENT_ERROR"; + + case NPP_RECT_ERROR: + return "NPP_RECTANGLE_ERROR"; + + case NPP_QUAD_ERROR: + return "NPP_QUADRANGLE_ERROR"; + + case NPP_MEM_ALLOC_ERR: + return "NPP_MEMORY_ALLOCATION_ERROR"; + + case NPP_HISTO_NUMBER_OF_LEVELS_ERROR: + return "NPP_HISTOGRAM_NUMBER_OF_LEVELS_ERROR"; + + case NPP_INVALID_INPUT: + return "NPP_INVALID_INPUT"; + + case NPP_POINTER_ERROR: + return "NPP_POINTER_ERROR"; + + case NPP_WARNING: + return "NPP_WARNING"; + + case NPP_ODD_ROI_WARNING: + return "NPP_ODD_ROI_WARNING"; +#else + + // These are for CUDA 5.5 or higher + case NPP_BAD_ARGUMENT_ERROR: + return "NPP_BAD_ARGUMENT_ERROR"; + + case NPP_COEFFICIENT_ERROR: + return "NPP_COEFFICIENT_ERROR"; + + case NPP_RECTANGLE_ERROR: + return "NPP_RECTANGLE_ERROR"; + + case NPP_QUADRANGLE_ERROR: + return "NPP_QUADRANGLE_ERROR"; + + case NPP_MEMORY_ALLOCATION_ERR: + return "NPP_MEMORY_ALLOCATION_ERROR"; + + case NPP_HISTOGRAM_NUMBER_OF_LEVELS_ERROR: + return "NPP_HISTOGRAM_NUMBER_OF_LEVELS_ERROR"; + + case NPP_INVALID_HOST_POINTER_ERROR: + return "NPP_INVALID_HOST_POINTER_ERROR"; + + case NPP_INVALID_DEVICE_POINTER_ERROR: + return "NPP_INVALID_DEVICE_POINTER_ERROR"; +#endif + + case NPP_LUT_NUMBER_OF_LEVELS_ERROR: + return "NPP_LUT_NUMBER_OF_LEVELS_ERROR"; + + case NPP_TEXTURE_BIND_ERROR: + return "NPP_TEXTURE_BIND_ERROR"; + + case NPP_WRONG_INTERSECTION_ROI_ERROR: + return "NPP_WRONG_INTERSECTION_ROI_ERROR"; + + case NPP_NOT_EVEN_STEP_ERROR: + return "NPP_NOT_EVEN_STEP_ERROR"; + + case NPP_INTERPOLATION_ERROR: + return "NPP_INTERPOLATION_ERROR"; + + case NPP_RESIZE_FACTOR_ERROR: + return "NPP_RESIZE_FACTOR_ERROR"; + + case NPP_HAAR_CLASSIFIER_PIXEL_MATCH_ERROR: + return "NPP_HAAR_CLASSIFIER_PIXEL_MATCH_ERROR"; + +#if ((NPP_VERSION_MAJOR << 12) + (NPP_VERSION_MINOR << 4)) <= 0x5000 + + case NPP_MEMFREE_ERR: + return "NPP_MEMFREE_ERR"; + + case NPP_MEMSET_ERR: + return "NPP_MEMSET_ERR"; + + case NPP_MEMCPY_ERR: + return "NPP_MEMCPY_ERROR"; + + case NPP_MIRROR_FLIP_ERR: + return "NPP_MIRROR_FLIP_ERR"; +#else + + case NPP_MEMFREE_ERROR: + return "NPP_MEMFREE_ERROR"; + + case NPP_MEMSET_ERROR: + return "NPP_MEMSET_ERROR"; + + case NPP_MEMCPY_ERROR: + return "NPP_MEMCPY_ERROR"; + + case NPP_MIRROR_FLIP_ERROR: + return "NPP_MIRROR_FLIP_ERROR"; +#endif + + case NPP_ALIGNMENT_ERROR: + return "NPP_ALIGNMENT_ERROR"; + + case NPP_STEP_ERROR: + return "NPP_STEP_ERROR"; + + case NPP_SIZE_ERROR: + return "NPP_SIZE_ERROR"; + + case NPP_NULL_POINTER_ERROR: + return "NPP_NULL_POINTER_ERROR"; + + case NPP_CUDA_KERNEL_EXECUTION_ERROR: + return "NPP_CUDA_KERNEL_EXECUTION_ERROR"; + + case NPP_NOT_IMPLEMENTED_ERROR: + return "NPP_NOT_IMPLEMENTED_ERROR"; + + case NPP_ERROR: + return "NPP_ERROR"; + + case NPP_SUCCESS: + return "NPP_SUCCESS"; + + case NPP_WRONG_INTERSECTION_QUAD_WARNING: + return "NPP_WRONG_INTERSECTION_QUAD_WARNING"; + + case NPP_MISALIGNED_DST_ROI_WARNING: + return "NPP_MISALIGNED_DST_ROI_WARNING"; + + case NPP_AFFINE_QUAD_INCORRECT_WARNING: + return "NPP_AFFINE_QUAD_INCORRECT_WARNING"; + + case NPP_DOUBLE_SIZE_WARNING: + return "NPP_DOUBLE_SIZE_WARNING"; + + case NPP_WRONG_INTERSECTION_ROI_WARNING: + return "NPP_WRONG_INTERSECTION_ROI_WARNING"; + +#if ((NPP_VERSION_MAJOR << 12) + (NPP_VERSION_MINOR << 4)) >= 0x6000 + /* These are 6.0 or higher */ + case NPP_LUT_PALETTE_BITSIZE_ERROR: + return "NPP_LUT_PALETTE_BITSIZE_ERROR"; + + case NPP_ZC_MODE_NOT_SUPPORTED_ERROR: + return "NPP_ZC_MODE_NOT_SUPPORTED_ERROR"; + + case NPP_QUALITY_INDEX_ERROR: + return "NPP_QUALITY_INDEX_ERROR"; + + case NPP_CHANNEL_ORDER_ERROR: + return "NPP_CHANNEL_ORDER_ERROR"; + + case NPP_ZERO_MASK_VALUE_ERROR: + return "NPP_ZERO_MASK_VALUE_ERROR"; + + case NPP_NUMBER_OF_CHANNELS_ERROR: + return "NPP_NUMBER_OF_CHANNELS_ERROR"; + + case NPP_COI_ERROR: + return "NPP_COI_ERROR"; + + case NPP_DIVISOR_ERROR: + return "NPP_DIVISOR_ERROR"; + + case NPP_CHANNEL_ERROR: + return "NPP_CHANNEL_ERROR"; + + case NPP_STRIDE_ERROR: + return "NPP_STRIDE_ERROR"; + + case NPP_ANCHOR_ERROR: + return "NPP_ANCHOR_ERROR"; + + case NPP_MASK_SIZE_ERROR: + return "NPP_MASK_SIZE_ERROR"; + + case NPP_MOMENT_00_ZERO_ERROR: + return "NPP_MOMENT_00_ZERO_ERROR"; + + case NPP_THRESHOLD_NEGATIVE_LEVEL_ERROR: + return "NPP_THRESHOLD_NEGATIVE_LEVEL_ERROR"; + + case NPP_THRESHOLD_ERROR: + return "NPP_THRESHOLD_ERROR"; + + case NPP_CONTEXT_MATCH_ERROR: + return "NPP_CONTEXT_MATCH_ERROR"; + + case NPP_FFT_FLAG_ERROR: + return "NPP_FFT_FLAG_ERROR"; + + case NPP_FFT_ORDER_ERROR: + return "NPP_FFT_ORDER_ERROR"; + + case NPP_SCALE_RANGE_ERROR: + return "NPP_SCALE_RANGE_ERROR"; + + case NPP_DATA_TYPE_ERROR: + return "NPP_DATA_TYPE_ERROR"; + + case NPP_OUT_OFF_RANGE_ERROR: + return "NPP_OUT_OFF_RANGE_ERROR"; + + case NPP_DIVIDE_BY_ZERO_ERROR: + return "NPP_DIVIDE_BY_ZERO_ERROR"; + + case NPP_RANGE_ERROR: + return "NPP_RANGE_ERROR"; + + case NPP_NO_MEMORY_ERROR: + return "NPP_NO_MEMORY_ERROR"; + + case NPP_ERROR_RESERVED: + return "NPP_ERROR_RESERVED"; + + case NPP_NO_OPERATION_WARNING: + return "NPP_NO_OPERATION_WARNING"; + + case NPP_DIVIDE_BY_ZERO_WARNING: + return "NPP_DIVIDE_BY_ZERO_WARNING"; +#endif + +#if ((NPP_VERSION_MAJOR << 12) + (NPP_VERSION_MINOR << 4)) >= 0x7000 + /* These are 7.0 or higher */ + case NPP_OVERFLOW_ERROR: + return "NPP_OVERFLOW_ERROR"; + + case NPP_CORRUPTED_DATA_ERROR: + return "NPP_CORRUPTED_DATA_ERROR"; +#endif + } + + return ""; +} +#endif \ No newline at end of file diff --git a/python/jittor/extern/llvm/jt_alignment_from_assumptions.cc b/python/jittor/extern/llvm/jt_alignment_from_assumptions.cc new file mode 100644 index 00000000..304dcc8a --- /dev/null +++ b/python/jittor/extern/llvm/jt_alignment_from_assumptions.cc @@ -0,0 +1,416 @@ +//===----------------------- AlignmentFromAssumptions.cpp -----------------===// +// Set Load/Store Alignments From Assumptions +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a ScalarEvolution-based transformation to set +// the alignments of load, stores and memory intrinsics based on the truth +// expressions of assume intrinsics. The primary motivation is to handle +// complex alignment assumptions that apply to vector loads and stores that +// appear after vectorization and unrolling. +// +//===----------------------------------------------------------------------===// + +#define AA_NAME "jt-alignment-from-assumptions" +#define DEBUG_TYPE AA_NAME + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +using namespace llvm; + +STATISTIC(NumLoadAlignChanged, + "Number of loads changed by alignment assumptions"); +STATISTIC(NumStoreAlignChanged, + "Number of stores changed by alignment assumptions"); +STATISTIC(NumMemIntAlignChanged, + "Number of memory intrinsics changed by alignment assumptions"); + +namespace { + +struct AlignmentFromAssumptionsPass + : public PassInfoMixin { + + bool runImpl(Function &F, AssumptionCache &AC, ScalarEvolution *SE_, + DominatorTree *DT_); + + ScalarEvolution *SE = nullptr; + DominatorTree *DT = nullptr; + + bool extractAlignmentInfo(CallInst *I, Value *&AAPtr, const SCEV *&AlignSCEV, + const SCEV *&OffSCEV); + bool processAssumption(CallInst *I); +}; + +struct JittorAlignmentFromAssumptions : public FunctionPass { + static char ID; + JittorAlignmentFromAssumptions() : FunctionPass(ID) {} + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + auto &AC = getAnalysis().getAssumptionCache(F); + ScalarEvolution *SE = &getAnalysis().getSE(); + DominatorTree *DT = &getAnalysis().getDomTree(); + + return Impl.runImpl(F, AC, SE, DT); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.addRequired(); + AU.addRequired(); + + AU.setPreservesCFG(); + AU.addPreserved(); + AU.addPreserved(); + AU.addPreserved(); + AU.addPreserved(); + AU.addPreserved(); + } + + AlignmentFromAssumptionsPass Impl; +}; // end of struct JittorAlignmentFromAssumptions + +// Given an expression for the (constant) alignment, AlignSCEV, and an +// expression for the displacement between a pointer and the aligned address, +// DiffSCEV, compute the alignment of the displaced pointer if it can be reduced +// to a constant. Using SCEV to compute alignment handles the case where +// DiffSCEV is a recurrence with constant start such that the aligned offset +// is constant. e.g. {16,+,32} % 32 -> 16. +static unsigned getNewAlignmentDiff(const SCEV *DiffSCEV, + const SCEV *AlignSCEV, + ScalarEvolution *SE) { + // DiffUnits = Diff % int64_t(Alignment) + const SCEV *DiffUnitsSCEV = SE->getURemExpr(DiffSCEV, AlignSCEV); + + LLVM_DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV << " is " + << *DiffUnitsSCEV << " (diff: " << *DiffSCEV << ")\n"); + + if (const SCEVConstant *ConstDUSCEV = + dyn_cast(DiffUnitsSCEV)) { + int64_t DiffUnits = ConstDUSCEV->getValue()->getSExtValue(); + + // If the displacement is an exact multiple of the alignment, then the + // displaced pointer has the same alignment as the aligned pointer, so + // return the alignment value. + if (!DiffUnits) + return (unsigned) + cast(AlignSCEV)->getValue()->getSExtValue(); + + // If the displacement is not an exact multiple, but the remainder is a + // constant, then return this remainder (but only if it is a power of 2). + uint64_t DiffUnitsAbs = std::abs(DiffUnits); + if (isPowerOf2_64(DiffUnitsAbs)) + return (unsigned) DiffUnitsAbs; + } + + return 0; +} + +// There is an address given by an offset OffSCEV from AASCEV which has an +// alignment AlignSCEV. Use that information, if possible, to compute a new +// alignment for Ptr. +static unsigned getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV, + const SCEV *OffSCEV, Value *Ptr, + ScalarEvolution *SE) { + const SCEV *PtrSCEV = SE->getSCEV(Ptr); + const SCEV *DiffSCEV = SE->getMinusSCEV(PtrSCEV, AASCEV); + + // On 32-bit platforms, DiffSCEV might now have type i32 -- we've always + // sign-extended OffSCEV to i64, so make sure they agree again. + DiffSCEV = SE->getNoopOrSignExtend(DiffSCEV, OffSCEV->getType()); + + // What we really want to know is the overall offset to the aligned + // address. This address is displaced by the provided offset. + DiffSCEV = SE->getMinusSCEV(DiffSCEV, OffSCEV); + + LLVM_DEBUG(dbgs() << "AFI: alignment of " << *Ptr << " relative to " + << *AlignSCEV << " and offset " << *OffSCEV + << " using diff " << *DiffSCEV << "\n"); + + unsigned NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE); + LLVM_DEBUG(dbgs() << "\tnew alignment: " << NewAlignment << "\n"); + + if (NewAlignment) { + return NewAlignment; + } else if (const SCEVAddRecExpr *DiffARSCEV = + dyn_cast(DiffSCEV)) { + // The relative offset to the alignment assumption did not yield a constant, + // but we should try harder: if we assume that a is 32-byte aligned, then in + // for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are + // 32-byte aligned, but instead alternate between 32 and 16-byte alignment. + // As a result, the new alignment will not be a constant, but can still + // be improved over the default (of 4) to 16. + + const SCEV *DiffStartSCEV = DiffARSCEV->getStart(); + const SCEV *DiffIncSCEV = DiffARSCEV->getStepRecurrence(*SE); + + LLVM_DEBUG(dbgs() << "\ttrying start/inc alignment using start " + << *DiffStartSCEV << " and inc " << *DiffIncSCEV << "\n"); + + // Now compute the new alignment using the displacement to the value in the + // first iteration, and also the alignment using the per-iteration delta. + // If these are the same, then use that answer. Otherwise, use the smaller + // one, but only if it divides the larger one. + NewAlignment = getNewAlignmentDiff(DiffStartSCEV, AlignSCEV, SE); + unsigned NewIncAlignment = getNewAlignmentDiff(DiffIncSCEV, AlignSCEV, SE); + + LLVM_DEBUG(dbgs() << "\tnew start alignment: " << NewAlignment << "\n"); + LLVM_DEBUG(dbgs() << "\tnew inc alignment: " << NewIncAlignment << "\n"); + + if (!NewAlignment || !NewIncAlignment) { + return 0; + } else if (NewAlignment > NewIncAlignment) { + if (NewAlignment % NewIncAlignment == 0) { + LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewIncAlignment + << "\n"); + return NewIncAlignment; + } + } else if (NewIncAlignment > NewAlignment) { + if (NewIncAlignment % NewAlignment == 0) { + LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewAlignment + << "\n"); + return NewAlignment; + } + } else if (NewIncAlignment == NewAlignment) { + LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << NewAlignment + << "\n"); + return NewAlignment; + } + } + + return 0; +} + +bool AlignmentFromAssumptionsPass::extractAlignmentInfo(CallInst *I, + Value *&AAPtr, + const SCEV *&AlignSCEV, + const SCEV *&OffSCEV) { + // An alignment assume must be a statement about the least-significant + // bits of the pointer being zero, possibly with some offset. + ICmpInst *ICI = dyn_cast(I->getArgOperand(0)); + if (!ICI) + return false; + + // This must be an expression of the form: x & m == 0. + if (ICI->getPredicate() != ICmpInst::ICMP_EQ) + return false; + + // Swap things around so that the RHS is 0. + Value *CmpLHS = ICI->getOperand(0); + Value *CmpRHS = ICI->getOperand(1); + const SCEV *CmpLHSSCEV = SE->getSCEV(CmpLHS); + const SCEV *CmpRHSSCEV = SE->getSCEV(CmpRHS); + if (CmpLHSSCEV->isZero()) + std::swap(CmpLHS, CmpRHS); + else if (!CmpRHSSCEV->isZero()) + return false; + + BinaryOperator *CmpBO = dyn_cast(CmpLHS); + if (!CmpBO || CmpBO->getOpcode() != Instruction::And) + return false; + + // Swap things around so that the right operand of the and is a constant + // (the mask); we cannot deal with variable masks. + Value *AndLHS = CmpBO->getOperand(0); + Value *AndRHS = CmpBO->getOperand(1); + const SCEV *AndLHSSCEV = SE->getSCEV(AndLHS); + const SCEV *AndRHSSCEV = SE->getSCEV(AndRHS); + if (isa(AndLHSSCEV)) { + std::swap(AndLHS, AndRHS); + std::swap(AndLHSSCEV, AndRHSSCEV); + } + + const SCEVConstant *MaskSCEV = dyn_cast(AndRHSSCEV); + if (!MaskSCEV) + return false; + + // The mask must have some trailing ones (otherwise the condition is + // trivial and tells us nothing about the alignment of the left operand). + unsigned TrailingOnes = MaskSCEV->getAPInt().countTrailingOnes(); + if (!TrailingOnes) + return false; + + // Cap the alignment at the maximum with which LLVM can deal (and make sure + // we don't overflow the shift). + uint64_t Alignment; + TrailingOnes = std::min(TrailingOnes, + unsigned(sizeof(unsigned) * CHAR_BIT - 1)); + Alignment = std::min(1u << TrailingOnes, +Value::MaximumAlignment); + + Type *Int64Ty = Type::getInt64Ty(I->getParent()->getParent()->getContext()); + AlignSCEV = SE->getConstant(Int64Ty, Alignment); + + // The LHS might be a ptrtoint instruction, or it might be the pointer + // with an offset. + AAPtr = nullptr; + OffSCEV = nullptr; + if (PtrToIntInst *PToI = dyn_cast(AndLHS)) { + AAPtr = PToI->getPointerOperand(); + OffSCEV = SE->getZero(Int64Ty); + } else if (const SCEVAddExpr* AndLHSAddSCEV = + dyn_cast(AndLHSSCEV)) { + // Try to find the ptrtoint; subtract it and the rest is the offset. + for (SCEVAddExpr::op_iterator J = AndLHSAddSCEV->op_begin(), + JE = AndLHSAddSCEV->op_end(); J != JE; ++J) + if (const SCEVUnknown *OpUnk = dyn_cast(*J)) + if (PtrToIntInst *PToI = dyn_cast(OpUnk->getValue())) { + AAPtr = PToI->getPointerOperand(); + OffSCEV = SE->getMinusSCEV(AndLHSAddSCEV, *J); + break; + } + } + + if (!AAPtr) + return false; + + // Sign extend the offset to 64 bits (so that it is like all of the other + // expressions). + unsigned OffSCEVBits = OffSCEV->getType()->getPrimitiveSizeInBits(); + if (OffSCEVBits < 64) + OffSCEV = SE->getSignExtendExpr(OffSCEV, Int64Ty); + else if (OffSCEVBits > 64) + return false; + + AAPtr = AAPtr->stripPointerCasts(); + return true; +} + +bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) { + Value *AAPtr; + const SCEV *AlignSCEV, *OffSCEV; + if (!extractAlignmentInfo(ACall, AAPtr, AlignSCEV, OffSCEV)) + return false; + + // Skip ConstantPointerNull and UndefValue. Assumptions on these shouldn't + // affect other users. + if (isa(AAPtr)) + return false; + + const SCEV *AASCEV = SE->getSCEV(AAPtr); + + // Apply the assumption to all other users of the specified pointer. + SmallPtrSet Visited; + SmallVector WorkList; + for (User *J : AAPtr->users()) { + if (J == ACall) + continue; + + if (Instruction *K = dyn_cast(J)) + if (isValidAssumeForContext(ACall, K, DT)) + WorkList.push_back(K); + } + + while (!WorkList.empty()) { + Instruction *J = WorkList.pop_back_val(); + + if (LoadInst *LI = dyn_cast(J)) { + unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, + LI->getPointerOperand(), SE); + + if (NewAlignment > LI->getAlignment()) { + LI->setAlignment(NewAlignment); + ++NumLoadAlignChanged; + } + } else if (StoreInst *SI = dyn_cast(J)) { + unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, + SI->getPointerOperand(), SE); + + if (NewAlignment > SI->getAlignment()) { + SI->setAlignment(NewAlignment); + ++NumStoreAlignChanged; + } + } else if (MemIntrinsic *MI = dyn_cast(J)) { + unsigned NewDestAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, + MI->getDest(), SE); + + LLVM_DEBUG(dbgs() << "\tmem inst: " << NewDestAlignment << "\n";); + if (NewDestAlignment > MI->getDestAlignment()) { + MI->setDestAlignment(NewDestAlignment); + ++NumMemIntAlignChanged; + } + + // For memory transfers, there is also a source alignment that + // can be set. + if (MemTransferInst *MTI = dyn_cast(MI)) { + unsigned NewSrcAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, + MTI->getSource(), SE); + + LLVM_DEBUG(dbgs() << "\tmem trans: " << NewSrcAlignment << "\n";); + + if (NewSrcAlignment > MTI->getSourceAlignment()) { + MTI->setSourceAlignment(NewSrcAlignment); + ++NumMemIntAlignChanged; + } + } + } + + // Now that we've updated that use of the pointer, look for other uses of + // the pointer to update. + Visited.insert(J); + for (User *UJ : J->users()) { + Instruction *K = cast(UJ); + if (!Visited.count(K) && isValidAssumeForContext(ACall, K, DT)) + WorkList.push_back(K); + } + } + + return true; +} + +bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC, + ScalarEvolution *SE_, + DominatorTree *DT_) { + SE = SE_; + DT = DT_; + + bool Changed = false; + for (auto &AssumeVH : AC.assumptions()) + if (AssumeVH) + Changed |= processAssumption(cast(AssumeVH)); + + return Changed; +} + +} // end of anonymous namespace + +char JittorAlignmentFromAssumptions::ID = 0; +static RegisterPass X( + "jt-alignment-from-assumptions", + "Jittor Alignment From Assumptions", + false /* Only looks at CFG */, + false /* Analysis Pass */); + +static RegisterStandardPasses Y( + PassManagerBuilder::EP_OptimizerLast, + [](const PassManagerBuilder &Builder, + legacy::PassManagerBase &PM) { PM.add(new JittorAlignmentFromAssumptions()); }); \ No newline at end of file diff --git a/python/jittor/extern/mkl/ops/cpu_cnn_inference_f32.cpp b/python/jittor/extern/mkl/ops/cpu_cnn_inference_f32.cpp new file mode 100644 index 00000000..4bf1cba4 --- /dev/null +++ b/python/jittor/extern/mkl/ops/cpu_cnn_inference_f32.cpp @@ -0,0 +1,800 @@ +/******************************************************************************* +* Copyright 2016-2019 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/// @example cpu_cnn_inference_f32.cpp +/// @copybrief cpu_cnn_inference_f32_cpp +/// > Annotated version: @ref cpu_cnn_inference_f32_cpp + +/// @page cpu_cnn_inference_f32_cpp CNN f32 inference example +/// This C++ API example demonstrates how to build an AlexNet neural +/// network topology for forward-pass inference. +/// +/// > Example code: @ref cpu_cnn_inference_f32.cpp +/// +/// Some key take-aways include: +/// +/// * How tensors are implemented and submitted to primitives. +/// * How primitives are created. +/// * How primitives are sequentially submitted to the network, where the output +/// from primitives is passed as input to the next primitive. The latter +/// specifies a dependency between the primitive input and output data. +/// * Specific 'inference-only' configurations. +/// * Limiting the number of reorders performed that are detrimental +/// to performance. +/// +/// The example implements the AlexNet layers +/// as numbered primitives (for example, conv1, pool1, conv2). + +#include + +#include +#include +#include +#include +#include +#include + +#include + +using namespace dnnl; + +using namespace std; + +memory::dim product(const memory::dims &dims) { + return std::accumulate(dims.begin(), dims.end(), (memory::dim)1, + std::multiplies()); +} + +void simple_net(int times = 100) { + using tag = memory::format_tag; + using dt = memory::data_type; + +/// Initialize a CPU engine and stream. The last parameter in the call represents +/// the index of the engine. +/// @snippet cpu_cnn_inference_f32.cpp Initialize engine and stream +//[Initialize engine and stream] + engine eng(engine::kind::cpu, 0); + stream s(eng); +//[Initialize engine and stream] + +/// Create a vector for the primitives and a vector to hold memory +/// that will be used as arguments. +/// @snippet cpu_cnn_inference_f32.cpp Create network +//[Create network] + std::vector net; + std::vector> net_args; +//[Create network] + + const memory::dim batch = 1; + + // AlexNet: conv1 + // {batch, 3, 227, 227} (x) {96, 3, 11, 11} -> {batch, 96, 55, 55} + // strides: {4, 4} + memory::dims conv1_src_tz = { batch, 3, 227, 227 }; + memory::dims conv1_weights_tz = { 96, 3, 11, 11 }; + memory::dims conv1_bias_tz = { 96 }; + memory::dims conv1_dst_tz = { batch, 96, 55, 55 }; + memory::dims conv1_strides = { 4, 4 }; + memory::dims conv1_padding = { 0, 0 }; + +/// Allocate buffers for input and output data, weights, and bias. +/// @snippet cpu_cnn_inference_f32.cpp Allocate buffers +//[Allocate buffers] + std::vector user_src(batch * 3 * 227 * 227); + std::vector user_dst(batch * 1000); + std::vector conv1_weights(product(conv1_weights_tz)); + std::vector conv1_bias(product(conv1_bias_tz)); +//[Allocate buffers] + +/// Create memory that describes data layout in the buffers. This example uses +/// tag::nchw (batch-channels-height-width) for input data and tag::oihw +/// for weights. +/// @snippet cpu_cnn_inference_f32.cpp Create user memory +//[Create user memory] + auto user_src_memory = memory( + { { conv1_src_tz }, dt::f32, tag::nchw }, eng, user_src.data()); + auto user_weights_memory + = memory({ { conv1_weights_tz }, dt::f32, tag::oihw }, eng, + conv1_weights.data()); + auto conv1_user_bias_memory = memory( + { { conv1_bias_tz }, dt::f32, tag::x }, eng, conv1_bias.data()); +//[Create user memory] + +/// Create memory descriptors with layout tag::any. The `any` format enables +/// the convolution primitive to choose the data format that will result in +/// best performance based on its input parameters (convolution kernel +/// sizes, strides, padding, and so on). If the resulting format is different +/// from `nchw`, the user data must be transformed to the format required for +/// the convolution (as explained below). +/// @snippet cpu_cnn_inference_f32.cpp Create convolution memory descriptors +//[Create convolution memory descriptors] + auto conv1_src_md = memory::desc({ conv1_src_tz }, dt::f32, tag::any); + auto conv1_bias_md = memory::desc({ conv1_bias_tz }, dt::f32, tag::any); + auto conv1_weights_md + = memory::desc({ conv1_weights_tz }, dt::f32, tag::any); + auto conv1_dst_md = memory::desc({ conv1_dst_tz }, dt::f32, tag::any); +//[Create convolution memory descriptors] + +/// Create a convolution descriptor by specifying propagation kind, +/// [convolution algorithm](@ref dev_guide_convolution), shapes of input, +/// weights, bias, output, convolution strides, padding, and kind of padding. +/// Propagation kind is set to prop_kind::forward_inference to optimize for +/// inference execution and omit computations that are necessary only for +/// backward propagation. +/// @snippet cpu_cnn_inference_f32.cpp Create convolution descriptor +//[Create convolution descriptor] + auto conv1_desc = convolution_forward::desc(prop_kind::forward_inference, + algorithm::convolution_direct, conv1_src_md, conv1_weights_md, conv1_bias_md, + conv1_dst_md, conv1_strides, conv1_padding, conv1_padding); +//[Create convolution descriptor] + +/// Create a convolution primitive descriptor. Once created, this +/// descriptor has specific formats instead of the `any` format specified +/// in the convolution descriptor. +/// @snippet cpu_cnn_inference_f32.cpp Create convolution primitive descriptor +//[Create convolution primitive descriptor] + auto conv1_prim_desc = convolution_forward::primitive_desc(conv1_desc, eng); +//[Create convolution primitive descriptor] + + +/// Check whether data and weights formats required by convolution is different +/// from the user format. In case it is different change the layout using +/// reorder primitive. +/// @snippet cpu_cnn_inference_f32.cpp Reorder data and weights +//[Reorder data and weights] + auto conv1_src_memory = user_src_memory; + if (conv1_prim_desc.src_desc() != user_src_memory.get_desc()) { + conv1_src_memory = memory(conv1_prim_desc.src_desc(), eng); + net.push_back(reorder(user_src_memory, conv1_src_memory)); + net_args.push_back({ { DNNL_ARG_FROM, user_src_memory }, + { DNNL_ARG_TO, conv1_src_memory } }); + } + + auto conv1_weights_memory = user_weights_memory; + if (conv1_prim_desc.weights_desc() != user_weights_memory.get_desc()) { + conv1_weights_memory = memory(conv1_prim_desc.weights_desc(), eng); + reorder(user_weights_memory, conv1_weights_memory) + .execute(s, user_weights_memory, conv1_weights_memory); + } +//[Reorder data and weights] + +/// Create a memory primitive for output. +/// @snippet cpu_cnn_inference_f32.cpp Create memory for output +//[Create memory for output] + auto conv1_dst_memory = memory(conv1_prim_desc.dst_desc(), eng); +//[Create memory for output] + +/// Create a convolution primitive and add it to the net. +/// @snippet cpu_cnn_inference_f32.cpp Create memory for output +//[Create convolution primitive] + net.push_back(convolution_forward(conv1_prim_desc)); + net_args.push_back({ { DNNL_ARG_SRC, conv1_src_memory }, + { DNNL_ARG_WEIGHTS, conv1_weights_memory }, + { DNNL_ARG_BIAS, conv1_user_bias_memory }, + { DNNL_ARG_DST, conv1_dst_memory } }); +//[Create convolution primitive] + + // AlexNet: relu1 + // {batch, 96, 55, 55} -> {batch, 96, 55, 55} + const float negative1_slope = 1.0f; + + +/// Create the relu primitive. For better performance, keep the input data +/// format for ReLU (as well as for other operation primitives until another +/// convolution or inner product is encountered) the same as the one chosen +/// for convolution. Also note that ReLU is done in-place by using conv1 memory. +/// @snippet cpu_cnn_inference_f32.cpp Create relu primitive +//[Create relu primitive] + auto relu1_desc = eltwise_forward::desc(prop_kind::forward_inference, + algorithm::eltwise_relu, conv1_dst_memory.get_desc(), + negative1_slope); + auto relu1_prim_desc = eltwise_forward::primitive_desc(relu1_desc, eng); + + net.push_back(eltwise_forward(relu1_prim_desc)); + net_args.push_back({ { DNNL_ARG_SRC, conv1_dst_memory }, + { DNNL_ARG_DST, conv1_dst_memory } }); +//[Create relu primitive] + + // AlexNet: lrn1 + // {batch, 96, 55, 55} -> {batch, 96, 55, 55} + // local size: 5 + // alpha1: 0.0001 + // beta1: 0.75 + const memory::dim local1_size = 5; + const float alpha1 = 0.0001f; + const float beta1 = 0.75f; + const float k1 = 1.0f; + + // create lrn primitive and add it to net + auto lrn1_desc = lrn_forward::desc(prop_kind::forward_inference, + algorithm::lrn_across_channels, conv1_dst_memory.get_desc(), local1_size, + alpha1, beta1, k1); + auto lrn1_prim_desc = lrn_forward::primitive_desc(lrn1_desc, eng); + auto lrn1_dst_memory = memory(lrn1_prim_desc.dst_desc(), eng); + + net.push_back(lrn_forward(lrn1_prim_desc)); + net_args.push_back({ { DNNL_ARG_SRC, conv1_dst_memory }, + { DNNL_ARG_DST, lrn1_dst_memory } }); + + // AlexNet: pool1 + // {batch, 96, 55, 55} -> {batch, 96, 27, 27} + // kernel: {3, 3} + // strides: {2, 2} + memory::dims pool1_dst_tz = { batch, 96, 27, 27 }; + memory::dims pool1_kernel = { 3, 3 }; + memory::dims pool1_strides = { 2, 2 }; + memory::dims pool_padding = { 0, 0 }; + + auto pool1_dst_md = memory::desc({ pool1_dst_tz }, dt::f32, tag::any); + +/// For training execution, pooling requires a private workspace memory +/// to perform the backward pass. However, pooling should not use 'workspace' +/// for inference, because this is detrimental to performance. +/// @snippet cpu_cnn_inference_f32.cpp Create pooling primitive +/// +/// The example continues to create more layers according +/// to the AlexNet topology. +//[Create pooling primitive] + auto pool1_desc = pooling_forward::desc(prop_kind::forward_inference, + algorithm::pooling_max, lrn1_dst_memory.get_desc(), pool1_dst_md, + pool1_strides, pool1_kernel, pool_padding, pool_padding); + auto pool1_pd = pooling_forward::primitive_desc(pool1_desc, eng); + auto pool1_dst_memory = memory(pool1_pd.dst_desc(), eng); + + net.push_back(pooling_forward(pool1_pd)); + net_args.push_back({ { DNNL_ARG_SRC, lrn1_dst_memory }, + { DNNL_ARG_DST, pool1_dst_memory } }); +//[Create pooling primitive] + + // AlexNet: conv2 + // {batch, 96, 27, 27} (x) {2, 128, 48, 5, 5} -> {batch, 256, 27, 27} + // strides: {1, 1} + memory::dims conv2_src_tz = { batch, 96, 27, 27 }; + memory::dims conv2_weights_tz = { 2, 128, 48, 5, 5 }; + memory::dims conv2_bias_tz = { 256 }; + memory::dims conv2_dst_tz = { batch, 256, 27, 27 }; + memory::dims conv2_strides = { 1, 1 }; + memory::dims conv2_padding = { 2, 2 }; + + std::vector conv2_weights(product(conv2_weights_tz)); + std::vector conv2_bias(product(conv2_bias_tz)); + + // create memory for user data + auto conv2_user_weights_memory + = memory({ { conv2_weights_tz }, dt::f32, tag::goihw }, eng, + conv2_weights.data()); + auto conv2_user_bias_memory = memory( + { { conv2_bias_tz }, dt::f32, tag::x }, eng, conv2_bias.data()); + + // create memory descriptors for convolution data w/ no specified format + auto conv2_src_md = memory::desc({ conv2_src_tz }, dt::f32, tag::any); + auto conv2_bias_md = memory::desc({ conv2_bias_tz }, dt::f32, tag::any); + auto conv2_weights_md + = memory::desc({ conv2_weights_tz }, dt::f32, tag::any); + auto conv2_dst_md = memory::desc({ conv2_dst_tz }, dt::f32, tag::any); + + // create a convolution + auto conv2_desc = convolution_forward::desc(prop_kind::forward_inference, + algorithm::convolution_direct, conv2_src_md, conv2_weights_md, conv2_bias_md, + conv2_dst_md, conv2_strides, conv2_padding, conv2_padding); + auto conv2_prim_desc = convolution_forward::primitive_desc(conv2_desc, eng); + + auto conv2_src_memory = pool1_dst_memory; + if (conv2_prim_desc.src_desc() != conv2_src_memory.get_desc()) { + conv2_src_memory = memory(conv2_prim_desc.src_desc(), eng); + net.push_back(reorder(pool1_dst_memory, conv2_src_memory)); + net_args.push_back({ { DNNL_ARG_FROM, pool1_dst_memory }, + { DNNL_ARG_TO, conv2_src_memory } }); + } + + auto conv2_weights_memory = conv2_user_weights_memory; + if (conv2_prim_desc.weights_desc() + != conv2_user_weights_memory.get_desc()) { + conv2_weights_memory = memory(conv2_prim_desc.weights_desc(), eng); + reorder(conv2_user_weights_memory, conv2_weights_memory) + .execute(s, conv2_user_weights_memory, conv2_weights_memory); + } + + auto conv2_dst_memory = memory(conv2_prim_desc.dst_desc(), eng); + + // create convolution primitive and add it to net + net.push_back(convolution_forward(conv2_prim_desc)); + net_args.push_back({ { DNNL_ARG_SRC, conv2_src_memory }, + { DNNL_ARG_WEIGHTS, conv2_weights_memory }, + { DNNL_ARG_BIAS, conv2_user_bias_memory }, + { DNNL_ARG_DST, conv2_dst_memory } }); + + // AlexNet: relu2 + // {batch, 256, 27, 27} -> {batch, 256, 27, 27} + const float negative2_slope = 1.0f; + + // create relu primitive and add it to net + auto relu2_desc = eltwise_forward::desc(prop_kind::forward_inference, + algorithm::eltwise_relu, conv2_dst_memory.get_desc(), + negative2_slope); + auto relu2_prim_desc = eltwise_forward::primitive_desc(relu2_desc, eng); + + net.push_back(eltwise_forward(relu2_prim_desc)); + net_args.push_back({ { DNNL_ARG_SRC, conv2_dst_memory }, + { DNNL_ARG_DST, conv2_dst_memory } }); + + // AlexNet: lrn2 + // {batch, 256, 27, 27} -> {batch, 256, 27, 27} + // local size: 5 + // alpha2: 0.0001 + // beta2: 0.75 + const memory::dim local2_size = 5; + const float alpha2 = 0.0001f; + const float beta2 = 0.75f; + const float k2 = 1.0f; + + // create lrn primitive and add it to net + auto lrn2_desc = lrn_forward::desc(prop_kind::forward_inference, + algorithm::lrn_across_channels, conv2_prim_desc.dst_desc(), local2_size, + alpha2, beta2, k2); + auto lrn2_prim_desc = lrn_forward::primitive_desc(lrn2_desc, eng); + auto lrn2_dst_memory = memory(lrn2_prim_desc.dst_desc(), eng); + + net.push_back(lrn_forward(lrn2_prim_desc)); + net_args.push_back({ { DNNL_ARG_SRC, conv2_dst_memory }, + { DNNL_ARG_DST, lrn2_dst_memory } }); + + // AlexNet: pool2 + // {batch, 256, 27, 27} -> {batch, 256, 13, 13} + // kernel: {3, 3} + // strides: {2, 2} + memory::dims pool2_dst_tz = { batch, 256, 13, 13 }; + memory::dims pool2_kernel = { 3, 3 }; + memory::dims pool2_strides = { 2, 2 }; + memory::dims pool2_padding = { 0, 0 }; + + auto pool2_dst_md = memory::desc({ pool2_dst_tz }, dt::f32, tag::any); + + // create a pooling + auto pool2_desc = pooling_forward::desc(prop_kind::forward_inference, + algorithm::pooling_max, lrn2_dst_memory.get_desc(), pool2_dst_md, + pool2_strides, pool2_kernel, pool2_padding, pool2_padding); + auto pool2_pd = pooling_forward::primitive_desc(pool2_desc, eng); + auto pool2_dst_memory = memory(pool2_pd.dst_desc(), eng); + + // create pooling primitive an add it to net + net.push_back(pooling_forward(pool2_pd)); + net_args.push_back({ { DNNL_ARG_SRC, lrn2_dst_memory }, + { DNNL_ARG_DST, pool2_dst_memory } }); + + // AlexNet: conv3 + // {batch, 256, 13, 13} (x) {384, 256, 3, 3}; -> {batch, 384, 13, 13}; + // strides: {1, 1} + memory::dims conv3_src_tz = { batch, 256, 13, 13 }; + memory::dims conv3_weights_tz = { 384, 256, 3, 3 }; + memory::dims conv3_bias_tz = { 384 }; + memory::dims conv3_dst_tz = { batch, 384, 13, 13 }; + memory::dims conv3_strides = { 1, 1 }; + memory::dims conv3_padding = { 1, 1 }; + + std::vector conv3_weights(product(conv3_weights_tz)); + std::vector conv3_bias(product(conv3_bias_tz)); + + // create memory for user data + auto conv3_user_weights_memory + = memory({ { conv3_weights_tz }, dt::f32, tag::oihw }, eng, + conv3_weights.data()); + auto conv3_user_bias_memory = memory( + { { conv3_bias_tz }, dt::f32, tag::x }, eng, conv3_bias.data()); + + // create memory descriptors for convolution data w/ no specified format + auto conv3_src_md = memory::desc({ conv3_src_tz }, dt::f32, tag::any); + auto conv3_bias_md = memory::desc({ conv3_bias_tz }, dt::f32, tag::any); + auto conv3_weights_md + = memory::desc({ conv3_weights_tz }, dt::f32, tag::any); + auto conv3_dst_md = memory::desc({ conv3_dst_tz }, dt::f32, tag::any); + + // create a convolution + auto conv3_desc = convolution_forward::desc(prop_kind::forward_inference, + algorithm::convolution_direct, conv3_src_md, conv3_weights_md, conv3_bias_md, + conv3_dst_md, conv3_strides, conv3_padding, conv3_padding); + auto conv3_prim_desc = convolution_forward::primitive_desc(conv3_desc, eng); + + auto conv3_src_memory = pool2_dst_memory; + if (conv3_prim_desc.src_desc() != conv3_src_memory.get_desc()) { + conv3_src_memory = memory(conv3_prim_desc.src_desc(), eng); + net.push_back(reorder(pool2_dst_memory, conv3_src_memory)); + net_args.push_back({ { DNNL_ARG_FROM, pool2_dst_memory }, + { DNNL_ARG_TO, conv3_src_memory } }); + } + + auto conv3_weights_memory = conv3_user_weights_memory; + if (conv3_prim_desc.weights_desc() + != conv3_user_weights_memory.get_desc()) { + conv3_weights_memory = memory(conv3_prim_desc.weights_desc(), eng); + reorder(conv3_user_weights_memory, conv3_weights_memory) + .execute(s, conv3_user_weights_memory, conv3_weights_memory); + } + + auto conv3_dst_memory = memory(conv3_prim_desc.dst_desc(), eng); + + // create convolution primitive and add it to net + net.push_back(convolution_forward(conv3_prim_desc)); + net_args.push_back({ { DNNL_ARG_SRC, conv3_src_memory }, + { DNNL_ARG_WEIGHTS, conv3_weights_memory }, + { DNNL_ARG_BIAS, conv3_user_bias_memory }, + { DNNL_ARG_DST, conv3_dst_memory } }); + + // AlexNet: relu3 + // {batch, 384, 13, 13} -> {batch, 384, 13, 13} + const float negative3_slope = 1.0f; + + // create relu primitive and add it to net + auto relu3_desc = eltwise_forward::desc(prop_kind::forward_inference, + algorithm::eltwise_relu, conv3_dst_memory.get_desc(), + negative3_slope); + auto relu3_prim_desc = eltwise_forward::primitive_desc(relu3_desc, eng); + + net.push_back(eltwise_forward(relu3_prim_desc)); + net_args.push_back({ { DNNL_ARG_SRC, conv3_dst_memory }, + { DNNL_ARG_DST, conv3_dst_memory } }); + + // AlexNet: conv4 + // {batch, 384, 13, 13} (x) {2, 192, 192, 3, 3}; -> + // {batch, 384, 13, 13}; + // strides: {1, 1} + memory::dims conv4_src_tz = { batch, 384, 13, 13 }; + memory::dims conv4_weights_tz = { 2, 192, 192, 3, 3 }; + memory::dims conv4_bias_tz = { 384 }; + memory::dims conv4_dst_tz = { batch, 384, 13, 13 }; + memory::dims conv4_strides = { 1, 1 }; + memory::dims conv4_padding = { 1, 1 }; + + std::vector conv4_weights(product(conv4_weights_tz)); + std::vector conv4_bias(product(conv4_bias_tz)); + + // create memory for user data + auto conv4_user_weights_memory + = memory({ { conv4_weights_tz }, dt::f32, tag::goihw }, eng, + conv4_weights.data()); + auto conv4_user_bias_memory = memory( + { { conv4_bias_tz }, dt::f32, tag::x }, eng, conv4_bias.data()); + + // create memory descriptors for convolution data w/ no specified format + auto conv4_src_md = memory::desc({ conv4_src_tz }, dt::f32, tag::any); + auto conv4_bias_md = memory::desc({ conv4_bias_tz }, dt::f32, tag::any); + auto conv4_weights_md + = memory::desc({ conv4_weights_tz }, dt::f32, tag::any); + auto conv4_dst_md = memory::desc({ conv4_dst_tz }, dt::f32, tag::any); + + // create a convolution + auto conv4_desc = convolution_forward::desc(prop_kind::forward_inference, + algorithm::convolution_direct, conv4_src_md, conv4_weights_md, conv4_bias_md, + conv4_dst_md, conv4_strides, conv4_padding, conv4_padding); + auto conv4_prim_desc = convolution_forward::primitive_desc(conv4_desc, eng); + + auto conv4_src_memory = conv3_dst_memory; + if (conv4_prim_desc.src_desc() != conv4_src_memory.get_desc()) { + conv4_src_memory = memory(conv4_prim_desc.src_desc(), eng); + net.push_back(reorder(conv3_dst_memory, conv4_src_memory)); + net_args.push_back({ { DNNL_ARG_FROM, conv3_dst_memory }, + { DNNL_ARG_TO, conv4_src_memory } }); + } + + auto conv4_weights_memory = conv4_user_weights_memory; + if (conv4_prim_desc.weights_desc() + != conv4_user_weights_memory.get_desc()) { + conv4_weights_memory = memory(conv4_prim_desc.weights_desc(), eng); + reorder(conv4_user_weights_memory, conv4_weights_memory) + .execute(s, conv4_user_weights_memory, conv4_weights_memory); + } + + auto conv4_dst_memory = memory(conv4_prim_desc.dst_desc(), eng); + + // create convolution primitive and add it to net + net.push_back(convolution_forward(conv4_prim_desc)); + net_args.push_back({ { DNNL_ARG_SRC, conv4_src_memory }, + { DNNL_ARG_WEIGHTS, conv4_weights_memory }, + { DNNL_ARG_BIAS, conv4_user_bias_memory }, + { DNNL_ARG_DST, conv4_dst_memory } }); + + // AlexNet: relu4 + // {batch, 384, 13, 13} -> {batch, 384, 13, 13} + const float negative4_slope = 1.0f; + + // create relu primitive and add it to net + auto relu4_desc = eltwise_forward::desc(prop_kind::forward_inference, + algorithm::eltwise_relu, conv4_dst_memory.get_desc(), + negative4_slope); + auto relu4_prim_desc = eltwise_forward::primitive_desc(relu4_desc, eng); + + net.push_back(eltwise_forward(relu4_prim_desc)); + net_args.push_back({ { DNNL_ARG_SRC, conv4_dst_memory }, + { DNNL_ARG_DST, conv4_dst_memory } }); + + // AlexNet: conv5 + // {batch, 384, 13, 13} (x) {2, 128, 192, 3, 3}; -> {batch, 256, 13, 13}; + // strides: {1, 1} + memory::dims conv5_src_tz = { batch, 384, 13, 13 }; + memory::dims conv5_weights_tz = { 2, 128, 192, 3, 3 }; + memory::dims conv5_bias_tz = { 256 }; + memory::dims conv5_dst_tz = { batch, 256, 13, 13 }; + memory::dims conv5_strides = { 1, 1 }; + memory::dims conv5_padding = { 1, 1 }; + + std::vector conv5_weights(product(conv5_weights_tz)); + std::vector conv5_bias(product(conv5_bias_tz)); + + // create memory for user data + auto conv5_user_weights_memory + = memory({ { conv5_weights_tz }, dt::f32, tag::goihw }, eng, + conv5_weights.data()); + auto conv5_user_bias_memory = memory( + { { conv5_bias_tz }, dt::f32, tag::x }, eng, conv5_bias.data()); + + // create memory descriptors for convolution data w/ no specified format + auto conv5_src_md = memory::desc({ conv5_src_tz }, dt::f32, tag::any); + auto conv5_weights_md + = memory::desc({ conv5_weights_tz }, dt::f32, tag::any); + auto conv5_bias_md = memory::desc({ conv5_bias_tz }, dt::f32, tag::any); + auto conv5_dst_md = memory::desc({ conv5_dst_tz }, dt::f32, tag::any); + + // create a convolution + auto conv5_desc = convolution_forward::desc(prop_kind::forward_inference, + algorithm::convolution_direct, conv5_src_md, conv5_weights_md, conv5_bias_md, + conv5_dst_md, conv5_strides, conv5_padding, conv5_padding); + auto conv5_prim_desc = convolution_forward::primitive_desc(conv5_desc, eng); + + auto conv5_src_memory = conv4_dst_memory; + if (conv5_prim_desc.src_desc() != conv5_src_memory.get_desc()) { + conv5_src_memory = memory(conv5_prim_desc.src_desc(), eng); + net.push_back(reorder(conv4_dst_memory, conv5_src_memory)); + net_args.push_back({ { DNNL_ARG_FROM, conv4_dst_memory }, + { DNNL_ARG_TO, conv5_src_memory } }); + } + + auto conv5_weights_memory = conv5_user_weights_memory; + if (conv5_prim_desc.weights_desc() + != conv5_user_weights_memory.get_desc()) { + conv5_weights_memory = memory(conv5_prim_desc.weights_desc(), eng); + reorder(conv5_user_weights_memory, conv5_weights_memory) + .execute(s, conv5_user_weights_memory, conv5_weights_memory); + } + + auto conv5_dst_memory = memory(conv5_prim_desc.dst_desc(), eng); + + // create convolution primitive and add it to net + net.push_back(convolution_forward(conv5_prim_desc)); + net_args.push_back({ { DNNL_ARG_SRC, conv5_src_memory }, + { DNNL_ARG_WEIGHTS, conv5_weights_memory }, + { DNNL_ARG_BIAS, conv5_user_bias_memory }, + { DNNL_ARG_DST, conv5_dst_memory } }); + + // AlexNet: relu5 + // {batch, 256, 13, 13} -> {batch, 256, 13, 13} + const float negative5_slope = 1.0f; + + // create relu primitive and add it to net + auto relu5_desc = eltwise_forward::desc(prop_kind::forward_inference, + algorithm::eltwise_relu, conv5_dst_memory.get_desc(), + negative5_slope); + auto relu5_prim_desc = eltwise_forward::primitive_desc(relu5_desc, eng); + + net.push_back(eltwise_forward(relu5_prim_desc)); + net_args.push_back({ { DNNL_ARG_SRC, conv5_dst_memory }, + { DNNL_ARG_DST, conv5_dst_memory } }); + + // AlexNet: pool5 + // {batch, 256, 13, 13} -> {batch, 256, 6, 6} + // kernel: {3, 3} + // strides: {2, 2} + memory::dims pool5_dst_tz = { batch, 256, 6, 6 }; + memory::dims pool5_kernel = { 3, 3 }; + memory::dims pool5_strides = { 2, 2 }; + memory::dims pool5_padding = { 0, 0 }; + + std::vector pool5_dst(product(pool5_dst_tz)); + + auto pool5_dst_md = memory::desc({ pool5_dst_tz }, dt::f32, tag::any); + + // create a pooling + auto pool5_desc = pooling_forward::desc(prop_kind::forward_inference, + algorithm::pooling_max, conv5_dst_memory.get_desc(), pool5_dst_md, + pool5_strides, pool5_kernel, pool5_padding, pool5_padding); + auto pool5_pd = pooling_forward::primitive_desc(pool5_desc, eng); + + auto pool5_dst_memory = memory(pool5_pd.dst_desc(), eng); + + // create pooling primitive an add it to net + net.push_back(pooling_forward(pool5_pd)); + net_args.push_back({ { DNNL_ARG_SRC, conv5_dst_memory }, + { DNNL_ARG_DST, pool5_dst_memory } }); + + + // fc6 inner product {batch, 256, 6, 6} (x) {4096, 256, 6, 6}-> {batch, + // 4096} + memory::dims fc6_src_tz = { batch, 256, 6, 6 }; + memory::dims fc6_weights_tz = { 4096, 256, 6, 6 }; + memory::dims fc6_bias_tz = { 4096 }; + memory::dims fc6_dst_tz = { batch, 4096 }; + + std::vector fc6_weights(product(fc6_weights_tz)); + std::vector fc6_bias(product(fc6_bias_tz)); + + // create memory for user data + auto fc6_user_weights_memory + = memory({ { fc6_weights_tz }, dt::f32, tag::oihw }, eng, + fc6_weights.data()); + auto fc6_user_bias_memory = memory( + { { fc6_bias_tz }, dt::f32, tag::x }, eng, fc6_bias.data()); + + // create memory descriptors for convolution data w/ no specified format + auto fc6_src_md = memory::desc({ fc6_src_tz }, dt::f32, tag::any); + auto fc6_bias_md = memory::desc({ fc6_bias_tz }, dt::f32, tag::any); + auto fc6_weights_md = memory::desc({ fc6_weights_tz }, dt::f32, tag::any); + auto fc6_dst_md = memory::desc({ fc6_dst_tz }, dt::f32, tag::any); + + // create a inner_product + auto fc6_desc = inner_product_forward::desc(prop_kind::forward_inference, + fc6_src_md, fc6_weights_md, fc6_bias_md, fc6_dst_md); + auto fc6_prim_desc = inner_product_forward::primitive_desc(fc6_desc, eng); + + auto fc6_src_memory = pool5_dst_memory; + if (fc6_prim_desc.src_desc() != fc6_src_memory.get_desc()) { + fc6_src_memory = memory(fc6_prim_desc.src_desc(), eng); + net.push_back(reorder(pool5_dst_memory, fc6_src_memory)); + net_args.push_back({ { DNNL_ARG_FROM, pool5_dst_memory }, + { DNNL_ARG_TO, fc6_src_memory } }); + } + + auto fc6_weights_memory = fc6_user_weights_memory; + if (fc6_prim_desc.weights_desc() != fc6_user_weights_memory.get_desc()) { + fc6_weights_memory = memory(fc6_prim_desc.weights_desc(), eng); + reorder(fc6_user_weights_memory, fc6_weights_memory) + .execute(s, fc6_user_weights_memory, fc6_weights_memory); + } + + auto fc6_dst_memory = memory(fc6_prim_desc.dst_desc(), eng); + + // create convolution primitive and add it to net + net.push_back(inner_product_forward(fc6_prim_desc)); + net_args.push_back({ { DNNL_ARG_SRC, fc6_src_memory }, + { DNNL_ARG_WEIGHTS, fc6_weights_memory }, + { DNNL_ARG_BIAS, fc6_user_bias_memory }, + { DNNL_ARG_DST, fc6_dst_memory } }); + + + // fc7 inner product {batch, 4096} (x) {4096, 4096}-> {batch, 4096} + memory::dims fc7_weights_tz = { 4096, 4096 }; + memory::dims fc7_bias_tz = { 4096 }; + memory::dims fc7_dst_tz = { batch, 4096 }; + + std::vector fc7_weights(product(fc7_weights_tz)); + std::vector fc7_bias(product(fc7_bias_tz)); + + // create memory for user data + auto fc7_user_weights_memory = memory( + { { fc7_weights_tz }, dt::f32, tag::nc }, eng, fc7_weights.data()); + + auto fc7_user_bias_memory = memory( + { { fc7_bias_tz }, dt::f32, tag::x }, eng, fc7_bias.data()); + + // create memory descriptors for convolution data w/ no specified format + auto fc7_bias_md = memory::desc({ fc7_bias_tz }, dt::f32, tag::any); + auto fc7_weights_md = memory::desc({ fc7_weights_tz }, dt::f32, tag::any); + auto fc7_dst_md = memory::desc({ fc7_dst_tz }, dt::f32, tag::any); + + // create a inner_product + auto fc7_desc = inner_product_forward::desc(prop_kind::forward_inference, + fc6_dst_memory.get_desc(), fc7_weights_md, fc7_bias_md, fc7_dst_md); + auto fc7_prim_desc = inner_product_forward::primitive_desc(fc7_desc, eng); + + auto fc7_weights_memory = fc7_user_weights_memory; + if (fc7_prim_desc.weights_desc() != fc7_user_weights_memory.get_desc()) { + fc7_weights_memory = memory(fc7_prim_desc.weights_desc(), eng); + reorder(fc7_user_weights_memory, fc7_weights_memory) + .execute(s, fc7_user_weights_memory, fc7_weights_memory); + } + + auto fc7_dst_memory = memory(fc7_prim_desc.dst_desc(), eng); + + // create convolution primitive and add it to net + net.push_back(inner_product_forward(fc7_prim_desc)); + net_args.push_back({ { DNNL_ARG_SRC, fc6_dst_memory }, + { DNNL_ARG_WEIGHTS, fc7_weights_memory }, + { DNNL_ARG_BIAS, fc7_user_bias_memory }, + { DNNL_ARG_DST, fc7_dst_memory } }); + + // fc8 inner product {batch, 4096} (x) {1000, 4096}-> {batch, 1000} + memory::dims fc8_weights_tz = { 1000, 4096 }; + memory::dims fc8_bias_tz = { 1000 }; + memory::dims fc8_dst_tz = { batch, 1000 }; + + std::vector fc8_weights(product(fc8_weights_tz)); + std::vector fc8_bias(product(fc8_bias_tz)); + + // create memory for user data + auto fc8_user_weights_memory = memory( + { { fc8_weights_tz }, dt::f32, tag::nc }, eng, fc8_weights.data()); + auto fc8_user_bias_memory = memory( + { { fc8_bias_tz }, dt::f32, tag::x }, eng, fc8_bias.data()); + auto user_dst_memory = memory( + { { fc8_dst_tz }, dt::f32, tag::nc }, eng, user_dst.data()); + + // create memory descriptors for convolution data w/ no specified format + auto fc8_bias_md = memory::desc({ fc8_bias_tz }, dt::f32, tag::any); + auto fc8_weights_md = memory::desc({ fc8_weights_tz }, dt::f32, tag::any); + auto fc8_dst_md = memory::desc({ fc8_dst_tz }, dt::f32, tag::any); + + // create a inner_product + auto fc8_desc = inner_product_forward::desc(prop_kind::forward_inference, + fc7_dst_memory.get_desc(), fc8_weights_md, fc8_bias_md, fc8_dst_md); + auto fc8_prim_desc = inner_product_forward::primitive_desc(fc8_desc, eng); + + auto fc8_weights_memory = fc8_user_weights_memory; + if (fc8_prim_desc.weights_desc() != fc8_user_weights_memory.get_desc()) { + fc8_weights_memory = memory(fc8_prim_desc.weights_desc(), eng); + reorder(fc8_user_weights_memory, fc8_weights_memory) + .execute(s, fc8_user_weights_memory, fc8_weights_memory); + } + + auto fc8_dst_memory = memory(fc8_prim_desc.dst_desc(), eng); + + // create convolution primitive and add it to net + net.push_back(inner_product_forward(fc8_prim_desc)); + net_args.push_back({ { DNNL_ARG_SRC, fc7_dst_memory }, + { DNNL_ARG_WEIGHTS, fc8_weights_memory }, + { DNNL_ARG_BIAS, fc8_user_bias_memory }, + { DNNL_ARG_DST, fc8_dst_memory } }); + + // create reorder between internal and user data if it is needed and + // add it to net after pooling + if (fc8_dst_memory != user_dst_memory) { + net.push_back(reorder(fc8_dst_memory, user_dst_memory)); + net_args.push_back({ { DNNL_ARG_FROM, fc8_dst_memory }, + { DNNL_ARG_TO, user_dst_memory } }); + } + +/// @page cpu_cnn_inference_f32_cpp +/// Finally, execute the primitives. For this example, the net is executed +/// multiple times and each execution is timed individually. +/// @snippet cpu_cnn_inference_f32.cpp Execute model +//[Execute model] + for (int j = 0; j < times; ++j) { + assert(net.size() == net_args.size() && "something is missing"); + for (size_t i = 0; i < net.size(); ++i) + net.at(i).execute(s, net_args.at(i)); + } +//[Execute model] + + s.wait(); +} + +// extern int mkl_test_entry(); + +int mkl_test_entry() { + try { + auto begin = chrono::duration_cast( + chrono::steady_clock::now().time_since_epoch()) + .count(); + int times = 100; + simple_net(times); + auto end = chrono::duration_cast( + chrono::steady_clock::now().time_since_epoch()) + .count(); + cout << "Use time " << (end - begin) / (times + 0.0) << "\n"; + } catch (error &e) { + std::cerr << "status: " << e.status << std::endl; + std::cerr << "message: " << e.message << std::endl; + return 1; + } + return 0; +} diff --git a/python/jittor/extern/mkl/ops/mkl_conv_backward_w_op.cc b/python/jittor/extern/mkl/ops/mkl_conv_backward_w_op.cc new file mode 100644 index 00000000..3996662f --- /dev/null +++ b/python/jittor/extern/mkl/ops/mkl_conv_backward_w_op.cc @@ -0,0 +1,215 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guowei Yang <471184555@qq.com> +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include + +#include "var.h" +#include "mkl_conv_backward_w_op.h" + +#include + +using namespace dnnl; +using namespace std; + +namespace jittor { +static inline int findc(const string& format, const char& c) { + if (c==format[0]) return 0; + if (c==format[1]) return 1; + if (c==format[2]) return 2; + ASSERT(c==format[3]) << "Not a valid format" << format << c; + return 3; +} + +#ifndef JIT +static inline void get_shape(Var* x, const char* f, const string& format, int& a, int& b, int &c, int& d) { + auto& shape = x->shape; + a = shape[findc(format, f[0])]; + b = shape[findc(format, f[1])]; + c = shape[findc(format, f[2])]; + d = shape[findc(format, f[3])]; +} + +static inline void set_shape(Var* x, const char* f, const string& format, int a, int b, int c, int d) { + int64 shape[4]; + shape[findc(format, f[0])] = a; + shape[findc(format, f[1])] = b; + shape[findc(format, f[2])] = c; + shape[findc(format, f[3])] = d; + x->set_shape(NanoVector( + shape[0], shape[1], shape[2], shape[3])); +} + +MklConvBackwardWOp::MklConvBackwardWOp(Var* x, Var* dy, int kh, int kw, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups, string xformat, string wformat, string yformat) + : x(x), dy(dy), kh(kh), kw(kw), strideh(strideh), stridew(stridew), paddingh(paddingh), paddingw(paddingw), dilationh(dilationh), dilationw(dilationw), groups(groups), + xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) { + dw = create_output(nullptr, dtype_infer(dy->ns, x->ns)); +} + +void MklConvBackwardWOp::infer_shape() { + ASSERTop(x->shape.size(),==,4); + ASSERTop(dy->shape.size(),==,4); + int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw; + get_shape(x, "abcd", xformat, xn, xc, xh, xw); + get_shape(dy, "abcd", yformat, yn, yc, yh, yw); + wco = yc, wci = xc / groups; + wh = kh; + ww = kw; + set_shape(dw, "oihw", wformat, wco, wci, wh, ww); +} + +static const char* short_type(Var* x) { + if (x->is_float()) { + if (x->dsize()==4) return "f32"; + if (x->dsize()==8) return "f64"; + if (x->dsize()==2) return "f16"; + return "f8"; + } else { + if (x->dsize()==4) return "s32"; + if (x->dsize()==8) return "s64"; + if (x->dsize()==2) return "s16"; + return "s8"; + } +} + +void MklConvBackwardWOp::jit_prepare(JK& jk) { + jk << "«Txd:" << x->dtype(); + jk << "«Tyd:" << dy->dtype(); + jk << "«Twd:" << dw->dtype(); + jk << "«Tx:" << short_type(x); + jk << "«Tw:" << short_type(dw); + jk << "«Ty:" << short_type(dy); + jk << "«XFORMAT:" << xformat; + jk << "«WFORMAT:" << wformat; + jk << "«YFORMAT:" << yformat; +} + +#else // JIT +#ifdef JIT_cpu +void MklConvBackwardWOp::jit_run() { + int batch = x->shape[findc("@XFORMAT",'a')]; + int ch_in = x->shape[findc("@XFORMAT",'b')]; + int height = x->shape[findc("@XFORMAT",'c')]; + int width = x->shape[findc("@XFORMAT",'d')]; + int ch_out = dw->shape[findc("@WFORMAT",'o')]; + int kh = dw->shape[findc("@WFORMAT",'h')]; + int kw = dw->shape[findc("@WFORMAT",'w')]; + + auto* __restrict__ net_src = x->ptr(); + auto* __restrict__ net_diff_dst = dy->ptr(); + auto* __restrict__ conv_user_diff_weights_buffer = dw->ptr(); + + using tag = memory::format_tag; + using dt = memory::data_type; + + auto eng = engine(engine::kind::cpu, 0); + stream s(eng); + + std::vector net_bwd; + std::vector> net_bwd_args; + + memory::dims conv_src_tz = {batch, ch_in, height, width}; + memory::dims conv_weights_tz = groups>1 + ? memory::dims{groups, ch_out/groups, ch_in/groups, kh, kw} + : memory::dims{ch_out, ch_in, kh, kw}; + memory::dims conv_dst_tz = {batch, ch_out, (height+paddingh*2-kh*dilationh+dilationh-1)/strideh+1, (width+paddingw*2-kw*dilationw+dilationw-1)/stridew+1}; + memory::dims conv_strides = {strideh, stridew}; + memory::dims conv_padding = {paddingh, paddingw}; + memory::dims conv_dilation = {dilationh-1, dilationw-1}; + + if (groups>1) ASSERT(tag::@WFORMAT == tag::oihw); + + auto conv_user_src_memory + = memory({{conv_src_tz}, dt::@Tx, tag::@XFORMAT}, eng, net_src); + + auto conv_src_md = memory::desc({conv_src_tz}, dt::@Tx, tag::any); + auto conv_weights_md = memory::desc({conv_weights_tz}, dt::@Tw, tag::any); + auto conv_dst_md = memory::desc({conv_dst_tz}, dt::@Ty, tag::any); + + auto conv_desc = convolution_forward::desc(prop_kind::forward, + algorithm::convolution_direct, conv_src_md, conv_weights_md, + conv_dst_md, conv_strides, conv_dilation, conv_padding, + conv_padding); + auto conv_pd = convolution_forward::primitive_desc(conv_desc, eng); + + auto conv_src_memory = conv_user_src_memory; + if (conv_pd.src_desc() != conv_user_src_memory.get_desc()) { + conv_src_memory = memory(conv_pd.src_desc(), eng); + net_bwd.push_back(reorder(conv_user_src_memory, conv_src_memory)); + net_bwd_args.push_back({{DNNL_ARG_FROM, conv_user_src_memory}, + {DNNL_ARG_TO, conv_src_memory}}); + } + + auto conv_user_diff_dst_memory + = memory({{conv_dst_tz}, dt::@Ty, tag::YFORMAT}, eng, net_diff_dst); + + auto conv_user_diff_weights_memory + = memory({{conv_weights_tz}, dt::@Tw, groups>1 ? tag::goihw : tag::@WFORMAT}, eng, conv_user_diff_weights_buffer); + + auto conv_bwd_src_md = memory::desc({conv_src_tz}, dt::@Tx, tag::any); + auto conv_diff_weights_md + = memory::desc({conv_weights_tz}, dt::@Tw, tag::any); + auto conv_diff_dst_md = memory::desc({conv_dst_tz}, dt::@Ty, tag::any); + + auto conv_bwd_weights_desc + = convolution_backward_weights::desc(algorithm::convolution_direct, + conv_bwd_src_md, conv_diff_weights_md, + conv_diff_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding); + auto conv_bwd_weights_pd = convolution_backward_weights::primitive_desc( + conv_bwd_weights_desc, eng, conv_pd); + + auto conv_bwd_src_memory = conv_src_memory; + if (conv_bwd_weights_pd.src_desc() != conv_src_memory.get_desc()) { + conv_bwd_src_memory = memory(conv_bwd_weights_pd.src_desc(), eng); + net_bwd.push_back(reorder(conv_src_memory, conv_bwd_src_memory)); + net_bwd_args.push_back({{DNNL_ARG_FROM, conv_src_memory}, + {DNNL_ARG_TO, conv_bwd_src_memory}}); + } + + auto conv_diff_dst_memory = conv_user_diff_dst_memory; + if (conv_bwd_weights_pd.diff_dst_desc() + != conv_user_diff_dst_memory.get_desc()) { + conv_diff_dst_memory = memory(conv_bwd_weights_pd.diff_dst_desc(), eng); + net_bwd.push_back(reorder(conv_user_diff_dst_memory, conv_diff_dst_memory)); + net_bwd_args.push_back({{DNNL_ARG_FROM, conv_user_diff_dst_memory}, + {DNNL_ARG_TO, conv_diff_dst_memory}}); + } + + net_bwd.push_back(convolution_backward_weights(conv_bwd_weights_pd)); + net_bwd_args.push_back({{DNNL_ARG_SRC, conv_bwd_src_memory}, + {DNNL_ARG_DIFF_DST, conv_diff_dst_memory}}); + + auto conv_diff_weights_memory = conv_user_diff_weights_memory; + if (conv_bwd_weights_pd.diff_weights_desc() + != conv_user_diff_weights_memory.get_desc()) { + conv_diff_weights_memory + = memory(conv_bwd_weights_pd.diff_weights_desc(), eng); + net_bwd_args.back().insert( + {DNNL_ARG_DIFF_WEIGHTS, conv_diff_weights_memory}); + + net_bwd.push_back(reorder( + conv_diff_weights_memory, conv_user_diff_weights_memory)); + net_bwd_args.push_back({{DNNL_ARG_FROM, conv_diff_weights_memory}, + {DNNL_ARG_TO, conv_user_diff_weights_memory}}); + } else { + net_bwd_args.back().insert( + {DNNL_ARG_DIFF_WEIGHTS, conv_diff_weights_memory}); + } + + ASSERTop(net_bwd.size(),==,net_bwd_args.size()); + + for (size_t i = 0; i < net_bwd.size(); ++i) + net_bwd.at(i).execute(s, net_bwd_args.at(i)); + + s.wait(); +} +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/mkl/ops/mkl_conv_backward_w_op.h b/python/jittor/extern/mkl/ops/mkl_conv_backward_w_op.h new file mode 100644 index 00000000..1b18cc56 --- /dev/null +++ b/python/jittor/extern/mkl/ops/mkl_conv_backward_w_op.h @@ -0,0 +1,27 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guowei Yang <471184555@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct MklConvBackwardWOp : Op { + Var* x, * dy, * dw; + int kh, kw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; + string xformat, wformat, yformat; + + MklConvBackwardWOp(Var* x, Var* y, int kh, int kw, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd"); + + const char* name() const override { return "mkl_conv_backward_w"; } + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor diff --git a/python/jittor/extern/mkl/ops/mkl_conv_backward_x_op.cc b/python/jittor/extern/mkl/ops/mkl_conv_backward_x_op.cc new file mode 100644 index 00000000..5ddee140 --- /dev/null +++ b/python/jittor/extern/mkl/ops/mkl_conv_backward_x_op.cc @@ -0,0 +1,211 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guowei Yang <471184555@qq.com> +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include + +#include "var.h" +#include "mkl_conv_backward_x_op.h" + +#include + +using namespace dnnl; +using namespace std; + +namespace jittor { +static inline int findc(const string& format, const char& c) { + if (c==format[0]) return 0; + if (c==format[1]) return 1; + if (c==format[2]) return 2; + ASSERT(c==format[3]) << "Not a valid format" << format << c; + return 3; +} + +#ifndef JIT +static inline void get_shape(Var* x, const char* f, const string& format, int& a, int& b, int &c, int& d) { + auto& shape = x->shape; + a = shape[findc(format, f[0])]; + b = shape[findc(format, f[1])]; + c = shape[findc(format, f[2])]; + d = shape[findc(format, f[3])]; +} + +static inline void set_shape(Var* x, const char* f, const string& format, int a, int b, int c, int d) { + int64 shape[4]; + shape[findc(format, f[0])] = a; + shape[findc(format, f[1])] = b; + shape[findc(format, f[2])] = c; + shape[findc(format, f[3])] = d; + x->set_shape(NanoVector( + shape[0], shape[1], shape[2], shape[3])); +} + +MklConvBackwardXOp::MklConvBackwardXOp(Var* w, Var* dy, int height, int width, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups, string xformat, string wformat, string yformat) + : w(w), dy(dy), xh(height), xw(width), strideh(strideh), stridew(stridew), paddingh(paddingh), paddingw(paddingw), dilationh(dilationh), dilationw(dilationw), groups(groups), + xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) { + dx = create_output(nullptr, dtype_infer(dy->ns, w->ns)); +} + +void MklConvBackwardXOp::infer_shape() { + ASSERTop(w->shape.size(),==,4); + ASSERTop(dy->shape.size(),==,4); + int xn, xc, wh, ww, wci, wco, yn, yc, yh, yw; + get_shape(w, "oihw", wformat, wco, wci, wh, ww); + get_shape(dy, "abcd", yformat, yn, yc, yh, yw); + xn = yn, xc = wci * groups; + set_shape(dx, "abcd", xformat, xn, xc, xh, xw); +} + +static const char* short_type(Var* x) { + if (x->is_float()) { + if (x->dsize()==4) return "f32"; + if (x->dsize()==8) return "f64"; + if (x->dsize()==2) return "f16"; + return "f8"; + } else { + if (x->dsize()==4) return "s32"; + if (x->dsize()==8) return "s64"; + if (x->dsize()==2) return "s16"; + return "s8"; + } +} + +void MklConvBackwardXOp::jit_prepare(JK& jk) { + jk << "«Tyd:" << dy->dtype(); + jk << "«Twd:" << w->dtype(); + jk << "«Txd:" << dx->dtype(); + jk << "«Tx:" << short_type(dx); + jk << "«Tw:" << short_type(w); + jk << "«Ty:" << short_type(dy); + jk << "«XFORMAT:" << xformat; + jk << "«WFORMAT:" << wformat; + jk << "«YFORMAT:" << yformat; +} + +#else // JIT +#ifdef JIT_cpu +void MklConvBackwardXOp::jit_run() { + int batch = dx->shape[findc("@XFORMAT",'a')]; + int ch_in = dx->shape[findc("@XFORMAT",'b')]; + int height = dx->shape[findc("@XFORMAT",'c')]; + int width = dx->shape[findc("@XFORMAT",'d')]; + int ch_out = w->shape[findc("@WFORMAT",'o')]; + int kernel_sizeh = w->shape[findc("@WFORMAT",'h')]; + int kernel_sizew = w->shape[findc("@WFORMAT",'w')]; + + auto* __restrict__ conv_weights = w->ptr(); + auto* __restrict__ net_diff_dst = dy->ptr(); + auto* __restrict__ conv_user_diff_src_buffer = dx->ptr(); + + using tag = memory::format_tag; + using dt = memory::data_type; + + auto eng = engine(engine::kind::cpu, 0); + stream s(eng); + + std::vector net_bwd; + std::vector> net_bwd_args; + + memory::dims conv_src_tz = {batch, ch_in, height, width}; + memory::dims conv_weights_tz = groups>1 + ? memory::dims{groups, ch_out/groups, ch_in/groups, kernel_sizeh, kernel_sizew} + : memory::dims{ch_out, ch_in, kernel_sizeh, kernel_sizew}; + memory::dims conv_dst_tz = {batch, ch_out, (height+paddingh*2-kernel_sizeh*dilationh+dilationh-1)/strideh+1, (width+paddingw*2-kernel_sizew*dilationw+dilationw-1)/stridew+1}; + memory::dims conv_strides = {strideh, stridew}; + memory::dims conv_padding = {paddingh, paddingw}; + memory::dims conv_dilation = {dilationh-1, dilationw-1}; + + if (groups>1) ASSERT(tag::@WFORMAT == tag::oihw); + + auto conv_user_weights_memory + = memory({{conv_weights_tz}, dt::@Tw, groups>1 ? tag::goihw : tag::@WFORMAT}, eng, conv_weights); + + auto conv_src_md = memory::desc({conv_src_tz}, dt::@Tx, tag::any); + auto conv_weights_md = memory::desc({conv_weights_tz}, dt::@Tw, tag::any); + auto conv_dst_md = memory::desc({conv_dst_tz}, dt::@Ty, tag::any); + + auto conv_desc = convolution_forward::desc(prop_kind::forward, + algorithm::convolution_direct, conv_src_md, conv_weights_md, + conv_dst_md, conv_strides, conv_dilation, conv_padding, + conv_padding); + auto conv_pd = convolution_forward::primitive_desc(conv_desc, eng); + + auto conv_weights_memory = conv_user_weights_memory; + if (conv_pd.weights_desc() != conv_user_weights_memory.get_desc()) { + conv_weights_memory = memory(conv_pd.weights_desc(), eng); + net_bwd.push_back( + reorder(conv_user_weights_memory, conv_weights_memory)); + net_bwd_args.push_back({{DNNL_ARG_FROM, conv_user_weights_memory}, + {DNNL_ARG_TO, conv_weights_memory}}); + } + + auto conv_user_diff_dst_memory + = memory({{conv_dst_tz}, dt::@Ty, tag::@YFORMAT}, eng, net_diff_dst); + + auto conv_user_diff_src_memory + = memory({{conv_src_tz}, dt::@Tx, tag::@XFORMAT}, eng, conv_user_diff_src_buffer); + + auto conv_bwd_weights_md + = memory::desc({conv_weights_tz}, dt::@Tw, tag::any); + auto conv_diff_src_md = memory::desc({conv_src_tz}, dt::@Tx, tag::any); + auto conv_diff_dst_md = memory::desc({conv_dst_tz}, dt::@Ty, tag::any); + + auto conv_bwd_data_desc + = convolution_backward_data::desc(algorithm::convolution_direct, + conv_diff_src_md, conv_bwd_weights_md, conv_diff_dst_md, conv_strides, conv_dilation, conv_padding, conv_padding); + auto conv_bwd_data_pd = convolution_backward_data::primitive_desc( + conv_bwd_data_desc, eng, conv_pd); + + auto conv_diff_dst_memory = conv_user_diff_dst_memory; + if (conv_bwd_data_pd.diff_dst_desc() + != conv_user_diff_dst_memory.get_desc()) { + conv_diff_dst_memory = memory(conv_bwd_data_pd.diff_dst_desc(), eng); + net_bwd.push_back(reorder(conv_user_diff_dst_memory, conv_diff_dst_memory)); + net_bwd_args.push_back({{DNNL_ARG_FROM, conv_user_diff_dst_memory}, + {DNNL_ARG_TO, conv_diff_dst_memory}}); + } + + auto conv_bwd_weights_memory = conv_weights_memory; + if (conv_bwd_data_pd.weights_desc() != conv_weights_memory.get_desc()) { + conv_bwd_weights_memory = memory(conv_bwd_data_pd.weights_desc(), eng); + net_bwd.push_back(reorder(conv_weights_memory, conv_bwd_weights_memory)); + net_bwd_args.push_back({{DNNL_ARG_FROM, conv_weights_memory}, + {DNNL_ARG_TO, conv_bwd_weights_memory}}); + } + + net_bwd.push_back(convolution_backward_data(conv_bwd_data_pd)); + net_bwd_args.push_back({{DNNL_ARG_WEIGHTS, conv_bwd_weights_memory}, + {DNNL_ARG_DIFF_DST, conv_diff_dst_memory}}); + + auto conv_diff_src_memory = conv_user_diff_src_memory; + if (conv_bwd_data_pd.diff_src_desc() + != conv_user_diff_src_memory.get_desc()) { + conv_diff_src_memory + = memory(conv_bwd_data_pd.diff_src_desc(), eng); + net_bwd_args.back().insert( + {DNNL_ARG_DIFF_SRC, conv_diff_src_memory}); + + net_bwd.push_back(reorder( + conv_diff_src_memory, conv_user_diff_src_memory)); + net_bwd_args.push_back({{DNNL_ARG_FROM, conv_diff_src_memory}, + {DNNL_ARG_TO, conv_user_diff_src_memory}}); + } else { + net_bwd_args.back().insert( + {DNNL_ARG_DIFF_SRC, conv_diff_src_memory}); + } + + ASSERTop(net_bwd.size(),==,net_bwd_args.size()); + + for (size_t i = 0; i < net_bwd.size(); ++i) + net_bwd.at(i).execute(s, net_bwd_args.at(i)); +} +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/mkl/ops/mkl_conv_backward_x_op.h b/python/jittor/extern/mkl/ops/mkl_conv_backward_x_op.h new file mode 100644 index 00000000..0c8a2adf --- /dev/null +++ b/python/jittor/extern/mkl/ops/mkl_conv_backward_x_op.h @@ -0,0 +1,27 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guowei Yang <471184555@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct MklConvBackwardXOp : Op { + Var* w, * dy, * dx; + int xh, xw, strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; + string xformat, wformat, yformat; + + MklConvBackwardXOp(Var* w, Var* y, int height, int width, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups=1, string xformat="abcd", string wformat="oihw", string yformat="abcd"); + + const char* name() const override { return "mkl_conv_backward_x"; } + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/mkl/ops/mkl_conv_op.cc b/python/jittor/extern/mkl/ops/mkl_conv_op.cc new file mode 100644 index 00000000..7b450bd5 --- /dev/null +++ b/python/jittor/extern/mkl/ops/mkl_conv_op.cc @@ -0,0 +1,194 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guowei Yang <471184555@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include + +#include "var.h" +#include "mkl_conv_op.h" + +using namespace dnnl; +using namespace std; + +namespace jittor { + +static inline int findc(const string& format, const char& c) { + if (c==format[0]) return 0; + if (c==format[1]) return 1; + if (c==format[2]) return 2; + ASSERT(c==format[3]) << "Not a valid format" << format << c; + return 3; +} + +static inline void get_shape(Var* x, const char* f, const string& format, int& a, int& b, int &c, int& d) { + auto& shape = x->shape; + a = shape[findc(format, f[0])]; + b = shape[findc(format, f[1])]; + c = shape[findc(format, f[2])]; + d = shape[findc(format, f[3])]; +} + +#ifndef JIT + +static inline void set_shape(Var* x, const char* f, const string& format, int a, int b, int c, int d) { + int64 shape[4]; + shape[findc(format, f[0])] = a; + shape[findc(format, f[1])] = b; + shape[findc(format, f[2])] = c; + shape[findc(format, f[3])] = d; + x->set_shape(NanoVector( + shape[0], shape[1], shape[2], shape[3])); +} + +MklConvOp::MklConvOp(Var* x, Var* w, int strideh, int stridew, int paddingh, int paddingw, int dilationh, int dilationw, int groups, string xformat, string wformat, string yformat) + : x(x), w(w), strideh(strideh), stridew(stridew), paddingh(paddingh), paddingw(paddingw), dilationh(dilationh), dilationw(dilationw), groups(groups), + xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) { + y = create_output(nullptr, dtype_infer(x->ns, w->ns)); + if (!this->yformat.size()) + this->yformat = this->xformat; +} + +void MklConvOp::infer_shape() { + ASSERTop(x->shape.size(),==,4); + ASSERTop(w->shape.size(),==,4); + int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw; + get_shape(x, "abcd", xformat, xn, xc, xh, xw); + get_shape(w, "oihw", wformat, wco, wci, wh, ww); + ASSERTop(wci * groups,==,xc); + yn = xn, yc = wco; + yh = (xh+paddingh*2-wh*dilationh+dilationh-1)/strideh+1; + yw = (xw+paddingw*2-ww*dilationw+dilationw-1)/stridew+1; + set_shape(y, "abcd", yformat, yn, yc, yh, yw); +} + +static const char* short_type(Var* x) { + if (x->is_float()) { + if (x->dsize()==4) return "f32"; + if (x->dsize()==8) return "f64"; + if (x->dsize()==2) return "f16"; + return "f8"; + } else { + if (x->dsize()==4) return "s32"; + if (x->dsize()==8) return "s64"; + if (x->dsize()==2) return "s16"; + return "s8"; + } +} + +void MklConvOp::jit_prepare(JK& jk) { + jk << "«Txd:" << x->dtype(); + jk << "«Tyd:" << y->dtype(); + jk << "«Twd:" << w->dtype(); + jk << "«Tx:" << short_type(x); + jk << "«Tw:" << short_type(w); + jk << "«Ty:" << short_type(y); + jk << "«XFORMAT:" << xformat; + jk << "«WFORMAT:" << wformat; + jk << "«YFORMAT:" << yformat; +} + +#else // JIT +#ifdef JIT_cpu +#pragma clang diagnostic ignored "-Wtautological-compare" +void MklConvOp::jit_run() { + const auto& xs = x->shape; + const auto& ws = w->shape; + + using tag = memory::format_tag; + using dt = memory::data_type; + + if (tag::@XFORMAT==tag::nhwc && tag::@YFORMAT==tag::nhwc && tag::@WFORMAT==tag::hwio + && strideh==1 && stridew==1 && paddingh==0 && paddingw==0 && dilationh==1 && dilationw==1 && ws[0]==1 && ws[1]==1 + && dt::@Tx==dt::f32 && dt::@Ty==dt::f32 && dt::@Tw==dt::f32) { + auto m = xs[0]*xs[1]*xs[2]; + auto n = ws[3]; + auto k = xs[3]; + // x: [m,k], w: [k,n], y: [m,n] + ASSERTop(0,==,dnnl_sgemm('N', 'N', m, n, k, + 1.f, x->ptr(), k, + w->ptr(), n, + 0.f, y->ptr(), n)); + return; + } + + engine eng(engine::kind::cpu, 0); + stream s(eng); + + std::vector net; + std::vector> net_args; + + int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw; + get_shape(x, "abcd", xformat, xn, xc, xh, xw); + get_shape(w, "oihw", wformat, wco, wci, wh, ww); + get_shape(y, "abcd", yformat, yn, yc, yh, yw); + + memory::dims conv1_src_tz = {xn, xc, xh, xw}; + memory::dims conv1_weights_tz = groups>1 + ? memory::dims{groups, wco/groups, wci, wh, ww} + : memory::dims{wco, wci, wh, ww}; + memory::dims conv1_dst_tz = {yn, yc, yh, yw}; + memory::dims conv1_strides = { strideh, stridew }; + memory::dims conv1_padding = { paddingh, paddingw }; + memory::dims conv1_dilation = { dilationh-1, dilationw-1 }; + + if (groups>1) ASSERT(tag::@WFORMAT == tag::oihw); + + auto user_src_memory = memory( + { { conv1_src_tz }, dt::@Tx, tag::@XFORMAT }, eng, x->mem_ptr); + auto user_dst_memory = memory( + { { conv1_dst_tz }, dt::@Ty, tag::@YFORMAT }, eng, y->mem_ptr); + auto user_weights_memory = memory( + { { conv1_weights_tz }, dt::@Tw, groups>1 ? tag::goihw : tag::@WFORMAT }, eng, w->mem_ptr); + + auto conv1_src_md = memory::desc({ conv1_src_tz }, dt::@Tx, tag::any); + auto conv1_weights_md + = memory::desc({ conv1_weights_tz }, dt::@Tw, tag::any); + auto conv1_dst_md = memory::desc({ conv1_dst_tz }, dt::@Ty, tag::any); + + auto conv1_desc = convolution_forward::desc(prop_kind::forward_inference, + algorithm::convolution_auto, conv1_src_md, conv1_weights_md, conv1_dst_md, conv1_strides, conv1_dilation, conv1_padding, conv1_padding); + + auto conv1_prim_desc = convolution_forward::primitive_desc(conv1_desc, eng); + + net.clear(); + net_args.clear(); + auto conv1_src_memory = user_src_memory; + if (conv1_prim_desc.src_desc() != user_src_memory.get_desc()) { + conv1_src_memory = memory(conv1_prim_desc.src_desc(), eng); + net.push_back(reorder(user_src_memory, conv1_src_memory)); + net_args.push_back({ { DNNL_ARG_FROM, user_src_memory }, + { DNNL_ARG_TO, conv1_src_memory } }); + } + + auto conv1_weights_memory = user_weights_memory; + if (conv1_prim_desc.weights_desc() != user_weights_memory.get_desc()) { + conv1_weights_memory = memory(conv1_prim_desc.weights_desc(), eng); + net.push_back(reorder(user_weights_memory, conv1_weights_memory)); + net_args.push_back({ { DNNL_ARG_FROM, user_weights_memory }, { DNNL_ARG_TO, conv1_weights_memory } }); + } + + auto conv1_dst_memory = memory(conv1_prim_desc.dst_desc(), eng); + + net.push_back(convolution_forward(conv1_prim_desc)); + net_args.push_back({ { DNNL_ARG_SRC, conv1_src_memory }, + { DNNL_ARG_WEIGHTS, conv1_weights_memory }, + { DNNL_ARG_DST, conv1_dst_memory } }); + + if (conv1_dst_memory != user_dst_memory) { + net.push_back(reorder(conv1_dst_memory, user_dst_memory)); + net_args.push_back({ { DNNL_ARG_FROM, conv1_dst_memory },{ DNNL_ARG_TO, user_dst_memory } }); + } + + ASSERTop(net.size(),==,net_args.size()); + for (size_t i = 0; i < net.size(); ++i) + net.at(i).execute(s, net_args.at(i)); +} +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/mkl/ops/mkl_conv_op.h b/python/jittor/extern/mkl/ops/mkl_conv_op.h new file mode 100644 index 00000000..28a1ee3b --- /dev/null +++ b/python/jittor/extern/mkl/ops/mkl_conv_op.h @@ -0,0 +1,27 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guowei Yang <471184555@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct MklConvOp : Op { + Var* x, * w, * y; + int strideh, stridew, paddingh, paddingw, dilationh, dilationw, groups; + string xformat, wformat, yformat; + /* MklConvOp: xformat abcd represents nchw */ + MklConvOp(Var* x, Var* w, int strideh, int stridew, int paddingh, int paddingw, int dilationh=1, int dilationw=1, int groups=1, string xformat="abcd", string wformat="oihw", string yformat=""); + + const char* name() const override { return "mkl_conv"; } + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/mkl/ops/mkl_matmul_op.cc b/python/jittor/extern/mkl/ops/mkl_matmul_op.cc new file mode 100644 index 00000000..dbd059e3 --- /dev/null +++ b/python/jittor/extern/mkl/ops/mkl_matmul_op.cc @@ -0,0 +1,77 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guowei Yang <471184555@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include + +#include "var.h" +#include "mkl_matmul_op.h" + +using namespace dnnl; +using namespace std; + +namespace jittor { + +#ifndef JIT + +MklMatmulOp::MklMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b) + : a(a), b(b), trans_a(trans_a), trans_b(trans_b) { + // TODO: support int8 * int8 + ASSERT(a->dtype().is_float() && b->dtype().is_float()) << "type of two inputs should be the same"; + // TODO: support diffrent input type + ASSERT(a->dtype().dsize() == 4 && b->dtype().dsize() == 4) << "support float32 only now."; + c = create_output(nullptr, a->dtype()); +} + +void MklMatmulOp::infer_shape() { + ASSERTop(a->shape.size(),==,2); + ASSERTop(b->shape.size(),==,2); + int n = a->shape[0], m = a->shape[1]; + int m_ = b->shape[0], k = b->shape[1]; + if (trans_a) { + swap(n, m); + } + if (trans_b) { + swap(m_, k); + } + ASSERTop(m,==,m_); + c->set_shape({n, k}); +} + +void MklMatmulOp::jit_prepare(JK& jk) { + jk << "«T:" << a->dtype(); + jk << "«Trans_a:" << (trans_a ? 'T' : 'N'); + jk << "«Trans_b:" << (trans_b ? 'T' : 'N'); +} + +#else // JIT +#ifdef JIT_cpu +#pragma clang diagnostic ignored "-Wtautological-compare" +void MklMatmulOp::jit_run() { + const auto& as = a->shape; + const auto& bs = b->shape; + auto n = as[0]; + auto m = as[1]; + auto k = bs[1]; + if ('@Trans_a'=='T') { + n = as[1]; + m = as[0]; + } + if ('@Trans_b'=='T') { + k = bs[0]; + } + // a: [n,m], b: [m,k], c: [n,k] + ASSERTop(0,==,dnnl_sgemm('@Trans_a', '@Trans_b', n, k, m, + 1.f, a->ptr(), '@Trans_a'=='N'? m : n, + b->ptr(), '@Trans_b' == 'N' ? k : m, + 0.f, c->ptr(), k)); +} +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/mkl/ops/mkl_matmul_op.h b/python/jittor/extern/mkl/ops/mkl_matmul_op.h new file mode 100644 index 00000000..ce854d29 --- /dev/null +++ b/python/jittor/extern/mkl/ops/mkl_matmul_op.h @@ -0,0 +1,25 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct MklMatmulOp : Op { + Var* a, * b, * c; + bool trans_a, trans_b; + MklMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b); + + const char* name() const override { return "mkl_matmul"; } + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/mkl/ops/mkl_test_op.cc b/python/jittor/extern/mkl/ops/mkl_test_op.cc new file mode 100644 index 00000000..2ea68395 --- /dev/null +++ b/python/jittor/extern/mkl/ops/mkl_test_op.cc @@ -0,0 +1,34 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include + +#include "var.h" +#include "mkl_test_op.h" + +int mkl_test_entry(); + +namespace jittor { + +#ifndef JIT +MklTestOp::MklTestOp() { + output = create_output(1, ns_float32); +} + +void MklTestOp::jit_prepare(JK& jk) { + jk << "«T:float32"; +} + +#else // JIT +#ifdef JIT_cpu +void MklTestOp::jit_run() { + ASSERT(mkl_test_entry()==0); + output->ptr()[0] = 123; +} +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/mkl/ops/mkl_test_op.h b/python/jittor/extern/mkl/ops/mkl_test_op.h new file mode 100644 index 00000000..e1879c0e --- /dev/null +++ b/python/jittor/extern/mkl/ops/mkl_test_op.h @@ -0,0 +1,20 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct MklTestOp : Op { + Var* output; + MklTestOp(); + + const char* name() const override { return "mkl_test"; } + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/mpi/inc/mpi_wrapper.h b/python/jittor/extern/mpi/inc/mpi_wrapper.h new file mode 100644 index 00000000..dbc3d4de --- /dev/null +++ b/python/jittor/extern/mpi/inc/mpi_wrapper.h @@ -0,0 +1,92 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. +// All Rights Reserved. +// Maintainers: +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#define OMPI_SKIP_MPICXX +#include +#include +#include "var_holder.h" + +extern void throw_mpi_error(int result, + char const *const func, const char *const file, int const line); + +static inline void mpi_check(int result, + char const *const func, const char *const file, int const line) { + if (result != MPI_SUCCESS) { + throw_mpi_error(result, func, file, line); + } +} + +#define MPI_CHECK(val) mpi_check((val), #val, __FILE__, __LINE__) + +namespace jittor { + +EXTERN_LIB int mpi_world_size; +EXTERN_LIB int mpi_world_rank; +EXTERN_LIB int mpi_local_size; +EXTERN_LIB int mpi_local_rank; +EXTERN_LIB bool inside_mpi; +EXTERN_LIB bool mpi_enabled; +EXTERN_LIB bool use_device_mpi; + +/** +Return number of MPI nodes. +*/ +// @pyjt(world_size) +int _mpi_world_size(); + +/** +Return global ID of this MPI node. +*/ +// @pyjt(world_rank) +int _mpi_world_rank(); + +/** +Return local ID of this MPI node. +*/ +// @pyjt(local_rank) +int _mpi_local_rank(); + +/** + Set MPI state, enable or disable, if disabled, all mpi operators + have no affect. +*/ +// @pyjt(set_state) +inline void _mpi_set_state(bool enable) { mpi_enabled = enable; } + +/** + Get MPI state, enable or disable. +*/ +// @pyjt(get_state) +inline int _mpi_get_state() { return mpi_enabled; } + +struct ArrayArgs; + +/** + +Use jt.Module.mpi_param_broadcast(root=0) to broadcast all moudule parameters of this module in [root] MPI node to all MPI nodes. + +This operation has no gradient, and the input parameter type is numpy array. +*/ +// @pyjt(broadcast) +void _mpi_broadcast(ArrayArgs&& args, int root); + +// @pyjt(var_broadcast) +void var_broadcast(VarHolder* x, int root=0); + +// @pyjt(var_reduce) +void var_reduce(VarHolder* x, int root=0); + +// @pyjt(var_all_reduce) +void var_all_reduce(VarHolder* x); + +// @pyjt(mpi_barrier) +void mpi_barrier(); + +} // jittor diff --git a/python/jittor/extern/mpi/ops/mpi_all_reduce_op.cc b/python/jittor/extern/mpi/ops/mpi_all_reduce_op.cc new file mode 100644 index 00000000..2ffc3bf8 --- /dev/null +++ b/python/jittor/extern/mpi/ops/mpi_all_reduce_op.cc @@ -0,0 +1,89 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Guowei Yang <471184555@qq.com>. +// Dun Liang . +// All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "mpi_wrapper.h" +#include "var.h" +#include "mpi_all_reduce_op.h" +#include "ops/op_register.h" +#include "utils/str_utils.h" +#include "misc/cuda_flags.h" + +namespace jittor { + +#ifndef JIT + +static auto make_array = get_op_info("array") + .get_constructor(); +static auto make_binary = get_op_info("binary") + .get_constructor(); +static auto make_mpi_all_reduce = get_op_info("mpi_all_reduce") + .get_constructor(); + +MpiAllReduceOp::MpiAllReduceOp(Var* x, NanoString op) : x(x), op(op) { + if (!mpi_enabled) { + forward(x); + return; + } + if (op == ns_mean) { + auto var = make_mpi_all_reduce(x, ns_add); + var = make_binary(var, make_array(&mpi_world_size, 1, ns_int32), ns_divide); + forward(var); + return; + } + ASSERT(op == ns_add) << "Not supported MPI op" << op; + #ifdef HAS_CUDA + if (use_device_mpi && use_cuda) { + static auto nccl_all_reduce = has_op("nccl_all_reduce") + ? get_op_info("nccl_all_reduce").get_constructor() + : nullptr; + if (nccl_all_reduce) { + auto var = nccl_all_reduce(x); + forward(var); + return; + } + } + #endif + y = create_output(nullptr, x->dtype()); +} + +void MpiAllReduceOp::infer_shape() { + y->set_shape(x->shape); +} + +VarPtr MpiAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) { + static auto mpi_all_reduce = + get_op_info("mpi_all_reduce").get_constructor(); + return mpi_all_reduce(dout, ns_add); +} + +void MpiAllReduceOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype(); + jk << "«OP:" << op; +} + +#else // JIT +#ifdef JIT_cpu +void MpiAllReduceOp::jit_run() { + @define(T_MPI, + @if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, MPI_FLOAT) + @if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, MPI_INT) + @if(@strcmp(@Tx,float64)==0 || @strcmp(@Tx,double)==0, MPI_DOUBLE) + @if(@strcmp(@Tx,int64)==0, MPI_DOUBLE_INT) + ) + @define(OP_MPI, + @if(@strcmp(@OP,add)==0, MPI_SUM) + ) + auto* __restrict__ xp = x->ptr(); + auto* __restrict__ yp = y->ptr(); + index_t num = y->num; + MPI_Allreduce(xp, yp, num, T_MPI, OP_MPI, MPI_COMM_WORLD); +} +#endif // JIT_cpu +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/mpi/ops/mpi_all_reduce_op.h b/python/jittor/extern/mpi/ops/mpi_all_reduce_op.h new file mode 100644 index 00000000..60d5c62b --- /dev/null +++ b/python/jittor/extern/mpi/ops/mpi_all_reduce_op.h @@ -0,0 +1,35 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Guowei Yang <471184555@qq.com>. +// Dun Liang . +// All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct MpiAllReduceOp : Op { + Var* x, * y; + NanoString op; + + /** + + Mpi All Reduce Operator uses the operator [op] to reduce variable [x] in all MPI nodes and broadcast to all MPI nodes. + + Args: + + * x: variable to be all reduced. + * op: 'sum' or 'add' means sum all [x], 'mean' means average all [x]. Default: 'add'. + */ + MpiAllReduceOp(Var* x, NanoString op=ns_add); + void infer_shape() override; + + const char* name() const override { return "mpi_all_reduce"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/mpi/ops/mpi_broadcast_op.cc b/python/jittor/extern/mpi/ops/mpi_broadcast_op.cc new file mode 100644 index 00000000..af8c1895 --- /dev/null +++ b/python/jittor/extern/mpi/ops/mpi_broadcast_op.cc @@ -0,0 +1,71 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Guowei Yang <471184555@qq.com>. +// Dun Liang . +// All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "mpi_wrapper.h" +#include "var.h" +#include "mpi_broadcast_op.h" +#include "ops/op_register.h" +#include "utils/str_utils.h" +#include "misc/cuda_flags.h" + +namespace jittor { + +#ifndef JIT +MpiBroadcastOp::MpiBroadcastOp(Var* x, int root) : x(x), root(root) { + if (!mpi_enabled) { + forward(x); + return; + } + #ifdef HAS_CUDA + if (use_device_mpi && use_cuda) { + static auto nccl_broadcast = has_op("nccl_broadcast") + ? get_op_info("nccl_broadcast").get_constructor() + : nullptr; + if (nccl_broadcast) { + auto var = nccl_broadcast(x, root); + forward(var); + return; + } + } + #endif + y = create_output(nullptr, x->dtype()); +} + +void MpiBroadcastOp::infer_shape() { + y->set_shape(x->shape); + if (root == mpi_world_rank) + y->share_with(x); +} + +VarPtr MpiBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) { + static auto mpi_reduce = + get_op_info("mpi_reduce").get_constructor(); + return mpi_reduce(dout, ns_add, root); +} + +void MpiBroadcastOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype(); +} + +#else // JIT +#ifdef JIT_cpu +void MpiBroadcastOp::jit_run() { + @define(T_MPI, + @if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, MPI_FLOAT) + @if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, MPI_INT) + @if(@strcmp(@Tx,float64)==0 || @strcmp(@Tx,double)==0, MPI_DOUBLE) + @if(@strcmp(@Tx,int64)==0, MPI_DOUBLE_INT) + @if(@strcmp(@Tx,uint8)==0, MPI_CHAR) + ) + auto* __restrict__ yp = y->ptr(); + MPI_Bcast(yp, y->num, T_MPI, root, MPI_COMM_WORLD); +} +#endif // JIT_cpu +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/mpi/ops/mpi_broadcast_op.h b/python/jittor/extern/mpi/ops/mpi_broadcast_op.h new file mode 100644 index 00000000..02e04e1c --- /dev/null +++ b/python/jittor/extern/mpi/ops/mpi_broadcast_op.h @@ -0,0 +1,35 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Guowei Yang <471184555@qq.com>. +// Dun Liang . +// All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct MpiBroadcastOp : Op { + Var* x, * y; + int root; + + /** + + Mpi Broadcast Operator broadcasts variable [x] in [root] MPI nodes to all MPI nodes. + + Args: + + * x: variable to be broadcasted. + * root: ID of MPI node to be broadcasted. Default: 0. + */ + MpiBroadcastOp(Var* x, int root=0); + void infer_shape() override; + + const char* name() const override { return "mpi_broadcast"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/mpi/ops/mpi_reduce_op.cc b/python/jittor/extern/mpi/ops/mpi_reduce_op.cc new file mode 100644 index 00000000..77d86a82 --- /dev/null +++ b/python/jittor/extern/mpi/ops/mpi_reduce_op.cc @@ -0,0 +1,91 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Guowei Yang <471184555@qq.com>. +// Dun Liang . +// All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "mpi_wrapper.h" +#include "var.h" +#include "mpi_reduce_op.h" +#include "ops/op_register.h" +#include "utils/str_utils.h" +#include "misc/cuda_flags.h" + +namespace jittor { + +#ifndef JIT + +static auto make_array = get_op_info("array") + .get_constructor(); +static auto make_binary = get_op_info("binary") + .get_constructor(); +static auto make_mpi_reduce = get_op_info("mpi_reduce") + .get_constructor(); + +MpiReduceOp::MpiReduceOp(Var* x, NanoString op, int root) : x(x), op(op), root(root) { + if (!mpi_enabled) { + forward(x); + return; + } + if (op == ns_mean) { + auto var = make_mpi_reduce(x, ns_add, root); + var = make_binary(var, make_array(&mpi_world_size, 1, ns_int32), ns_divide); + forward(var); + return; + } + ASSERT(op == ns_add) << "Not supported MPI op" << op; + #ifdef HAS_CUDA + if (use_device_mpi && use_cuda) { + static auto nccl_reduce = has_op("nccl_reduce") + ? get_op_info("nccl_reduce").get_constructor() + : nullptr; + if (nccl_reduce) { + auto var = nccl_reduce(x, root); + forward(var); + return; + } + } + #endif + y = create_output(nullptr, x->dtype()); +} + +void MpiReduceOp::infer_shape() { + y->set_shape(x->shape); +} + +VarPtr MpiReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) { + static VarPtr(*mpi_broadcast)(Var*, int) = + get_op_info("mpi_broadcast").get_constructor(); + return mpi_broadcast(dout,root); +} + +void MpiReduceOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype(); + jk << "«OP:" << op; +} + +#else // JIT +#ifdef JIT_cpu +void MpiReduceOp::jit_run() { + @define(T_MPI, + @if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, MPI_FLOAT) + @if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, MPI_INT) + @if(@strcmp(@Tx,float64)==0 || @strcmp(@Tx,double)==0, MPI_DOUBLE) + @if(@strcmp(@Tx,int64)==0, MPI_DOUBLE_INT) + ) + @define(OP_MPI, + @if(@strcmp(@OP,add)==0, MPI_SUM) + ) + auto* __restrict__ xp = x->ptr(); + auto* __restrict__ yp = y->ptr(); + index_t num = y->num; + MPI_CHECK(MPI_Reduce(xp, yp, num, T_MPI, OP_MPI, root, MPI_COMM_WORLD)); + if (root != mpi_world_rank) + for (index_t i=0; i. +// Dun Liang . +// All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct MpiReduceOp : Op { + Var* x, * y; + NanoString op; + int root; + + /** + + Mpi Reduce Operator uses the operator [op] to reduce variable [x] in all MPI nodes and send to the [root] MPI node. + + Args: + + * x: variable to be reduced. + * op: 'sum' or 'add' means sum all [x], 'mean' means average all [x]. Default: 'add'. + * root: ID of MPI node to output. Default: 0. + */ + MpiReduceOp(Var* x, NanoString op=ns_add, int root=0); + void infer_shape() override; + + const char* name() const override { return "mpi_reduce"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/mpi/ops/mpi_test_op.cc b/python/jittor/extern/mpi/ops/mpi_test_op.cc new file mode 100644 index 00000000..54e5ecb5 --- /dev/null +++ b/python/jittor/extern/mpi/ops/mpi_test_op.cc @@ -0,0 +1,42 @@ +// *************************************************************** +// Copyright (c) 2019 Dun Liang . All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "mpi_wrapper.h" + +#include "var.h" +#include "mpi_test_op.h" +#include "utils/str_utils.h" + +namespace jittor { + +#ifndef JIT +MpiTestOp::MpiTestOp(string cmd) : cmd(cmd) { + output = create_output(1, ns_float32); +} + +void MpiTestOp::jit_prepare(JK& jk) { + jk << "«T:float32"; +} + +#else // JIT + +void MpiTestOp::jit_run() { + output->ptr()[0] = 123; + + int world_size = mpi_world_size; + + int world_rank = mpi_world_rank; + + char processor_name[MPI_MAX_PROCESSOR_NAME]; + int name_len; + MPI_CHECK(MPI_Get_processor_name(processor_name, &name_len)); + + printf("Hello world from processor %s, rank %d out of %d processors\\n",processor_name, world_rank, world_size); + +} + +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/mpi/ops/mpi_test_op.h b/python/jittor/extern/mpi/ops/mpi_test_op.h new file mode 100644 index 00000000..b2e0df21 --- /dev/null +++ b/python/jittor/extern/mpi/ops/mpi_test_op.h @@ -0,0 +1,23 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Dun Liang . +// All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct MpiTestOp : Op { + Var* output; + string cmd; + + MpiTestOp(string cmd); + + const char* name() const override { return "mpi_test"; } + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/mpi/src/mpi_wrapper.cc b/python/jittor/extern/mpi/src/mpi_wrapper.cc new file mode 100644 index 00000000..498633a1 --- /dev/null +++ b/python/jittor/extern/mpi/src/mpi_wrapper.cc @@ -0,0 +1,246 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. +// All Rights Reserved. +// Maintainers: +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#include +#include + +#include "mpi_wrapper.h" +#include "common.h" +#include "ops/array_op.h" + +char jt_mpi_err_buffer[MPI_MAX_ERROR_STRING]; + +void throw_mpi_error(int result, + char const *const func, const char *const file, int const line) { + int resultlen; + MPI_Error_string(result, jt_mpi_err_buffer, &resultlen); + LOGf << "MPI error at " >> file >> ":" >> line << "code=" + >> result >> '(' >> jt_mpi_err_buffer >> ')' << func; +} + +namespace jittor { + +MPI_Datatype MPI_HALF; +MPI_Op MPI_HALF_ADD; + +void HalfAdd(void* invec, void* inoutvec, int* len, MPI_Datatype* type) { + // return; + short* in = (short*)invec; + short* inout = (short*)inoutvec; + + int i = 0; + int total = *len; + for (; i+8 <= total; i += 8) { + // 将半精度浮点数转换为单精度浮点数 + __m256 in1 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(in + i))); + __m256 in2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(inout + i))); + + // 执行向量加法 + __m256 out = _mm256_add_ps(in1, in2); + + // 将单精度浮点数转换回半精度浮点数,并存储结果 + _mm_storeu_si128((__m128i*)(inout + i), _mm256_cvtps_ph(out, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + + // 处理剩余的半精度浮点数 + for (; i < total; i++) { + // 将半精度浮点数转换为单精度浮点数 + __m128 in1 = _mm_cvtph_ps(_mm_set1_epi16(*(in + i))); + __m128 in2 = _mm_cvtph_ps(_mm_set1_epi16(*(inout + i))); + + // 执行向量加法 + __m128 out = _mm_add_ps(in1, in2); + + // 将单精度浮点数转换回半精度浮点数,并存储结果 + *(inout + i) = _mm_cvtps_ph(out, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)[0]; + } +} + + +int mpi_world_size = 1; +int mpi_world_rank = 0; +int mpi_local_size = 1; +int mpi_local_rank = 0; +bool inside_mpi = false; +bool mpi_enabled = false; +bool use_device_mpi = false; + +int _mpi_world_size() { + return mpi_enabled ? mpi_world_size : 1; +} + +int _mpi_world_rank() { + return mpi_enabled ? mpi_world_rank : 0; +} + +int _mpi_local_rank() { + return mpi_enabled ? mpi_local_rank : 0; +} + +void _mpi_broadcast(ArrayArgs&& args, int root) { + if (!mpi_enabled) return; + int64 size = args.dtype.dsize(); + for (auto j : args.shape) + size *= j; + MPI_CHECK(MPI_Bcast((void *)args.ptr, size, MPI_BYTE, root, MPI_COMM_WORLD)); +} + +static uint64_t getHostHash(const char* string) { + // Based on DJB2, result = result * 33 + char + uint64_t result = 5381; + for (int c = 0; string[c] != '\0'; c++){ + result = ((result << 5) + result) + string[c]; + } + return result; +} + + +static void getHostName(char* hostname, int maxlen) { + gethostname(hostname, maxlen); + for (int i=0; i< maxlen; i++) { + if (hostname[i] == '.') { + hostname[i] = '\0'; + return; + } + } +} + +struct mpi_initer { + +mpi_initer() { + inside_mpi = !!getenv("OMPI_COMM_WORLD_SIZE"); + if (!inside_mpi) return; + mpi_enabled = true; + LOGvv << "MPI init..."; + MPI_CHECK(MPI_Init(NULL, NULL)); + MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size)); + MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank)); + + //calculating localRank based on hostname which is used in selecting a GPU + uint64_t hostHashs[mpi_world_rank]; + char hostname[1024]; + getHostName(hostname, 1024); + hostHashs[mpi_world_rank] = getHostHash(hostname); + MPI_CHECK(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, hostHashs, sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD)); + mpi_local_rank = 0; + for (int p=0; pvar; + ASSERT(v->mem_ptr && !v->allocator->is_cuda()); + int64 MPI_MAX_SIZE = 1ll<<30; + for (int64 i=0; isize; i+=MPI_MAX_SIZE) { + int64 size = std::min(v->size-i, MPI_MAX_SIZE); + MPI_Bcast(v->ptr()+i, size, MPI_BYTE, 0, MPI_COMM_WORLD); + } +} + +void var_reduce(VarHolder* x, int root) { + if (!inside_mpi) return; + Var* v = x->var; + ASSERT(v->mem_ptr && !v->allocator->is_cuda()); + MPI_Datatype dtype; + MPI_Op op; + if (v->dtype() == ns_float16) + dtype = MPI_HALF, op = MPI_HALF_ADD; + else if (v->dtype() == ns_int16) + dtype = MPI_SHORT, op = MPI_SUM; + else if (v->dtype() == ns_float32) + dtype = MPI_FLOAT, op = MPI_SUM; + else if (v->dtype() == ns_float64) + dtype = MPI_DOUBLE, op = MPI_SUM; + else if (v->dtype() == ns_int32) + dtype = MPI_INT, op = MPI_SUM; + else if (v->dtype() == ns_int64) + dtype = MPI_LONG_LONG_INT, op = MPI_SUM; + else if (v->dtype() == ns_uint8) + dtype = MPI_UNSIGNED_CHAR, op = MPI_SUM; + else + LOGf << "Not supported dtype" << v->dtype(); + // mpi reduce performace magically reduce from 4194304 + int64 MPI_MAX_SIZE = (4194304) / v->dtype().dsize(); + for (int64 i=0; inum; i+=MPI_MAX_SIZE) { + int64 size = std::min(v->num-i, MPI_MAX_SIZE); + auto mem_ptr = v->ptr()+i*v->dtype().dsize(); + if (mpi_world_rank == root) + MPI_Reduce(MPI_IN_PLACE, mem_ptr, size, dtype, op, root, MPI_COMM_WORLD); + else + MPI_Reduce(mem_ptr, nullptr, size, dtype, op, root, MPI_COMM_WORLD); + } +} + +void var_all_reduce(VarHolder* x) { + if (!inside_mpi) return; + Var* v = x->var; + ASSERT(v->mem_ptr && !v->allocator->is_cuda()); + MPI_Datatype dtype; + MPI_Op op; + if (v->dtype() == ns_float16) + dtype = MPI_HALF, op = MPI_HALF_ADD; + else if (v->dtype() == ns_int16) + dtype = MPI_SHORT, op = MPI_SUM; + else if (v->dtype() == ns_float32) + dtype = MPI_FLOAT, op = MPI_SUM; + else if (v->dtype() == ns_float64) + dtype = MPI_DOUBLE, op = MPI_SUM; + else if (v->dtype() == ns_int32) + dtype = MPI_INT, op = MPI_SUM; + else if (v->dtype() == ns_int64) + dtype = MPI_LONG_LONG_INT, op = MPI_SUM; + else if (v->dtype() == ns_uint8) + dtype = MPI_UNSIGNED_CHAR, op = MPI_SUM; + else + LOGf << "Not supported dtype" << v->dtype(); + int64 MPI_MAX_SIZE = (1<<30) / v->dtype().dsize(); + for (int64 i=0; inum; i+=MPI_MAX_SIZE) { + int64 size = std::min(v->num-i, MPI_MAX_SIZE); + auto mem_ptr = v->ptr()+i*v->dtype().dsize(); + MPI_Allreduce(MPI_IN_PLACE, mem_ptr, size, dtype, op, MPI_COMM_WORLD); + } +} + +void mpi_barrier() { + if (!inside_mpi) return; + MPI_CHECK(MPI_Barrier(MPI_COMM_WORLD)); +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/extern/rocm/rocm_cache.tar.gz b/python/jittor/extern/rocm/rocm_cache.tar.gz new file mode 100644 index 00000000..3c9677ee Binary files /dev/null and b/python/jittor/extern/rocm/rocm_cache.tar.gz differ diff --git a/python/jittor/extern/rocm/rocm_compiler.py b/python/jittor/extern/rocm/rocm_compiler.py new file mode 100644 index 00000000..26d44d45 --- /dev/null +++ b/python/jittor/extern/rocm/rocm_compiler.py @@ -0,0 +1,154 @@ +# *************************************************************** +# Copyright (c) 2021 Jittor. All Rights Reserved. +# Maintainers: Zheng-Ning Liu . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import os +import ctypes +import glob +import tarfile + +import jittor_utils +from jittor_utils import env_or_try_find, run_cmd, cache_path, LOG +import jittor.compiler as compiler + + +has_rocm = 0 +cc_flags = "" +hipcc_path = env_or_try_find('hipcc_path', 'hipcc') +rocm_home = "" +dlopen_flags = os.RTLD_NOW | os.RTLD_GLOBAL +compiler.has_rocm = has_rocm + + +def check_gcc_use_cxx11_abi(): + gcc_info = run_cmd("gcc -v") + if "--with-default-libstdcxx-abi=new" in gcc_info: + return True + elif "--with-default-libstdcxx-abi=gcc4-compatible" in gcc_info: + return False + else: + LOG.d("unknown cxx abi, defaults to gcc4-compatible") + return False + + +def install_rocm_jittor_core(): + import jittor.compiler as compiler + global has_rocm, cc_flags, rocm_home + rocm_home = run_cmd("hipconfig -R") + rocm_version = run_cmd("hipconfig -v") + + rocm_compiler_home = os.path.dirname(__file__) + rocm_cache_gz_path = os.path.join(rocm_compiler_home, "rocm_cache.tar.gz") + if os.path.exists(rocm_cache_gz_path): + for o_file in glob.glob(rocm_compiler_home + "/**/*.o", recursive=True): + os.remove(o_file) + with tarfile.open(rocm_cache_gz_path, "r:gz") as tar: + if (check_gcc_use_cxx11_abi()): + tar.extractall(rocm_compiler_home, members=[tar.getmember("rocm_cache_cxx11.o")]) + o_files = [ os.path.join(rocm_compiler_home, "rocm_cache_cxx11.o") ] + else: + tar.extractall(rocm_compiler_home, members=[tar.getmember("rocm_cache.o")]) + o_files = [ os.path.join(rocm_compiler_home, "rocm_cache.o") ] + + cc_files = sorted(glob.glob(rocm_compiler_home + "/**/*.cc", recursive=True)) + cc_flags += f" -DHAS_CUDA -DIS_ROCM -I{rocm_compiler_home} " + cc_flags += " " + run_cmd("hipconfig -C") + " " + cc_flags += ' -L"' + os.path.join(rocm_home, "lib") + '" -lamdhip64 ' + LOG.i(f"ROCm ({rocm_version}) detected in {rocm_home}") + + mod = jittor_utils.compile_module(''' +#include "common.h" +namespace jittor { +// @pyjt(process) +string process_rocm(const string& src, const string& name, const map& kargs); +}''', compiler.cc_flags + " " + " ".join(cc_files + o_files) + cc_flags) + jittor_utils.process_jittor_source("rocm", mod.process) + + # preload hip driver to ensure the correct initialization of hip context + hip_driver = ctypes.CDLL(os.path.join(rocm_home, 'lib', 'libamdhip64.so'), os.RTLD_GLOBAL | os.RTLD_NOW) + r = hip_driver.hipDeviceSynchronize() + + has_rocm = 1 + + +def install_hip(): + import jittor.compiler as compiler + + LOG.vv("setup rocm extern...") + cache_path_cuda = os.path.join(cache_path, "cuda") + cuda_include = os.path.join(compiler.jittor_path, "extern", "cuda", "inc") + compiler.make_cache_dir(cache_path_cuda) + cuda_extern_src = os.path.join(compiler.jittor_path, "extern", "cuda", "src") + cuda_extern_files = [os.path.join(cuda_extern_src, name) for name in os.listdir(cuda_extern_src)] + so_name = os.path.join(cache_path_cuda, "libcuda_extern" + compiler.so) + compiler.compile(compiler.cc_path, compiler.cc_flags+f" -I\"{cuda_include}\" ", cuda_extern_files, so_name) + ctypes.CDLL(so_name, dlopen_flags) + + +def install_rocm_library(lib_name, cuda_name, link=True): + import jittor.compiler as compiler + import jittor.compile_extern as compile_extern + + LOG.vv(f"setup {lib_name}...") + rocmlib_include_path = os.path.join(rocm_home, lib_name.lower(), "include") + + jt_cuda_include = os.path.join(compiler.jittor_path, "extern", "cuda", "inc") + jt_culib_include = os.path.join(compiler.jittor_path, "extern", "cuda", cuda_name, "inc") + + culib_src_dir = os.path.join(compiler.jittor_path, "extern", "cuda", cuda_name) + culib_src_files = [] + for r, _, f in os.walk(culib_src_dir): + for fname in f: + culib_src_files.append(os.path.join(r, fname)) + + extra_flags = f" -I\"{jt_cuda_include}\" -I\"{jt_culib_include}\" -I\"{rocmlib_include_path}\" " + extra_flags += f" -L\"{os.path.join(cache_path, 'cuda')}\" -llibcuda_extern " + if lib_name == "rccl": + extra_flags += compile_extern.mpi_compile_flags + + if link: + rocmlib_lib_path = os.path.join(rocm_home, lib_name.lower(), "lib") + if os.path.exists(os.path.join(rocmlib_lib_path, f"lib{lib_name}.so")): + jittor_utils.LOG.i(f"Found {os.path.join(rocmlib_lib_path, 'lib' + lib_name + '.so')}") + extra_flags += f" -L{rocmlib_lib_path} -l{lib_name} " + + rocmlib = compiler.compile_custom_ops(culib_src_files, return_module=True, extra_flags=extra_flags) + setattr(compile_extern, cuda_name, rocmlib) + setattr(compile_extern, cuda_name + "_ops", rocmlib.ops) + + +def install_extern(): + if has_rocm: + install_hip() + install_rocm_library("MIOpen", "cudnn") + install_rocm_library("rocblas", "cublas") + install_rocm_library("rocprim", "cub", link=False) + install_rocm_library("rccl", "nccl") + return True + else: + return False + +def convert_nvcc_flags(nvcc_flags): + return nvcc_flags + +def check(): + import jittor.compiler as compiler + global has_rocm, cc_flags + if hipcc_path: + try: + install_rocm_jittor_core() + except Exception as e: + jittor_utils.LOG.w(f"load ROCm failed, exception: {e}") + has_rocm = 0 + compiler.has_rocm = has_rocm + compiler.hipcc_path = hipcc_path + if not has_rocm: + return False + + compiler.cc_flags += cc_flags + compiler.nvcc_path = hipcc_path + compiler.nvcc_flags = compiler.cc_flags.replace("-std=c++14", "-std=c++17") + compiler.convert_nvcc_flags = convert_nvcc_flags + return True diff --git a/python/jittor/extern/rocm/rocm_config.cc b/python/jittor/extern/rocm/rocm_config.cc new file mode 100644 index 00000000..097aea5b --- /dev/null +++ b/python/jittor/extern/rocm/rocm_config.cc @@ -0,0 +1,40 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: Zheng-Ning Liu . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "rocm_config.h" +#include "utils/str_utils.h" + +namespace jittor +{ + void rocm_config(const string &name, string &src) + { + int pos; + string error_token_substring = "benchmark = false;"; + string error_token_0 = "if (fwd_algo_cache.size()>=max_cache_size) " + error_token_substring; + string error_token_1 = "if (bwdw_algo_cache.size()>=max_cache_size) " + error_token_substring; + string error_token_2 = "if (bwdx_algo_cache.size()>=max_cache_size) " + error_token_substring; + if ((pos = src.find(error_token_0)) != string::npos) { + src.erase(pos, error_token_0.size()); + } + if ((pos = src.find(error_token_1)) != string::npos) { + src.erase(pos, error_token_1.size()); + } + if ((pos = src.find(error_token_2)) != string::npos) { + src.erase(pos, error_token_2.size()); + } + + string use_cub_where = "cub_where && (ndim>1 || std::abs(cond->num)>4096)"; + if ((pos = src.find(use_cub_where)) != string::npos) { + src.replace(pos, use_cub_where.size(), "cub_where"); + } + + string enable_rocm = "HIP enabled"; + if ((pos = src.find(enable_rocm)) != string::npos) { + src.replace(pos, enable_rocm.size(), "ROCm enabled"); + } + } +} + diff --git a/python/jittor/extern/rocm/rocm_config.h b/python/jittor/extern/rocm/rocm_config.h new file mode 100644 index 00000000..7ecea4fb --- /dev/null +++ b/python/jittor/extern/rocm/rocm_config.h @@ -0,0 +1,16 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: Zheng-Ning Liu . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +namespace jittor { + +void rocm_config(const string& name, string& src); + +} + + diff --git a/python/jittor/extern/rocm/rocm_jittor.h b/python/jittor/extern/rocm/rocm_jittor.h new file mode 100644 index 00000000..ea4f2642 --- /dev/null +++ b/python/jittor/extern/rocm/rocm_jittor.h @@ -0,0 +1,14 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: Zheng-Ning Liu . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +namespace jittor { + +void rocm_jittor_op_compiler(string& filename, string& src, bool is_rocm, string& extra_flags); + +} diff --git a/python/jittor/extern/rocm/rocm_wrapper.h b/python/jittor/extern/rocm/rocm_wrapper.h new file mode 100644 index 00000000..036625a7 --- /dev/null +++ b/python/jittor/extern/rocm/rocm_wrapper.h @@ -0,0 +1,150 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: Zheng-Ning Liu . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +#include +#include +#include + +namespace jittor { + +struct RocprimArgMax +{ + template< + class Key, + class Value + > + __host__ __device__ inline + constexpr rocprim::key_value_pair + operator()(const rocprim::key_value_pair& a, + const rocprim::key_value_pair& b) const + { + return ((b.value > a.value) || ((a.value == b.value) && (b.key > a.key))) ? b : a; + } +}; + +struct RocprimArgMin +{ + template< + class Key, + class Value + > + __host__ __device__ inline + constexpr rocprim::key_value_pair + operator()(const rocprim::key_value_pair& a, + const rocprim::key_value_pair& b) const + { + return ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) ? b : a; + } +}; + +template +static __global__ void global_index_to_segment_index(rocprim::key_value_pair* d_out, Key* offsetsp, int n) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int tnum = blockDim.x * gridDim.x; + for (int i = tid; i < n; i += tnum) { + d_out[i].key -= offsetsp[i]; + } +} + +template< + typename InputIteratorT, + typename OutputIteratorT, + typename OffsetIteratorT + > +static hipError_t rocprim_argmax(void * d_temp_storage, + size_t& temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + hipStream_t stream = 0, + bool debug_synchronous = false) { + using OffsetT = int; + using T = typename std::iterator_traits::value_type; + using O = typename std::iterator_traits::value_type; + using OutputTupleT = typename std::conditional< + std::is_same::value, + rocprim::key_value_pair, + O + >::type; + + using OutputValueT = typename OutputTupleT::Value; + using IteratorT = rocprim::arg_index_iterator; + + IteratorT d_indexed_in(d_in); + const OutputTupleT init(1, std::numeric_limits::lowest()); + + auto ret = rocprim::segmented_reduce(d_temp_storage, + temp_storage_bytes, + d_indexed_in, + d_out, + num_segments, + d_begin_offsets, + d_end_offsets, + RocprimArgMax(), + init, + stream, + debug_synchronous); + if (d_temp_storage != NULL) { + global_index_to_segment_index<<>>(d_out, d_begin_offsets, num_segments); + } + + return ret; +} + + +template< + typename InputIteratorT, + typename OutputIteratorT, + typename OffsetIteratorT + > +static hipError_t rocprim_argmin(void * d_temp_storage, + size_t& temp_storage_bytes, + InputIteratorT d_in, + OutputIteratorT d_out, + int num_segments, + OffsetIteratorT d_begin_offsets, + OffsetIteratorT d_end_offsets, + hipStream_t stream = 0, + bool debug_synchronous = false) { + using OffsetT = int; + using T = typename std::iterator_traits::value_type; + using O = typename std::iterator_traits::value_type; + using OutputTupleT = typename std::conditional< + std::is_same::value, + rocprim::key_value_pair, + O + >::type; + + using OutputValueT = typename OutputTupleT::Value; + using IteratorT = rocprim::arg_index_iterator; + + IteratorT d_indexed_in(d_in); + const OutputTupleT init(1, std::numeric_limits::max()); + + auto ret = rocprim::segmented_reduce(d_temp_storage, + temp_storage_bytes, + d_indexed_in, + d_out, + num_segments, + d_begin_offsets, + d_end_offsets, + RocprimArgMin(), + init, + stream, + debug_synchronous); + if (d_temp_storage != NULL) { + global_index_to_segment_index<<>>(d_out, d_begin_offsets, num_segments); + } + + return ret; +} + +} diff --git a/python/jittor/init.py b/python/jittor/init.py new file mode 100644 index 00000000..6917062c --- /dev/null +++ b/python/jittor/init.py @@ -0,0 +1,738 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +from jittor import NanoVector, Var +import numpy as np +import math +import warnings + +def eye(shape, dtype="float32"): + ''' Generate 2-D identity matrix. + + Args: + shape (int or tuple of int): + shape of the output matrix + dtype (string): + dtype of the output matrix, default float32 + + Return: + A Jittor Var of identity matrix. + + Example:: + + from jittor import init + print(init.eye(2)) + # output: [[1.,0.],[0.,1.]] + print(init.eye((2,3), "float32")) + # output: [[1.,0.,0.],[0.,1.,0.]] + + ''' + if isinstance(shape, int): + shape = (shape,shape) + assert len(shape)==2, f"len of shape should be 2, but got {shape}" + index = jt.index(shape) + return (index[0]==index[1]).unary(dtype) + +def eye_(var): + ''' Inplace initialize variable with identity matrix. + + Args: + var (Jittor Var): + Var to initialize with identity matrix. + + Return: + var itself. + + Example:: + + from jittor import init + from jittor import nn + linear = nn.Linear(2,2) + init.eye_(linear.weight) + print(linear.weight) + # output: [[1.,0.],[0.,1.]] + linear.weight.eye_() # This is ok too + + ''' + return var.assign(eye(var.shape, var.dtype)) +Var.eye_ = eye_ + +def constant(shape, dtype="float32", value=0.0): + '''Generate constant Jittor Var. + + Args: + shape (int or tuple of int): + shape of the output Var + dtype (string): + dtype of the output Var, default float32 + value (int or float): + value to be filled in output Var + + Return: + A Jittor Var which filled by constant value. + + Example:: + + from jittor import init + print(init.constant(2)) + # output: [0.,0.] + print(init.constant((2,3), value=1.)) + # output: [[1.,1.,1.],[1.,1.,1.]] + + ''' + return jt.array(value).unary(dtype).broadcast(NanoVector(shape)) + +def constant_(var, value=0.0): + ''' Inplace initialize variable with constant value. + + Args: + var (Jittor Var): + Var to initialize with constant value. + + Return: + var itself. + + Example:: + + from jittor import init + from jittor import nn + linear = nn.Linear(2,2) + init.constant_(linear.weight) + print(linear.weight) + # output: [[0.,0.],[0.,0.]] + linear.weight.constant_() # This is ok too + + ''' + return var.assign(constant(var.shape, var.dtype, value)) +Var.constant_ = constant_ +fill = Var.fill_ = constant_ + +def zero(shape, dtype="float32"): + '''Generate zero Jittor Var. + + Args: + shape (int or tuple of int): + shape of the output Var + dtype (string): + dtype of the output Var, default float32 + + Return: + A Jittor Var which filled by constant value. + + Example:: + + from jittor import init + print(init.zero(2)) + # output: [0.,0.] + print(init.zero((2,3))) + # output: [[0.,0.,0.],[0.,0.,0.]] + + ''' + return constant(shape, dtype, 0) +def zero_(var): + ''' Inplace initialize variable with zero. + + Args: + var (Jittor Var): + Var to initialize with zero. + + Return: + var itself. + + Example:: + + from jittor import init + from jittor import nn + linear = nn.Linear(2,2) + init.zero_(linear.weight) + print(linear.weight) + # output: [[0.,0.],[0.,0.]] + linear.weight.zero_() # This is ok too + + ''' + return var.assign(zero(var.shape, var.dtype)) +Var.zero_ = zero_ + +def random_(var): + return var.assign(jt.rand(var.shape, var.dtype)) +Var.random_ = random_ + +def one(shape, dtype="float32"): + '''Generate Jittor Var filled by one. + + Args: + shape (int or tuple of int): + shape of the output Var + dtype (string): + dtype of the output Var, default float32 + + Return: + A Jittor Var which filled by one. + + Example:: + + from jittor import init + print(init.one(2)) + # output: [1.,1.] + print(init.one((2,3))) + # output: [[1.,1.,1.],[1.,1.,1.]] + + ''' + return constant(shape, dtype, 1) +def one_(var): + ''' Inplace initialize variable with one. + + Args: + var (Jittor Var): + Var to initialize with one. + + Return: + var itself. + + Example:: + + from jittor import init + from jittor import nn + linear = nn.Linear(2,2) + init.one_(linear.weight) + print(linear.weight) + # output: [[1.,1.],[1.,1.]] + linear.weight.one_() # This is ok too + + ''' + return var.assign(one(var.shape, var.dtype)) +Var.one_ = one_ + +def uniform(shape, dtype="float32", low=0, high=1): + '''Generate random uniform Jittor Var. + + Args: + shape (int or tuple of int): + shape of the output Var + dtype (string): + dtype of the output Var, default float32 + low (int or float or Var): + lower bound value of the random uniform + high (int or float or Var): + upper bound value of the random uniform + + Return: + A Jittor Var which filled by random uniform. + + Example:: + + from jittor import init + print(init.uniform(5)) + # output: [0.202268, 0.518688, 0.595274, 0.777354, 0.981979] + print(init.uniform((2,3), low=-1, high=1)) + # output: [[ 0.6647397 0.2801202 -0.01981187] + # [-0.9779438 -0.30149996 0.69056886]] + + ''' + return jt.random(NanoVector(shape), dtype) * (low - high) + high + +def uniform_(var, low=0, high=1): + ''' Inplace initialize Jittor Var by random uniform. + + Args: + var (Jittor Var): + Var to be initialized by random uniform + low (int or float or Var): + lower bound value of the random uniform + high (int or float or Var): + upper bound value of the random uniform + + Example:: + + from jittor import init + from jittor import nn + linear = nn.Linear(2,2) + init.uniform_(linear.weight, -1.0, 1.0) + print(linear.weight) + # output: [[ 0.6647397 0.2801202], [-0.9779438 -0.30149996]] + linear.weight.uniform_(-1.0, 1.0) # This is ok too + + ''' + return var.assign(uniform(var.shape, var.dtype, low, high)) +Var.uniform_ = uniform_ + +def gauss(shape, dtype="float32", mean=0.0, std=1.0): + ''' Return Jittor Var initialize by random gauss. + + Args: + shape (int or tuple of int): + shape of the output Var + dtype (string): + dtype of the output Var, default float32 + mean (int or float or Var): + mean value of the random gauss + std (int or float or Var): + std value of the random gauss + + Example:: + + from jittor import init + from jittor import nn + a = init.gauss((2,2), "float32", 0.0, 1.0) + print(a) + + ''' + return jt.random(NanoVector(shape), dtype, "normal") * std + mean + +def gauss_(var, mean=0.0, std=1.0): + ''' Inplace initialize Jittor Var by random gauss. + + Args: + var (Jittor Var): + Var to be initialized by random gauss + mean (int or float or Var): + mean value of the random gauss + std (int or float or Var): + std value of the random gauss + + Example:: + + from jittor import init + from jittor import nn + linear = nn.Linear(2,2) + init.gauss_(linear.weight, 0.0, 1.0) + print(linear.weight) + linear.weight.gauss_(0.0, 1.0) # This is ok too + + ''' + return var.assign(gauss(var.shape, var.dtype, mean, std)) +Var.gauss_ = gauss_ +Var.normal_ = gauss_ + +def invariant_uniform(shape, dtype="float32", mode="fan_in"): + ''' Return Jittor initialized Var by invariant_uniform. + + Args: + shape (int or tuple of int): + shape of the output Var + dtype (string): + dtype of the output Var, default float32 + mode (string): + mode selection, should be fan_in or fan_out. + Choosing 'fan_in' preserves the magnitude of the variance of the weights in the forward pass. Choosing 'fan_out' preserves the magnitudes in the backwards pass. + + Example:: + + from jittor import init + from jittor import nn + a = init.invariant_uniform_((2,2)) + print(a) + + ''' + assert len(shape)>1 + assert mode=="fan_in" or mode=="fan_out", \ + f"mode not supported, should be fan_in or fan_out, but got {mode}" + + matsize=1 + for i in shape[2:]: + matsize *= i + fan = (shape[1] * matsize) if mode=="fan_in" else (shape[0] * matsize) + bound = math.sqrt(1.0/fan) + return uniform(shape, dtype, -bound, bound) + +def invariant_uniform_(var, mode="fan_in"): + ''' Inplace initialize Jittor Var by invariant_uniform. + + Args: + var (Jittor Var): + Var to be initialized by random invariant_uniform + mode (string): + mode selection, should be fan_in or fan_out. + Choosing 'fan_in' preserves the magnitude of the variance of the weights in the forward pass. Choosing 'fan_out' preserves the magnitudes in the backwards pass. + + Example:: + + from jittor import init + from jittor import nn + linear = nn.Linear(2,2) + init.invariant_uniform_(linear.weight) + print(linear.weight) + linear.weight.invariant_uniform_() # This is ok too + + ''' + var.assign(invariant_uniform(tuple(var.shape), var.dtype, mode)) +Var.invariant_uniform_ = invariant_uniform_ + +def relu_invariant_gauss(shape, dtype="float32", mode="fan_in"): + ''' Return Jittor Var initialized by relu_invariant_gauss. + + Args: + shape (int or tuple of int): + shape of the output Var + dtype (string): + dtype of the output Var, default float32 + mode (string): + mode selection, should be fan_in or fan_out. + Choosing 'fan_in' preserves the magnitude of the variance of the weights in the forward pass. Choosing 'fan_out' preserves the magnitudes in the backwards pass. + + Example:: + + from jittor import init + from jittor import nn + a = init.relu_invariant_gauss((2,2)) + print(a) + + ''' + assert len(shape)>1 + assert mode=="fan_in" or mode=="fan_out" + + matsize=1 + for i in shape[2:]: + matsize *= i + fan = (shape[1] * matsize) if mode=="fan_in" else (shape[0] * matsize) + std = math.sqrt(2.0/fan) + return gauss(shape, dtype, 0, std) + +def relu_invariant_gauss_(var, mode="fan_in"): + ''' Inplace initialize Jittor Var by relu_invariant_gauss. + + Args: + var (Jittor Var): + Var to be initialized by random relu_invariant_gauss + mode (string): + mode selection, should be fan_in or fan_out. + Choosing 'fan_in' preserves the magnitude of the variance of the weights in the forward pass. Choosing 'fan_out' preserves the magnitudes in the backwards pass. + + Example:: + + from jittor import init + from jittor import nn + linear = nn.Linear(2,2) + init.relu_invariant_gauss_(linear.weight) + print(linear.weight) + linear.weight.relu_invariant_gauss_() # This is ok too + + ''' + return var.assign(relu_invariant_gauss(tuple(var.shape), var.dtype, mode)) +Var.relu_invariant_gauss_ = relu_invariant_gauss_ + +def calculate_std(var, mode, nonlinearity, param=0.01): + mode = mode.lower() + assert isinstance(param,(int,float)) + assert var.ndim>=2 + assert mode in ['fan_in', 'fan_out'] + + fan = var.shape[1] if mode == 'fan_in' else var.shape[0] + fan *= var[0][0].numel() + + gains = { + 'linear':1, + 'conv1d':1, + 'conv2d':1, + 'conv3d':1, + 'conv_transpose1d':1, + 'conv_transpose2d':1, + 'conv_transpose3d':1, + 'sigmoid':1, + 'tanh':5.0/3, + 'relu':math.sqrt(2.0), + 'leaky_relu':math.sqrt(2.0 / (1 + param ** 2)), + } + gain = gains[nonlinearity] + std = gain/math.sqrt(fan) + return std + + +def kaiming_uniform_(var, a=0, mode='fan_in', nonlinearity='leaky_relu'): + ''' Inplace initialize Jittor Var by kaiming_uniform. + + Args: + var (Jittor Var): + Var to be initialized by random kaiming_uniform + a (float): + the negative slope of the rectifier used after this layer (only used with 'leaky_relu') + mode (string): + mode selection, should be fan_in or fan_out. + Choosing 'fan_in' preserves the magnitude of the variance of the weights in the forward pass. Choosing 'fan_out' preserves the magnitudes in the backwards pass. + nonlinearity (string): + nonlinearity used after this layer. + It can be one of [linear, conv*, sigmoid, tanh, relu, leaky_relu]. + leaky_relu is used by default. + + Example:: + + from jittor import init + from jittor import nn + linear = nn.Linear(2,2) + init.kaiming_uniform_(linear.weight) + print(linear.weight) + linear.weight.kaiming_uniform_() # This is ok too + + ''' + std = calculate_std(var,mode,nonlinearity,a) + bound = math.sqrt(3.0) * std + return uniform_(var,-bound, bound) +Var.kaiming_uniform_ = kaiming_uniform_ + +def kaiming_normal_(var, a=0, mode='fan_in', nonlinearity='leaky_relu'): + ''' Inplace initialize Jittor Var by kaiming_normal. + + Args: + var (Jittor Var): + Var to be initialized by random kaiming_normal + a (float): + the negative slope of the rectifier used after this layer (only used with 'leaky_relu') + mode (string): + mode selection, should be fan_in or fan_out. + Choosing 'fan_in' preserves the magnitude of the variance of the weights in the forward pass. Choosing 'fan_out' preserves the magnitudes in the backwards pass. + nonlinearity (string): + nonlinearity used after this layer. + It can be one of [linear, conv*, sigmoid, tanh, relu, leaky_relu]. + leaky_relu is used by default. + + Example:: + + from jittor import init + from jittor import nn + linear = nn.Linear(2,2) + init.kaiming_normal_(linear.weight) + print(linear.weight) + linear.weight.kaiming_normal_() # This is ok too + + ''' + std = calculate_std(var,mode,nonlinearity,a) + return gauss_(var,0, std) +Var.kaiming_normal_ = kaiming_normal_ + + +def xavier_uniform(shape, dtype="float32", gain=1.0): + ''' Inplace initialize Jittor Var by xavier_uniform. + The resulting var will have values sampled from + :math:`uniform(-a, a)` where + + .. math:: + a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}} + + Args: + shape (int or tuple of int): + shape of the return Var. + dtype (string): + dtype of the return Var, default float32. + gain (float): + an optional scaling factor. + + Example:: + + from jittor import init + from jittor import nn + a = init.xavier_uniform((2,2), gain=init.calculate_gain('relu')) + print(a) + ''' + assert len(shape)>1 + + matsize=1 + for i in shape[2:]: + matsize *= i + fan = (shape[1] * matsize) + (shape[0] * matsize) + bound = gain * math.sqrt(6.0/fan) + return uniform(shape, dtype, -bound, bound) + +def xavier_uniform_(var, gain=1.0): + ''' Inplace initialize Jittor Var by xavier_uniform. + The resulting var will have values sampled from + :math:`uniform(-a, a)` where + + .. math:: + a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}} + + Args: + var (Jittor Var): + Var to be initialized by random xavier_uniform + gain (float): + an optional scaling factor. + + Example:: + + from jittor import init + from jittor import nn + linear = nn.Linear(2,2) + init.xavier_uniform_(linear.weight, init.calculate_gain('relu')) + print(linear.weight) + linear.weight.xavier_uniform_() # This is ok too + + ''' + return var.assign(xavier_uniform(tuple(var.shape), var.dtype, gain)) +Var.xavier_uniform_ = xavier_uniform_ + +def xavier_gauss(shape, dtype="float32", gain=1.0): + ''' Return Jittor Var initialized by xavier_gauss, a.k.a xavier_normal. + The resulting var will have values sampled from + :math:`gauss(-a, a)` where + + .. math:: + \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}} + + Args: + shape (int or tuple of int): + shape of the return Var. + dtype (string): + dtype of the return Var, default float32. + gain (float): + an optional scaling factor. + + Example:: + + from jittor import init + from jittor import nn + linear = nn.Linear(2,2) + init.xavier_gauss_(linear.weight, init.calculate_gain('relu')) + print(linear.weight) + linear.weight.xavier_gauss_() # This is ok too + + ''' + assert len(shape)>1 + + matsize=1 + for i in shape[2:]: + matsize *= i + fan = (shape[1] * matsize) + (shape[0] * matsize) + std = gain * math.sqrt(2.0/fan) + return gauss(shape, dtype, 0, std) + +def xavier_gauss_(var, gain=1.0): + ''' Inplace initialize Jittor Var by xavier_gauss, a.k.a xavier_normal. + The resulting var will have values sampled from + :math:`gauss(-a, a)` where + + .. math:: + \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}} + + Args: + var (Jittor Var): + Var to be initialized by random xavier_gauss + gain (float): + an optional scaling factor. + + Example:: + + from jittor import init + from jittor import nn + linear = nn.Linear(2,2) + init.xavier_gauss_(linear.weight, init.calculate_gain('relu')) + print(linear.weight) + linear.weight.xavier_gauss_() # This is ok too + + ''' + return var.assign(xavier_gauss(tuple(var.shape), var.dtype, gain)) +Var.xavier_gauss_ = xavier_gauss_ + +def calculate_gain(nonlinearity, param=None): + r"""Return the recommended gain value for the given nonlinearity function. + The values are as follows: + + ================= ==================================================== + nonlinearity gain + ================= ==================================================== + Linear / Identity :math:`1` + Conv{1,2,3}D :math:`1` + Sigmoid :math:`1` + Tanh :math:`\frac{5}{3}` + ReLU :math:`\sqrt{2}` + Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` + SELU :math:`\frac{3}{4}` + ================= ==================================================== + + Args: + nonlinearity: the non-linear function (`nn.functional` name) + param: optional parameter for the non-linear function + + Examples: + >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 + + .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html + """ + linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] + if nonlinearity in linear_fns or nonlinearity == 'sigmoid': + return 1 + elif nonlinearity == 'tanh': + return 5.0 / 3 + elif nonlinearity == 'relu': + return math.sqrt(2.0) + elif nonlinearity == 'leaky_relu': + if param is None: + negative_slope = 0.01 + elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): + # True/False are instances of int, hence check above + negative_slope = param + else: + raise ValueError("negative_slope {} not a valid number".format(param)) + return math.sqrt(2.0 / (1 + negative_slope ** 2)) + elif nonlinearity == 'selu': + return 3.0 / 4 + else: + raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) + + +def trunc_normal_(var, mean=0., std=1., a=-2., b=2.): + # type: (jt.jittor_core.Var, float, float, float, float) -> jt.jittor_core.Var + r"""Fills the input jt.jittor_core.Var with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + var: an n-dimensional `jt.jittor_core.Var` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + + from jittor import init + from jittor import nn + linear = nn.Linear(2,2) + init.trunc_normal_(linear.weight, std=.02) + print(linear.weight) + linear.weight.trunc_normal_(std=.02) # This is ok too + """ + return var.assign(_no_grad_trunc_normal_(var, mean, std, a, b)) +Var.trunc_normal_ = trunc_normal_ + +def _no_grad_trunc_normal_(var, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + # var.uniform(2 * l - 1, 2 * u - 1) + var.uniform_(low=2 * l - 1, high=2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + var = var.erfinv() + + # Transform to proper mean, std + var = var.multiply(std * math.sqrt(2.)) + var = var.add(mean) + + # Clamp to ensure it's in the proper range + var = var.clamp(min_v=a, max_v=b) + return var \ No newline at end of file diff --git a/python/jittor/init_cupy.py b/python/jittor/init_cupy.py new file mode 100644 index 00000000..11e6bc7f --- /dev/null +++ b/python/jittor/init_cupy.py @@ -0,0 +1,52 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** + +has_cupy = 0 +try: + import cupy as cp + has_cupy = 1 +except: + pass +if has_cupy: + import jittor as jt + import os + import ctypes + device_num = 0 + if jt.mpi: + device_num = jt.mpi.local_rank() + device_num = device_num % cp.cuda.runtime.getDeviceCount() + cupy_device = cp.cuda.Device(device_num) + cupy_device.__enter__() + + def cvt(a): + a_pointer, read_only_flag = a.__array_interface__['data'] + aptr=cp.cuda.MemoryPointer(cp.cuda.memory.UnownedMemory(a_pointer,a.size*a.itemsize,a, device_num),0) + a = cp.ndarray(a.shape,a.dtype,aptr) + return a + + def numpy2cupy(snp, data): + for key in data: + if isinstance(data[key], list): + for i in range(len(data[key])): + data[key][i]=cvt(data[key][i]) + elif isinstance(data[key], int): + pass + else: + data[key]=cvt(data[key]) + + jt_allocator = ctypes.CDLL(os.path.join( + jt.compiler.cache_path, + "jittor_core"+jt.compiler.extension_suffix), + os.RTLD_NOW | os.RTLD_GLOBAL) + malloc = jt_allocator.get_jittor_cuda_malloc() + free = jt_allocator.get_jittor_cuda_free() +else: + def numpy2cupy(snp, data): + pass \ No newline at end of file diff --git a/python/jittor/linalg.py b/python/jittor/linalg.py new file mode 100644 index 00000000..2053d86c --- /dev/null +++ b/python/jittor/linalg.py @@ -0,0 +1,751 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Haoyang Peng <2247838039@qq.com> +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +from functools import partial +from .nn import ComplexNumber + +def complex_inv(x:ComplexNumber): + r""" + calculate the inverse of x. + :param x (...,M,M): + :return:x^-1 (...,M,M). + + TODO: Faster Implementation; Check backward. + """ + assert isinstance(x, ComplexNumber), "complex_inv is implemented for nn.ComplexNumber" + assert x.real.dtype == jt.float32 and x.imag.dtype == jt.float32, "real and imag in ComplexNumber should be jt.float32" + assert x.shape[-2] == x.shape[-1], "only square matrix is supported for complex_inv" + + def forward_code(np, data): + def _stack_to_complex(x): + return x[..., 0] + 1j * x[..., 1] + def _complex_to_stack(x): + return np.stack([np.real(x), np.imag(x)], axis=-1) + + a = _stack_to_complex(data["inputs"][0]) + m_a = data["outputs"][0] + t_a = np.linalg.inv(a) + np.copyto(m_a, _complex_to_stack(t_a)) + + + def backward_code(np, data): + def T(x): + return np.conj(np.swapaxes(x, -1, -2)) + def _stack_to_complex(x): + return x[..., 0] + 1j * x[..., 1] + def _complex_to_stack(x): + return np.stack([np.real(x), np.imag(x)], axis=-1) + _dot = partial(np.einsum, '...ij,...jk->...ik') + dout = _stack_to_complex(data["dout"]) + out = data["outputs"][0] + mx = _stack_to_complex(data["f_outputs"][0]) + t = -_dot(_dot(T(mx), dout), T(mx)) + np.copyto(out, _complex_to_stack(t)) + + lmx = jt.numpy_code( + x.value.shape, + x.value.dtype, + [x.value], + forward_code, + [backward_code], + ) + + return ComplexNumber(lmx, is_concat_value=True) + +def complex_eig(x:ComplexNumber): + r""" + calculate the eigenvalues and eigenvectors of x. + :param x (...,M,M): + :return:w, v. + w (...,M) : the eigenvalues. + v (...,M,M) : normalized eigenvectors. + """ + assert isinstance(x, ComplexNumber), "complex_eig is implemented for nn.ComplexNumber" + assert x.real.dtype == jt.float32 and x.imag.dtype == jt.float32, "real and imag in ComplexNumber should be jt.float32" + assert x.shape[-2] == x.shape[-1], "only square matrix is supported for complex_eig" + def forward_code(np, data): + def _stack_to_complex(x): + return x[..., 0] + 1j * x[..., 1] + def _complex_to_stack(x): + return np.stack([np.real(x), np.imag(x)], axis=-1) + a = _stack_to_complex(data["inputs"][0]) + w, v = data["outputs"] + tw, tv = np.linalg.eig(a) + np.copyto(w, _complex_to_stack(tw)) + np.copyto(v, _complex_to_stack(tv)) + + def backward_code(np, data): + raise NotImplementedError + + sw = x.shape[:-2] + x.shape[-1:] + (2,) + sv = x.value.shape + w, v = jt.numpy_code( + [sw, sv], + [x.value.dtype, x.value.dtype], + [x.value], + forward_code, + [backward_code], + ) + return ComplexNumber(w, is_concat_value=True), ComplexNumber(v, is_concat_value=True) + +def complex_qr(x): + r""" + do the qr factorization of x in the below formula: + x = QR where Q is orthogonal matrix and R is upper-triangle matrix. + :param x (...,M,M): + :return:q,r as the result of qr factorization.They are both in the shape of (...,M,M). + """ + assert isinstance(x, ComplexNumber), "linalg_qr is implemented for nn.ComplexNumber" + assert x.real.dtype == jt.float32 and x.imag.dtype == jt.float32, "real and imag in ComplexNumber should be jt.float32" + assert x.shape[-2] == x.shape[-1], "only square matrix is supported for linalg_qr" + def forward_code(np, data): + def _stack_to_complex(x): + return x[..., 0] + 1j * x[..., 1] + def _complex_to_stack(x): + return np.stack([np.real(x), np.imag(x)], axis=-1) + a = _stack_to_complex(data["inputs"][0]) + qr = data["outputs"][0] + Q, R = np.linalg.qr(a) + QR = np.stack([Q, R], axis=0) + np.copyto(qr, _complex_to_stack(QR)) + + def backward_code(np, data): + # reference: https://github.com/tencent-quantum-lab/tensorcircuit/blob/master/tensorcircuit/backends/pytorch_ops.py + def H(x): + return np.conj(np.swapaxes(x, -1, -2)) + def _TriangularSolve(x, r): + return H(np.linalg.solve(r, H(x))) + def _stack_to_complex(x): + return x[..., 0] + 1j * x[..., 1] + def _complex_to_stack(x): + return np.stack([np.real(x), np.imag(x)], axis=-1) + _dot = partial(np.einsum, '...ij,...jk->...ik') + _diag = partial(np.einsum, '...ii->...i') + + dout = data["dout"] + out = data["outputs"][0] + qr = data["f_outputs"][0] + dout = _stack_to_complex(dout) + dq, dr = dout[0], dout[1] + qr = _stack_to_complex(qr) + q, r = qr[0], qr[1] + + + qdq = _dot(H(q), dq) + qdq_ = qdq - H(qdq) + rdr = _dot(r, H(dr)) + rdr_ = rdr - H(rdr) + tril = np.tril(qdq_ + rdr_) + + grad_a = _dot(q, dr + _TriangularSolve(tril, r)) + grad_b = _TriangularSolve(dq - _dot(q, qdq), r) + ret = grad_a + grad_b + + m = rdr - H(qdq) + eyem = np.zeros_like(m) + _diag(eyem)[:] = _diag(m) + correction = eyem - np.real(eyem) + ret = ret + _TriangularSolve(_dot(q, H(correction)), r) + + ret = _complex_to_stack(ret) + np.copyto(out,ret) + + qr = jt.numpy_code( + (2,) + x.value.shape, + x.value.dtype, + [x.value], + forward_code, + [backward_code], + ) + q, r = qr[0], qr[1] + return ComplexNumber(q, is_concat_value=True), ComplexNumber(r, is_concat_value=True) + +def complex_svd(x:ComplexNumber): + r''' + calculate the Singular Value Decomposition of x.It follows the below fomula: + x = usv* + only support full matrices == False ver now, which means: + x's shape (...,M,K) + u's shape (...,M,K) + s's shape (...,K) + v's shape (...,K,N) + where K is min(M,N). + :param x: + :return:u,s,v. + ''' + def forward_code(np, data): + def _stack_to_complex(x): + return x[..., 0] + 1j * x[..., 1] + def _complex_to_stack(x): + return np.stack([np.real(x), np.imag(x)], axis=-1) + a = _stack_to_complex(data["inputs"][0]) + u, s, v = data["outputs"] + #TODO:remove copyto + tu, ts, tv = np.linalg.svd(a, full_matrices=0) + np.copyto(u, _complex_to_stack(tu)) + np.copyto(s, _complex_to_stack(ts)) + np.copyto(v, _complex_to_stack(tv)) + + def backward_code(np, data): + raise NotImplementedError + + m, n = x.shape[-2:] + k = min(m, n) + s1 = list(x.shape) + s1[-1] = k + s2 = list(x.shape) + s2[-2] = k + s3 = list(x.shape)[:-2] + s3.append(k) + s1.append(2) + s2.append(2) + s3.append(2) + u, s, v = jt.numpy_code( + [s1, s3, s2], + [x.value.dtype, x.value.dtype, x.value.dtype], + [x.value], + forward_code, + [backward_code], + ) + return ComplexNumber(u, is_concat_value=True), \ + ComplexNumber(s, is_concat_value=True), \ + ComplexNumber(v, is_concat_value=True) + +#TODO:full_matrices=1 +def svd(x): + r''' + calculate the Singular Value Decomposition of x.It follows the below fomula: + x = usv* + only support full matrices == False ver now, which means: + x's shape (...,M,K) + u's shape (...,M,K) + s's shape (...,K) + v's shape (...,K,N) + where K is min(M,N). + :param x: + :return:u,s,v. + ''' + if isinstance(x, ComplexNumber): + return complex_svd(x) + def forward_code(np, data): + a = data["inputs"][0] + u, s, v = data["outputs"] + #TODO:remove copyto + tu, ts, tv = np.linalg.svd(a, full_matrices=0) + np.copyto(u, tu) + np.copyto(s, ts) + np.copyto(v, tv) + + def backward_code(np, data): + def T(x): + return np.swapaxes(x, -1, -2) + _dot = partial(np.einsum, '...ij,...jk->...ik') + dout = data["dout"] + out = data["outputs"][0] + inp = data["inputs"][0] + out_index = data["out_index"] + u, s, v = data["f_outputs"] + v = T(v) + m, n = inp.shape[-2:] + k = np.min((m, n)) + i = np.reshape(np.eye(k), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (k, k)))) + if out_index == 0: + f = 1 / (s[..., np.newaxis, :] ** 2 - s[..., :, np.newaxis] ** 2 + i) + gu = dout + utgu = _dot(T(u), gu) + t = (f * (utgu - T(utgu))) * s[..., np.newaxis, :] + t = _dot(_dot(u, t), T(v)) + if m > n: + i_minus_uut = (np.reshape(np.eye(m), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (m, m)))) - + _dot(u, np.conj(T(u)))) + t = t + T(_dot(_dot(v / s[..., np.newaxis, :], T(gu)), i_minus_uut)) + np.copyto(out, t) + elif out_index == 1: + gs = dout + t = i * gs[..., :, np.newaxis] + t = _dot(_dot(u, t), T(v)) + np.copyto(out, t) + elif out_index == 2: + f = 1 / (s[..., np.newaxis, :] ** 2 - s[..., :, np.newaxis] ** 2 + i) + gv = dout + vtgv = _dot(T(v), gv) + t = s[..., :, np.newaxis] * (f * (vtgv - T(vtgv))) + t = _dot(_dot(u, t), T(v)) + if m < n: + i_minus_vvt = (np.reshape(np.eye(n), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (n, n)))) - + _dot(v, np.conj(T(v)))) + t = t + T(_dot(_dot(u / s[..., np.newaxis, :], T(gv)), i_minus_vvt)) + np.copyto(out, t) + + m, n = x.shape[-2:] + k = min(m, n) + s1 = list(x.shape) + s1[-1] = k + s2 = list(x.shape) + s2[-2] = k + s3 = list(x.shape)[:-2] + s3.append(k) + u, s, v = jt.numpy_code( + [s1, s3, s2], + [x.dtype, x.dtype, x.dtype], + [x], + forward_code, + [backward_code], + ) + return u, s, v + +def eig(x): + r""" + calculate the eigenvalues and eigenvectors of x. + :param x (...,M,M): + :return (ComplexNumber):w, v. + w (...,M) : the eigenvalues. + v (...,M,M) : normalized eigenvectors. + """ + if isinstance(x, ComplexNumber): + return complex_eig(x) + return complex_eig(ComplexNumber(x)) + +def eigh(x): + r""" + calculate the eigenvalues and eigenvectors of x. + :param x (...,M,M): + :return:w, v. + w (...,M) : the eigenvalues. + v (...,M,M) : normalized eigenvectors. + """ + def forward_code(np, data): + a = data["inputs"][0] + w, v = data["outputs"] + tw, tv = np.linalg.eigh(a, UPLO='L') + np.copyto(w, tw) + np.copyto(v, tv) + + def backward_code(np, data): + def T(x): + return np.swapaxes(x, -1, -2) + _dot = partial(np.einsum, '...ij,...jk->...ik') + dout = data["dout"] + out = data["outputs"][0] + inp = data["inputs"][0] + out_index = data["out_index"] + w, v = data["f_outputs"] + k = int(inp.shape[-1]) + w_repeated = np.repeat(w[..., np.newaxis], k, axis=-1) + if out_index == 0: + t = _dot(v * dout[..., np.newaxis, :], T(v)) + np.copyto(out, t) + elif out_index == 1: + if np.any(dout): + off_diag = np.ones((k, k)) - np.eye(k) + F = off_diag / (T(w_repeated) - w_repeated + np.eye(k)) + t = _dot(_dot(v, F * _dot(T(v), dout)), T(v)) + np.copyto(out, t) + + sw = x.shape[:-2] + x.shape[-1:] + sv = x.shape + w, v = jt.numpy_code( + [sw, sv], + [x.dtype, x.dtype], + [x], + forward_code, + [backward_code], + ) + return w, v + + +def inv(x): + r""" + calculate the inverse of x. + :param x (...,M,M): + :return:x^-1 (...,M,M). + """ + if isinstance(x, ComplexNumber): + return complex_inv(x) + def forward_code(np, data): + a = data["inputs"][0] + m_a = data["outputs"][0] + t_a = np.linalg.inv(a) + np.copyto(m_a, t_a) + + def backward_code(np, data): + def T(x): + return np.swapaxes(x, -1, -2) + _dot = partial(np.einsum, '...ij,...jk->...ik') + dout = data["dout"] + out = data["outputs"][0] + lmx = data["f_outputs"] + mx = lmx[0] + t = -_dot(_dot(T(mx), dout), T(mx)) + np.copyto(out, t) + + lmx = jt.numpy_code( + [x.shape], + [x.dtype], + [x], + forward_code, + [backward_code], + ) + mx = lmx[0] + return mx + + +def pinv(x): + r""" + calculate the pseudo-inverse of a x. + :param x (...,M,N) + :return: x's pinv (...N,M) + """ + def forward_code(np, data): + a = data["inputs"][0] + m_a = data["outputs"][0] + t_a = np.linalg.pinv(a) + np.copyto(m_a, t_a) + + def backward_code(np, data): + def T(x): + return np.swapaxes(x, -1, -2) + _dot = partial(np.einsum, '...ij,...jk->...ik') + dout = data["dout"] + out = data["outputs"][0] + inp = data["inputs"][0] + lmx = data["f_outputs"] + mx = lmx[0] + t = T( + -_dot(_dot(mx, T(dout)), mx) + + _dot(_dot(_dot(mx, T(mx)), dout), np.eye(inp.shape[-2]) - _dot(inp, mx)) + + _dot(_dot(_dot(np.eye(mx.shape[-2]) - _dot(mx, inp), dout), T(mx)), mx) + ) + np.copyto(out, t) + sw = list(x.shape[:-2]) + [x.shape[-1]] + [x.shape[-2]] + lmx = jt.numpy_code( + [sw], + [x.dtype], + [x], + forward_code, + [backward_code], + ) + mx = lmx[0] + return mx + + +def det(x): + r""" + calculate the determinant of x. + :param x (...,M,M): + :return:|x| (...,1) + """ + def forward_code(np, data): + a = data["inputs"][0] + L = data["outputs"][0] + tL = np.linalg.det(a) + np.copyto(L, tL) + + def backward_code(np, data): + def T(x): + return np.swapaxes(x, -1, -2) + _dot = partial(np.einsum, '...ij,...jk->...ik') + dout = data["dout"] + out = data["outputs"][0] + f_out = data["f_outputs"][0] + inp = data["inputs"][0] + n_d = np.reshape(dout, np.shape(dout) + (1, 1)) + n_o = np.reshape(f_out, np.shape(f_out) + (1, 1)) + s = n_d * n_o * T(np.linalg.inv(inp)) + np.copyto(out, s) + + s = x.shape + x_s = s[:-2] + if len(s) == 2: + x_s.append(1) + l_det = jt.numpy_code( + [x_s], + [x.dtype], + [x], + forward_code, + [backward_code], + ) + det = l_det[0] + return det + + +def slogdet(x): + r""" + calculate the sign and log of the determinant of x. + :param x (...,M,M): + :return sign, x's logdet. + sign array decides the sign of determinant and their values can be -1,0,1.Only Real number now.0 means det is 0 and logdet is -inf. + logdet in shape (...,1). + """ + def forward_code(np, data): + a = data["inputs"][0] + sign, m_a = data["outputs"] + sign_, t_a = np.linalg.slogdet(a) + np.copyto(m_a, t_a) + np.copyto(sign, sign_) + + def backward_code(np, data): + def T(x): + return np.swapaxes(x, -1, -2) + _dot = partial(np.einsum, '...ij,...jk->...ik') + dout = data["dout"] + out = data["outputs"][0] + inp = data["inputs"][0] + out_index = data["out_index"] + if out_index == 0: + np.copyto(out, 0) + if out_index == 1: + t = np.reshape(dout, np.shape(dout) + (1, 1)) + t = t * T(np.linalg.inv(inp)) + np.copyto(out, t) + + s = x.shape + det_s = s[:-2] + if len(det_s) == 0: + det_s.append(1) + sign, mx = jt.numpy_code( + [det_s, det_s], + [x.dtype, x.dtype], + [x], + forward_code, + [backward_code], + ) + return sign, mx + + +def cholesky(x): + r""" + do Cholesky decomposition of x in the form of below formula: + x = LL^T + x must be a Hermite and positive-definite matrix. L is a lower-triangular matrix. + :param x (...,M,M): + :return: L (...,M,M). + """ + def forward_code(np, data): + a = data["inputs"][0] + L = data["outputs"][0] + tL = np.linalg.cholesky(a) + np.copyto(L, tL) + + def backward_code(np, data): + def T(x): + return np.swapaxes(x, -1, -2) + _dot = partial(np.einsum, '...ij,...jk->...ik') + dout = data["dout"] + out = data["outputs"][0] + f_out = data["f_outputs"][0] + solve_trans = lambda a, b: np.linalg.solve(T(a), b) + phi = lambda X: np.tril(X) / (1. + np.eye(X.shape[-1])) + + def conjugate_solve(L, X): + return solve_trans(L, T(solve_trans(L, T(X)))) + + s = conjugate_solve(f_out, phi(np.einsum('...ki,...kj->...ij', f_out, dout))) + s = (s + T(s)) / 2. + np.copyto(out, s) + + lL = jt.numpy_code( + [x.shape], + [x.dtype], + [x], + forward_code, + [backward_code], + ) + L = lL[0] + return L + + +def solve(a,b): + r""" + Solve a linear matrix equation Ax = B.This is done by calculating x = A^-1B.So A must not be singular. + :param a:(...,M,M) + :param b:(...,M) + :return:solution of Ax = b formula.x in the shape of (...M) + """ + def forward_code(np, data): + a, b = data["inputs"] + L = data["outputs"][0] + ans = np.linalg.solve(a, b) + np.copyto(L, ans) + + def backward_code1(np, data): + def T(x): + return np.swapaxes(x, -1, -2) + _dot = partial(np.einsum, '...ij,...jk->...ik') + dout = data["dout"] + out = data["outputs"][0] + f_out = data["f_outputs"][0] + inp = data["inputs"][0] + updim = lambda x: x if x.ndim == a.ndim else x[..., None] + t = -_dot(updim(np.linalg.solve(T(inp), dout)), T(updim(f_out))) + np.copyto(out, t) + + def backward_code2(np, data): + out = data["outputs"][0] + np.copyto(out, 0) + + l_ans = jt.numpy_code( + [b.shape], + [b.dtype], + [a, b], + forward_code, + [backward_code1, backward_code2], + ) + ans = l_ans[0] + return ans + + +def qr(x): + r""" + do the qr factorization of x in the below formula: + x = QR where Q is orthogonal matrix and R is upper-triangle matrix. + :param x (...,M,M): + :return:q,r as the result of qr factorization.They are both in the shape of (...,M,M). + """ + if isinstance(x, ComplexNumber): + return complex_qr(x) + def forward_code(np, data): + a = data["inputs"][0] + q, r = data["outputs"] + Q, R = np.linalg.qr(a) + np.copyto(q,Q) + np.copyto(r,R) + + def backward_code(np, data): + def T(x): + return np.swapaxes(x, -1, -2) + _dot = partial(np.einsum, '...ij,...jk->...ik') + _harmard = partial(np.einsum, '...ij,...ij->...ij') + dout = data["dout"] + out = data["outputs"][0] + q, r = data["f_outputs"] + out_index = data["out_index"] + #pl = np.tril(np.ones((inp.shape[-1],inp.shape[-1])))-diags + if out_index == 0: # Q_TERM + q_t = _dot(T(q),dout) + rhs_solve = q_t - T(q_t) + rhs_solve = T(np.tril(rhs_solve,-1)) + qsolve = np.linalg.solve(r,rhs_solve) + qsolve = T(qsolve) + tq = _dot(q,qsolve) + np.copyto(out,tq) + else: #R_TERM + r_t = _dot(r ,T(dout)) + rhs_solve = r_t - T(r_t) + rhs_solve = np.tril(rhs_solve,-1) + rhs_solve = T(rhs_solve) + r_solve = np.linalg.solve(r,rhs_solve) + tr = _dot(q,(T(r_solve) + dout)) + np.copyto(out,tr) + + q, r = jt.numpy_code( + [x.shape,x.shape], + [x.dtype,x.dtype], + [x], + forward_code, + [backward_code], + ) + return q, r + + +def einsum(string, *args): + r""" + do the einsum operation. Using the implementation in https://github.com/HIPS/autograd + :param string, args: + :return: return values depend on the input string kinds. + """ + import numpy as np_cpu + if string == "i,j->ij": + return args[0].broadcast((args[0].shape[0], args[1].shape[0]), dims=[1]).multiply(args[1]) + def forward_code(np, data): + out = data["outputs"][0] + npout = np.einsum(string, *data["inputs"]) + np.copyto(out, npout) + + def backward_code(np, data, argnum=0): + real_len = len(data["inputs"]) - 2 + operands = data["inputs"][:real_len] + _ops = operands + if np_cpu is not np: + # fake a numpy array + _ops = [ np_cpu.zeros((1,)*o.ndim) for o in _ops ] + in_subs, out_subs, _ = np_cpu.core.einsumfunc._parse_einsum_input([string] + _ops) + dout = data["dout"] + out_index = data["out_index"] + out = data["outputs"][0] + inp = data["inputs"][argnum] + c = data["f_outputs"] + + in_subs_list = in_subs.split(',') + op_num = argnum + subs_wrt = in_subs_list[op_num] + rest_of_ops = operands[:op_num] + operands[op_num+1:] + rest_of_subs = in_subs_list[:op_num] + in_subs_list[op_num+1:] + other_named_subs = set(''.join([out_subs] + rest_of_subs)) + naked_summed = [(i, sub) for i, sub in enumerate(subs_wrt) + if sub not in other_named_subs] + if naked_summed: + naked_summed_dims, ones_subs = zip(*naked_summed) + ones_subs = ''.join(ones_subs) + ones = np_cpu.ones(np_cpu.array(operands[op_num].shape)[list(naked_summed_dims)]) + new_input_subs = ','.join([out_subs, ones_subs] + rest_of_subs) + new_operands = [dout, ones] + rest_of_ops + else: + new_input_subs = ','.join([out_subs] + rest_of_subs) + new_operands = [dout] + rest_of_ops + + new_subscripts = new_input_subs + '->' + subs_wrt + x = np.einsum(new_subscripts, *new_operands) + while np.ndim(x) > np.ndim(inp): + x = np.sum(x, axis=broadcast_idx) + for axis, size in enumerate(inp.shape): + if size == 1: + x = np.sum(x, axis=axis, keepdims=True) + np.copyto(out, x) + + def einsum_outshape(einsum_expr, inputs): + shps = np_cpu.concatenate([in_.shape for in_ in inputs]) + p = einsum_expr.replace(" ", "").split(',') + s = p[:-1] + p[-1].split('->') + rec_shape = [] + ellip_expr = None + const_rep = '1234567890' # assume tensor shape no more than 10 dimensions + for idx, expr in enumerate(s[:-1]): + if "..." in expr: + assert "..." in s[-1] + else: + continue + shp = inputs[idx].shape + ellipsis_pos = len(expr.replace("...", "")) + nellip_expr = const_rep[0 : len(shp) - ellipsis_pos] + if ellip_expr is None: + ellip_expr = nellip_expr + else: + assert ellip_expr == nellip_expr, "Please keep broadcast ellipsis record the same ellipsis." + s[idx] = expr.replace("...", ellip_expr) + if ellip_expr: + s[-1] = s[-1].replace("...", ellip_expr) + if s[-1]=='': + return () + else: + inop = list(map(list,s)) + return tuple(shps[(np_cpu.concatenate(inop[:-1])[:,None]==inop[-1]).argmax(0)].astype(np_cpu.int64)) + + output_shape = [int(x) for x in einsum_outshape(string, args)] + backwards = [partial(backward_code, argnum=idx) for idx in range(len(args))] + a = jt.numpy_code( + [output_shape], + [args[0].dtype], + args, + forward_code, + backwards, + )[0] + return a \ No newline at end of file diff --git a/python/jittor/loss3d/__init__.py b/python/jittor/loss3d/__init__.py new file mode 100644 index 00000000..4081ce05 --- /dev/null +++ b/python/jittor/loss3d/__init__.py @@ -0,0 +1,2 @@ +from .chamfer import chamfer_loss, ChamferLoss +from .emd import earth_mover_distance, EarthMoverDistance diff --git a/python/jittor/loss3d/chamfer.py b/python/jittor/loss3d/chamfer.py new file mode 100644 index 00000000..c0864d63 --- /dev/null +++ b/python/jittor/loss3d/chamfer.py @@ -0,0 +1,153 @@ +# Author: Zheng-Ning Liu +# +# This file implements chamfer loss on both CPU and GPU. +# The implementation does no use extra NxM matrix to store distances, and thus +# supports large point clouds. + +import jittor as jt +import jittor.nn as nn + +cpu_src = ''' + for (int bs = 0; bs < in0_shape0; ++bs) + for (int i = 0; i < in0_shape1; ++i) { + float min_dis = (@in0(bs, i, 0) - @in1(bs, 0, 0)) * (@in0(bs, i, 0) - @in1(bs, 0, 0)) + + (@in0(bs, i, 1) - @in1(bs, 0, 1)) * (@in0(bs, i, 1) - @in1(bs, 0, 1)) + + (@in0(bs, i, 2) - @in1(bs, 0, 2)) * (@in0(bs, i, 2) - @in1(bs, 0, 2)); + @out(bs, i) = 0; + for (int j = 1; j < in1_shape1; ++j) { + float dis = (@in0(bs, i, 0) - @in1(bs, j, 0)) * (@in0(bs, i, 0) - @in1(bs, j, 0)) + + (@in0(bs, i, 1) - @in1(bs, j, 1)) * (@in0(bs, i, 1) - @in1(bs, j, 1)) + + (@in0(bs, i, 2) - @in1(bs, j, 2)) * (@in0(bs, i, 2) - @in1(bs, j, 2)); + if (dis < min_dis) { + min_dis = dis; + @out(bs, i) = j; + } + } + } +''' + +cuda_src = ''' + __global__ void chamfer_loss_min_idx_kernel(@ARGS_DEF) { + @PRECALC + int bs = blockIdx.x; + int n = in0_shape1; + int m = in1_shape1; + + for (int i = threadIdx.x; i < n; i += blockDim.x) { + float min_dis = (@in0(bs, i, 0) - @in1(bs, 0, 0)) * (@in0(bs, i, 0) - @in1(bs, 0, 0)) + + (@in0(bs, i, 1) - @in1(bs, 0, 1)) * (@in0(bs, i, 1) - @in1(bs, 0, 1)) + + (@in0(bs, i, 2) - @in1(bs, 0, 2)) * (@in0(bs, i, 2) - @in1(bs, 0, 2)); + @out(bs, i) = 0; + for (int j = 1; j < m; ++j) { + float dis = (@in0(bs, i, 0) - @in1(bs, j, 0)) * (@in0(bs, i, 0) - @in1(bs, j, 0)) + + (@in0(bs, i, 1) - @in1(bs, j, 1)) * (@in0(bs, i, 1) - @in1(bs, j, 1)) + + (@in0(bs, i, 2) - @in1(bs, j, 2)) * (@in0(bs, i, 2) - @in1(bs, j, 2)); + if (dis < min_dis) { + min_dis = dis; + @out(bs, i) = j; + } + } + } + } + + chamfer_loss_min_idx_kernel<<>>(@ARGS); +''' + + +def chamfer_loss(pc1, pc2, reduction='mean', dims='BNC', bidirectional=False): + ''' return the chamfer loss from pc1 to pc2. + + :param pc1: input point cloud + :type pc1: jittor array + + :param pc2: input point cloud + :type pc2: jittor array + + :param reduction: reduction method in batches, can be 'mean', 'sum', or None. Default: 'mean'. + :type reduction: str, optional + + :param dims: a string that represents each dimension, can be + '[BNC]' ([batch, number of points, xyz]), or + '[BCN]' ([batch, xyz, number of points]). Default: 'BNC'. + :type dims: str, optional + + Example: + + >>> import jittor as jt + >>> from jittor.loss3d import chamfer_loss + >>> jt.flags.use_cuda = True + >>> pc1 = jt.rand([10, 100, 3], dtype=jt.float32) + >>> pc2 = jt.rand([10, 100, 3], dtype=jt.float32) + >>> cf = chamfer_loss(pc1, pc2, dims='BNC', bidirectional=True) + >>> print('chamfer loss =', cf.item()) + ''' + if bidirectional: + return chamfer_loss(pc1, pc2, reduction, dims) + chamfer_loss(pc2, pc1, reduction, dims) + + assert dims in ['BNC', 'BCN'] + if dims == 'BCN': + pc1, pc2 = pc1.permute(0, 2, 1), pc2.permute(0, 2, 1) + + batch_size_1, N, _ = pc1.shape + batch_size_2, M, _ = pc2.shape + assert batch_size_1 == batch_size_2 + batch_size = batch_size_1 + + idx = jt.code([batch_size, N], 'int32', [pc1, pc2], + cpu_src=cpu_src, + cuda_src=cuda_src) + + nearest_pts = pc2.reindex([batch_size, idx.shape[1], 3], [ + 'i0', + '@e0(i0, i1)', + 'i2' + ], extras=[idx]) + + chamfer_distance = (((pc1 - nearest_pts) ** 2).sum(dim=-1)).sqrt() + if reduction is None: + return chamfer_distance + elif reduction == 'sum': + return jt.sum(chamfer_distance) + elif reduction == 'mean': + return jt.mean(chamfer_distance) + + +class ChamferLoss(nn.Module): + ''' A loss layer that computes the chamfer loss from pc1 to pc2. + + :param pc1: input point cloud + :type pc1: jittor array + + :param pc2: input point cloud + :type pc2: jittor array + + :param reduction: reduction method in batches, can be 'mean', 'sum', or None. Default: 'mean'. + :type reduction: str, optional + + :param dims: a string that represents each dimension, can be + '[BNC]' ([batch, number of points, xyz]), or + '[BCN]' ([batch, xyz, number of points]). Default: 'BNC'. + :type dims: str, optional + + Example: + + >>> import jittor as jt + >>> from jittor.loss3d import ChamferLoss + >>> jt.flags.use_cuda = True + >>> pc1 = jt.rand([10, 100, 3], dtype=jt.float32) + >>> pc2 = jt.rand([10, 100, 3], dtype=jt.float32) + >>> CF = ChamferLoss(dims='BNC', bidirectional=True) + >>> cf = CF(pc1, pc2) + >>> print('chamfer loss =', cf.item()) + ''' + + def __init__(self, reduction='mean', dims='BNC', bidirectional=False): + ''' see function @chamfer_loss + ''' + super().__init__() + self.reduction = reduction + self.dims = dims + self.bidirectional = bidirectional + + def execute(self, pc1, pc2): + return chamfer_loss(pc1, pc2, self.reduction, self.dims, self.bidirectional) diff --git a/python/jittor/loss3d/emd.py b/python/jittor/loss3d/emd.py new file mode 100644 index 00000000..df0b60b5 --- /dev/null +++ b/python/jittor/loss3d/emd.py @@ -0,0 +1,440 @@ +# Author: Zheng-Ning Liu +# +# The gpu implementation is original provided by Haoqiang Fan and Kaichun Mo, +# . + +import jittor as jt +from jittor import Function + +EMD_gpu_header = ''' +namespace jittor { +__device__ inline out_type dist2(out_type x1, out_type y1, out_type z1, + out_type x2, out_type y2, out_type z2) { + return (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); +} +} +''' + +approxmatch_gpu_src = ''' + __global__ void approxmatch_gpu_kernel(@ARGS_DEF) { + @PRECALC + @alias(xyz1, in0) + @alias(xyz2, in1) + @alias(match, out) + + int b = in0_shape0; + int n = in0_shape1; + int m = in1_shape1; + + out_type *remainL = in2_p + blockIdx.x * (n + m) * 2; + out_type *remainR = remainL + n; + out_type *ratioL = remainR + m; + out_type *ratioR = ratioL + n; + + const int Block = 1024; + __shared__ out_type buf[Block * 4]; + + for (int i = blockIdx.x; i < b; i += gridDim.x) { + for (int j = threadIdx.x; j < n * m; j += blockDim.x) + match_p[i * n * m + j] = 0; + for (int j = threadIdx.x; j < n; j += blockDim.x) + remainL[j] = n >= m ? 1 : m / n; + for (int j = threadIdx.x; j < m; j += blockDim.x) + remainR[j] = n >= m ? n / m : 1; + __syncthreads(); + + for (int j = 7; j >= -2; j--) { + out_type level = j > -2 ? -powf(4.0f, j) : 0; + + for (int k0 = 0; k0 < n; k0 += blockDim.x) { + int k = k0 + threadIdx.x; + out_type x1 = 0, y1 = 0, z1 = 0; + if (k < n) { + x1 = @xyz1(i, k, 0); + y1 = @xyz1(i, k, 1); + z1 = @xyz1(i, k, 2); + } + + out_type suml = 1e-9f; + for (int l0 = 0; l0 < m; l0 += Block){ + int lend = min(m, l0 + Block) - l0; + for (int l = threadIdx.x; l < lend; l += blockDim.x) { + buf[l * 4 + 0] = @xyz2(i, l0 + l, 0); + buf[l * 4 + 1] = @xyz2(i, l0 + l, 1); + buf[l * 4 + 2] = @xyz2(i, l0 + l, 2); + buf[l * 4 + 3] = remainR[l0 + l]; + } + __syncthreads(); + + for (int l = 0; l < lend; l++){ + out_type x2 = buf[l * 4 + 0]; + out_type y2 = buf[l * 4 + 1]; + out_type z2 = buf[l * 4 + 2]; + out_type d = level * dist2(x1, y1, z1, x2, y2, z2); + out_type w = __expf(d) * buf[l * 4 + 3]; + suml += w; + } + __syncthreads(); + } + if (k < n) + ratioL[k] = remainL[k] / suml; + } + __syncthreads(); + + for (int l0 = 0; l0 < m; l0 += blockDim.x){ + int l = l0 + threadIdx.x; + out_type x2 = 0, y2 = 0, z2 = 0; + if (l < m){ + x2 = @xyz2(i, l, 0); + y2 = @xyz2(i, l, 1); + z2 = @xyz2(i, l, 2); + } + out_type sumr = 0; + for (int k0 = 0; k0 < n; k0 += Block){ + int kend = min(n, k0 + Block) - k0; + for (int k = threadIdx.x; k < kend; k += blockDim.x){ + buf[k * 4 + 0] = @xyz1(i, k0 + k, 0); + buf[k * 4 + 1] = @xyz1(i, k0 + k, 1); + buf[k * 4 + 2] = @xyz1(i, k0 + k, 2); + buf[k * 4 + 3] = ratioL[k0 + k]; + } + __syncthreads(); + + for (int k = 0; k < kend; k++){ + out_type x1 = buf[k * 4 + 0]; + out_type y1 = buf[k * 4 + 1]; + out_type z1 = buf[k * 4 + 2]; + out_type d = level * dist2(x1, y1, z1, x2, y2, z2); + out_type w = __expf(d) * buf[k * 4 + 3]; + sumr += w; + } + __syncthreads(); + } + + if (l < m){ + sumr *= remainR[l]; + out_type consumption = fminf(remainR[l] / (sumr + 1e-9f), 1.0f); + ratioR[l] = consumption * remainR[l]; + remainR[l] = fmaxf(0.0f, remainR[l] - sumr); + } + } + __syncthreads(); + + for (int k0 = 0; k0 < n; k0 += blockDim.x){ + int k = k0 + threadIdx.x; + out_type x1 = 0, y1 = 0, z1 = 0; + if (k < n){ + x1 = @xyz1(i, k, 0); + y1 = @xyz1(i, k, 1); + z1 = @xyz1(i, k, 2); + } + out_type suml = 0; + for (int l0 = 0; l0 < m; l0 += Block){ + int lend = min(m, l0 + Block)-l0; + for (int l = threadIdx.x; l < lend; l += blockDim.x){ + buf[l * 4 + 0] = @xyz2(i, l0 + l, 0); + buf[l * 4 + 1] = @xyz2(i, l0 + l, 1); + buf[l * 4 + 2] = @xyz2(i, l0 + l, 2); + buf[l * 4 + 3] = ratioR[l0 + l]; + } + __syncthreads(); + + out_type rl = ratioL[k]; + if (k < n){ + for (int l = 0; l < lend; l++){ + out_type x2 = buf[l * 4 + 0]; + out_type y2 = buf[l * 4 + 1]; + out_type z2 = buf[l * 4 + 2]; + out_type d = level * dist2(x1, y1, z1, x2, y2, z2); + out_type w = __expf(d) * rl * buf[l*4+3]; + @match(i, l0 + l, k) += w; + suml += w; + } + } + __syncthreads(); + } + if (k < n) + remainL[k] = fmaxf(0.0f, remainL[k] - suml); + } + __syncthreads(); + } + } + } + + approxmatch_gpu_kernel<<<32, 512>>>(@ARGS); +''' + +matchcost_gpu_src = ''' + __global__ void matchcost_gpu_kernel(@ARGS_DEF) { + @PRECALC + @alias(xyz1, in0) + @alias(xyz2, in1) + @alias(match, in2) + + int b = in0_shape0; + int n = in0_shape1; + int m = in1_shape1; + + const int Block = 1024; + __shared__ out_type allsum[512]; + __shared__ out_type buf[Block * 3]; + + for (int i = blockIdx.x; i < b; i += gridDim.x) { + out_type subsum = 0; + for (int k0 = 0; k0 < n; k0 += blockDim.x) { + int k = k0 + threadIdx.x; + out_type x1 = 0, y1 = 0, z1 = 0; + if (k < n) { + x1 = @xyz1(i, k, 0); + y1 = @xyz1(i, k, 1); + z1 = @xyz1(i, k, 2); + } + + for (int l0 = 0; l0 < m; l0 += Block) { + int lend = min(m, l0 + Block) - l0; + for (int l = threadIdx.x; l < lend * 3; l += blockDim.x) + buf[l] = xyz2_p[i * m * 3 + l0 * 3 + l]; + __syncthreads(); + + if (k < n) { + for (int l = 0; l < lend; l++) { + out_type x2 = buf[l * 3 + 0]; + out_type y2 = buf[l * 3 + 1]; + out_type z2 = buf[l * 3 + 2]; + out_type d = dist2(x1, y1, z1, x2, y2, z2); + subsum += d * @match(i, l0 + l, k); + } + } + __syncthreads(); + } + } + + allsum[threadIdx.x] = subsum; + for (int j = 1; j < blockDim.x; j <<= 1) { + __syncthreads(); + if ((threadIdx.x & j) == 0 && threadIdx.x + j < blockDim.x) { + allsum[threadIdx.x] += allsum[threadIdx.x + j]; + } + } + + if (threadIdx.x == 0) + @out(i) = allsum[0]; + __syncthreads(); + } + } + + matchcost_gpu_kernel<<<32, 512>>>(@ARGS); +''' + +matchcost_grad1_gpu_src = ''' + __global__ void matchcost_grad1_gpu_kernel(@ARGS_DEF) { + @PRECALC + @alias(grad, in0) + @alias(xyz1, in1) + @alias(xyz2, in2) + @alias(match, in3) + + int b = grad_shape0; + int n = xyz1_shape1; + int m = xyz2_shape1; + + for (int i = blockIdx.x; i < b ; i += gridDim.x){ + for (int l = threadIdx.x; l < n; l += blockDim.x){ + out_type x1 = @xyz1(i, l, 0); + out_type y1 = @xyz1(i, l, 1); + out_type z1 = @xyz1(i, l, 2); + out_type dx = 0, dy = 0, dz = 0; + for (int k = 0; k < m; k++){ + out_type x2 = @xyz2(i, k, 0); + out_type y2 = @xyz2(i, k, 1); + out_type z2 = @xyz2(i, k, 2); + out_type d = @match(i, k, l) * 2; + dx += (x1 - x2) * d; + dy += (y1 - y2) * d; + dz += (z1 - z2) * d; + } + @out(i, l, 0) = dx * @grad(i); + @out(i, l, 1) = dy * @grad(i); + @out(i, l, 2) = dz * @grad(i); + } + } + } + + matchcost_grad1_gpu_kernel<<<32, 512>>>(@ARGS); +''' + +matchcost_grad2_gpu_src = ''' + __global__ void matchcost_grad2_gpu_kernel(@ARGS_DEF) { + @PRECALC + @alias(grad, in0) + @alias(xyz1, in1) + @alias(xyz2, in2) + @alias(match, in3) + + int b = grad_shape0; + int n = xyz1_shape1; + int m = xyz2_shape1; + + __shared__ out_type sum_grad[256 * 3]; + for (int i = blockIdx.x; i < b; i += gridDim.x) { + int kbeg = m * blockIdx.y / gridDim.y; + int kend = m * (blockIdx.y + 1) / gridDim.y; + for (int k = kbeg; k < kend; k++) { + out_type x2 = @xyz2(i, k, 0); + out_type y2 = @xyz2(i, k, 1); + out_type z2 = @xyz2(i, k, 2); + out_type subsumx = 0, subsumy = 0, subsumz = 0; + for (int j = threadIdx.x; j < n; j += blockDim.x) { + out_type x1 = x2 - @xyz1(i, j, 0); + out_type y1 = y2 - @xyz1(i, j, 1); + out_type z1 = z2 - @xyz1(i, j, 2); + out_type d = @match(i, k, j) * 2; + subsumx += x1 * d; + subsumy += y1 * d; + subsumz += z1 * d; + } + sum_grad[threadIdx.x * 3 + 0] = subsumx; + sum_grad[threadIdx.x * 3 + 1] = subsumy; + sum_grad[threadIdx.x * 3 + 2] = subsumz; + + for (int j = 1; j < blockDim.x; j <<= 1) { + __syncthreads(); + int j1 = threadIdx.x; + int j2 = threadIdx.x + j; + if ((j1 & j) == 0 && j2 < blockDim.x){ + sum_grad[j1 * 3 + 0] += sum_grad[j2 * 3 + 0]; + sum_grad[j1 * 3 + 1] += sum_grad[j2 * 3 + 1]; + sum_grad[j1 * 3 + 2] += sum_grad[j2 * 3 + 2]; + } + } + if (threadIdx.x == 0){ + @out(i, k, 0) = sum_grad[0] * @grad(i); + @out(i, k, 1) = sum_grad[1] * @grad(i); + @out(i, k, 2) = sum_grad[2] * @grad(i); + } + __syncthreads(); + } + } + } + + matchcost_grad2_gpu_kernel<<>>(@ARGS); +''' + +class EarthMoverDistance(Function): + ''' A loss layer that computes Earth Mover's distance from pc1 to pc2. Only supports GPU. + + :param pc1: input point cloud + :type pc1: jittor array + + :param pc2: input point cloud + :type pc2: jittor array + + :param reduction: reduction method in batches, can be 'mean', 'sum', or None. Default: 'mean'. + :type reduction: str, optional + + :param dims: a string that represents each dimension, can be + '[BNC]' ([batch, number of points, xyz]), or + '[BCN]' ([batch, xyz, number of points]). Default: 'BNC'. + :type dims: str, optional + + Example: + + >>> import jittor as jt + >>> from jittor.loss3d import EarthMoverDistance + >>> jt.flags.use_cuda = True + >>> pc1 = jt.rand([10, 100, 3], dtype=jt.float32) + >>> pc2 = jt.rand([10, 100, 3], dtype=jt.float32) + >>> EMD = EarthMoverDistance(dims='BNC') + >>> emd = EMD(pc1, pc2) + >>> print('EMD =', emd.item()) + ''' + def execute(self, pc1, pc2, reduction='mean', dims='BNC'): + assert dims in ['BNC', 'BCN'] + if dims == 'BCN': + pc1, pc2 = pc1.permute(0, 2, 1), pc2.permute(0, 2, 1) + + batch_size_1, N, _ = pc1.shape + batch_size_2, M, _ = pc2.shape + assert batch_size_1 == batch_size_2 + batch_size = batch_size_1 + + temp = jt.zeros([batch_size, (N + M) * 2], pc1.dtype) + match = jt.code( + shape=[batch_size, M, N], + dtype=pc1.dtype, + inputs=[pc1, pc2, temp], + cuda_header=EMD_gpu_header, + cuda_src=approxmatch_gpu_src, + ) + + emd = jt.code( + shape=[batch_size], + dtype=pc1.dtype, + inputs=[pc1, pc2, match], + cuda_header=EMD_gpu_header, + cuda_src=matchcost_gpu_src, + ) + + self.saved_vars = (pc1, pc2, match, reduction) + + if reduction is None: + return emd + elif reduction == 'sum': + return emd.sum() + elif reduction == 'mean': + return emd.mean() + + def grad(self, grad): + pc1, pc2, match, reduction = self.saved_vars + + if reduction == 'sum': + grad = jt.ones([pc1.shape[0]]) * grad + elif reduction == 'mean': + grad = jt.ones([pc1.shape[0]]) * grad / pc1.shape[0] + + grad_pc1 = jt.code( + shape=pc1.shape, + dtype=pc1.dtype, + inputs=[grad, pc1, pc2, match], + cuda_src=matchcost_grad1_gpu_src, + ) + + grad_pc2 = jt.code( + shape=pc2.shape, + dtype=pc2.dtype, + inputs=[grad, pc1, pc2, match], + cuda_src=matchcost_grad2_gpu_src, + ) + + return grad_pc1, grad_pc2 + + +def earth_mover_distance(pc1, pc2, reduction='mean', dims='BNC'): + ''' Earth Mover's distance from pc1 to pc2. Only supports GPU. + + :param pc1: input point cloud + :type pc1: jittor array + + :param pc2: input point cloud + :type pc2: jittor array + + :param reduction: reduction method in batches, can be 'mean', 'sum', or None. Default: 'mean'. + :type reduction: str, optional + + :param dims: a string that represents each dimension, can be + '[BNC]' ([batch, number of points, xyz]), or + '[BCN]' ([batch, xyz, number of points]). Default: 'BNC'. + :type dims: str, optional + + + Example: + + >>> import jittor as jt + >>> from jittor.loss3d import earth_mover_distance + >>> jt.flags.use_cuda = True + >>> pc1 = jt.rand([10, 100, 3], dtype=jt.float32) + >>> pc2 = jt.rand([10, 100, 3], dtype=jt.float32) + >>> emd = earth_mover_distance(pc1, pc2, dims='BNC') + >>> print('EMD =', emd.item()) + ''' + return EarthMoverDistance.apply(pc1, pc2, reduction, dims) diff --git a/python/jittor/lr_scheduler.py b/python/jittor/lr_scheduler.py new file mode 100644 index 00000000..9be55a4d --- /dev/null +++ b/python/jittor/lr_scheduler.py @@ -0,0 +1,205 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** + +import jittor as jt +from jittor.optim import Optimizer +import math + +class ReduceLROnPlateau(object): + def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=1e-4, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-8): + assert factor < 1.0, "factor should be < 1.0." + assert isinstance(optimizer, Optimizer), '{} is not an Optimizer'.format(type(optimizer).__name__) + assert mode in {'min', 'max'}, 'mode ' + mode + ' is unknown!' + assert threshold_mode in {'rel', 'abs'}, 'threshold mode ' + threshold_mode + ' is unknown!' + + if isinstance(min_lr, list) or isinstance(min_lr, tuple): + assert len(min_lr) == len(optimizer.param_groups), "expected {} min_lrs, got {}".format(len(optimizer.param_groups), len(min_lr)) + self.min_lrs = list(min_lr) + else: + self.min_lrs = [min_lr] * len(optimizer.param_groups) + self.factor = factor + self.optimizer = optimizer + self.patience = patience + self.verbose = verbose + self.cooldown = cooldown + self.n_cd = 0 + self.mode = mode + self.threshold = threshold + self.threshold_mode = threshold_mode + self.loss_best = None + self.n_bad = 0 + self.eps = eps + self.last_epoch = 0 + self.loss_best = math.inf if mode=="min" else -math.inf + + def step(self, loss, epoch=None): + # convert `metrics` to float, in case it's a zero-dim Tensor + loss_now = float(loss) + if epoch is None: + epoch = self.last_epoch + 1 + self.last_epoch = epoch + + if self.better(loss_now, self.loss_best): + self.loss_best = loss_now + self.n_bad = 0 + else: + self.n_bad += 1 + + if self.n_cd > 0: + self.n_cd -= 1 + self.n_bad = 0 + + if self.n_bad > self.patience: + self.update_lr(epoch) + self.n_cd = self.cooldown + self.n_bad = 0 + + def update_lr(self, epoch): + for i, param_group in enumerate(self.optimizer.param_groups): + old_lr = float(param_group.get("lr", self.optimizer.lr)) + new_lr = max(old_lr * self.factor, self.min_lrs[i]) + if old_lr - new_lr > self.eps: + if param_group.get("lr")!=None: + param_group["lr"] = max(param_group["lr"] * self.factor, self.min_lrs[i]) + else: + self.optimizer.lr = new_lr + if self.verbose: + print('Epoch {:5d}: reducing learning rate of group {} from {:.4e} to {:.4e}.'.format(epoch, i, old_lr, new_lr)) + + def better(self, a, b): + if self.mode == 'min' and self.threshold_mode == 'rel': + save = 1.0 - self.threshold + return a < b * save + elif self.mode == 'min' and self.threshold_mode == 'abs': + return a < b - self.threshold + elif self.mode == 'max' and self.threshold_mode == 'rel': + save = self.threshold + 1.0 + return a > b * save + else: + return a > b + self.threshold + +class CosineAnnealingLR(object): + def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1): + self.T_max = T_max + self.eta_min = eta_min + self.optimizer = optimizer + self.last_epoch = last_epoch + self.base_lr = optimizer.lr + self.base_lr_pg = [pg.get("lr") for pg in optimizer.param_groups] + #TODO set last_epoch is not ready + + def get_lr(self, base_lr, now_lr): + if self.last_epoch == 0: + return base_lr + if (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0: + return (now_lr + (base_lr - self.eta_min) * + (1 - math.cos(math.pi / self.T_max)) / 2) + return ((1 + math.cos(math.pi * self.last_epoch / self.T_max)) / + (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * + (now_lr - self.eta_min) + self.eta_min) + + def step(self): + self.last_epoch += 1 + self.update_lr() + + def update_lr(self): + self.optimizer.lr = self.get_lr(self.base_lr, self.optimizer.lr) + for i, param_group in enumerate(self.optimizer.param_groups): + if param_group.get("lr") != None: + param_group["lr"] = self.get_lr(self.base_lr_pg[i], param_group["lr"]) + + +class ExponentialLR(object): + """ learning rate is multiplied by gamma in each step. + """ + def __init__(self, optimizer, gamma, last_epoch=-1): + self.optimizer = optimizer + self.gamma = gamma + self.last_epoch = last_epoch + self.base_lr = optimizer.lr + self.base_lr_pg = [pg.get("lr") for pg in optimizer.param_groups] + + def get_lr(self, base_lr, now_lr): + if self.last_epoch == 0: + return base_lr + return base_lr * self.gamma ** self.last_epoch + + def step(self): + self.last_epoch += 1 + self.update_lr() + + def update_lr(self): + self.optimizer.lr = self.get_lr(self.base_lr, self.optimizer.lr) + for i, param_group in enumerate(self.optimizer.param_groups): + if param_group.get("lr") != None: + param_group["lr"] = self.get_lr(self.base_lr_pg[i], param_group["lr"]) + + +class StepLR(object): + def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1): + self.optimizer = optimizer + self.step_size = step_size + self.gamma = gamma + self.last_epoch = last_epoch + self.cur_epoch = 0 + + def get_gamma(self): + if self.last_epoch < 0: + if (self.cur_epoch != 0 and (self.cur_epoch + 1) % self.step_size == 0): + return self.gamma + else: + if (self.cur_epoch + 1 + self.last_epoch) % self.step_size == 0: + return self.gamma + return 1. + + def get_lr(self): + return self.optimizer.lr + + def step(self): + self.update_lr() + self.cur_epoch += 1 + + def update_lr(self): + gamma = self.get_gamma() + if gamma != 1.: + self.optimizer.lr = self.optimizer.lr * gamma + for i, param_group in enumerate(self.optimizer.param_groups): + if param_group.get("lr") != None: + param_group["lr"] = param_group["lr"] * gamma + +class MultiStepLR(object): + def __init__(self, optimizer, milestones=[], gamma=0.1, last_epoch=-1): + self.optimizer = optimizer + self.milestones = milestones + self.gamma = gamma + self.last_epoch = last_epoch + #TODO set last_epoch is not ready + + def get_gamma(self): + if (self.last_epoch in self.milestones): + return self.gamma + return 1.0 + + def get_lr(self): + now_lr = self.optimizer.lr + return now_lr * self.get_gamma() + + def step(self): + self.last_epoch += 1 + self.update_lr() + + def update_lr(self): + gamma = self.get_gamma() + if gamma != 1.0: + self.optimizer.lr = self.optimizer.lr * gamma + for i, param_group in enumerate(self.optimizer.param_groups): + if param_group.get("lr") != None: + param_group["lr"] = param_group["lr"] * gamma diff --git a/python/jittor/math_util/__init__.py b/python/jittor/math_util/__init__.py new file mode 100644 index 00000000..f124571b --- /dev/null +++ b/python/jittor/math_util/__init__.py @@ -0,0 +1,2 @@ +from .gamma import digamma, lgamma +from .igamma import igamma diff --git a/python/jittor/math_util/gamma.py b/python/jittor/math_util/gamma.py new file mode 100644 index 00000000..3884c495 --- /dev/null +++ b/python/jittor/math_util/gamma.py @@ -0,0 +1,416 @@ +import numpy as np +import jittor as jt +from jittor import nn + +class lgamma(jt.Function): + def __init__(self): + self.cpu_src = ''' + @alias(x, in0) + @alias(di_x, out0) + int numel = x_shape0 * x_stride0; + for(int i=0;i>>(x_p, lx_p, batch_shape); + ''' + + def execute(self, x): + if jt.flags.use_cuda: + return jt.code(x.shape, x.dtype, [x], cuda_header=self.cuda_header, cuda_src=self.cuda_src) + else: + return jt.code(x.shape, x.dtype, [x], cpu_src=self.cpu_src) + +class polygamma(jt.Function): + def __init__(self): + self.cpu_header = ''' + #ifdef __CUDACC__ + #define C10_HOST_DEVICE __host__ __device__ + #else + #define C10_HOST_DEVICE + #endif + + template C10_HOST_DEVICE static inline scalar_t zeta(scalar_t x, scalar_t q) { + using acc_t = float; + const acc_t MACHEP = acc_t{1.11022302462515654042E-16}; + constexpr acc_t zero = acc_t{0.0}; + constexpr acc_t half = acc_t{0.5}; + constexpr acc_t one = acc_t{1.0}; + static const acc_t A[] = { + 12.0, + -720.0, + 30240.0, + -1209600.0, + 47900160.0, + -1.8924375803183791606e9, /*1.307674368e12/691*/ + 7.47242496e10, + -2.950130727918164224e12, /*1.067062284288e16/3617*/ + 1.1646782814350067249e14, /*5.109094217170944e18/43867*/ + -4.5979787224074726105e15, /*8.028576626982912e20/174611*/ + 1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/ + -7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/ + }; + + int i = 0; + acc_t a, b, k, s, t, w; + if (x == one) { + return std::numeric_limits::infinity(); + } + + if (x < one) { + return std::numeric_limits::quiet_NaN(); + } + + if (q <= zero) { + if (q == ::floor(q)) { + return std::numeric_limits::infinity(); + } + if (x != ::floor(x)) { + return std::numeric_limits::quiet_NaN(); + } + } + + s = ::pow(q, -x); + a = q; + i = 0; + b = zero; + while ((i < 9) || (a <= acc_t{9.0})) { + i += 1; + a += one; + b = ::pow(a, -x); + s += b; + if ((-MACHEP * s < b) && (b < MACHEP * s)) { + return static_cast(s); + } + }; + + w = a; + s += b * w / (x - one); + s -= half * b; + a = one; + k = zero; + for (int i = 0; i < 12; i++) { + a *= x + k; + b /= w; + t = a * b / A[i]; + s = s + t; + t = ::fabs(t / s); + if (t < MACHEP) { + return static_cast(s); + } + k += one; + a *= x + k; + b /= w; + k += one; + } + return static_cast(s); + } + using scalar_t = float; + ''' + self.cuda_header = self.cpu_header + ''' + __global__ void polygamma_cuda(float* __restrict__ x, + float* out, + int n, + int batch_shape) + { + int tidx = threadIdx.x; + int start = batch_shape / blockDim.x * tidx; + int end = threadIdx.x == blockDim.x - 1 ? batch_shape : start + batch_shape / blockDim.x; + float* bx = x+batch_shape*blockIdx.x; + float* bout = out + batch_shape * blockIdx.x; + for(int i=start;i(n) + 1.0)) * + zeta(static_cast(n + 1), bx[i]); + } + ''' + + def execute(self, x, n): + if jt.flags.use_cuda: + self.cuda_src = f''' + @alias(x, in0) + @alias(px ,out0) + int batch_size = x_stride0 == 1 ? 1 : x_shape0; + int batch_shape = x_shape0 * x_stride0 / batch_size; + polygamma_cuda<<>>(x_p, px_p, {n}, batch_shape); + ''' + return jt.code(x.shape, x.dtype, [x], cuda_header=self.cuda_header, cuda_src=self.cuda_src) + else: + self.cpu_src = f''' + @alias(x, in0) + @alias(px, out0) + int numel = x_shape0 * x_stride0; + for(int i=0;i({n}) + 1.0)) * + zeta(static_cast({n} + 1), x_p[i]); + }} + ''' + return jt.code(x.shape, x.dtype, [x], cpu_header=self.cpu_header, cpu_src=self.cpu_src) + +class digamma(jt.Function): + ''' + digamma(x) = psi(x) = d/dx[ln(gamma(x))] + ''' + def __init__(self): + self.cpu_header = ''' + #include + #define C10_HOST_DEVICE + template + C10_HOST_DEVICE static inline T polevl(const T x, const T A[], size_t len) { + T result = 0; + for (size_t i = 0; i <= len; i++) { + result = result * x + A[i]; + } + return result; + } + + static inline float calc_digamma(float x) { + // See [C++ Standard Reference: Gamma Function] + static float PSI_10 = 2.25175258906672110764f; + if (x == 0) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is ±0, ±∞ is returned + return std::copysign(INFINITY, -x); + } + + bool x_is_integer = x == truncf(x); + if (x < 0) { + if (x_is_integer) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is a negative integer, NaN is returned + return std::numeric_limits::quiet_NaN(); + } + // Extracts the fractional part of x as r, since tan(pi * r) is more numerically + // accurate than tan(pi * x). While these operations are mathematically equivalent + // since both x and r are in radians and tan() has a periodicity of pi, in practice + // the computation of pi * x is a source of error (when |x| > 1). + double q, r; + r = std::modf(x, &q); + float pi_over_tan_pi_x = (float)(M_PI / tan(M_PI * r)); + return calc_digamma(1 - x) - pi_over_tan_pi_x; + } + + // Push x to be >= 10 + float result = 0; + while (x < 10) { + result -= 1 / x; + x += 1; + } + if (x == 10) { + return result + PSI_10; + } + + // Compute asymptotic digamma + static const float A[] = { + 8.33333333333333333333E-2f, + -2.10927960927960927961E-2f, + 7.57575757575757575758E-3f, + -4.16666666666666666667E-3f, + 3.96825396825396825397E-3f, + -8.33333333333333333333E-3f, + 8.33333333333333333333E-2f, + }; + + float y = 0; + if (x < 1.0e17f) { + float z = 1 / (x * x); + y = z * polevl(z, A, 6); + } + return result + logf(x) - (0.5f / x) - y; + } + ''' + self.cpu_src = ''' + @alias(x, in0) + @alias(di_x, out0) + int numel = x_shape0 * x_stride0; + for(int i=0;i + C10_HOST_DEVICE static inline T polevl(const T x, const T A[], size_t len) { + T result = 0; + for (size_t i = 0; i <= len; i++) { + result = result * x + A[i]; + } + return result; + } + + __device__ static inline float calc_digamma(float x) { + // See [C++ Standard Reference: Gamma Function] + static float PSI_10 = 2.25175258906672110764f; + if (x == 0) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is ±0, ±∞ is returned + return std::copysign(INFINITY, -x); + } + + bool x_is_integer = x == truncf(x); + if (x < 0) { + if (x_is_integer) { + // As per C++ standard for gamma related functions and SciPy, + // If the argument is a negative integer, NaN is returned + return std::numeric_limits::quiet_NaN(); + } + // Extracts the fractional part of x as r, since tan(pi * r) is more numerically + // accurate than tan(pi * x). While these operations are mathematically equivalent + // since both x and r are in radians and tan() has a periodicity of pi, in practice + // the computation of pi * x is a source of error (when |x| > 1). + double q, r; + r = std::modf(x, &q); + float pi_over_tan_pi_x = (float)(M_PI / tan(M_PI * r)); + return calc_digamma(1 - x) - pi_over_tan_pi_x; + } + + // Push x to be >= 10 + float result = 0; + while (x < 10) { + result -= 1 / x; + x += 1; + } + if (x == 10) { + return result + PSI_10; + } + + // Compute asymptotic digamma + static const float A[] = { + 8.33333333333333333333E-2f, + -2.10927960927960927961E-2f, + 7.57575757575757575758E-3f, + -4.16666666666666666667E-3f, + 3.96825396825396825397E-3f, + -8.33333333333333333333E-3f, + 8.33333333333333333333E-2f, + }; + + float y = 0; + if (x < 1.0e17f) { + float z = 1 / (x * x); + y = z * polevl(z, A, 6); + } + return result + logf(x) - (0.5f / x) - y; + } + + __global__ void digamma_cuda(float* __restrict__ x, + float* out, + int batch_shape) + { + int tidx = threadIdx.x; + int start = batch_shape / blockDim.x * tidx; + int end = threadIdx.x == blockDim.x - 1 ? batch_shape : start + batch_shape / blockDim.x; + float* bx = x+batch_shape*blockIdx.x; + float* bout = out + batch_shape * blockIdx.x; + for(int i=start;i>>(x_p, di_x_p, batch_shape); + ''' + + def execute(self, x): + self.input = x + if jt.flags.use_cuda: + dx = jt.code(x.shape, x.dtype, [x], cuda_header=self.cuda_header, cuda_src=self.cuda_src) + dx.compile_options = {"FLAGS: --expt-relaxed-constexpr":1} + return dx + else: + return jt.code(x.shape, x.dtype, [x], cpu_header=self.cpu_header, cpu_src=self.cpu_src) + + def grad(self, grad_d): + return grad_d * polygamma.apply(self.input, 1) + +def gamma_grad(x, alpha): + cuda_header = open(os.path.join(os.path.realpath(os.path.dirname(__file__)), "src", "gamma_grad.h"), "r").read() + cuda_src = ''' + @alias(x, in0) + @alias(di_x, out0) + int block_num = x_stride0 == 1 ? 1 : x_shape0; + int batch_shape = x_stride0 == 1 ? x_shape0: x_stride0; + float alpha = data["alpha"]; + gamma_grad_kenrel<<>>(x_p, di_x_p, alpha, batch_shape); + ''' + grad = jt.code(x.shape, x.dtype, [x], cuda_header=cuda_header, cuda_src=cuda_src, data={"alpha":alpha}) + return grad + +def sample_gamma(alpha, shape): + cuda_header = ''' + #include + + template + __device__ float sample_gamma(float alpha, curandState& state) { + accscalar_t scale = 1.0f; + + // Boost alpha for higher acceptance probability. + if (alpha < 1.0f) { + if (alpha == 0.f) return 0.f; + scale *= pow(1 - curand_uniform(&state), 1.0f / alpha); + alpha += 1.0f; + } + + // This implements the acceptance-rejection method of Marsaglia and Tsang (2000) + // doi:10.1145/358407.358414 + const accscalar_t d = alpha - 1.0f / 3.0f; + const accscalar_t c = 1.0f / sqrt(9.0f * d + 1e-8); + for (;;) { + accscalar_t x, y; + do { + x = curand_normal(&state); + y = 1.0f + c * x; + } while (y <= 0); + const accscalar_t v = y * y * y; + const accscalar_t u = 1 - curand_uniform(&state); + const accscalar_t xx = x * x; + if (u < 1.0f - 0.0331f * xx * xx) + return static_cast(scale * d * v); + if (log(u) < 0.5f * xx + d * (1.0f - v + log(v))) + return static_cast(scale * d * v); + } + } + + __global__ void sample_gamma_kernel(float* out, + float alpha, + int seed, + int batch_shape) + { + int tidx = threadIdx.x; + int start = batch_shape / blockDim.x * tidx; + int end = threadIdx.x == blockDim.x - 1 ? batch_shape : start + batch_shape / blockDim.x; + if(start > end) + return; + float* bout = out + batch_shape * blockIdx.x; + curandState state; + curand_init(clock64(), threadIdx.x, 0, &state); + for(int i=start;i(alpha, state); + } + ''' + cuda_src = ''' + @alias(lx ,out0) + int batch_size = lx_stride0 == 1 ? 1 : lx_shape0; + int batch_shape = lx_shape0 * lx_stride0 / batch_size; + float alpha = data["alpha"]; + sample_gamma_kernel<<>>(lx_p, alpha, time(NULL), batch_shape); + ''' + samples = jt.code(shape, jt.float32, [], cuda_header=cuda_header, cuda_src=cuda_src, data={"alpha":alpha}) + return samples diff --git a/python/jittor/math_util/igamma.py b/python/jittor/math_util/igamma.py new file mode 100644 index 00000000..a261b717 --- /dev/null +++ b/python/jittor/math_util/igamma.py @@ -0,0 +1,21 @@ +import os + +import numpy as np +import jittor as jt +from jittor import nn + +f = open(os.path.join(os.path.realpath(os.path.dirname(__file__)), "src", "igamma.h"), "r") +cuda_header = f.read() +f.close() + +def igamma(alpha, x): + cuda_src = ''' + @alias(x, in0) + @alias(px ,out0) + int batch_size = x_stride0 == 1 ? 1 : x_shape0; + int batch_shape = x_shape0 * x_stride0 / batch_size; + float alpha = data["alpha"]; + igamma_kernel<<>>(x_p, px_p, alpha, batch_shape); + ''' + out = jt.code(x.shape, x.dtype, [x], cuda_header=cuda_header, cuda_src=cuda_src, data={"alpha": alpha}) + return out diff --git a/python/jittor/math_util/src/gamma_grad.h b/python/jittor/math_util/src/gamma_grad.h new file mode 100644 index 00000000..70ce4ccb --- /dev/null +++ b/python/jittor/math_util/src/gamma_grad.h @@ -0,0 +1,141 @@ +#include + + template + __device__ static inline T polevl(const T x, const T A[], size_t len) { + T result = 0; + for (size_t i = 0; i <= len; i++) { + result = result * x + A[i]; + } + return result; + } + + template + __device__ static inline scalar_t digamma_one(scalar_t x) { + constexpr accscalar_t PSI_10 = 2.25175258906672110764; + if (x == 0) { + return INFINITY; + } + accscalar_t additional_summand = 0; + int x_is_integer = x == floor(x); + if (x < 0) { + if (x_is_integer) { + return INFINITY; + } + // it is more standard to write this as recursion, but + // nvcc does not like that + additional_summand = -M_PI / + tan(M_PI * x); + x = 1 - x; + } + + // Push x to be >= 10 + accscalar_t result = 0; + while (x < 10) { + result -= 1 / x; + x += 1; + } + if (x == 10) { + return result + PSI_10 + additional_summand; + } + + // Compute asymptotic digamma + static const accscalar_t A[] = { + 8.33333333333333333333E-2, + -2.10927960927960927961E-2, + 7.57575757575757575758E-3, + -4.16666666666666666667E-3, + 3.96825396825396825397E-3, + -8.33333333333333333333E-3, + 8.33333333333333333333E-2, + }; + + accscalar_t y = 0; + if (x < 1.0e17f) { + accscalar_t z = 1.0 / (x * x); + y = z * polevl(z, A, 6); + } + return static_cast( + result + log(x) - (0.5f / x) - y + additional_summand); + } + + template + __device__ scalar_t standard_gamma_grad_one(scalar_t alpha_, scalar_t x_) { + // Use a Taylor series expansion for small x. + accscalar_t x = static_cast(x_); + accscalar_t alpha = static_cast(alpha_); + if (x < 0.8f) { + accscalar_t numer = 1; + accscalar_t denom = alpha; + auto series1 = numer / denom; + auto series2 = numer / (denom * denom); + for (int i = 1; i <= 5; ++i) { + numer *= -x / static_cast(i); + denom += 1; + series1 += numer / denom; + series2 += numer / (denom * denom); + } + const auto pow_x_alpha = pow(x, alpha); + const auto gamma_pdf = pow(x, alpha - 1) * exp(-x); + const auto gamma_cdf = pow_x_alpha * series1; + const auto gamma_cdf_alpha = + (log(x) - digamma_one(alpha)) * + gamma_cdf - + pow_x_alpha * series2; + const auto result = -gamma_cdf_alpha / gamma_pdf; + return isnan(result) ? static_cast( 0.f ) : static_cast(result); + } + + // Use a Rice saddle point expansion for large alpha. + if (alpha > 8.0f) { + if (0.9f * alpha <= x && x <= 1.1f * alpha) { + const auto numer_1 = 1 + 24 * alpha * (1 + 12 * alpha); + const auto numer_2 = 1440 * (alpha * alpha) + 6 * x * (53 - 120 * x) + - 65 * x * x / alpha + alpha * (107 + 3600 * x); + const auto denom = 1244160 * (alpha * alpha) * (alpha * alpha); + return static_cast(numer_1 * numer_2 / denom); + } + const auto denom = sqrt(8 * alpha + 1e-8); + const auto term2 = denom / (alpha - x); + const auto term3 = pow( + x - alpha - alpha * log(x / alpha), + static_cast(-1.5)); + const auto term23 = (x < alpha) ? term2 - term3 : term2 + term3; + const auto term1 = log(x / alpha) * term23 - + sqrt(2 / alpha + 1e-8) * (alpha + x) / ((alpha - x) * (alpha - x)); + const auto stirling = 1 + 1 / (12 * alpha) * (1 + 1 / (24 * alpha)); + const auto numer = x * term1; + return static_cast(-stirling * numer / denom); + } + + // Use a bivariate rational approximation to the reparameterized gradient. + const auto u = log(x / alpha); + const auto v = log(alpha); + static const accscalar_t coef_uv[3][8] = { + {0.16009398, -0.094634809, 0.025146376, -0.0030648343, + 1, 0.32668115, 0.10406089, 0.0014179084}, + {0.53487893, 0.1298071, 0.065735949, -0.0015649758, + 0.16639465, 0.020070113, -0.0035938915, -0.00058392623}, + {0.040121004, -0.0065914022, -0.0026286047, -0.0013441777, + 0.017050642, -0.0021309326, 0.00085092367, -1.5247877e-07}, + }; + accscalar_t coef_v[8]; + for (int i = 0; i < 8; ++ i) { + coef_v[i] = coef_uv[0][i] + u * (coef_uv[1][i] + u * coef_uv[2][i]); + } + const auto p = coef_v[0] + v * (coef_v[1] + v * (coef_v[2] + v * coef_v[3])); + const auto q = coef_v[4] + v * (coef_v[5] + v * (coef_v[6] + v * coef_v[7])); + return static_cast(exp(p / q)); + } + + __global__ void gamma_grad_kenrel(float* __restrict__ x, + float* out, + float alpha, + int batch_shape) + { + int tidx = threadIdx.x; + int start = batch_shape / blockDim.x * tidx; + int end = threadIdx.x == blockDim.x - 1 ? batch_shape : start + batch_shape / blockDim.x; + float* bx = x+batch_shape*blockIdx.x; + float* bout = out + batch_shape * blockIdx.x; + for(int i=start;i(alpha, bx[i]); + } diff --git a/python/jittor/math_util/src/igamma.h b/python/jittor/math_util/src/igamma.h new file mode 100644 index 00000000..cb949452 --- /dev/null +++ b/python/jittor/math_util/src/igamma.h @@ -0,0 +1,694 @@ +// THIS FILE ACTS AS THE HEADER OF IGAMMA FUNCTION. +#include +#define C10_DEVICE __host__ __device__ +template +static C10_DEVICE scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M, + const scalar_t denom[], int64_t N) { + // evaluating rational function, i.e., the ratio of two polynomials + // the coefficients for numerator are given by `num` while coeffs for + // denumerator are given by `denom` + + int64_t i, dir; + scalar_t y, num_ans, denom_ans; + scalar_t absx = std::fabs(x); + const scalar_t *p; + + if (absx > 1) { + /* Evaluate as a polynomial in 1/x. */ + dir = -1; + p = num + M; + y = 1 / x; + } + else { + dir = 1; + p = num; + y = x; + } + + /* Evaluate the numerator */ + num_ans = *p; + p += dir; + for (i = 1; i <= M; i++) { + num_ans = num_ans * y + *p; + p += dir; + } + /* Evaluate the denominator */ + if (absx > 1) { + p = denom + N; + } + else { + p = denom; + } + + denom_ans = *p; + p += dir; + for (i = 1; i <= N; i++) { + denom_ans = denom_ans * y + *p; + p += dir; + } + if (absx > 1) { + i = N - M; + return std::pow(x, i) * num_ans / denom_ans; + } + else { + return num_ans / denom_ans; + } +} + +template +static C10_DEVICE scalar_t lanczos_sum_expg_scaled(scalar_t x) { + // lanczos approximation + static const scalar_t lanczos_sum_expg_scaled_num[13] = { + 0.006061842346248906525783753964555936883222, + 0.5098416655656676188125178644804694509993, + 19.51992788247617482847860966235652136208, + 449.9445569063168119446858607650988409623, + 6955.999602515376140356310115515198987526, + 75999.29304014542649875303443598909137092, + 601859.6171681098786670226533699352302507, + 3481712.15498064590882071018964774556468, + 14605578.08768506808414169982791359218571, + 43338889.32467613834773723740590533316085, + 86363131.28813859145546927288977868422342, + 103794043.1163445451906271053616070238554, + 56906521.91347156388090791033559122686859 + }; + static const scalar_t lanczos_sum_expg_scaled_denom[13] = { + 1., + 66., + 1925., + 32670., + 357423., + 2637558., + 13339535., + 45995730., + 105258076., + 150917976., + 120543840., + 39916800., + 0. + }; + return ratevl(x, lanczos_sum_expg_scaled_num, + sizeof(lanczos_sum_expg_scaled_num) / sizeof(lanczos_sum_expg_scaled_num[0]) - 1, + lanczos_sum_expg_scaled_denom, + sizeof(lanczos_sum_expg_scaled_denom) / sizeof(lanczos_sum_expg_scaled_denom[0]) - 1); +} + +template +static C10_DEVICE scalar_t _igam_helper_fac(scalar_t a, scalar_t x) { + // compute x^a * exp(-a) / gamma(a) + // corrected from (15) and (16) in [igam2] by replacing exp(x - a) with + // exp(a - x). + + scalar_t ax, fac, res, num, numfac; + static scalar_t MAXLOG = std::is_same::value ? + 7.09782712893383996843E2 : 88.72283905206835; + static scalar_t EXP1 = 2.718281828459045; + static scalar_t lanczos_g = 6.024680040776729583740234375; + + if (std::fabs(a - x) > 0.4 * std::fabs(a)) { + ax = a * std::log(x) - x - std::lgamma(a); + if (ax < -MAXLOG) { + return 0.0; + } + return std::exp(ax); + } + + fac = a + lanczos_g - 0.5; + res = std::sqrt(fac / EXP1) / lanczos_sum_expg_scaled(a); + + if ((a < 200) && (x < 200)) { + res *= std::exp(a - x) * std::pow(x / fac, a); + } + else { + num = x - a - lanczos_g + 0.5; + numfac = num / fac; + res *= std::exp(a * (std::log1p(numfac) - numfac) + x * (0.5 - lanczos_g) / fac); + } + return res; +} + +template +static C10_DEVICE scalar_t _igam_helper_series(scalar_t a, scalar_t x) { + // Compute igam using DLMF 8.11.4. [igam1] + static scalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static int MAXITER = 2000; + + int i; + scalar_t ans, ax, c, r; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* power series */ + r = a; + c = 1.0; + ans = 1.0; + + for (i = 0; i < MAXITER; i++) { + r += 1.0; + c *= x / r; + ans += c; + if (c <= MACHEP * ans) { + break; + } + } + return (ans * ax / a); +} + +template +static C10_DEVICE scalar_t _igamc_helper_series(scalar_t a, scalar_t x) { + // Compute igamc using DLMF 8.7.3 [igam1]. This is related to the series in + // _igam_helper_series but extra care is taken to avoid cancellation. + + int n; + scalar_t fac = 1; + scalar_t sum = 0; + scalar_t term, logx; + static scalar_t MAXITER = 2000; + static scalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + + for (n = 1; n < MAXITER; n++) { + fac *= -x / n; + term = fac / (a + n); + sum += term; + if (std::fabs(term) <= MACHEP * std::fabs(sum)) { + break; + } + } + + logx = std::log(x); + term = -std::expm1(a * logx - std::lgamma(1+a)); + return term - std::exp(a * logx - std::lgamma(a)) * sum; +} + +template +static C10_DEVICE scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) { + // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1] + static const scalar_t d[25][25] = + {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2, + 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4, + 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6, + 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9, + 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10, + -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11, + -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13, + -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16, + -1.9752288294349443e-15}, + {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3, + -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7, + -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6, + 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8, + 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9, + 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14, + 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13, + -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14, + -4.13125571381061e-15}, + {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4, + 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5, + -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6, + -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10, + -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9, + 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11, + 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12, + 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17, + 8.8592218725911273e-15}, + {6.4943415637860082e-4, 2.2947209362139918e-4, -4.6918949439525571e-4, + 2.6772063206283885e-4, -7.5618016718839764e-5, -2.3965051138672967e-7, + 1.1082654115347302e-5, -5.6749528269915966e-6, 1.4230900732435884e-6, + -2.7861080291528142e-11, -1.6958404091930277e-7, 8.0994649053880824e-8, + -1.9111168485973654e-8, 2.3928620439808118e-12, 2.0620131815488798e-9, + -9.4604966618551322e-10, 2.1541049775774908e-10, -1.388823336813903e-14, + -2.1894761681963939e-11, 9.7909989511716851e-12, -2.1782191880180962e-12, + 6.2088195734079014e-17, 2.126978363279737e-13, -9.3446887915174333e-14, + 2.0453671226782849e-14}, + {-8.618882909167117e-4, 7.8403922172006663e-4, -2.9907248030319018e-4, + -1.4638452578843418e-6, 6.6414982154651222e-5, -3.9683650471794347e-5, + 1.1375726970678419e-5, 2.5074972262375328e-10, -1.6954149536558306e-6, + 8.9075075322053097e-7, -2.2929348340008049e-7, 2.956794137544049e-11, + 2.8865829742708784e-8, -1.4189739437803219e-8, 3.4463580499464897e-9, + -2.3024517174528067e-13, -3.9409233028046405e-10, 1.8602338968504502e-10, + -4.356323005056618e-11, 1.2786001016296231e-15, 4.6792750266579195e-12, + -2.1492464706134829e-12, 4.9088156148096522e-13, -6.3385914848915603e-18, + -5.0453320690800944e-14}, + {-3.3679855336635815e-4, -6.9728137583658578e-5, 2.7727532449593921e-4, + -1.9932570516188848e-4, 6.7977804779372078e-5, 1.419062920643967e-7, + -1.3594048189768693e-5, 8.0184702563342015e-6, -2.2914811765080952e-6, + -3.252473551298454e-10, 3.4652846491085265e-7, -1.8447187191171343e-7, + 4.8240967037894181e-8, -1.7989466721743515e-14, -6.3061945000135234e-9, + 3.1624176287745679e-9, -7.8409242536974293e-10, 5.1926791652540407e-15, + 9.3589442423067836e-11, -4.5134262161632782e-11, 1.0799129993116827e-11, + -3.661886712685252e-17, -1.210902069055155e-12, 5.6807435849905643e-13, + -1.3249659916340829e-13}, + {5.3130793646399222e-4, -5.9216643735369388e-4, 2.7087820967180448e-4, + 7.9023532326603279e-7, -8.1539693675619688e-5, 5.6116827531062497e-5, + -1.8329116582843376e-5, -3.0796134506033048e-9, 3.4651553688036091e-6, + -2.0291327396058604e-6, 5.7887928631490037e-7, 2.338630673826657e-13, + -8.8286007463304835e-8, 4.7435958880408128e-8, -1.2545415020710382e-8, + 8.6496488580102925e-14, 1.6846058979264063e-9, -8.5754928235775947e-10, + 2.1598224929232125e-10, -7.6132305204761539e-16, -2.6639822008536144e-11, + 1.3065700536611057e-11, -3.1799163902367977e-12, 4.7109761213674315e-18, + 3.6902800842763467e-13}, + {3.4436760689237767e-4, 5.1717909082605922e-5, -3.3493161081142236e-4, + 2.812695154763237e-4, -1.0976582244684731e-4, -1.2741009095484485e-7, + 2.7744451511563644e-5, -1.8263488805711333e-5, 5.7876949497350524e-6, + 4.9387589339362704e-10, -1.0595367014026043e-6, 6.1667143761104075e-7, + -1.7562973359060462e-7, -1.2974473287015439e-12, 2.695423606288966e-8, + -1.4578352908731271e-8, 3.887645959386175e-9, -3.8810022510194121e-17, + -5.3279941738772867e-10, 2.7437977643314845e-10, -6.9957960920705679e-11, + 2.5899863874868481e-17, 8.8566890996696381e-12, -4.403168815871311e-12, + 1.0865561947091654e-12}, + {-6.5262391859530942e-4, 8.3949872067208728e-4, -4.3829709854172101e-4, + -6.969091458420552e-7, 1.6644846642067548e-4, -1.2783517679769219e-4, + 4.6299532636913043e-5, 4.5579098679227077e-9, -1.0595271125805195e-5, + 6.7833429048651666e-6, -2.1075476666258804e-6, -1.7213731432817145e-11, + 3.7735877416110979e-7, -2.1867506700122867e-7, 6.2202288040189269e-8, + 6.5977038267330006e-16, -9.5903864974256858e-9, 5.2132144922808078e-9, + -1.3991589583935709e-9, 5.382058999060575e-16, 1.9484714275467745e-10, + -1.0127287556389682e-10, 2.6077347197254926e-11, -5.0904186999932993e-18, + -3.3721464474854592e-12}, + {-5.9676129019274625e-4, -7.2048954160200106e-5, 6.7823088376673284e-4, + -6.4014752602627585e-4, 2.7750107634328704e-4, 1.8197008380465151e-7, + -8.4795071170685032e-5, 6.105192082501531e-5, -2.1073920183404862e-5, + -8.8585890141255994e-10, 4.5284535953805377e-6, -2.8427815022504408e-6, + 8.7082341778646412e-7, 3.6886101871706965e-12, -1.5344695190702061e-7, + 8.862466778790695e-8, -2.5184812301826817e-8, -1.0225912098215092e-14, + 3.8969470758154777e-9, -2.1267304792235635e-9, 5.7370135528051385e-10, + -1.887749850169741e-19, -8.0931538694657866e-11, 4.2382723283449199e-11, + -1.1002224534207726e-11}, + {1.3324454494800656e-3, -1.9144384985654775e-3, 1.1089369134596637e-3, + 9.932404122642299e-7, -5.0874501293093199e-4, 4.2735056665392884e-4, + -1.6858853767910799e-4, -8.1301893922784998e-9, 4.5284402370562147e-5, + -3.127053674781734e-5, 1.044986828530338e-5, 4.8435226265680926e-11, + -2.1482565873456258e-6, 1.329369701097492e-6, -4.0295693092101029e-7, + -1.7567877666323291e-13, 7.0145043163668257e-8, -4.040787734999483e-8, + 1.1474026743371963e-8, 3.9642746853563325e-18, -1.7804938269892714e-9, + 9.7480262548731646e-10, -2.6405338676507616e-10, 5.794875163403742e-18, + 3.7647749553543836e-11}, + {1.579727660730835e-3, 1.6251626278391582e-4, -2.0633421035543276e-3, + 2.1389686185689098e-3, -1.0108559391263003e-3, -3.9912705529919201e-7, + 3.6235025084764691e-4, -2.8143901463712154e-4, 1.0449513336495887e-4, + 2.1211418491830297e-9, -2.5779417251947842e-5, 1.7281818956040463e-5, + -5.6413773872904282e-6, -1.1024320105776174e-11, 1.1223224418895175e-6, + -6.8693396379526735e-7, 2.0653236975414887e-7, 4.6714772409838506e-14, + -3.5609886164949055e-8, 2.0470855345905963e-8, -5.8091738633283358e-9, + -1.332821287582869e-16, 9.0354604391335133e-10, -4.9598782517330834e-10, + 1.3481607129399749e-10}, + {-4.0725121195140166e-3, 6.4033628338080698e-3, -4.0410161081676618e-3, + -2.183732802866233e-6, 2.1740441801254639e-3, -1.9700440518418892e-3, + 8.3595469747962458e-4, 1.9445447567109655e-8, -2.5779387120421696e-4, + 1.9009987368139304e-4, -6.7696499937438965e-5, -1.4440629666426572e-10, + 1.5712512518742269e-5, -1.0304008744776893e-5, 3.304517767401387e-6, + 7.9829760242325709e-13, -6.4097794149313004e-7, 3.8894624761300056e-7, + -1.1618347644948869e-7, -2.816808630596451e-15, 1.9878012911297093e-8, + -1.1407719956357511e-8, 3.2355857064185555e-9, 4.1759468293455945e-20, + -5.0423112718105824e-10}, + {-5.9475779383993003e-3, -5.4016476789260452e-4, 8.7910413550767898e-3, + -9.8576315587856125e-3, 5.0134695031021538e-3, 1.2807521786221875e-6, + -2.0626019342754683e-3, 1.7109128573523058e-3, -6.7695312714133799e-4, + -6.9011545676562133e-9, 1.8855128143995902e-4, -1.3395215663491969e-4, + 4.6263183033528039e-5, 4.0034230613321351e-11, -1.0255652921494033e-5, + 6.612086372797651e-6, -2.0913022027253008e-6, -2.0951775649603837e-13, + 3.9756029041993247e-7, -2.3956211978815887e-7, 7.1182883382145864e-8, + 8.925574873053455e-16, -1.2101547235064676e-8, 6.9350618248334386e-9, + -1.9661464453856102e-9}, + {1.7402027787522711e-2, -2.9527880945699121e-2, 2.0045875571402799e-2, + 7.0289515966903407e-6, -1.2375421071343148e-2, 1.1976293444235254e-2, + -5.4156038466518525e-3, -6.3290893396418616e-8, 1.8855118129005065e-3, + -1.473473274825001e-3, 5.5515810097708387e-4, 5.2406834412550662e-10, + -1.4357913535784836e-4, 9.9181293224943297e-5, -3.3460834749478311e-5, + -3.5755837291098993e-12, 7.1560851960630076e-6, -4.5516802628155526e-6, + 1.4236576649271475e-6, 1.8803149082089664e-14, -2.6623403898929211e-7, + 1.5950642189595716e-7, -4.7187514673841102e-8, -6.5107872958755177e-17, + 7.9795091026746235e-9}, + {3.0249124160905891e-2, 2.4817436002649977e-3, -4.9939134373457022e-2, + 5.9915643009307869e-2, -3.2483207601623391e-2, -5.7212968652103441e-6, + 1.5085251778569354e-2, -1.3261324005088445e-2, 5.5515262632426148e-3, + 3.0263182257030016e-8, -1.7229548406756723e-3, 1.2893570099929637e-3, + -4.6845138348319876e-4, -1.830259937893045e-10, 1.1449739014822654e-4, + -7.7378565221244477e-5, 2.5625836246985201e-5, 1.0766165333192814e-12, + -5.3246809282422621e-6, 3.349634863064464e-6, -1.0381253128684018e-6, + -5.608909920621128e-15, 1.9150821930676591e-7, -1.1418365800203486e-7, + 3.3654425209171788e-8}, + {-9.9051020880159045e-2, 1.7954011706123486e-1, -1.2989606383463778e-1, + -3.1478872752284357e-5, 9.0510635276848131e-2, -9.2828824411184397e-2, + 4.4412112839877808e-2, 2.7779236316835888e-7, -1.7229543805449697e-2, + 1.4182925050891573e-2, -5.6214161633747336e-3, -2.39598509186381e-9, + 1.6029634366079908e-3, -1.1606784674435773e-3, 4.1001337768153873e-4, + 1.8365800754090661e-11, -9.5844256563655903e-5, 6.3643062337764708e-5, + -2.076250624489065e-5, -1.1806020912804483e-13, 4.2131808239120649e-6, + -2.6262241337012467e-6, 8.0770620494930662e-7, 6.0125912123632725e-16, + -1.4729737374018841e-7}, + {-1.9994542198219728e-1, -1.5056113040026424e-2, 3.6470239469348489e-1, + -4.6435192311733545e-1, 2.6640934719197893e-1, 3.4038266027147191e-5, + -1.3784338709329624e-1, 1.276467178337056e-1, -5.6213828755200985e-2, + -1.753150885483011e-7, 1.9235592956768113e-2, -1.5088821281095315e-2, + 5.7401854451350123e-3, 1.0622382710310225e-9, -1.5335082692563998e-3, + 1.0819320643228214e-3, -3.7372510193945659e-4, -6.6170909729031985e-12, + 8.4263617380909628e-5, -5.5150706827483479e-5, 1.7769536448348069e-5, + 3.8827923210205533e-14, -3.53513697488768e-6, 2.1865832130045269e-6, + -6.6812849447625594e-7}, + {7.2438608504029431e-1, -1.3918010932653375, 1.0654143352413968, + 1.876173868950258e-4, -8.2705501176152696e-1, 8.9352433347828414e-1, + -4.4971003995291339e-1, -1.6107401567546652e-6, 1.9235590165271091e-1, + -1.6597702160042609e-1, 6.8882222681814333e-2, 1.3910091724608687e-8, + -2.146911561508663e-2, 1.6228980898865892e-2, -5.9796016172584256e-3, + -1.1287469112826745e-10, 1.5167451119784857e-3, -1.0478634293553899e-3, + 3.5539072889126421e-4, 8.1704322111801517e-13, -7.7773013442452395e-5, + 5.0291413897007722e-5, -1.6035083867000518e-5, 1.2469354315487605e-14, + 3.1369106244517615e-6}, + {1.6668949727276811, 1.165462765994632e-1, -3.3288393225018906, + 4.4692325482864037, -2.6977693045875807, -2.600667859891061e-4, + 1.5389017615694539, -1.4937962361134612, 6.8881964633233148e-1, + 1.3077482004552385e-6, -2.5762963325596288e-1, 2.1097676102125449e-1, + -8.3714408359219882e-2, -7.7920428881354753e-9, 2.4267923064833599e-2, + -1.7813678334552311e-2, 6.3970330388900056e-3, 4.9430807090480523e-11, + -1.5554602758465635e-3, 1.0561196919903214e-3, -3.5277184460472902e-4, + 9.3002334645022459e-14, 7.5285855026557172e-5, -4.8186515569156351e-5, + 1.5227271505597605e-5}, + {-6.6188298861372935, 1.3397985455142589e+1, -1.0789350606845146e+1, + -1.4352254537875018e-3, 9.2333694596189809, -1.0456552819547769e+1, + 5.5105526029033471, 1.2024439690716742e-5, -2.5762961164755816, + 2.3207442745387179, -1.0045728797216284, -1.0207833290021914e-7, + 3.3975092171169466e-1, -2.6720517450757468e-1, 1.0235252851562706e-1, + 8.4329730484871625e-10, -2.7998284958442595e-2, 2.0066274144976813e-2, + -7.0554368915086242e-3, 1.9402238183698188e-12, 1.6562888105449611e-3, + -1.1082898580743683e-3, 3.654545161310169e-4, -5.1290032026971794e-11, + -7.6340103696869031e-5}, + {-1.7112706061976095e+1, -1.1208044642899116, 3.7131966511885444e+1, + -5.2298271025348962e+1, 3.3058589696624618e+1, 2.4791298976200222e-3, + -2.061089403411526e+1, 2.088672775145582e+1, -1.0045703956517752e+1, + -1.2238783449063012e-5, 4.0770134274221141, -3.473667358470195, + 1.4329352617312006, 7.1359914411879712e-8, -4.4797257159115612e-1, + 3.4112666080644461e-1, -1.2699786326594923e-1, -2.8953677269081528e-10, + 3.3125776278259863e-2, -2.3274087021036101e-2, 8.0399993503648882e-3, + -1.177805216235265e-9, -1.8321624891071668e-3, 1.2108282933588665e-3, + -3.9479941246822517e-4}, + {7.389033153567425e+1, -1.5680141270402273e+2, 1.322177542759164e+2, + 1.3692876877324546e-2, -1.2366496885920151e+2, 1.4620689391062729e+2, + -8.0365587724865346e+1, -1.1259851148881298e-4, 4.0770132196179938e+1, + -3.8210340013273034e+1, 1.719522294277362e+1, 9.3519707955168356e-7, + -6.2716159907747034, 5.1168999071852637, -2.0319658112299095, + -4.9507215582761543e-9, 5.9626397294332597e-1, -4.4220765337238094e-1, + 1.6079998700166273e-1, -2.4733786203223402e-8, -4.0307574759979762e-2, + 2.7849050747097869e-2, -9.4751858992054221e-3, 6.419922235909132e-6, + 2.1250180774699461e-3}, + {2.1216837098382522e+2, 1.3107863022633868e+1, -4.9698285932871748e+2, + 7.3121595266969204e+2, -4.8213821720890847e+2, -2.8817248692894889e-2, + 3.2616720302947102e+2, -3.4389340280087117e+2, 1.7195193870816232e+2, + 1.4038077378096158e-4, -7.52594195897599e+1, 6.651969984520934e+1, + -2.8447519748152462e+1, -7.613702615875391e-7, 9.5402237105304373, + -7.5175301113311376, 2.8943997568871961, -4.6612194999538201e-7, + -8.0615149598794088e-1, 5.8483006570631029e-1, -2.0845408972964956e-1, + 1.4765818959305817e-4, 5.1000433863753019e-2, -3.3066252141883665e-2, + 1.5109265210467774e-2}, + {-9.8959643098322368e+2, 2.1925555360905233e+3, -1.9283586782723356e+3, + -1.5925738122215253e-1, 1.9569985945919857e+3, -2.4072514765081556e+3, + 1.3756149959336496e+3, 1.2920735237496668e-3, -7.525941715948055e+2, + 7.3171668742208716e+2, -3.4137023466220065e+2, -9.9857390260608043e-6, + 1.3356313181291573e+2, -1.1276295161252794e+2, 4.6310396098204458e+1, + -7.9237387133614756e-6, -1.4510726927018646e+1, 1.1111771248100563e+1, + -4.1690817945270892, 3.1008219800117808e-3, 1.1220095449981468, + -7.6052379926149916e-1, 3.6262236505085254e-1, 2.216867741940747e-1, + 4.8683443692930507e-1}}; + + int k, n, sgn; + int maxpow = 0; + static scalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + scalar_t lambda = x / a; + scalar_t sigma = (x - a) / a; + scalar_t eta, res, ck, ckterm, term, absterm; + scalar_t absoldterm = INFINITY; + scalar_t etapow[25] = {1}; + scalar_t sum = 0; + scalar_t afac = 1; + + if (igam) { + sgn = -1; + } + else { + sgn = 1; + } + + if (lambda > 1) { + eta = std::sqrt(-2 * (std::log1p(sigma) - sigma)); + } + else if (lambda < 1) { + eta = -std::sqrt(-2 * (std::log1p(sigma) - sigma)); + } + else { + eta = 0; + } + res = 0.5 * std::erfc(sgn * eta * std::sqrt(a / 2)); + + for (k = 0; k < 25; k++) { + ck = d[k][0]; + for (n = 1; n < 25; n++) { + if (n > maxpow) { + etapow[n] = eta * etapow[n-1]; + maxpow += 1; + } + ckterm = d[k][n]*etapow[n]; + ck += ckterm; + if (std::fabs(ckterm) < MACHEP * std::fabs(ck)) { + break; + } + } + term = ck * afac; + absterm = std::fabs(term); + if (absterm > absoldterm) { + break; + } + sum += term; + if (absterm < MACHEP * std::fabs(sum)) { + break; + } + absoldterm = absterm; + afac /= a; + } + res += sgn * std::exp(-0.5 * a * eta * eta) * sum / std::sqrt(2 * M_PI * a); + + return res; +} + +template +static C10_DEVICE scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar_t x) { + // Compute igamc using DLMF 8.9.2. [igam1] + int i; + scalar_t ans, ax, c, yc, r, t, y, z; + scalar_t pk, pkm1, pkm2, qk, qkm1, qkm2; + int MAXITER = 2000; + static scalar_t MACHEP = std::is_same::value ? + 1.11022302462515654042E-16 : 5.9604644775390625E-8; + static scalar_t BIG = std::is_same::value ? + 4.503599627370496e15 : 16777216.; + static scalar_t BIGINV = std::is_same::value ? + 2.22044604925031308085e-16 : 5.9604644775390625E-8; + + ax = _igam_helper_fac(a, x); + if (ax == 0.0) { + return 0.0; + } + + /* continued fraction */ + y = 1.0 - a; + z = x + y + 1.0; + c = 0.0; + pkm2 = 1.0; + qkm2 = x; + pkm1 = x + 1.0; + qkm1 = z * x; + ans = pkm1 / qkm1; + + for (i = 0; i < MAXITER; i++) { + c += 1.0; + y += 1.0; + z += 2.0; + yc = y * c; + pk = pkm1 * z - pkm2 * yc; + qk = qkm1 * z - qkm2 * yc; + if (qk != 0) { + r = pk / qk; + t = std::fabs((ans - r) / r); + ans = r; + } + else { + t = 1.0; + } + pkm2 = pkm1; + pkm1 = pk; + qkm2 = qkm1; + qkm1 = qk; + if (std::fabs(pk) > BIG) { + pkm2 *= BIGINV; + pkm1 *= BIGINV; + qkm2 *= BIGINV; + qkm1 *= BIGINV; + } + if (t <= MACHEP) { + break; + } + } + return ans * ax; +} + +template +static C10_DEVICE inline scalar_t calc_igammac(scalar_t a, scalar_t x) { + /* the calculation of the regularized upper incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.4 [igam1]) + * - if x > 1.1 and x < a, using the substraction from the regularized lower + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (5) + */ + scalar_t absxma_a; + + static scalar_t SMALL = 20.0; + static scalar_t LARGE = 200.0; + static scalar_t SMALLRATIO = 0.3; + static scalar_t LARGERATIO = 4.5; + + // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e., + // at most 1 of them can be 0), where igammac(0, x) = 0.0 iff x > 0. + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 0.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 1.0; + } + else if (std::isinf(a)) { + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + return 1.0; + } + else if (std::isinf(x)) { + return 0.0; + } + + absxma_a = std::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 0); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 0); + } + + if (x > 1.1) { + if (x < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_continued_fraction(a, x); + } + } + else if (x <= 0.5) { + if (-0.4 / std::log(x) < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } + else { + if (x * 1.1 < a) { + return 1.0 - _igam_helper_series(a, x); + } + else { + return _igamc_helper_series(a, x); + } + } +} + +template +static C10_DEVICE inline scalar_t calc_igamma(scalar_t a, scalar_t x) { + /* the calculation of the regularized lower incomplete gamma function + * is done differently based on the values of a and x: + * - if x and/or a is at the boundary of defined region, then assign the + * result at the boundary + * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for + * Large Parameter (see DLMF 8.12.3 [igam1]) + * - if x > 1 and x > a, using the substraction from the regularized upper + * incomplete gamma + * - otherwise, calculate the series from [igam2] eq (4) + */ + scalar_t absxma_a; + static scalar_t SMALL = 20.0; + static scalar_t LARGE = 200.0; + static scalar_t SMALLRATIO = 0.3; + static scalar_t LARGERATIO = 4.5; + + // boundary values following SciPy + // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e., + // at most 1 of them can be 0), where igamma(0, x) = 1.0 iff x > 0. + if ((x < 0) || (a < 0)) { + // out of defined-region of the function + return std::numeric_limits::quiet_NaN(); + } + else if (a == 0) { + if (x > 0) { + return 1.0; + } + else { + return std::numeric_limits::quiet_NaN(); + } + } + else if (x == 0) { + return 0.0; // zero integration limit + } + else if (std::isinf(a)) { + if (std::isinf(x)) { + return std::numeric_limits::quiet_NaN(); + } + return 0.0; + } + else if (std::isinf(x)) { + return 1.0; + } + + /* Asymptotic regime where a ~ x. See [igam2] */ + absxma_a = std::fabs(x - a) / a; + if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) { + return _igam_helper_asymptotic_series(a, x, 1); + } + else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) { + return _igam_helper_asymptotic_series(a, x, 1); + } + + if ((x > 1.0) && (x > a)) { + return 1.0 - calc_igammac(a, x); + } + + return _igam_helper_series(a, x); +} + +__global__ void igamma_kernel(float* __restrict__ x, + float* out, + float alpha, + int batch_shape) +{ + int tidx = threadIdx.x; + int start = batch_shape / blockDim.x * tidx; + int end = threadIdx.x == blockDim.x - 1 ? batch_shape : start + batch_shape / blockDim.x; + float* bx = x+batch_shape*blockIdx.x; + float* bout = out + batch_shape * blockIdx.x; + for(int i=start;i. +# Wenyang Zhou <576825820@qq.com> +# Guoye Yang <498731903@qq.com> +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +import numpy as np +import math +from collections.abc import Sequence,Iterable + +def knn(unknown, known, k): + ''' find k neighbors for unknown array from known array + + Args: + + unknown (var): shape [b, n, c] + known (var): shape [b, m, c] + k (int) + + ''' + b, n, c = unknown.shape + _, m, _ = known.shape + dists2 = jt.empty((b, n, k), dtype="float") + idx = jt.empty((b, n, k), dtype="int") + src = ''' +__inline_static__ +@python.jittor.auto_parallel(2, block_num=256) +void knn_kernel(int b, int batch_index, int n, int index, int m, + const float *__restrict__ unknown, + const float *__restrict__ known, + float *__restrict__ dist2, + int *__restrict__ idx) { + +#define K %s + unknown += batch_index * n * 3; + known += batch_index * m * 3; + dist2 += batch_index * n * K; + idx += batch_index * n * K; + int j = index; + { + float ux = unknown[j * 3 + 0]; + float uy = unknown[j * 3 + 1]; + float uz = unknown[j * 3 + 2]; + + float tmp_dist[K]; + int tmp_idx[K]; + #pragma unroll + for (int i=0; i first) { + tmp_dist[K-1-i] = tmp_dist[K-2-i]; + tmp_idx[K-1-i] = tmp_idx[K-2-i]; + } + tmp_dist[first] = d; + tmp_idx[first] = k; + } + #pragma unroll + for (int i=0; ishape[0], 0, in0->shape[1], 0, in1->shape[1], in0_p, in1_p, out0_p, out1_p); + ''' % k + return jt.code([unknown, known], [dists2, idx], + cpu_src=src, + cuda_src=src) + +def index_add_(x, dim, index, tensor): + """ Take out each index subscript vector of the dim dimension and add the corresponding tensor variable. + + Example: + + x = jt.ones((5,3)) + tensor = jt.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + index = jt.array([0,4,2]) + x.index_add_(0, index, tensor) + print(x) + + >>> jt.Var([[ 2., 3., 4.], + [ 1., 1., 1.], + [ 8., 9., 10.], + [ 1., 1., 1.], + [ 5., 6., 7.]]) + """ + assert len(index.shape) == 1 + assert tensor.shape[0] == index.shape[0] + x[(slice(None,),)*dim+(index,)] += tensor +jt.Var.index_add_ = index_add_ + +def __copy__(x): + return x.copy().detach() +jt.Var.__copy__ = __copy__ + +def __deepcopy__(x,memo): + result = x.copy().detach() + memo[id(x)]=result + return result +jt.Var.__deepcopy__ = __deepcopy__ + +def __len__(x): + return x.shape[0] +jt.Var.__len__ = __len__ + +def __iter__(x): + result = [] + for i in range(x.shape[0]): + result.append(x[i]) + return result.__iter__() +jt.Var.__iter__ = __iter__ + +def __contains__(x, key): + return bool((x == key).any()) +jt.Var.__contains__ = __contains__ + +def new(x, *args): + if len(args) != 1 or isinstance(args[0], int): + return jt.empty(args, x.dtype) + return jt.array(args[0]).cast(x.dtype) +jt.Var.new = new + +def __index__(x): + return int(x.item()) +jt.Var.__index__ = __index__ + +def sort(input, dim=-1, descending=False, stable=False): + index, value = jt.argsort(input, dim, descending) + return value, index +jt.Var.sort = sort + +def all(x, dim=()): + return x.all_(dim).bool() +jt.Var.all = all + +def any(x,dim=()): + return x.any_(dim).bool() +jt.Var.any = any + +def bernoulli(input): + return (input>jt.rand_like(input)).cast(input.dtype) + +def repeat(x, *shape): + r''' + Repeats this var along the specified dimensions. + + Args: + + x (var): jittor var. + + shape (tuple): int or tuple. The number of times to repeat this var along each dimension. + + Example: + + >>> x = jt.array([1, 2, 3]) + + >>> x.repeat(4, 2) + [[ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3], + [ 1, 2, 3, 1, 2, 3]] + + >>> x.repeat(4, 2, 1).size() + [4, 2, 3,] + ''' + if len(shape) == 1 and isinstance(shape[0], Sequence): + shape = shape[0] + len_x_shape = len(x.shape) + len_shape = len(shape) + x_shape = x.shape + rep_shape = shape + if len_x_shape < len_shape: + x_shape = (len_shape - len_x_shape) * [1] + x.shape + x = x.broadcast(x_shape) + elif len_x_shape > len_shape: + rep_shape = (len_x_shape - len_shape) * [1] + list(shape) + + reshape_shape = [] + broadcast_shape = [] + for x_s,r_s in zip(x_shape,rep_shape): + if r_s != 1: + reshape_shape.append(1) + broadcast_shape.append(r_s) + reshape_shape.append(x_s) + broadcast_shape.append(1) + + x = x.reshape(reshape_shape) + x = x.broadcast(broadcast_shape) + + tar_shape = (np.array(x_shape) * np.array(rep_shape)).tolist() + + x = x.reshape(tar_shape) + return x + +jt.Var.repeat = repeat +# tile = jt.Var.tile = repeat +ne = jt.Var.ne = jt.Var.not_equal + +def repeat_interleave(x,repeats,dim=None): + # TODO repeats is jt.Var + assert isinstance(repeats,int) + if dim == None: + x = x.reshape(-1) + dim=0 + if dim<0: dim+=x.ndim + + tar_shape = list(x.shape) + x_shape = list(x.shape) + tar_shape[dim] = tar_shape[dim]*repeats + dims = [] + for i in range(len(tar_shape)): + if dim==i: + dims.append(f"i{i}/{repeats}") + else: + dims.append(f"i{i}") + return x.reindex(tar_shape,dims) + +jt.Var.repeat_interleave = repeat_interleave + +def chunk(x, chunks, dim=0): + r''' + Splits a var into a specific number of chunks. Each chunk is a view of the input var. + + Last chunk will be smaller if the var size along the given dimension dim is not divisible by chunks. + + Args: + + input (var) – the var to split. + + chunks (int) – number of chunks to return. + + dim (int) – dimension along which to split the var. + + Example: + + >>> x = jt.random((10,3,3)) + + >>> res = jt.chunk(x, 2, 0) + + >>> print(res[0].shape, res[1].shape) + [5,3,3,] [5,3,3,] + ''' + if dim<0: + dim += x.ndim + l = x.shape[dim] + res = [] + if l <= chunks: + for i in range(l): + res.append(x[(slice(None,),)*dim+([i,],)]) + else: + nums = (l-1) // chunks + 1 + for i in range(chunks-1): + res.append(x[(slice(None,),)*dim+(slice(i*nums,(i+1)*nums),)]) + if (i+1)*nums < l: + res.append(x[(slice(None,),)*dim+(slice((i+1)*nums,None),)]) + return res +jt.Var.chunk = chunk + + +def expand(x, *shape): + ''' Expand and broadcast this array, -1 represents this dimension is not changed. + +Example:: + + a = jt.zeros((3,1)) + b = a.expand(3, 4) + assert b.shape == (3,4) + b = a.expand(-1, 4) + assert b.shape == (3,4) + b = a.expand((3, 4)) + assert b.shape == (3,4) + b = a.expand((-1, 4)) + assert b.shape == (3,4) + + ''' + if len(shape) == 1 and isinstance(shape[0], (tuple,list,jt.NanoVector)): + shape = shape[0] + shape = list(shape) + offset = len(shape) - len(x.shape) + for i in range(len(x.shape)): + if shape[offset + i] == -1: + shape[offset + i] = x.shape[i] + return x.broadcast(shape) +jt.Var.expand = expand + + +def t(x): + pose = [i for i in range(x.ndim)] + pose[-1], pose[-2] = pose[-2], pose[-1] + return x.transpose(*pose) +jt.Var.t = t + +def median(x,dim=None,keepdim=False, keepdims=False): + keepdim = keepdim or keepdims + if dim is None: + x = x.reshape(-1) + dim=0 + _,x = jt.argsort(x, dim) + slices = [slice(None) for i in range(dim-1)] + k = (x.shape[dim]-1)//2 + if keepdim: + slices.append(slice(k,k+1)) + else: + slices.append(k) + return x[tuple(slices)] + +jt.Var.median = median + +def stack(x, dim=0): + r''' + Concatenates sequence of vars along a new dimension. + + All vars need to be of the same size. + + Args: + + x (sequence of vars) – sequence of vars to concatenate. + + dim (int) – dimension to insert. Has to be between 0 and the number of dimensions of concatenated vars (inclusive). + + Example: + + >>> a1 = jt.array([[1,2,3]]) + + >>> a2 = jt.array([[4,5,6]]) + + >>> jt.stack([a1, a2], 0) + [[[1 2 3] + [[4 5 6]]] + ''' + assert isinstance(x, Sequence) + for i,x_ in enumerate(x): + x[i] = jt.array(x_) + if len(x) < 2: + return x[0].unsqueeze(dim) + + res = [x_.unsqueeze(dim) for x_ in x] + return jt.concat(res, dim=dim) +jt.Var.stack = stack + +def flip(x, dim=0): + r''' + Reverse the order of a n-D var along given axis in dims. + + Args: + + input (var) – the input var. + + dims (a list or tuple) – axis to flip on. + + Example: + + >>> x = jt.array([[1,2,3,4]]) + + >>> x.flip(1) + [[4 3 2 1]] + ''' + if isinstance(dim, int): + dim = [dim] + for i in range(len(dim)): + if dim[i]<0: + dim[i] += x.ndim + assert dim[i]>=0 and dim[i]>> input = jt.random((6,3)) + + >>> other = jt.random((6,3)) + + >>> jt.cross(input, other, dim=1) + [[-0.42732686 0.6827885 -0.49206433] + [ 0.4651107 0.27036983 -0.5580432 ] + [-0.31933784 0.10543461 0.09676848] + [-0.58346975 -0.21417202 0.55176204] + [-0.40861478 0.01496297 0.38638002] + [ 0.18393655 -0.04907863 -0.17928357]] + + >>> jt.cross(input, other) + [[-0.42732686 0.6827885 -0.49206433] + [ 0.4651107 0.27036983 -0.5580432 ] + [-0.31933784 0.10543461 0.09676848] + [-0.58346975 -0.21417202 0.55176204] + [-0.40861478 0.01496297 0.38638002] + [ 0.18393655 -0.04907863 -0.17928357]] + ''' + assert input.shape==other.shape, "input shape and other shape must be same" + if dim < 0: dim += len(input.shape) + assert input.shape[dim] == 3, "input dim shape must be 3" + a1 = input[(slice(None,),)*dim+(1,)]*other[(slice(None,),)*dim+(2,)]-input[(slice(None,),)*dim+(2,)]*other[(slice(None,),)*dim+(1,)] + a2 = input[(slice(None,),)*dim+(2,)]*other[(slice(None,),)*dim+(0,)]-input[(slice(None,),)*dim+(0,)]*other[(slice(None,),)*dim+(2,)] + a3 = input[(slice(None,),)*dim+(0,)]*other[(slice(None,),)*dim+(1,)]-input[(slice(None,),)*dim+(1,)]*other[(slice(None,),)*dim+(0,)] + return jt.concat([a1.unsqueeze(dim),a2.unsqueeze(dim),a3.unsqueeze(dim)], dim=dim) +jt.Var.cross = cross + +def normalize(input, p=2, dim=1, eps=1e-30): + r''' + Performs L_p normalization of inputs over specified dimension. + + Args: + + input – input array of any shape + + p (float) – the exponent value in the norm formulation. Default: 2 + + dim (int) – the dimension to reduce. Default: 1 + + eps (float) – small value to avoid division by zero. Default: 1e-12 + + Example: + + >>> x = jt.random((6,3)) + [[0.18777736 0.9739261 0.77647036] + [0.13710196 0.27282116 0.30533272] + [0.7272278 0.5174613 0.9719775 ] + [0.02566639 0.37504175 0.32676998] + [0.0231761 0.5207773 0.70337296] + [0.58966476 0.49547017 0.36724383]] + + >>> jt.normalize(x) + [[0.14907198 0.7731768 0.61642134] + [0.31750825 0.63181424 0.7071063 ] + [0.5510936 0.39213243 0.736565 ] + [0.05152962 0.7529597 0.656046 ] + [0.02647221 0.59484214 0.80340654] + [0.6910677 0.58067477 0.4303977 ]] + ''' + return input / input.norm(p, dim, True, eps) +jt.Var.normalize = normalize + +def unbind(x, dim=0): + r''' + Removes a var dimension. + + Returns a tuple of all slices along a given dimension, already without it. + + Args: + + input (var) – the var to unbind + + dim (int) – dimension to remove + + Example: + + a = jt.random((3,3)) + b = jt.unbind(a, 0) + + ''' + if dim < 0: dim += len(x.shape) + return [x[(slice(None),)*dim+(i,)] for i in range(x.shape[dim])] + +jt.Var.unbind = unbind + +def make_grid(x, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0): + assert isinstance(range, tuple) or range is None + assert scale_each == False + if isinstance(x, list): x = jt.stack(x) + assert isinstance(x, jt.Var) + if normalize: + if range is None: x = (x - x.min()) / (x.max() - x.min()) + else: x = (x - range[0]) / (range[1] - range[0]) + if x.ndim < 4: return x + if x.ndim == 4 and x.shape[0] <= 1: return x + nrow = min(nrow, x.shape[0]) + b,c,h,w = x.shape + ncol = math.ceil(b / nrow) + return x.reindex([c, h*ncol+(ncol+1)*padding, w*nrow+(nrow+1)*padding], + [f"i1/{padding+h}*{nrow}+i2/{padding+w}", "i0", + f"i1-i1/{padding+h}*{padding+h}-{padding}", f"i2-i2/{padding+w}*{padding+w}-{padding}"], overflow_value=pad_value) + +def save_image( + x, + filepath, + nrow: int = 8, + padding: int = 2, + normalize: bool = False, + range = None, + scale_each = False, + pad_value = 0, + format = None +): + from PIL import Image + grid = make_grid(x, nrow=nrow, padding=padding, pad_value=pad_value, + normalize=normalize, range=range, scale_each=scale_each) + + ndarr = (grid*255+0.5).clamp(0, 255).permute(1, 2, 0).uint8().numpy() + im = Image.fromarray(ndarr) + im.save(filepath, format=format) + + +def _ntuple(n): + def parse(x): + if isinstance(x, Iterable): + return x + return tuple([x]*n) + return parse + +_single = _ntuple(1) +_pair = _ntuple(2) +_triple = _ntuple(3) +_quadruple = _ntuple(4) + + +def unique( + input: jt.Var, + return_inverse: bool=False, + return_counts: bool=False, + dim: int=None): + + r''' + Returns the unique elements of the input tensor. + + Args: + + input (var) – the input var + + return_inverse (bool) – Whether to also return the indices for where elements in the original input ended up in the returned unique list. default: False + + return_counts (bool) – Whether to also return the counts for each unique element. default: False + + dim (int) – the dimension to apply unique. If None, the unique of the flattened input is returned. default: None + + Example: + + >>> jittor.unique(jittor.array([1, 3, 2, 3])) + jt.Var([1 2 3], dtype=int32) + + >>> jittor.unique(jittor.array([1, 3, 2, 3, 2]), return_inverse=True, return_counts=True) + (jt.Var([1 2 3], dtype=int32), jt.Var([0 2 1 2 1], dtype=int32), jt.Var([1 2 2], dtype=int32)) + + >>> jittor.unique(jittor.array([[1, 3], [2, 3]]), return_inverse=True) + (jt.Var([1 2 3], dtype=int32), jt.Var([[0 2] + [1 2]], dtype=int32)) + + >>> jittor.unique(jittor.array([[1, 3], [1, 3]]), dim=0) + jt.Var([[1 3]], dtype=int32) + ''' + + temp_shape = None + if dim == None: + temp_shape = list(input.shape) + input_flatten = input.flatten() + dim = 0 + else: + input_flatten = input + + input_flatten = input_flatten.transpose(dim, 0) + orig_shape = input_flatten.shape + input_flatten = input_flatten.view(orig_shape[0], -1) + + with jt.flag_scope(compile_options = {"FLAGS: --extended-lambda ": 1} if jt.flags.use_cuda else {}): + indice = jt.code((input_flatten.shape[0], ), 'int32', [input_flatten], + cpu_header=''' + #include + ''', + cpu_src=''' + @alias(input_flatten, in0) + @alias(indice, out) + + int dimlen = input_flatten_shape0, dimsize = input_flatten_shape1; + for(int i = 0; i < dimlen; ++i) @indice(i) = i; + std::sort(&@indice(0), &@indice(dimlen), [&](int a, int b){ + for(int i = 0; i < dimsize; ++i) { + int lhs = @input_flatten(a, i), rhs = @input_flatten(b, i); + if (lhs != rhs) return lhs < rhs; + } + return false; + }); + ''', + cuda_header=''' + #undef out + #include + #include + #include + #include + #include + + #include + #include + ''', + cuda_src= + ''' + @alias(input_flatten, in0) + @alias(indice, out) + int dimlen = indice_shape0, dimsize = input_flatten_shape1; + + if (dimsize == 1) { + size_t raw_allocation, d_allocation, temp_storage_bytes = 0; + void *d_temp_storage = NULL; + int32_t* raw_ptr = (int32_t*)exe.allocator->alloc(dimlen * (sizeof(int32_t) + sizeof(input_flatten_type)), raw_allocation); + + thrust::device_ptr arange_ptr = thrust::device_pointer_cast(raw_ptr); + thrust::sequence(arange_ptr, arange_ptr + dimlen); + + cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, input_flatten_p, + (input_flatten_type*)(raw_ptr + dimlen), thrust::raw_pointer_cast(arange_ptr), indice_p, dimlen); + d_temp_storage = exe.allocator->alloc(temp_storage_bytes, d_allocation); + cub::DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, input_flatten_p, + (input_flatten_type*)(raw_ptr + dimlen), thrust::raw_pointer_cast(arange_ptr), indice_p, dimlen); + + exe.allocator->free(raw_ptr, dimlen * (sizeof(int) + sizeof(input_flatten_type)), raw_allocation); + exe.allocator->free(d_temp_storage, temp_storage_bytes, d_allocation); + } else { + thrust::device_ptr input_ptr = thrust::device_pointer_cast(input_flatten_p); + thrust::device_ptr indice_ptr = thrust::device_pointer_cast(indice_p); + + thrust::sequence(indice_ptr, indice_ptr + dimlen); + thrust::sort(thrust::device, indice_ptr, indice_ptr + dimlen, + [=] __device__ (int32_t a, int32_t b)->bool { + for(int i = 0; i < dimsize; ++i) { + input_flatten_type lhs = input_ptr[i + a * dimsize], + rhs = input_ptr[i + b * dimsize]; + if (lhs != rhs) return lhs < rhs; + } + return false; + }); + } + ''' + ) + input_sorted = input_flatten[indice][:] + + dimlen = indice.shape[0] + + diff = jt.logical_not(jt.all(input_sorted[1:] == input_sorted[: -1], 1)) + diff = jt.concat([jt.Var([False]), diff], 0) + diff = jt.array(diff, dtype = jt.int32) + + with jt.flag_scope(compile_options = {"FLAGS: --extended-lambda ": 1} if jt.flags.use_cuda else {}): + output, inverse = jt.code( + [(-input_sorted.shape[0], ), (indice.shape)], + [input_sorted.dtype, indice.dtype], + [input_sorted, diff, indice], + cpu_header=''' + #include + @alias(input_sorted, in0) + @alias(diff, in1) + @alias(indice, in2) + @alias(output, out0) + @alias(inverse, out1) + ''', + cpu_src= + f"bool return_inverse = {int(return_inverse)};" + + ''' + int tot = -1; + for (int i = 0; i < input_sorted_shape0; ++i) { + if (i == 0 || @diff(i)) { + ++tot; @output(tot) = i; + } + if (return_inverse) + @inverse(@indice(i)) = tot; + } + output->set_shape({tot + 1}); + ''', + cuda_header=''' + #undef out + + #include + #include + #include + #include + #include + + @alias(input_sorted, in0) + @alias(diff, in1) + @alias(indice, in2) + @alias(output, out0) + @alias(inverse, out1) + ''', + cuda_src= + f"bool return_inverse = {int(return_inverse)};" + + ''' + int dimlen = input_sorted_shape0, dimsize = input_sorted_shape1; + size_t raw_allocation; + int32_t* raw_ptr = (int32_t*)exe.allocator->alloc(2 * dimlen * sizeof(int), raw_allocation); + + thrust::device_ptr diff_ptr = thrust::device_pointer_cast(diff_p), + inverse_ptr = thrust::device_pointer_cast(inverse_p), + array_ptr = thrust::device_pointer_cast(raw_ptr), + sum_ptr = thrust::device_pointer_cast(raw_ptr + dimlen), + indice_ptr = thrust::device_pointer_cast(indice_p); + thrust::device_ptr input_ptr = thrust::device_pointer_cast(input_sorted_p); + + if (return_inverse) { + thrust::inclusive_scan(diff_ptr, diff_ptr + dimlen, sum_ptr); + thrust::scatter(sum_ptr, sum_ptr + dimlen, indice_ptr, inverse_ptr); + } + + thrust::sequence(array_ptr, array_ptr + dimlen); + int32_t num = thrust::unique(array_ptr, array_ptr + dimlen, + [=] __device__ (int32_t a, int32_t b)->bool { + for(int i = 0; i < dimsize; ++i) { + input_sorted_type lhs = input_ptr[i + a * dimsize], + rhs = input_ptr[i + b * dimsize]; + if (lhs != rhs) return false; + } + return true; + }) - array_ptr; + + cudaMemcpy(output_p, raw_ptr, sizeof(int32_t) * num, cudaMemcpyDeviceToDevice); + exe.allocator->free(raw_ptr, 2 * dimlen * sizeof(int32_t), raw_allocation); + output->set_shape({ num }); + ''' + ) + indice_shape = (output.shape[0], ) + output = input_sorted[output][:] + + new_shape = list(orig_shape[1:]) + new_shape.insert(0, -1) + output = output.view(new_shape).transpose(dim, 0) + if temp_shape != None: + inverse = inverse.view(temp_shape).transpose(dim, 0) + + if return_inverse: + if return_counts: + counts = jt.zeros(indice_shape, dtype=jt.int32) + jt.scatter_(counts, 0, inverse.flatten(), jt.ones(dimlen), reduce='add') + return output, inverse, counts + else: + return output, inverse + else: + return output + +jt.Var.unique = unique + + +def hypot(a,b): + return jt.sqrt(a.sqr()+b.sqr()) + +def rad2deg(x): + return 180 * x / np.pi + +jt.Var.rad2deg = rad2deg + +def deg2rad(x): + return x * np.pi / 180. + +jt.Var.deg2rad = deg2rad + +def arctan2(y,x): + angle = jt.zeros(x.shape,dtype=x.dtype) + x = (x!=0.0).ternary(x, 1e-30) + angle = (y/x).arctan() + mask = (x<0)&(y<0) + angle = angle - mask*np.pi + mask = (x<0)&(y>=0) + angle = angle + mask*np.pi + return angle +atan2 = arctan2 + + +def nonzero(x): + r''' + Return the index of the elements of input tensor which are not equal to zero. + ''' + x = jt.where(x) + x = [xx.unsqueeze(1) for xx in x] + if len(x)<2: + return x[0] + x = jt.concat(x,dim=1) + return x + +jt.Var.nonzero = nonzero + + +def arange(start=0, end=None, step=1,dtype=None): + if isinstance(start, jt.Var): start = start.item() + if isinstance(end, jt.Var): end = end.item() + if isinstance(step, jt.Var): step = step.item() + if end is None: + end,start = start,0 + l = round((end-start)//step)+1 + if (l-1)*step+start>=end: + l-=1 + x = jt.index((l,),0) + x = x*step+start + if dtype is not None: + x= x.cast(dtype) + return x + +def log2(x): + return jt.log(x)/math.log(2.0) + +jt.Var.log2 = log2 + +def meshgrid(*tensors): + r''' + Take N tensors, each of which can be 1-dimensional vector, and create N n-dimensional grids, + where the i th grid is defined by expanding the i th input over dimensions defined by other inputs. + ''' + if len(tensors)==1 and isinstance(tensors[0], list): + tensors = tensors[0] + size = len(tensors) + shape = [] + for i in range(size): + assert isinstance(tensors[i],jt.Var) and tensors[i].ndim==1 + shape.append(tensors[i].shape[0]) + grids = [] + view_shape = [1]*size + for i in range(size): + vs = view_shape[:] + vs[i]=-1 + grids.append(tensors[i].reshape(vs).expand(shape)) + + return grids + + +def split(d, split_size, dim=0): + r''' + Splits the tensor into chunks. Each chunk is a view of the original tensor. + + If split_size is an integer type, then tensor will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimension dim is not divisible by split_size. + + If split_size is a list, then tensor will be split into len(split_size) chunks with sizes in dim according to split_size_or_sections. + + Args: + d (Tensor) – tensor to split. + + split_size (int) or (list(int)) – size of a single chunk or list of sizes for each chunk + + dim (int) – dimension along which to split the tensor. + ''' + if isinstance(split_size,int): + shape = d.shape[dim] + if shape % split_size == 0: + split_size = [split_size]*(shape//split_size) + else: + split_size = [split_size]*(shape//split_size)+[shape%split_size] + if isinstance(split_size, Iterable): + assert sum(split_size)==d.shape[dim] + + if dim<0: + dim+=d.ndim + + ans = [] + last = 0 + s_last = len(split_size)-1 + gopt_disable = jt.flags.gopt_disable or jt.flags.use_acl + for j, i in enumerate(split_size): + if i==0: + shape = list(d.shape) + shape[dim]=0 + new_d = jt.zeros(tuple(shape),dtype=d.dtype) + ans.append(new_d) + continue + + ss = (slice(None),)*dim+(slice(last,last+i),) + if gopt_disable: + new_d = d.getitem(ss) + else: + new_d, d = d.getitem(ss, int(j==s_last)) + + last +=i + ans.append(new_d) + return tuple(ans) + +jt.Var.split = split + +def tolist(x): + return x.numpy().tolist() +jt.Var.tolist = tolist + +def view_as(x,y): + return x.reshape(y.shape) +jt.Var.view_as = view_as + +def diag(x,diagonal=0): + assert x.ndim==1 or (x.ndim==2 and x.shape[0]==x.shape[1]) + d = diagonal if diagonal>=0 else -diagonal + d_str = f'+{diagonal}' if diagonal>=0 else f'{diagonal}' + + if x.ndim==1: + output_shape = (x.shape[0]+d,)*2 + return x.reindex(output_shape,[f'i1-{d}' if diagonal>=0 else f'i0-{d}'],overflow_conditions=[f'i0{d_str}!=i1']) + else: + output_shape = (x.shape[0]-d,) + return x.reindex(output_shape,[f'i0+{d}' if diagonal<=0 else 'i0',f'i0+{d}' if diagonal>=0 else 'i0']) + +jt.Var.diag = diag + + +def topk(input, k, dim=None, largest=True, sorted=True): + if input.numel()==0: + return jt.array([],dtype=input.dtype),jt.array([],dtype='int32') + if dim is None: + dim = -1 + if dim<0: + dim+=input.ndim + + index,values = jt.argsort(input,dim=dim,descending=largest) + dims = (slice(None),)*dim+(slice(0,k),) + indices = index[dims] + values = values[dims] + return values,indices + +jt.Var.topk = topk + +def kthvalue(input, k, dim=None, keepdim=False, keepdims=False): + keepdim = keepdim or keepdims + if dim is None: + dim = -1 + if dim<0: + dim+=input.ndim + index,values = jt.argsort(input,dim=dim) + dims = (slice(None),)*dim+(slice(k-1,k),) + indices = index[dims] + values = values[dims] + if not keepdim and indices.ndim>1: + indices = indices.squeeze(dim) + values = values.squeeze(dim) + return values,indices + +jt.Var.kthvalue = kthvalue + +def _prod(x,dim=0): + x = jt.log(x) + x = x.sum(dim=dim) + return jt.exp(x) + + +def numpy_cumsum(x, dim=None): + ''' cumsum implemented with numpy or cupy. + + This function should not be called directly. Instead, jittor.misc.cumsum is recommended. + ''' + def cumsum_forward(np, data): + a = data['inputs'][0] + b = data['outputs'][0] + np.cumsum(a, axis=dim, out=b) + + def cumsum_backward(np, data): + dout = data['dout'] + out = data['outputs'][0] + np.cumsum(np.flip(dout, dim), axis=dim, out=out) + np.copyto(out, np.flip(out, dim)) + if (dim == None): + dim = -1 + assert(dim >= -1 and dim < len(x.shape)) + return jt.numpy_code(x.shape, x.dtype, [x], cumsum_forward, [cumsum_backward]) + +def cub_cumsum(x, dim=None): + ''' cumsum implemented with CUB. + + This function should not be called directly. Instead, jittor.misc.cumsum is recommended. + ''' + if (dim == None): + dim = -1 + assert(dim >= -1 and dim < len(x.shape)) + shape = list(x.shape) + if (dim != -1 and dim != len(shape) - 1): + order = list(range(len(shape))) + order[dim], order[-1] = order[-1], order[dim] + shape[dim], shape[-1] = shape[-1], shape[dim] + x = x.permute(order) + if (len(shape) > 2): + x = x.reshape([-1, shape[-1]]) + x = jt.compile_extern.cub_ops.cub_cumsum(x) + if (len(shape) > 2): + x = x.reshape(shape) + if (dim != -1 and dim != len(shape) - 1): + x = x.permute(order) + return x + +def cumsum(x, dim=None): + ''' + Parameters: + ----------- + x: jt.var + dim: int + + Returns: + -------- + the cumulative sum in dim of x + ''' + if (dim == None): + dim = -1 + assert(dim >= -1 and dim < len(x.shape)) + if jt.flags.use_cuda: + return cub_cumsum(x, dim) + else: + return numpy_cumsum(x, dim) + +jt.Var.cumsum = cumsum + +def cumprod(x,dim=None): + x = jt.log(x) + x = cumsum(x,dim=dim) + return jt.exp(x) + +jt.Var.cumprod=cumprod + +def nms(dets,thresh): + ''' + dets jt.array [x1,y1,x2,y2,score] + x(:,0)->x1,x(:,1)->y1,x(:,2)->x2,x(:,3)->y2,x(:,4)->score + ''' + threshold = str(thresh) + order = jt.argsort(dets[:,4],descending=True)[0] + dets = dets[order] + s_1 = '(@x(j,2)-@x(j,0)+1)*(@x(j,3)-@x(j,1)+1)' + s_2 = '(@x(i,2)-@x(i,0)+1)*(@x(i,3)-@x(i,1)+1)' + s_inter_w = 'max((Tx)0,min(@x(j,2),@x(i,2))-max(@x(j,0),@x(i,0))+1)' + s_inter_h = 'max((Tx)0,min(@x(j,3),@x(i,3))-max(@x(j,1),@x(i,1))+1)' + s_inter = s_inter_h+'*'+s_inter_w + iou = s_inter + '/(' + s_1 +'+' + s_2 + '-' + s_inter + ')' + fail_cond = iou+'>'+threshold + selected = jt.candidate(dets, fail_cond) + return order[selected] + + +jt.Var.expand_as = jt.Var.broadcast_var + + +def index_fill_(x,dim,indexs,val): + r''' + Fills the elements of the input tensor with value val by selecting the indices in the order given in index. + + Args: + x - the input tensor + dim - dimension along which to index + index – indices of input tensor to fill in + val – the value to fill with + ''' + overflow_conditions = [f'i{dim}=={i}'for i in indexs] + indexs = [f'i{i}' for i in range(len(x.shape))] + return x.reindex(shape = x.shape,indexes = indexs,overflow_conditions=overflow_conditions,overflow_value=val) + +# def triu_(x,diagonal=0): +# r''' +# Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0. + +# The upper triangular part of the matrix is defined as the elements on and above the diagonal. + +# Args: +# x – the input tensor. + +# diagonal – the diagonal to consider,default =0 +# ''' +# l = len(x.shape) +# assert l>1 +# overflow_conditions=[f'i{l-1} 1: out += "..." + out += '\n' + if (build_by == 0): + for p in now['path']: + out += prefix2+p+'\n' + else: + out += prefix2+now['path'] + '\n' + if (len(now['children']) > 0): + out += prefix2 + tab + '| ' + '\n' + else: + out += prefix2 + '\n' + for i in range(len(now['children'])): + c = now['children'][i] + if i < len(now['children']) - 1: + prefix1_ = prefix2 + tab + '├─' + prefix2_ = prefix2 + tab + '| ' + else: + prefix1_ = prefix2 + tab + '└─' + prefix2_ = prefix2 + tab + ' ' + out += print_tree(c, max_memory_size, prefix1_, prefix2_, build_by) + return out + +def get_max_memory_treemap(build_by=0, do_print=True): + '''show treemap of max memory consumption + +Example:: + + net = jt.models.resnet18() + with jt.flag_scope(trace_py_var=3, profile_memory_enable=1): + imgs = jt.randn((1,3,224,224)) + net(imgs).sync() + jt.get_max_memory_treemap() + +Output:: + + | + ├─./python/jittor/test/test_memory_profiler.py:100(test_sample) + | [19.03 MB; 29.67%] + | ./python/jittor/test/test_memory_profiler.py:100 + | | + | └─./python/jittor/__init__.py:730(__call__) + | [19.03 MB; 29.67%] + | ./python/jittor/__init__.py:730 + | | + | └─./python/jittor/models/resnet.py:152(execute) + | [19.03 MB; 29.67%] + | ./python/jittor/models/resnet.py:152 + | | + | ├─./python/jittor/models/resnet.py:142(_forward_impl) + | | [6.13 MB; 9.55%] + | | ./python/jittor/models/resnet.py:142 + | | | + + + + ''' + div1 = "[!@#div1!@#]" + div2 = "[!@#div2!@#]" + div3 = "[!@#div3!@#]" + info = jt.get_max_memory_info() + + vars = [] + vars_ = info.split(div1) + max_memory_size = int(vars_[0]) + vars_ = vars_[1:] + for v_ in vars_: + v__ = v_.split(div2) + vinfo = v__[0].split("{")[0] + var = {'size':int(v__[1]), 'stack':[], 'cnt':1, "vinfo":vinfo} + v__ = v__[2:-1] + for s_ in v__: + s__ = s_.split(div3) + s = {'path':s__[0], 'name':s__[1], 'type':s__[2]} + var['stack'].append(s) + vars.append(var) + if (build_by == 0): # build tree by name + tree = {'name':'root', "children":[], 'size':0, 'cnt':1, 'path':[], 'type':'', 'vinfo':[]} + + def find_child(now, key): + for c in now['children']: + if (c['name'] == key): + return c + return None + for v in vars: + now = tree + now['size'] += v['size'] + now['cnt'] += v['cnt'] + now['vinfo'].append(v['vinfo']) + for s in v['stack']: + ch = find_child(now, s['name']) + if (ch is not None): + if (not s['path'] in ch['path']): + ch['path'].append(s['path']) + assert(ch['type']==s['type']) + now = ch + now['size'] += v['size'] + now['cnt'] += v['cnt'] + now['vinfo'].append(v['vinfo']) + else: + now_ = {'name':s['name'], "children":[], 'size':v['size'], 'cnt':v['cnt'], 'path':[s['path']], 'type':s['type'], 'vinfo':[v['vinfo']]} + now['children'].append(now_) + now = now_ + elif (build_by == 1): # build tree by path + tree = {'name':'root', "children":[], 'size':0, 'cnt':0, 'path':'_root_', 'type':'', 'vinfo':[]} + + def find_child(now, key): + for c in now['children']: + if (c['path'] == key): + return c + return None + for v in vars: + now = tree + now['size'] += v['size'] + now['cnt'] += v['cnt'] + now['vinfo'].append(v['vinfo']) + for s in v['stack']: + ch = find_child(now, s['path']) + if (ch is not None): + now = ch + now['size'] += v['size'] + now['cnt'] += v['cnt'] + now['vinfo'].append(v['vinfo']) + else: + now_ = {'name':s['name'], "children":[], 'size':v['size'], 'cnt':v['cnt'], 'path':s['path'], 'type':s['type'], 'vinfo':[v['vinfo']]} + now['children'].append(now_) + now = now_ + else: + assert(False) + + def sort_tree(now): + def takeSize(elem): + return elem['size'] + now['children'].sort(key=takeSize, reverse=True) + for c in now['children']: + sort_tree(c) + sort_tree(tree) + out = print_tree(tree, max_memory_size, '', '', build_by) + if (do_print): + print(out) + return tree, out + +def python_pass_wrapper(mod_func, args, kw): + import importlib + mod, func = mod_func.rsplit(".", 1) + mod = importlib.import_module(mod) + func = getattr(mod, func) + args = args + ("**kw",) + args = ",".join(args) + return eval(f"func({args})") + +def auto_parallel(n, src, block_num=1024, **kw): + """ + auto parallel(CPU and GPU) n-d for loop function like below: + + Before: + + void inner_func(int n0, int i0, int n1, int i1) { + ... + } + + for (int i0=0; i0= n*2, (args, n) + oargs = args[n*2:] + pargs = args[:n*2] + piargs = pargs[1::2] + pnargs = pargs[0::2] + pnargs2 = [ a.split()[-1] for a in pnargs ] + oargs2 = [ a.split()[-1] for a in oargs ] + entry_func_args_def = ",".join(["int tn"+str(i) for i in range(n)] + + pnargs + oargs) + entry_func_args = ",".join(["tn"+str(i) for i in range(n)] + + pnargs2 + oargs2) + tid_def = "" + tid_loop = "" + call_args = [] + for i in reversed(range(n)): + tid_def += f"\nauto tid{i} = tid & ((1<>tn{i};" + for i in range(n): + tid_loop += f"\nfor (int i{i}=tid{i}; i{i}<{pnargs2[i]}; i{i}+=tnum{i})" + call_args.append(pnargs2[i]) + call_args.append(f"i{i}") + call_args += oargs2 + call_args = ",".join(call_args) + xn = '\n' + new_src = f""" +#ifdef JIT_cuda +__device__ +#endif +{src.replace(func_name, func_name+"_inner", 1)} + +#ifdef JIT_cuda +__global__ static void {func_name}_entry({entry_func_args_def}) {{ + int tid = threadIdx.x + blockIdx.x * blockDim.x; + {tid_def} + {tid_loop} + {func_name}_inner({call_args}); +}} +#endif + +inline static void {func_name}({",".join(pargs+oargs)}) {{ +#ifdef JIT_cuda + int thread_num = 256*{block_num}; + {xn.join([f"int tn{i} = NanoVector::get_nbits(std::min(thread_num, {pnargs2[i]})) - 2;thread_num >>= tn{i};" for i in reversed(range(n))])} + thread_num = 1<<({"+".join([f"tn{i}" for i in range(n)])}); + int p1 = std::max(thread_num/{block_num}, 1); + int p2 = std::min(thread_num, {block_num}); + {func_name}_entry<<>>({entry_func_args}); +#else + {xn.join([f"for (int i{i}=0; i{i}<{pnargs2[i]}; i{i}++)" for i in range(n)])} + {func_name}_inner({call_args}); +#endif +}} +""" + return new_src + + +def numpy_cumprod(a, dim): + class CumprodFunc(jt.Function): + def forward_code(self, np, data): + a = data["inputs"][0] + b = data["outputs"][0] + out = np.cumprod(a, self.dim) + np.copyto(b, out) + + def backward_code(self, np, data): + a, b, dout = data["inputs"] + out = data["outputs"][0] + + sdim = a.shape[self.dim] + dim = (len(a.shape)+1)*[1] + dim[self.dim+1] = sdim + res = np.tile(np.expand_dims(b, self.dim+1), dim) + dout = np.tile(np.expand_dims(dout, self.dim+1), dim) + + dim[self.dim]=sdim + dim[self.dim+1]=1 + a = np.tile(np.expand_dims(a, self.dim), dim) + res = res/a + + mask = np.tril(np.ones((sdim, sdim))) + for i in range(self.dim): + mask = np.expand_dims(mask, 0) + for i in range(len(a.shape)-self.dim-2): + mask = np.expand_dims(mask, -1) + res = np.sum(mask*res*dout, self.dim) + + np.copyto(out, res) + + def execute(self, a, dim): + self.save_vars = a + self.dim = dim + self.res = jt.numpy_code( + a.shape, + a.dtype, + [a], + self.forward_code, + ) + return self.res + + def grad(self, grad_a): + a = self.save_vars + b = self.res + return jt.numpy_code( + a.shape, + a.dtype, + [a, b, grad_a], + self.backward_code, + ) + + func = CumprodFunc() + if dim<0: + dim+=len(a.shape) + return func(a, dim) + +def linspace(start, end, steps): + if steps > 1: + res = jt.index((steps,))[0] + res = res*float((end-start)/(steps-1))+start + else: + res = jt.array([start]) + return res + +def randperm(n, dtype="int32"): + key = jt.random((n,)) + index, _ = jt.argsort(key) + return index.cast(dtype) + +def set_global_seed(seed, different_seed_for_mpi=True): + ''' Sets the seeds of the random number generators of Python, numpy and jittor, + simultaneously. + + .. note:: + Jittor also gurantees each worker of jittor.dataset.Dataset to hold a different seed, + also gurantees each process hold a different seed which using mpi, + which is (global_seed ^ (worker_id*1167)) ^ 1234 + jt.rank * 2591 + ''' + if (different_seed_for_mpi): + seed = seed + jt.rank * 2591 + import random + random.seed(seed) + jt.set_seed(seed) + np.random.seed(seed) + try: + import cupy + cupy.random.seed(seed) + except: + pass + +import time +set_global_seed(int(time.time() * 1000000) % 100000007) + +def searchsorted(sorted, values, right=False): + """ + Find the indices from the innermost dimension of `sorted` for each `values`. + +Example:: + + sorted = jt.array([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]) + values = jt.array([[3, 6, 9], [3, 6, 9]]) + ret = jt.searchsorted(sorted, values) + assert (ret == [[1, 3, 4], [1, 2, 4]]).all(), ret + + ret = jt.searchsorted(sorted, values, right=True) + assert (ret == [[2, 3, 5], [1, 3, 4]]).all(), ret + + sorted_1d = jt.array([1, 3, 5, 7, 9]) + ret = jt.searchsorted(sorted_1d, values) + assert (ret == [[1, 3, 4], [1, 3, 4]]).all(), ret + + + """ + _searchsorted_header = f""" +namespace jittor {{ + +@python.jittor.auto_parallel(2) +inline static void searchsorted( + int batch_num, int batch_id, int value_num, int value_id, + int sorted_num, int batch_stride, + {sorted.dtype}* __restrict__ sort_p, {values.dtype}* __restrict__ value_p, + int32* __restrict__ index_p) {{ + int32 l = batch_id * batch_stride; + int32 r = l + sorted_num; + auto v = value_p[batch_id * value_num + value_id]; + while (lshape[in1->shape.size()-1]; + int sorted_num = in0->shape[in0->shape.size()-1]; + int32 batch_num = in0->num / sorted_num; + int32 batch_num2 = in1->num / value_num; + int32 batch_stride = batch_num == 1 ? 0 : sorted_num; + CHECK(batch_num == batch_num2 || batch_num == 1); + + searchsorted(batch_num2, 0, value_num, 0, sorted_num, batch_stride, in0_p, in1_p, out0_p); +""" + return jt.code(values.shape, "int32", [sorted, values], + cpu_header=_searchsorted_header, + cpu_src=_searchsorted_src, + cuda_header=_searchsorted_header, + cuda_src=_searchsorted_src) + + +def scatter(x:jt.Var, dim:int, index:jt.Var, src:jt.Var, reduce='void'): + ''' if x is a 3-D array, rewrite x like: + + self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 + self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 + self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2 + +Parameters:: + + * x (jt.Var) – input array + * dim (int) – the axis along which to index + * index (jt.Var) – the indices of elements to scatter, can be either empty or of the same dimensionality as src. When empty, the operation returns self unchanged. + * src (jt.Var) – the source element(s) to scatter. + * reduce (str, optional) – reduction operation to apply, can be either 'add' or 'multiply'. + +Example:: + + src = jt.arange(1, 11).reshape((2, 5)) + index = jt.array([[0, 1, 2, 0]]) + x = jt.zeros((3, 5), dtype=src.dtype).scatter_(0, index, src) + assert (x.data == + [[1, 0, 0, 4, 0], + [0, 2, 0, 0, 0], + [0, 0, 3, 0, 0]]).all() + index = jt.array([[0, 1, 2], [0, 1, 4]]) + x = jt.zeros((3, 5), dtype=src.dtype).scatter_(1, index, src) + assert (x.data == + [[1, 2, 3, 0, 0], + [6, 7, 0, 0, 8], + [0, 0, 0, 0, 0]]).all() + x = jt.full((2, 4), 2.).scatter_(1, jt.array([[2], [3]]), + jt.array(1.23), reduce='multiply') + assert np.allclose(x.data, + [[2.0000, 2.0000, 2.4600, 2.0000], + [2.0000, 2.0000, 2.0000, 2.4600]]), x + x = jt.full((2, 4), 2.).scatter_(1, jt.array([[2], [3]]), + jt.array(1.23), reduce='add') + assert np.allclose(x.data, + [[2.0000, 2.0000, 3.2300, 2.0000], + [2.0000, 2.0000, 2.0000, 3.2300]]) + + ''' + shape = index.shape + if src.shape != shape and src.numel() != 1: + src = src[tuple( slice(None,s) for s in shape )] + indexes = [ f'i{i}' for i in range(len(shape)) ] + indexes[dim] = index + return x.setitem(tuple(indexes), src, reduce) + +def scatter_(x, dim, index, src, reduce='void'): + return x.assign(x.scatter(dim, index, src, reduce)) + +jt.Var.scatter = scatter +jt.Var.scatter_ = scatter_ + +def gather(x, dim, index): + ''' if x is a 3-D array, reindex x like: + + out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 + out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 + out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 + + +Parameters:: + + * x (jt.Var) – the source array + * dim (int) – the axis along which to index + * index (jt.Var) – the indices of elements to gather + +Example:: + + t = jt.array([[1, 2], [3, 4]]) + data = t.gather(1, jt.array([[0, 0], [1, 0]])) + assert (data.data == [[ 1, 1], [ 4, 3]]).all() + data = t.gather(0, jt.array([[0, 0], [1, 0]])) + assert (data.data == [[ 1, 2], [ 3, 2]]).all() + + ''' + shape = index.shape + indexes = [ f'i{i}' for i in range(len(shape)) ] + indexes[dim] = index + return x.getitem(tuple(indexes)) + +jt.Var.gather = gather + +def roll(x, shifts, dims=None): + '''Roll the tensor along the given dimension(s). + +Parameters:: + + * x (jt.Var) – the source array + * shifts (int or tuple) – shift offset of dims + * dims (int or tuple) – shift dims + +Examples:: + + x = jt.array([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2) + y = x.roll(1, 0) + assert (y.numpy() == [[7,8],[1,2],[3,4],[5,6]]).all() + y = x.roll(-1, 0) + assert (y.numpy() == [[3,4],[5,6],[7,8],[1,2]]).all() + y = x.roll(shifts=(2, 1), dims=(0, 1)) + assert (y.numpy() == [[6,5],[8,7],[2,1],[4,3]]).all() + + ''' + if isinstance(shifts, int): + shifts = (shifts,) + if dims is None: + dims = tuple(range(len(shifts))) + elif isinstance(dims, int): + dims = (dims,) + assert len(dims) == len(shifts) + ids = [ f'i{i}' for i in range(x.ndim) ] + for i in range(len(dims)): + shift = shifts[i] + d = dims[i] + size = x.shape[d] + shift = shift % size + if shift<0: shift += size + ids[d] = f'(i{d}<{shift}?i{d}+{size-shift}:(i{d}-{shift}))' + return x.reindex(x.shape, ids) + +jt.Var.roll = roll + +def safe_log(x): + return jt.safe_clip(x, 1e-30, 1e30).log() +jt.Var.safe_log = safe_log + +class _CTCLossFunction(jt.Function): + def execute(self, log_probs, targets, input_lengths, target_lengths, blank=0, zero_infinity=False): + self.blank = blank + T, N, C = log_probs.shape + _N, S = targets.shape + assert _N == N + log_alpha = jt.full([T,N,S*2+1], -1e30) + result = jt.empty((N,)) + jt.code([log_probs, targets, input_lengths, target_lengths], [log_alpha, result], cpu_src=f""" + constexpr int blank = {blank}; + for (int i=0; i1 && k%2) target_2 = @in1(i,k/2-1); + out_type l1 = @out0(j-1,i,k); + out_type l2 = -1e30; + if (k>0) l2 = @out0(j-1,i,k-1); + out_type l3 = -1e30; + if (k>1 && target_2 != target) + l3 = @out0(j-1,i,k-2); + out_type m = std::max(l1, std::max(l2, l3)); + @out0(j,i,k) = std::log( + std::exp(l1-m) + + std::exp(l2-m) + + std::exp(l3-m) + ) + m + @in0(j,i,target); + }} + if (input_len==0) + @out1(i) = @out0(0,i,0); + else {{ + out_type l1 = @out0(input_len-1, i, target_len*2); + out_type l2 = -1e30; + if (target_len) + l2 = @out0(input_len-1, i, target_len*2-1); + out_type m = std::max(l1, l2); + out_type log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m; + @out1(i) = -log_likelihood; + }} + }} + """, cuda_src=f""" + __global__ void kernel(@ARGS_DEF) {{ + @PRECALC; + constexpr int blank = {blank}; + for (int i=blockIdx.x; i=target_len*2+1) + continue; + int target = k%2 ? @in1(i,k/2) : blank; + int target_2 = target; + if (k>1 && k%2) target_2 = @in1(i,k/2-1); + out_type l1 = @out0(j-1,i,k); + out_type l2 = -1e30; + if (k>0) l2 = @out0(j-1,i,k-1); + out_type l3 = -1e30; + if (k>1 && target_2 != target) + l3 = @out0(j-1,i,k-2); + out_type m = ::max(l1, ::max(l2, l3)); + @out0(j,i,k) = ::log( + ::exp(l1-m) + + ::exp(l2-m) + + ::exp(l3-m) + ) + m + @in0(j,i,target); + }} + __syncthreads(); + if (input_len==0) + @out1(i) = @out0(0,i,0); + else {{ + out_type l1 = @out0(input_len-1, i, target_len*2); + out_type l2 = -1e30; + if (target_len) + l2 = @out0(input_len-1, i, target_len*2-1); + out_type m = ::max(l1, l2); + out_type log_likelihood = ::log(::exp(l1-m)+::exp(l2-m))+m; + @out1(i) = -log_likelihood; + }} + }} + }} + kernel<<>>(@ARGS); + """) + self.saved_var = [log_probs, targets, input_lengths, target_lengths, log_alpha, result] + return result + + def grad(self, dout): + blank = self.blank + inputs = self.saved_var + [dout] + dlog_probs = jt.zeros_like(inputs[0]) + dlog_alpha = jt.zeros_like(inputs[4]) + jt.code(inputs, [dlog_probs, dlog_alpha], cpu_src=f""" + constexpr int blank = {blank}; + for (int i=0; i read in6 + // out1(i) = out0(0,i,0); + @out1(0,i,0) = @in6(i); + else {{ + out_type l1 = @in4(input_len-1, i, target_len*2); + out_type l2 = -1e30; + if (target_len) + l2 = @in4(input_len-1, i, target_len*2-1); + out_type m = std::max(l1, l2); + // out_type log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m; + // out1(i) = -log_likelihood; + out_type l1_exp = std::exp(l1-m); + out_type l2_exp = std::exp(l2-m); + out_type sumexp = l1_exp + l2_exp; + + out_type dlog_likelihood = -@in6(i); + out_type dl1 = dlog_likelihood * l1_exp / sumexp; + out_type dl2 = dlog_likelihood * l2_exp / sumexp; + + @out1(input_len-1, i, target_len*2) = dl1; + if (target_len) + @out1(input_len-1, i, target_len*2-1) = dl2; + }} + for (int j=input_len-1; j>0; j--) + for (int k=0; k1 && k%2) target_2 = @in1(i,k/2-1); + out_type l1 = @in4(j-1,i,k); + out_type l2 = -1e30; + if (k>0) l2 = @in4(j-1,i,k-1); + out_type l3 = -1e30; + if (k>1 && target_2 != target) + l3 = @in4(j-1,i,k-2); + out_type m = std::max(l1, std::max(l2, l3)); + out_type l1_exp = std::exp(l1-m); + out_type l2_exp = std::exp(l2-m); + out_type l3_exp = std::exp(l3-m); + out_type sumexp = l1_exp + l2_exp + l3_exp; + out_type dalpha = @out1(j,i,k); + + @out0(j,i,target) += dalpha; + + @out1(j-1,i,k) += dalpha * l1_exp / sumexp; + if (k>0) + @out1(j-1,i,k-1) += dalpha * l2_exp / sumexp; + if (k>1 && target_2 != target) + @out1(j-1,i,k-2) += dalpha * l3_exp / sumexp; + }} + // read in0 -> white out0 + // write out0 ->read out1 + // out0(0,i,0) = in0(0,i,blank); + @out0(0,i,blank) += @out1(0,i,0); + if (target_len) + @out0(0,i,@in1(i,0)) += @out1(0,i,1); + }} + """, cuda_src=f""" + __global__ void kernel(@ARGS_DEF) {{ + @PRECALC; + constexpr int blank = {blank}; + for (int i=blockIdx.x; i read in6 + // out1(i) = out0(0,i,0); + @out1(0,i,0) = @in6(i); + else {{ + out_type l1 = @in4(input_len-1, i, target_len*2); + out_type l2 = -1e30; + if (target_len) + l2 = @in4(input_len-1, i, target_len*2-1); + out_type m = ::max(l1, l2); + // out_type log_likelihood = ::log(::exp(l1-m)+::exp(l2-m))+m; + // out1(i) = -log_likelihood; + out_type l1_exp = ::exp(l1-m); + out_type l2_exp = ::exp(l2-m); + out_type sumexp = l1_exp + l2_exp; + + out_type dlog_likelihood = -@in6(i); + out_type dl1 = dlog_likelihood * l1_exp / sumexp; + out_type dl2 = dlog_likelihood * l2_exp / sumexp; + + @out1(input_len-1, i, target_len*2) = dl1; + if (target_len) + @out1(input_len-1, i, target_len*2-1) = dl2; + }} + for (int j=input_len-1; j>0; j--) + for (int k=threadIdx.x; k-threadIdx.x=target_len*2+1) + continue; + int target = k%2 ? @in1(i,k/2) : blank; + int target_2 = target; + if (k>1 && k%2) target_2 = @in1(i,k/2-1); + out_type l1 = @in4(j-1,i,k); + out_type l2 = -1e30; + if (k>0) l2 = @in4(j-1,i,k-1); + out_type l3 = -1e30; + if (k>1 && target_2 != target) + l3 = @in4(j-1,i,k-2); + out_type m = ::max(l1, ::max(l2, l3)); + out_type l1_exp = ::exp(l1-m); + out_type l2_exp = ::exp(l2-m); + out_type l3_exp = ::exp(l3-m); + out_type sumexp = l1_exp + l2_exp + l3_exp; + out_type dalpha = @out1(j,i,k); + + atomicAdd(&@out0(j,i,target), dalpha); + + atomicAdd(&@out1(j-1,i,k), dalpha * l1_exp / sumexp); + if (k>0) + atomicAdd(&@out1(j-1,i,k-1), dalpha * l2_exp / sumexp); + if (k>1 && target_2 != target) + atomicAdd(&@out1(j-1,i,k-2), dalpha * l3_exp / sumexp); + }} + // read in0 -> white out0 + // write out0 ->read out1 + // out0(0,i,0) = in0(0,i,blank); + __syncthreads(); + if (threadIdx.x==0) {{ + @out0(0,i,blank) += @out1(0,i,0); + if (target_len) + @out0(0,i,@in1(i,0)) += @out1(0,i,1); + }} + }} + }} + kernel<<>>(@ARGS); + """) + return (dlog_probs,) + + +def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False): + '''The Connectionist Temporal Classification loss. + + + Reference: + A. Graves et al.: Connectionist Temporal Classification: + Labelling Unsegmented Sequence Data with Recurrent Neural Networks: + https://www.cs.toronto.edu/~graves/icml_2006.pdf + + Input: + + log_probs: shape is [T, N, C], T is the sequence length, N is the batch size, C is the class number. + targets: shape is [N, S], N is the batch size, S is the target sequence length, element should between [0,C). + input_lengths: shape is [N], which represents the length of input, element should between [0,T]. + target_lengths: shape is N, which represents the length of target, element should between [0,S]. + blank (int, default 0): blank label index + reduction (string): reduce batch loss, + if reduction is none, it will return (N,) array, + if reduction is mean or sum, it will return one scalar + zero_infinity (bool, default False): + zero_infinity for grad + + Example: + + import jittor as jt + T = 50 # Input sequence length + C = 20 # Number of classes (including blank) + N = 16 # Batch size + S = 30 # Target sequence length of longest target in batch (padding length) + S_min = 10 # Minimum target length, for demonstration purposes + + input = jt.randn(T, N, C).log_softmax(2) + # Initialize random batch of targets (0 = blank, 1:C = classes) + target = jt.randint(low=1, high=C, shape=(N, S), dtype=jt.int) + + input_lengths = jt.full((N,), T, dtype=jt.int) + target_lengths = jt.randint(low=S_min, high=S+1, shape=(N,), dtype=jt.int) + loss = jt.ctc_loss(input, target, input_lengths, target_lengths) + + dinput = jt.grad(loss, input) + + ''' + result = _CTCLossFunction.apply(log_probs, targets, input_lengths, target_lengths, blank, zero_infinity) + if reduction=="mean": + return result.mean() + elif reduction=="sum": + return result.sum() + assert reduction=="none" + return result + + +class CTCLoss(jt.Module): + '''The Connectionist Temporal Classification loss. + + + Reference: + A. Graves et al.: Connectionist Temporal Classification: + Labelling Unsegmented Sequence Data with Recurrent Neural Networks: + https://www.cs.toronto.edu/~graves/icml_2006.pdf + + + Args: + + blank (int, default 0): blank label index + reduction (string): reduce batch loss, + if reduction is none, it will return (N,) array, + if reduction is mean or sum, it will return one scalar + zero_infinity (bool, default False): + zero_infinity for grad + + Input: + + log_probs: shape is [T, N, C], T is the sequence length, N is the batch size, C is the class number. + targets: shape is [N, S], N is the batch size, S is the target sequence length, element should between [0,C). + input_lengths: shape is [N], which represents the length of input, element should between [0,T]. + target_lengths: shape is N, which represents the length of target, element should between [0,S]. + + Example: + + import jittor as jt + T = 50 # Input sequence length + C = 20 # Number of classes (including blank) + N = 16 # Batch size + S = 30 # Target sequence length of longest target in batch (padding length) + S_min = 10 # Minimum target length, for demonstration purposes + + input = jt.randn(T, N, C).log_softmax(2) + # Initialize random batch of targets (0 = blank, 1:C = classes) + target = jt.randint(low=1, high=C, shape=(N, S), dtype=jt.int) + + input_lengths = jt.full((N,), T, dtype=jt.int) + target_lengths = jt.randint(low=S_min, high=S+1, shape=(N,), dtype=jt.int) + ctc_loss = jt.CTCLoss() + loss = ctc_loss(input, target, input_lengths, target_lengths) + + dinput = jt.grad(loss, input) + + ''' + def __init__(self, blank=0, reduction='mean', zero_infinity=False): + self.blank = blank + self.reduction = reduction + self.zero_infinity = zero_infinity + + def execute(self, log_probs, targets, input_lengths, target_lengths): + return ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction, self.zero_infinity) + +def _simple_for(x, func): + with jt.flag_scope(compile_options={"FLAGS: -O2 ":1}): + src = f''' + __inline_static__ + @python.jittor.auto_parallel(1) + void kernel(int n0, int i0, in0_type* _x, out0_type* y) {{ + using namespace std; + auto x = _x[i0]; + y[i0] = {func}; + }} + kernel(in0->num, 0, in0_p, out0_p); + ''' + return jt.code(x.shape, "bool", [x], cpu_src=src, cuda_src=src) + +def isnan(x): return _simple_for(x, "isnan(float(x))") +jt.Var.isnan = isnan +def isfinite(x): return _simple_for(x, "!isnan(float(x)) && !isinf(float(x))") +jt.Var.isfinite = isfinite +def isinf(x): return _simple_for(x, "isinf(float(x))") +jt.Var.isinf = isinf +def isneginf(x): return _simple_for(x, "x<0 && isinf(float(x))") +jt.Var.isneginf = isneginf +def isposinf(x): return _simple_for(x, "x>0 && isinf(float(x))") +jt.Var.isposinf = isposinf + +# fake torch interface +def contiguous(x): return x.clone() +jt.Var.contiguous = contiguous +def cpu(x): return x.clone() +jt.Var.cpu = cpu +def to(x, *args, **kargs): + args += tuple(kargs.values()) + if len(args) >= 1: + s = args[0] + if isinstance(s, jt.NanoString) or callable(s): + return x.cast(s) + s = str(s) + if "cuda" in s: + jt.flags.use_cuda = 1 + elif "cpu" in s: + jt.flags.use_cuda = 0 + return x.clone() +jt.Var.to = to + +def rsqrt(x): + return 1/jt.sqrt(x) +jt.Var.rsqrt = rsqrt + +def from_torch(x): + ''' + Convert torch Tensor to Jittor Var + ''' + return jt.Var(x.cpu().numpy()) + +def triu(input: jt.Var, diagonal:int=0) -> jt.Var: + ''' Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0. + + :param input: the input tensor. + :param diagonal: the diagonal to consider(int). + + Example:: + + a = jt.ones(3, 3) + b = jt.triu(a) + assert jt.all_equal(b, [[1,1,1],[0,1,1],[0,0,1]]) + + b = jt.triu(a, diagonal=1) + assert jt.all_equal(b, [[0,1,1],[0,0,1],[0,0,0]]) + + b = jt.triu(a, diagonal=-1) + assert jt.all_equal(b, [[1,1,1],[1,1,1],[0,1,1]]) + + ''' + index = input.index() + mask = index[-2] <= index[-1] - diagonal + return jt.ternary(mask, input, jt.zeros_like(input)) +jt.Var.triu = triu +jt.Var.triu_ = lambda x,diagonal=0: x.assign(x.triu(diagonal)) + +def tril(input: jt.Var, diagonal:int=0) -> jt.Var: + ''' Returns the lower triangular part of a matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0. + + :param input: the input tensor. + :param diagonal: the diagonal to consider(int). + + Example:: + + a = jt.ones(3, 3) + b = jt.tril(a) + assert jt.all_equal(b, [[1,0,0],[1,1,0],[1,1,1]]) + + b = jt.tril(a, diagonal=1) + assert jt.all_equal(b, [[1,1,0],[1,1,1],[1,1,1]]) + + b = jt.tril(a, diagonal=-1) + assert jt.all_equal(b, [[0,0,0],[1,0,0],[1,1,0]]) + + ''' + index = input.index() + mask = index[-2] >= index[-1] - diagonal + return jt.ternary(mask, input, jt.zeros_like(input)) +jt.Var.tril = tril +jt.Var.tril_ = lambda x: x.assign(x.tril()) + +def all_equal(a: jt.Var, b: jt.Var) -> bool: + return (a == b).all().item() +jt.all_equal = all_equal + +def _to_float(x: jt.Var) -> jt.Var: + if x.dtype != "float64": x = x.float() + return x +jt.Var._to_float = _to_float + +def index_select(x: jt.Var, dim:int, index: jt.Var) -> jt.Var: + '''Returns a new var which indexes the x var along dimension dim using the entries in index. + +The returned var has the same number of dimensions as the original var (x). The dimth dimension has the same size as the length of index; other dimensions have the same size as in the original tensor. + + :param x: the input tensor. + :param dim: the dimension to index. + :param index: the 1-D tensor containing the indices to index. + + Example:: + + x = jt.randn(3, 4) + indices = torch.tensor([2, 1]) + y = jt.index_select(x, 0, indices) + assert jt.all_equal(y, x[indices]) + y = jt.index_select(x, 1, indices) + assert jt.all_equal(y, x[:, indices]) + + + ''' + return x.getitem(((slice(None),)*dim)+(index,)) +jt.index_select = index_select + +def multinomial(weights: jt.Var, num_samples: int, replacement: bool=False) -> jt.Var: + ''' Returns a var where each row contains num_samples indices sampled from the multinomial probability distribution located in the corresponding row of input weights. + + :param weights: the input probability. + :param num_samples: number of samples. + :param replacement: whether to draw with replacement or not. + + + Example:: + + weights = jt.float32([0, 10, 3, 0]) + x = jt.multinomial(weights, 2) + assert jt.all_equal(x, [1, 2]) or jt.all_equal(x, [2, 1]) + x = jt.multinomial(weights, 4, replacement=True) + assert x.shape == (4, ) + + weights = jt.float32([[0,0,2],[0,1,0], [0.5,0,0]]) + x = jt.multinomial(weights, 1) + assert jt.all_equal(x, [[2],[1],[0]]) + + ''' + if replacement: + cum_probs = jt.cumsum(weights)[..., None, :] + cum_probs_l = cum_probs[..., :-1] + cum_probs_r = cum_probs[..., 1:] + shape = weights.shape[:-1] + (num_samples, 1) + rand = jt.rand(shape) * cum_probs[..., :1, -1:] + one_hot = jt.logical_and(cum_probs_l < rand, rand <= cum_probs_r) + index = one_hot.index(one_hot.ndim - 1) + 1 + return (one_hot * index).sum(-1) + else: + # A-Res algorithm + # Pavlos S. Efraimidis and Paul G. Spirakis, 2006, Weighted random sampling with a reservoir + assert num_samples <= weights.shape[-1], "num_samples larger than the input" + # prevent rand generate 1, 1^inf = 1, with override other result + a = jt.rand(weights.shape).minimum(0.999999) + rand = a ** (1/weights) + _, indices = jt.topk(rand, num_samples) + return indices + +def histc(input, bins, min=0., max=0.): + ''' Return the histogram of the input N-d array. + + :param input: the input array. + :param bins: number of bins. + :param min: min of the range. + :param max: max of the range. + + Example:: + + inputs = jt.randn((40,40)) + joup = jt.histc(x, bins=10) + + ''' + if min == 0 and max == 0: + min, max = input.min(), input.max() + assert min < max + if bins <= 0: + raise RuntimeError(f"bins must be > 0, but got {bins}") + bin_length = (max - min) / bins + histc = jt.floor((input[jt.logical_and(input >= min, input <= max)] - min) / bin_length).int().reshape(-1) + hist = jt.ones_like(histc).float().reindex_reduce("add", [bins,], ["@e0(i0)"], extras=[histc]) + if hist.sum() != histc.shape[0]: + hist[-1] += 1 + return hist + +def peek_s(x): + if isinstance(x, jt.Var): + return x.peek() + if isinstance(x, (list, tuple)): + res = "[" + for a in x: + res += peek_s(a) + res += ", " + res += "]" + return res + if isinstance(x, dict): + res = "{" + for a in x: + res += a + res += ":" + res += peek_s(x[a]) + res += ", " + res += "}" + return res + if isinstance(x, str): + return x + return x.__class__.__name__ + +def peek(x): + print(peek_s(x)) + +class Finfo: + pass +bfloat16_finfo = Finfo() +bfloat16_finfo.min = -1e38 +bfloat16_finfo.max = 1e38 + +def finfo(dtype): + if dtype == "bfloat16": + return bfloat16_finfo + return np.finfo(str(dtype).split('.')[-1]) + +def iinfo(dtype): + return np.iinfo(str(dtype).split('.')[-1]) + + +def index_select(input,dim,indices): + return input[(None,)*dim+(indices,)] + +jt.Var.index_select = index_select + +def cuda(x): + jt.flags.use_cuda = 1 + return x +jt.Var.cuda = cuda +jt.Var.npu = cuda + +def expm1(x): + return jt.exp(x) - 1 + + +def isin(elements, test_elements, assume_unique=False, invert=False): + + elements = elements.unsqueeze(-1) + test_elements = test_elements.unsqueeze(0) + comparison = elements == test_elements + result = comparison.any(dim=-1) + + if invert: + result = jt.logical_not(result) + + return result \ No newline at end of file diff --git a/python/jittor/models/__init__.py b/python/jittor/models/__init__.py new file mode 100644 index 00000000..253bd296 --- /dev/null +++ b/python/jittor/models/__init__.py @@ -0,0 +1,21 @@ +from . import resnet +from .resnet import * +from . import vgg +from .vgg import * +from . import alexnet +from .alexnet import * +from . import squeezenet +from .squeezenet import * +from . import inception +from .inception import * +from . import googlenet +from .googlenet import * +from . import mobilenet +from .mobilenet import * +from . import mnasnet +from .mnasnet import * +from . import shufflenetv2 +from .shufflenetv2 import * +from .res2net import res2net50, res2net101 +from . import densenet +from .densenet import * diff --git a/python/jittor/models/alexnet.py b/python/jittor/models/alexnet.py new file mode 100644 index 00000000..804c7697 --- /dev/null +++ b/python/jittor/models/alexnet.py @@ -0,0 +1,68 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Wenyang Zhou <576825820@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +# This model is generated by pytorch converter. +import jittor as jt +import jittor.nn as nn + +__all__ = ['AlexNet', 'alexnet'] + +class AlexNet(nn.Module): + """ AlexNet model architecture. + + Args: + + * num_classes: Number of classes. Default: 1000. + + Example:: + + model = jittor.models.AlexNet(500) + x = jittor.random([10,3,224,224]) + y = model(x) # [10, 500] + + """ + + def __init__(self, num_classes=1000): + super(AlexNet, self).__init__() + self.features = nn.Sequential( + nn.Conv(3, 64, kernel_size=11, stride=4, padding=2), + nn.Relu(), + nn.Pool(kernel_size=3, stride=2, op='maximum'), + nn.Conv(64, 192, kernel_size=5, padding=2), + nn.Relu(), nn.Pool(kernel_size=3, stride=2, op='maximum'), + nn.Conv(192, 384, kernel_size=3, padding=1), + nn.Relu(), + nn.Conv(384, 256, kernel_size=3, padding=1), + nn.Relu(), + nn.Conv(256, 256, kernel_size=3, padding=1), + nn.Relu(), + nn.Pool(kernel_size=3, stride=2, op='maximum') + ) + self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) + self.classifier = nn.Sequential( + nn.Dropout(), + nn.Linear(((256 * 6) * 6), 4096), + nn.Relu(), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.Relu(), + nn.Linear(4096, num_classes) + ) + + def execute(self, x): + x = self.features(x) + x = self.avgpool(x) + x = jt.reshape(x, (x.shape[0], (- 1))) + x = self.classifier(x) + return x + +def alexnet(pretrained=False, **kwargs): + model = AlexNet(**kwargs) + if pretrained: model.load("jittorhub://alexnet.pkl") + return model diff --git a/python/jittor/models/densenet.py b/python/jittor/models/densenet.py new file mode 100644 index 00000000..3ae36752 --- /dev/null +++ b/python/jittor/models/densenet.py @@ -0,0 +1,145 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +# This model is generated by pytorch converter. +import jittor as jt +from jittor import nn +from jittor import init +from collections import OrderedDict + + +def densenet121(pretrained=False, **kwargs): + '''Densenet-121 model from + `"Densely Connected Convolutional Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + ''' + model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs) + if pretrained: model.load("jittorhub://densenet121.pkl") + return model + +def densenet161(pretrained=False, **kwargs): + '''Densenet-161 model from + `"Densely Connected Convolutional Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + ''' + model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs) + if pretrained: model.load("jittorhub://densenet161.pkl") + return model + +def densenet169(pretrained=False, **kwargs): + '''Densenet-169 model from + `"Densely Connected Convolutional Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + ''' + model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs) + if pretrained: model.load("jittorhub://densenet169.pkl") + return model + +def densenet201(pretrained=False, **kwargs): + '''Densenet-201 model from + `"Densely Connected Convolutional Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + ''' + model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs) + if pretrained: model.load("jittorhub://densenet201.pkl") + return model + + +class _DenseLayer(nn.Sequential): + + def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): + super(_DenseLayer, self).__init__() + self.add_module('norm1', nn.BatchNorm(num_input_features)) + self.add_module('relu1', nn.ReLU()) + self.add_module('conv1', nn.Conv(num_input_features, (bn_size * growth_rate), 1, stride=1, bias=False)) + self.add_module('norm2', nn.BatchNorm((bn_size * growth_rate))) + self.add_module('relu2', nn.ReLU()) + self.add_module('conv2', nn.Conv((bn_size * growth_rate), growth_rate, 3, stride=1, padding=1, bias=False)) + self.drop_rate = drop_rate + self.drop = nn.Dropout(self.drop_rate) + + def execute(self, x): + new_features = super(_DenseLayer, self).execute(x) + if (self.drop_rate > 0): + new_features = self.drop(new_features) + return jt.concat([x, new_features], dim=1) + +class _DenseBlock(nn.Sequential): + + def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): + super(_DenseBlock, self).__init__() + for i in range(num_layers): + layer = _DenseLayer((num_input_features + (i * growth_rate)), growth_rate, bn_size, drop_rate) + self.add_module('denselayer%d' % (i + 1), layer) + +class _Transition(nn.Sequential): + + def __init__(self, num_input_features, num_output_features): + super(_Transition, self).__init__() + self.add_module('norm', nn.BatchNorm(num_input_features)) + self.add_module('relu', nn.ReLU()) + self.add_module('conv', nn.Conv(num_input_features, num_output_features, 1, stride=1, bias=False)) + self.add_module('pool', nn.Pool(2, stride=2, op='mean')) + +class DenseNet(nn.Module): + '''Densenet-BC model class, based on + `"Densely Connected Convolutional Networks" `_ + + Args: + growth_rate (int) - how many filters to add each layer (`k` in paper) + block_config (list of 4 ints) - how many layers in each pooling block + num_init_features (int) - the number of filters to learn in the first convolution layer + bn_size (int) - multiplicative factor for number of bottle neck layers + (i.e. bn_size * k features in the bottleneck layer) + drop_rate (float) - dropout rate after each dense layer + num_classes (int) - number of classification classes + ''' + + def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): + super(DenseNet, self).__init__() + self.features = nn.Sequential(OrderedDict([ + ('conv0', nn.Conv(3, num_init_features, 7, stride=2, padding=3, bias=False)), + ('norm0', nn.BatchNorm(num_init_features)), + ('relu0', nn.ReLU()), + ('pool0', nn.Pool(3, stride=2, padding=1, op='maximum')), + ])) + num_features = num_init_features + for (i, num_layers) in enumerate(block_config): + block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) + self.features.add_module('denseblock%d' % (i + 1), block) + num_features = (num_features + (num_layers * growth_rate)) + if (i != (len(block_config) - 1)): + trans = _Transition(num_input_features=num_features, num_output_features=(num_features // 2)) + self.features.add_module('transition%d' % (i + 1), trans) + num_features = (num_features // 2) + self.features.add_module('norm5', nn.BatchNorm(num_features)) + self.classifier = nn.Linear(num_features, num_classes) + for m in self.modules(): + if isinstance(m, nn.Conv): + nn.init.invariant_uniform_(m.weight) + elif isinstance(m, nn.BatchNorm): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.constant_(m.bias, 0) + + def execute(self, x): + features = self.features(x) + out = nn.relu(features) + out = out.mean([2,3]) + out = self.classifier(out) + return out diff --git a/python/jittor/models/googlenet.py b/python/jittor/models/googlenet.py new file mode 100644 index 00000000..586a3829 --- /dev/null +++ b/python/jittor/models/googlenet.py @@ -0,0 +1,155 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Wenyang Zhou <576825820@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +# This model is generated by pytorch converter. +import jittor as jt +from jittor import nn + +__all__ = ['GoogLeNet', 'googlenet'] + +def googlenet(pretrained=False, **kwargs): + model = GoogLeNet(**kwargs) + if pretrained: model.load("jittorhub://googlenet.pkl") + return model + +class GoogLeNet(nn.Module): + """ GoogLeNet model architecture. + + Args: + + * num_classes: Number of classes. Default: 1000. + * aux_logits: If True, add an auxiliary branch that can improve training. Default: True + * init_weights: Defualt: True. + * blocks: List of three blocks, [conv_block, inception_block, inception_aux_block]. If None, will use [BasicConv2d, Inception, InceptionAux] instead. Default: None. + """ + + def __init__(self, num_classes=1000, aux_logits=True, init_weights=True, blocks=None): + super(GoogLeNet, self).__init__() + if (blocks is None): + blocks = [BasicConv2d, Inception, InceptionAux] + assert (len(blocks) == 3) + conv_block = blocks[0] + inception_block = blocks[1] + inception_aux_block = blocks[2] + self.aux_logits = aux_logits + self.conv1 = conv_block(3, 64, kernel_size=7, stride=2, padding=3) + self.maxpool1 = nn.Pool(3, stride=2, ceil_mode=True, op='maximum') + self.conv2 = conv_block(64, 64, kernel_size=1) + self.conv3 = conv_block(64, 192, kernel_size=3, padding=1) + self.maxpool2 = nn.Pool(3, stride=2, ceil_mode=True, op='maximum') + self.inception3a = inception_block(192, 64, 96, 128, 16, 32, 32) + self.inception3b = inception_block(256, 128, 128, 192, 32, 96, 64) + self.maxpool3 = nn.Pool(3, stride=2, ceil_mode=True, op='maximum') + self.inception4a = inception_block(480, 192, 96, 208, 16, 48, 64) + self.inception4b = inception_block(512, 160, 112, 224, 24, 64, 64) + self.inception4c = inception_block(512, 128, 128, 256, 24, 64, 64) + self.inception4d = inception_block(512, 112, 144, 288, 32, 64, 64) + self.inception4e = inception_block(528, 256, 160, 320, 32, 128, 128) + self.maxpool4 = nn.Pool(2, stride=2, ceil_mode=True, op='maximum') + self.inception5a = inception_block(832, 256, 160, 320, 32, 128, 128) + self.inception5b = inception_block(832, 384, 192, 384, 48, 128, 128) + if aux_logits: + self.aux1 = inception_aux_block(512, num_classes) + self.aux2 = inception_aux_block(528, num_classes) + else: + self.aux1 = None + self.aux2 = None + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.dropout = nn.Dropout(0.2) + self.fc = nn.Linear(1024, num_classes) + + def _forward(self, x): + x = self.conv1(x) + x = self.maxpool1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.maxpool2(x) + x = self.inception3a(x) + x = self.inception3b(x) + x = self.maxpool3(x) + x = self.inception4a(x) + if (self.aux1 is not None): + aux1 = self.aux1(x) + x = self.inception4b(x) + x = self.inception4c(x) + x = self.inception4d(x) + if (self.aux2 is not None): + aux2 = self.aux2(x) + x = self.inception4e(x) + x = self.maxpool4(x) + x = self.inception5a(x) + x = self.inception5b(x) + x = self.avgpool(x) + + x = jt.reshape(x, (x.shape[0], (- 1))) + x = self.dropout(x) + x = self.fc(x) + return (x, aux2, aux1) + + def eager_outputs(self, x, aux2, aux1): + return x + + def execute(self, x): + (x, aux1, aux2) = self._forward(x) + aux_defined = (self.aux_logits) + return self.eager_outputs(x, aux2, aux1) + +class Inception(nn.Module): + + def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj, conv_block=None): + super(Inception, self).__init__() + if (conv_block is None): + conv_block = BasicConv2d + self.branch1 = conv_block(in_channels, ch1x1, kernel_size=1) + self.branch2 = nn.Sequential(conv_block(in_channels, ch3x3red, kernel_size=1), conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1)) + self.branch3 = nn.Sequential(conv_block(in_channels, ch5x5red, kernel_size=1), conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1)) + self.branch4 = nn.Sequential(nn.Pool(kernel_size=3, stride=1, padding=1, ceil_mode=True, op='maximum'), conv_block(in_channels, pool_proj, kernel_size=1)) + + def _forward(self, x): + branch1 = self.branch1(x) + branch2 = self.branch2(x) + branch3 = self.branch3(x) + branch4 = self.branch4(x) + outputs = [branch1, branch2, branch3, branch4] + return outputs + + def execute(self, x): + outputs = self._forward(x) + return jt.concat(outputs, dim=1) + +class InceptionAux(nn.Module): + + def __init__(self, in_channels, num_classes, conv_block=None): + super(InceptionAux, self).__init__() + if (conv_block is None): + conv_block = BasicConv2d + self.conv = conv_block(in_channels, 128, kernel_size=1) + self.fc1 = nn.Linear(2048, 1024) + self.fc2 = nn.Linear(1024, num_classes) + + def execute(self, x): + x = nn.AdaptiveAvgPool2d(4)(x) + x = self.conv(x) + x = jt.reshape(x, (x.shape[0], (- 1))) + x = nn.relu(self.fc1(x)) + x = nn.Dropout(0.7)(x) + x = self.fc2(x) + return x + +class BasicConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, **kwargs): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm(out_channels, eps=0.001) + + def execute(self, x): + x = self.conv(x) + x = self.bn(x) + return nn.relu(x) diff --git a/python/jittor/models/inception.py b/python/jittor/models/inception.py new file mode 100644 index 00000000..55ffe0ba --- /dev/null +++ b/python/jittor/models/inception.py @@ -0,0 +1,279 @@ + +import jittor as jt +from jittor import nn +__all__ = ['Inception3', 'inception_v3'] + +def inception_v3(pretrained=False, progress=True, **kwargs): + model = Inception3(**kwargs) + if pretrained: model.load("jittorhub://inception_v3.pkl") + return model + +class Inception3(nn.Module): + """ Inceptionv3 model architecture. + + Args: + + * num_classes: Number of classes. Default: 1000. + * aux_logits: If True, add an auxiliary branch that can improve training. Default: True + * inception_blocks: List of seven blocks, [conv_block, inception_a, inception_b, inception_c, inception_d, inception_e, inception_aux]. If None, will use [BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux] instead. Default: None. + * init_weights: Defualt: True. + """ + + def __init__(self, num_classes=1000, aux_logits=True, inception_blocks=None, init_weights=True): + super(Inception3, self).__init__() + if (inception_blocks is None): + inception_blocks = [BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux] + assert (len(inception_blocks) == 7) + conv_block = inception_blocks[0] + inception_a = inception_blocks[1] + inception_b = inception_blocks[2] + inception_c = inception_blocks[3] + inception_d = inception_blocks[4] + inception_e = inception_blocks[5] + inception_aux = inception_blocks[6] + self.aux_logits = aux_logits + self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2) + self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3) + self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1) + self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1) + self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3) + self.Mixed_5b = inception_a(192, pool_features=32) + self.Mixed_5c = inception_a(256, pool_features=64) + self.Mixed_5d = inception_a(288, pool_features=64) + self.Mixed_6a = inception_b(288) + self.Mixed_6b = inception_c(768, channels_7x7=128) + self.Mixed_6c = inception_c(768, channels_7x7=160) + self.Mixed_6d = inception_c(768, channels_7x7=160) + self.Mixed_6e = inception_c(768, channels_7x7=192) + if aux_logits: + self.AuxLogits = inception_aux(768, num_classes) + self.Mixed_7a = inception_d(768) + self.Mixed_7b = inception_e(1280) + self.Mixed_7c = inception_e(2048) + self.fc = nn.Linear(2048, num_classes) + + def _forward(self, x): + x = self.Conv2d_1a_3x3(x) + x = self.Conv2d_2a_3x3(x) + x = self.Conv2d_2b_3x3(x) + x = nn.pool(x, 3, "maximum", stride=2) + x = self.Conv2d_3b_1x1(x) + x = self.Conv2d_4a_3x3(x) + x = nn.pool(x, 3, "maximum", stride=2) + x = self.Mixed_5b(x) + x = self.Mixed_5c(x) + x = self.Mixed_5d(x) + x = self.Mixed_6a(x) + x = self.Mixed_6b(x) + x = self.Mixed_6c(x) + x = self.Mixed_6d(x) + x = self.Mixed_6e(x) + aux_defined = self.aux_logits + if aux_defined: + aux = self.AuxLogits(x) + else: + aux = None + x = self.Mixed_7a(x) + x = self.Mixed_7b(x) + x = self.Mixed_7c(x) + x = nn.AdaptiveAvgPool2d(1)(x) + x = nn.Dropout()(x) + x = jt.reshape(x, (x.shape[0], (- 1))) + x = self.fc(x) + return (x, aux) + + def eager_outputs(self, x, aux): + return x + + def execute(self, x): + (x, aux) = self._forward(x) + aux_defined = self.aux_logits + return self.eager_outputs(x, aux) + +class InceptionA(nn.Module): + + def __init__(self, in_channels, pool_features, conv_block=None): + super(InceptionA, self).__init__() + if (conv_block is None): + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 64, kernel_size=1) + self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1) + self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2) + self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) + self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) + self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1) + self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1) + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + branch5x5 = self.branch5x5_1(x) + branch5x5 = self.branch5x5_2(branch5x5) + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + branch_pool = nn.pool(x, 3, "mean", stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] + return outputs + + def execute(self, x): + outputs = self._forward(x) + return jt.concat(outputs, dim=1) + +class InceptionB(nn.Module): + + def __init__(self, in_channels, conv_block=None): + super(InceptionB, self).__init__() + if (conv_block is None): + conv_block = BasicConv2d + self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2) + self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1) + self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1) + self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2) + + def _forward(self, x): + branch3x3 = self.branch3x3(x) + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) + branch_pool = nn.pool(x, 3, "maximum", stride=2) + outputs = [branch3x3, branch3x3dbl, branch_pool] + return outputs + + def execute(self, x): + outputs = self._forward(x) + return jt.concat(outputs, dim=1) + +class InceptionC(nn.Module): + + def __init__(self, in_channels, channels_7x7, conv_block=None): + super(InceptionC, self).__init__() + if (conv_block is None): + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 192, kernel_size=1) + c7 = channels_7x7 + self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1) + self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1) + self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3)) + self.branch_pool = conv_block(in_channels, 192, kernel_size=1) + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + branch7x7 = self.branch7x7_1(x) + branch7x7 = self.branch7x7_2(branch7x7) + branch7x7 = self.branch7x7_3(branch7x7) + branch7x7dbl = self.branch7x7dbl_1(x) + branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) + branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) + branch_pool = nn.pool(x, kernel_size=3, op="mean", stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] + return outputs + + def execute(self, x): + outputs = self._forward(x) + return jt.concat(outputs, dim=1) + +class InceptionD(nn.Module): + + def __init__(self, in_channels, conv_block=None): + super(InceptionD, self).__init__() + if (conv_block is None): + conv_block = BasicConv2d + self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1) + self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2) + self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1) + self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3)) + self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0)) + self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2) + + def _forward(self, x): + branch3x3 = self.branch3x3_1(x) + branch3x3 = self.branch3x3_2(branch3x3) + branch7x7x3 = self.branch7x7x3_1(x) + branch7x7x3 = self.branch7x7x3_2(branch7x7x3) + branch7x7x3 = self.branch7x7x3_3(branch7x7x3) + branch7x7x3 = self.branch7x7x3_4(branch7x7x3) + branch_pool = nn.pool(x, kernel_size=3, op="maximum", stride=2) + outputs = [branch3x3, branch7x7x3, branch_pool] + return outputs + + def execute(self, x): + outputs = self._forward(x) + return jt.concat(outputs, dim=1) + +class InceptionE(nn.Module): + + def __init__(self, in_channels, conv_block=None): + super(InceptionE, self).__init__() + if (conv_block is None): + conv_block = BasicConv2d + self.branch1x1 = conv_block(in_channels, 320, kernel_size=1) + self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1) + self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) + self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) + self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1) + self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1) + self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1)) + self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0)) + self.branch_pool = conv_block(in_channels, 192, kernel_size=1) + + def _forward(self, x): + branch1x1 = self.branch1x1(x) + branch3x3 = self.branch3x3_1(x) + branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)] + branch3x3 = jt.concat(branch3x3, dim=1) + branch3x3dbl = self.branch3x3dbl_1(x) + branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) + branch3x3dbl = [self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl)] + branch3x3dbl = jt.concat(branch3x3dbl, dim=1) + branch_pool = nn.pool(x, kernel_size=3, op="mean", stride=1, padding=1) + branch_pool = self.branch_pool(branch_pool) + outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] + return outputs + + def execute(self, x): + outputs = self._forward(x) + return jt.concat(outputs, dim=1) + +class InceptionAux(nn.Module): + + def __init__(self, in_channels, num_classes, conv_block=None): + super(InceptionAux, self).__init__() + if (conv_block is None): + conv_block = BasicConv2d + self.conv0 = conv_block(in_channels, 128, kernel_size=1) + self.conv1 = conv_block(128, 768, kernel_size=5) + self.conv1.stddev = 0.01 + self.fc = nn.Linear(768, num_classes) + self.fc.stddev = 0.001 + + def execute(self, x): + x = nn.pool(x, kernel_size=5, op="mean", stride=3) + x = self.conv0(x) + x = self.conv1(x) + + + x = nn.AdaptiveAvgPool2d(1)(x) + x = jt.reshape(x, (x.shape[0], (- 1))) + x = self.fc(x) + return x + +class BasicConv2d(nn.Module): + + def __init__(self, in_channels, out_channels, **kwargs): + super(BasicConv2d, self).__init__() + self.conv = nn.Conv(in_channels, out_channels, bias=False, **kwargs) + self.bn = nn.BatchNorm(out_channels, eps=0.001) + + def execute(self, x): + x = self.conv(x) + x = self.bn(x) + return nn.relu(x) diff --git a/python/jittor/models/mnasnet.py b/python/jittor/models/mnasnet.py new file mode 100644 index 00000000..f734af9d --- /dev/null +++ b/python/jittor/models/mnasnet.py @@ -0,0 +1,112 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Wenyang Zhou <576825820@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +# This model is generated by pytorch converter. + +import jittor as jt +from jittor import nn +__all__ = ['MNASNet', 'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3'] +_BN_MOMENTUM = (1 - 0.9997) + +class _InvertedResidual(nn.Module): + + def __init__(self, in_ch, out_ch, kernel_size, stride, expansion_factor, bn_momentum=0.1): + super(_InvertedResidual, self).__init__() + assert (stride in [1, 2]) + assert (kernel_size in [3, 5]) + mid_ch = (in_ch * expansion_factor) + self.apply_residual = ((in_ch == out_ch) and (stride == 1)) + self.layers = nn.Sequential(nn.Conv(in_ch, mid_ch, 1, bias=False), nn.BatchNorm(mid_ch, momentum=bn_momentum), nn.Relu(), nn.Conv(mid_ch, mid_ch, kernel_size, padding=(kernel_size // 2), stride=stride, groups=mid_ch, bias=False), nn.BatchNorm(mid_ch, momentum=bn_momentum), nn.Relu(), nn.Conv(mid_ch, out_ch, 1, bias=False), nn.BatchNorm(out_ch, momentum=bn_momentum)) + + def execute(self, input): + if self.apply_residual: + return (self.layers(input) + input) + else: + return self.layers(input) + +def _stack(in_ch, out_ch, kernel_size, stride, exp_factor, repeats, bn_momentum): + assert (repeats >= 1) + first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, bn_momentum=bn_momentum) + remaining = [] + for _ in range(1, repeats): + remaining.append(_InvertedResidual(out_ch, out_ch, kernel_size, 1, exp_factor, bn_momentum=bn_momentum)) + return nn.Sequential(first, *remaining) + +def _round_to_multiple_of(val, divisor, round_up_bias=0.9): + assert (0.0 < round_up_bias < 1.0) + new_val = max(divisor, ((int((val + (divisor / 2))) // divisor) * divisor)) + return (new_val if (new_val >= (round_up_bias * val)) else (new_val + divisor)) + +def _get_depths(alpha): + depths = [24, 40, 80, 96, 192, 320] + return [_round_to_multiple_of((depth * alpha), 8) for depth in depths] + +class MNASNet(nn.Module): + """ MNASNet model architecture. version=2. + + Args: + + * alpha: Depth multiplier. + * num_classes: Number of classes. Default: 1000. + * dropout: Dropout probability of dropout layer. + """ + _version = 2 + + def __init__(self, alpha, num_classes=1000, dropout=0.2): + super(MNASNet, self).__init__() + assert (alpha > 0.0) + self.alpha = alpha + self.num_classes = num_classes + depths = _get_depths(alpha) + layers = [ + nn.Conv(3, 32, 3, padding=1, stride=2, bias=False), + nn.BatchNorm(32, momentum=_BN_MOMENTUM), + nn.Relu(), + nn.Conv(32, 32, 3, padding=1, stride=1, groups=32, bias=False), + nn.BatchNorm(32, momentum=_BN_MOMENTUM), + nn.Relu(), + nn.Conv(32, 16, 1, padding=0, stride=1, bias=False), + nn.BatchNorm(16, momentum=_BN_MOMENTUM), + _stack(16, depths[0], 3, 2, 3, 3, _BN_MOMENTUM), + _stack(depths[0], depths[1], 5, 2, 3, 3, _BN_MOMENTUM), + _stack(depths[1], depths[2], 5, 2, 6, 3, _BN_MOMENTUM), + _stack(depths[2], depths[3], 3, 1, 6, 2, _BN_MOMENTUM), + _stack(depths[3], depths[4], 5, 2, 6, 4, _BN_MOMENTUM), + _stack(depths[4], depths[5], 3, 1, 6, 1, _BN_MOMENTUM), + nn.Conv(depths[5], 1280, 1, padding=0, stride=1, bias=False), + nn.BatchNorm(1280, momentum=_BN_MOMENTUM), + nn.Relu() + ] + self.layers = nn.Sequential(*layers) + self.classifier = nn.Sequential(nn.Dropout(p=dropout), nn.Linear(1280, num_classes)) + + def execute(self, x): + x = self.layers(x) + x = x.mean([2, 3]) + return self.classifier(x) + +def mnasnet0_5(pretrained=False, **kwargs): + model = MNASNet(0.5, **kwargs) + if pretrained: model.load("jittorhub://mnasnet0_5.pkl") + return model + +def mnasnet0_75(pretrained=False, **kwargs): + model = MNASNet(0.75, **kwargs) + if pretrained: model.load("jittorhub://mnasnet0_75.pkl") + return model + +def mnasnet1_0(pretrained=False, **kwargs): + model = MNASNet(1.0, **kwargs) + if pretrained: model.load("jittorhub://mnasnet1_0.pkl") + return model + +def mnasnet1_3(pretrained=False, **kwargs): + model = MNASNet(1.3, **kwargs) + if pretrained: model.load("jittorhub://mnasnet1_3.pkl") + return model diff --git a/python/jittor/models/mobilenet.py b/python/jittor/models/mobilenet.py new file mode 100644 index 00000000..72b6dad8 --- /dev/null +++ b/python/jittor/models/mobilenet.py @@ -0,0 +1,101 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Wenyang Zhou <576825820@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +# This model is generated by pytorch converter. + +import jittor as jt +from jittor import init +from jittor import nn +__all__ = ['MobileNetV2', 'mobilenet_v2'] + +def _make_divisible(v, divisor, min_value=None): + if (min_value is None): + min_value = divisor + new_v = max(min_value, ((int((v + (divisor / 2))) // divisor) * divisor)) + if (new_v < (0.9 * v)): + new_v += divisor + return new_v + +class ConvBNReLU(nn.Sequential): + + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + padding = ((kernel_size - 1) // 2) + super(ConvBNReLU, self).__init__(nn.Conv(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), nn.BatchNorm(out_planes), nn.ReLU6()) + +class InvertedResidual(nn.Module): + + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert (stride in [1, 2]) + hidden_dim = int(round((inp * expand_ratio))) + self.use_res_connect = ((self.stride == 1) and (inp == oup)) + layers = [] + if (expand_ratio != 1): + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), nn.Conv(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm(oup)]) + self.conv = nn.Sequential(*layers) + + def execute(self, x): + if self.use_res_connect: + return (x + self.conv(x)) + else: + return self.conv(x) + +class MobileNetV2(nn.Module): + """ MobileNetV2 model architecture. + + Args: + + * num_classes: Number of classes. Default: 1000. + * width_mult: Width multiplier - adjusts number of channels in each layer by this amount. Default: 1.0. + * init_weights: Defualt: True. + * inverted_residual_setting: Network structure + * round_nearest: Round the number of channels in each layer to be a multiple of this number. Set to 1 to turn off rounding. Default: 8. + * block: Module specifying inverted residual building block for mobilenet. If None, use InvertedResidual instead. Default: None. + """ + + def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8, block=None): + super(MobileNetV2, self).__init__() + if (block is None): + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + if (inverted_residual_setting is None): + inverted_residual_setting = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2], [6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2], [6, 320, 1, 1]] + if ((len(inverted_residual_setting) == 0) or (len(inverted_residual_setting[0]) != 4)): + raise ValueError('inverted_residual_setting should be non-empty or a 4-element list, got {}'.format(inverted_residual_setting)) + input_channel = _make_divisible((input_channel * width_mult), round_nearest) + self.last_channel = _make_divisible((last_channel * max(1.0, width_mult)), round_nearest) + features = [ConvBNReLU(3, input_channel, stride=2)] + for (t, c, n, s) in inverted_residual_setting: + output_channel = _make_divisible((c * width_mult), round_nearest) + for i in range(n): + stride = (s if (i == 0) else 1) + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) + self.features = nn.Sequential(*features) + self.classifier = nn.Sequential(nn.Dropout(0.2), nn.Linear(self.last_channel, num_classes)) + + def _forward_impl(self, x): + x = self.features(x) + x = nn.AdaptiveAvgPool2d(1)(x) + x = jt.reshape(x, (x.shape[0], -1)) + x = self.classifier(x) + return x + + def execute(self, x): + return self._forward_impl(x) + +def mobilenet_v2(pretrained=False): + model = MobileNetV2() + if pretrained: model.load("jittorhub://mobilenet_v2.pkl") + return model + diff --git a/python/jittor/models/res2net.py b/python/jittor/models/res2net.py new file mode 100644 index 00000000..859e9c05 --- /dev/null +++ b/python/jittor/models/res2net.py @@ -0,0 +1,231 @@ +import jittor as jt +from jittor import nn +from jittor import Module +from jittor import init +from jittor.contrib import concat, argmax_pool +import math + + +model_urls = { + 'res2net50_14w_8s': 'jittorhub://res2net50_14w_8s.pkl', + 'res2net50_26w_4s': 'jittorhub://res2net50_26w_4s.pkl', + 'res2net50_26w_6s': 'jittorhub://res2net50_26w_6s.pkl', + 'res2net50_26w_8s': 'jittorhub://res2net50_26w_8s.pkl', + 'res2net50_48w_2s': 'jittorhub://res2net50_48w_2s.pkl', + 'res2net101_26w_4s': 'jittorhub://res2net101_26w_4s.pkl', +} + + +class Bottle2neck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale = 4, stype='normal'): + """ Constructor + Args: + inplanes: input channel dimensionality + planes: output channel dimensionality + stride: conv stride. Replaces pooling layer. + downsample: None when stride = 1 + baseWidth: basic width of conv3x3 + scale: number of scale. + type: 'normal': normal set. 'stage': first block of a new stage. + """ + super(Bottle2neck, self).__init__() + + width = int(math.floor(planes * (baseWidth/64.0))) + self.conv1 = nn.Conv2d(inplanes, width*scale, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(width*scale) + + if scale == 1: + self.nums = 1 + else: + self.nums = scale -1 + if stype == 'stage': + self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1) + convs = [] + bns = [] + for i in range(self.nums): + convs.append(nn.Conv2d(width, width, kernel_size=3, stride = stride, padding=1, bias=False)) + bns.append(nn.BatchNorm2d(width)) + self.convs = nn.ModuleList(convs) + self.bns = nn.ModuleList(bns) + + self.conv3 = nn.Conv2d(width*scale, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU() + self.downsample = downsample + self.stype = stype + self.scale = scale + self.width = width + + def execute(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + spx = jt.split(out, self.width, 1) + for i in range(self.nums): + if i==0 or self.stype=='stage': + sp = spx[i] + else: + sp = sp + spx[i] + sp = self.convs[i](sp) + sp = self.relu(self.bns[i](sp)) + if i==0: + out = sp + else: + out = jt.concat((out, sp), 1) + if self.scale != 1 and self.stype=='normal': + out = jt.concat((out, spx[self.nums]),1) + elif self.scale != 1 and self.stype=='stage': + out = jt.concat((out, self.pool(spx[self.nums])),1) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + +class Res2Net(nn.Module): + + def __init__(self, block, layers, baseWidth = 26, scale = 4, num_classes=1000): + self.inplanes = 64 + super(Res2Net, self).__init__() + self.baseWidth = baseWidth + self.scale = scale + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample=downsample, + stype='stage', baseWidth = self.baseWidth, scale=self.scale)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, baseWidth = self.baseWidth, scale=self.scale)) + + return nn.Sequential(*layers) + + def execute(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + +def res2net50(pretrained=False, **kwargs): + """Constructs a Res2Net-50 model. + Res2Net-50 refers to the Res2Net-50_26w_4s. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs) + if pretrained: + model.load(model_urls['res2net50_26w_4s']) + return model + +def res2net50_26w_4s(pretrained=False, **kwargs): + """Constructs a Res2Net-50_26w_4s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 4, **kwargs) + if pretrained: + model.load(model_urls['res2net50_26w_4s']) + return model + +def res2net101_26w_4s(pretrained=False, **kwargs): + """Constructs a Res2Net-50_26w_4s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth = 26, scale = 4, **kwargs) + if pretrained: + model.load(model_urls['res2net101_26w_4s']) + return model + +res2net101 = res2net101_26w_4s + +def res2net50_26w_6s(pretrained=False, **kwargs): + """Constructs a Res2Net-50_26w_4s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 6, **kwargs) + if pretrained: + model.load(model_urls['res2net50_26w_6s']) + return model + +def res2net50_26w_8s(pretrained=False, **kwargs): + """Constructs a Res2Net-50_26w_4s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 26, scale = 8, **kwargs) + if pretrained: + model.load(model_urls['res2net50_26w_8s']) + return model + +def res2net50_48w_2s(pretrained=False, **kwargs): + """Constructs a Res2Net-50_48w_2s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 48, scale = 2, **kwargs) + if pretrained: + model.load(model_urls['res2net50_48w_2s']) + return model + +def res2net50_14w_8s(pretrained=False, **kwargs): + """Constructs a Res2Net-50_14w_8s model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth = 14, scale = 8, **kwargs) + if pretrained: + model.load(model_urls['res2net50_14w_8s']) + return model + diff --git a/python/jittor/models/resnet.py b/python/jittor/models/resnet.py new file mode 100644 index 00000000..0a093972 --- /dev/null +++ b/python/jittor/models/resnet.py @@ -0,0 +1,239 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Wenyang Zhou <576825820@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +# This model is generated by pytorch converter. +import jittor as jt +from jittor import nn + +__all__ = ['ResNet', 'Resnet18', 'Resnet34', 'Resnet26', 'Resnet38', 'Resnet50', 'Resnet101', 'Resnet152', 'Resnext50_32x4d', 'Resnext101_32x8d', 'Wide_resnet50_2', 'Wide_resnet101_2', + 'resnet18', 'resnet34', 'resnet26', 'resnet38', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'] + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + conv=nn.Conv(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation) + jt.init.relu_invariant_gauss_(conv.weight, mode="fan_out") + return conv + +def conv1x1(in_planes, out_planes, stride=1): + conv=nn.Conv(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + jt.init.relu_invariant_gauss_(conv.weight, mode="fan_out") + return conv + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if (norm_layer is None): + norm_layer = nn.BatchNorm + if ((groups != 1) or (base_width != 64)): + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if (dilation > 1): + raise NotImplementedError('Dilation > 1 not supported in BasicBlock') + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.Relu() + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def execute(self, x): + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + if (self.downsample is not None): + identity = self.downsample(x) + out += identity + out = self.relu(out) + return out + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if (norm_layer is None): + norm_layer = nn.BatchNorm + width = (int((planes * (base_width / 64.0))) * groups) + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, (planes * self.expansion)) + self.bn3 = norm_layer((planes * self.expansion)) + self.relu = nn.Relu() + self.downsample = downsample + self.stride = stride + + def execute(self, x): + identity = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + if (self.downsample is not None): + identity = self.downsample(x) + out += identity + out = self.relu(out) + return out + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None): + super(ResNet, self).__init__() + if (norm_layer is None): + norm_layer = nn.BatchNorm + self._norm_layer = norm_layer + self.inplanes = 64 + self.dilation = 1 + if (replace_stride_with_dilation is None): + replace_stride_with_dilation = [False, False, False] + if (len(replace_stride_with_dilation) != 3): + raise ValueError('replace_stride_with_dilation should be None or a 3-element tuple, got {}'.format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) + jt.init.relu_invariant_gauss_(self.conv1.weight, mode="fan_out") + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.Relu() + # self.maxpool = nn.Pool(kernel_size=3, stride=2, padding=1, op='maximum') + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + # self.fc = nn.Linear((512 * block.expansion), num_classes) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if ((stride != 1) or (self.inplanes != (planes * block.expansion))): + downsample = nn.Sequential(conv1x1(self.inplanes, (planes * block.expansion), stride), norm_layer((planes * block.expansion))) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer)) + self.inplanes = (planes * block.expansion) + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer)) + return nn.Sequential(*layers) + + def _forward_impl(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + # x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.avgpool(x).float_auto() + x = jt.reshape(x, (x.shape[0], -1)) + # x = self.fc(x) + return x + + def execute(self, x): + return self._forward_impl(x) + +def _resnet(block, layers, **kwargs): + model = ResNet(block, layers, **kwargs) + return model + +def Resnet18(pretrained=False, **kwargs): + model = _resnet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: model.load("jittorhub://resnet18.pkl") + return model +resnet18 = Resnet18 + +def Resnet34(pretrained=False, **kwargs): + model = _resnet(BasicBlock, [3, 4, 6, 3], **kwargs) + if pretrained: model.load("jittorhub://resnet34.pkl") + return model +resnet34 = Resnet34 + +def Resnet50(pretrained=False, **kwargs): + model = _resnet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: model.load("jittorhub://resnet50.pkl") + return model + +resnet50 = Resnet50 + +def Resnet38(pretrained=False, **kwargs): + model = _resnet(Bottleneck, [2, 3, 5, 2], **kwargs) + if pretrained: model.load("jittorhub://resnet38.pkl") + return model +resnet38 = Resnet38 + +def Resnet26(pretrained=False, **kwargs): + model = _resnet(Bottleneck, [1, 2, 4, 1], **kwargs) + if pretrained: model.load("jittorhub://resnet26.pkl") + return model +resnet26 = Resnet26 + +def Resnet101(pretrained=False, **kwargs): + """ + ResNet-101 model architecture. + + Example:: + + model = jittor.models.Resnet101() + x = jittor.random([10,3,224,224]) + y = model(x) # [10, 1000] + + """ + model = _resnet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: model.load("jittorhub://resnet101.pkl") + return model +resnet101 = Resnet101 + +def Resnet152(pretrained=False, **kwargs): + model = _resnet(Bottleneck, [3, 8, 36, 3], **kwargs) + if pretrained: model.load("jittorhub://resnet152.pkl") + return model +resnet152 = Resnet152 + +def Resnext50_32x4d(pretrained=False, **kwargs): + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + model = _resnet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: model.load("jittorhub://resnext50_32x4d.pkl") + return model +resnext50_32x4d = Resnext50_32x4d + +def Resnext101_32x8d(pretrained=False, **kwargs): + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + model = _resnet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: model.load("jittorhub://resnext101_32x8d.pkl") + return model +resnext101_32x8d = Resnext101_32x8d + +def Wide_resnet50_2(pretrained=False, **kwargs): + kwargs['width_per_group'] = (64 * 2) + model = _resnet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: model.load("jittorhub://wide_resnet50_2.pkl") + return model +wide_resnet50_2 = Wide_resnet50_2 + +def Wide_resnet101_2(pretrained=False, **kwargs): + kwargs['width_per_group'] = (64 * 2) + model = _resnet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: model.load("jittorhub://wide_resnet101_2.pkl") + return model +wide_resnet101_2 = Wide_resnet101_2 \ No newline at end of file diff --git a/python/jittor/models/shufflenetv2.py b/python/jittor/models/shufflenetv2.py new file mode 100644 index 00000000..21856b4d --- /dev/null +++ b/python/jittor/models/shufflenetv2.py @@ -0,0 +1,115 @@ + +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Wenyang Zhou <576825820@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +# This model is generated by pytorch converter. +import jittor as jt +from jittor import nn + +__all__ = ['ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'] + +def channel_shuffle(x, groups): + (batchsize, num_channels, height, width) = x.data.shape + channels_per_group = (num_channels // groups) + x = jt.reshape(x, [batchsize, groups, channels_per_group, height, width]) + x = jt.transpose(x, (0,2,1,3,4)) + x = jt.reshape(x, [batchsize, (- 1), height, width]) + return x + +class InvertedResidual(nn.Module): + + def __init__(self, inp, oup, stride): + super(InvertedResidual, self).__init__() + if (not (1 <= stride <= 3)): + raise ValueError('illegal stride value') + self.stride = stride + branch_features = (oup // 2) + assert ((self.stride != 1) or (inp == (branch_features << 1))) + if (self.stride > 1): + self.branch1 = nn.Sequential(self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), nn.BatchNorm(inp), nn.Conv(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm(branch_features), nn.Relu()) + else: + self.branch1 = nn.Sequential() + self.branch2 = nn.Sequential(nn.Conv((inp if (self.stride > 1) else branch_features), branch_features, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm(branch_features), nn.Relu(), self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), nn.BatchNorm(branch_features), nn.Conv(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm(branch_features), nn.Relu()) + + @staticmethod + def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): + return nn.Conv(i, o, kernel_size, stride, padding, bias=bias, groups=i) + + def execute(self, x): + if (self.stride == 1): + x1 = x[:,0:x.shape[1]//2] + x2 = x[:,x.shape[1]//2:x.shape[1]] + out = jt.concat([x1, self.branch2(x2)], dim=1) + else: + out = jt.concat([self.branch1(x), self.branch2(x)], dim=1) + out = channel_shuffle(out, 2) + return out + +class ShuffleNetV2(nn.Module): + + def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, inverted_residual=InvertedResidual): + super(ShuffleNetV2, self).__init__() + if (len(stages_repeats) != 3): + raise ValueError('expected stages_repeats as list of 3 positive ints') + if (len(stages_out_channels) != 5): + raise ValueError('expected stages_out_channels as list of 5 positive ints') + self._stage_out_channels = stages_out_channels + input_channels = 3 + output_channels = self._stage_out_channels[0] + self.conv1 = nn.Sequential(nn.Conv(input_channels, output_channels, 3, 2, 1, bias=False), nn.BatchNorm(output_channels), nn.Relu()) + input_channels = output_channels + self.maxpool = nn.Pool(kernel_size=3, stride=2, padding=1, op='maximum') + stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] + for (name, repeats, output_channels) in zip(stage_names, stages_repeats, self._stage_out_channels[1:]): + seq = [inverted_residual(input_channels, output_channels, 2)] + for i in range((repeats - 1)): + seq.append(inverted_residual(output_channels, output_channels, 1)) + setattr(self, name, nn.Sequential(*seq)) + input_channels = output_channels + output_channels = self._stage_out_channels[(- 1)] + self.conv5 = nn.Sequential(nn.Conv(input_channels, output_channels, 1, 1, 0, bias=False), nn.BatchNorm(output_channels), nn.Relu()) + self.fc = nn.Linear(output_channels, num_classes) + + def _forward_impl(self, x): + x = self.conv1(x) + x = self.maxpool(x) + x = self.stage2(x) + x = self.stage3(x) + x = self.stage4(x) + x = self.conv5(x) + x = x.mean([2, 3]) + x = self.fc(x) + return x + + def execute(self, x): + return self._forward_impl(x) + +def _shufflenetv2(arch, *args): + model = ShuffleNetV2(*args) + return model + +def shufflenet_v2_x0_5(pretrained=False): + model = _shufflenetv2('shufflenetv2_x0.5', [4, 8, 4], [24, 48, 96, 192, 1024]) + if pretrained: model.load("jittorhub://shufflenet_v2_x0_5.pkl") + return model + +def shufflenet_v2_x1_0(pretrained=False): + model = _shufflenetv2('shufflenetv2_x1.0', [4, 8, 4], [24, 116, 232, 464, 1024]) + if pretrained: model.load("jittorhub://shufflenet_v2_x1_0.pkl") + return model + +def shufflenet_v2_x1_5(pretrained=False): + model = _shufflenetv2('shufflenetv2_x1.5', [4, 8, 4], [24, 176, 352, 704, 1024]) + if pretrained: model.load("jittorhub://shufflenet_v2_x1_5.pkl") + return model + +def shufflenet_v2_x2_0(pretrained=False): + model = _shufflenetv2('shufflenetv2_x2.0', [4, 8, 4], [24, 244, 488, 976, 2048]) + if pretrained: model.load("jittorhub://shufflenet_v2_x2_0.pkl") + return model diff --git a/python/jittor/models/squeezenet.py b/python/jittor/models/squeezenet.py new file mode 100644 index 00000000..212e837b --- /dev/null +++ b/python/jittor/models/squeezenet.py @@ -0,0 +1,95 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Wenyang Zhou <576825820@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +# This model is generated by pytorch converter. +import jittor as jt +from jittor import nn +__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1'] + +class Fire(nn.Module): + + def __init__(self, inplanes, squeeze_planes, expand1x1_planes, expand3x3_planes): + super(Fire, self).__init__() + self.inplanes = inplanes + self.squeeze = nn.Conv(inplanes, squeeze_planes, kernel_size=1) + self.squeeze_activation = nn.Relu() + self.expand1x1 = nn.Conv(squeeze_planes, expand1x1_planes, kernel_size=1) + self.expand1x1_activation = nn.Relu() + self.expand3x3 = nn.Conv(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1) + self.expand3x3_activation = nn.Relu() + + def execute(self, x): + x = self.squeeze_activation(self.squeeze(x)) + return jt.concat([self.expand1x1_activation(self.expand1x1(x)), self.expand3x3_activation(self.expand3x3(x))], dim=1) + +class SqueezeNet(nn.Module): + + def __init__(self, version='1_0', num_classes=1000): + super(SqueezeNet, self).__init__() + self.num_classes = num_classes + if (version == '1_0'): + self.features = nn.Sequential( + nn.Conv(3, 96, kernel_size=7, stride=2), + nn.Relu(), + nn.Pool(kernel_size=3, stride=2, ceil_mode=True, op='maximum'), + Fire(96, 16, 64, 64), + Fire(128, 16, 64, 64), + Fire(128, 32, 128, 128), + nn.Pool(kernel_size=3, stride=2, ceil_mode=True, op='maximum'), + Fire(256, 32, 128, 128), + Fire(256, 48, 192, 192), + Fire(384, 48, 192, 192), + Fire(384, 64, 256, 256), + nn.Pool(kernel_size=3, stride=2, ceil_mode=True, op='maximum'), + Fire(512, 64, 256, 256) + ) + elif (version == '1_1'): + self.features = nn.Sequential( + nn.Conv(3, 64, kernel_size=3, stride=2), + nn.Relu(), + nn.Pool(kernel_size=3, stride=2, ceil_mode=True, op='maximum'), + Fire(64, 16, 64, 64), + Fire(128, 16, 64, 64), + nn.Pool(kernel_size=3, stride=2, ceil_mode=True, op='maximum'), + Fire(128, 32, 128, 128), + Fire(256, 32, 128, 128), + nn.Pool(kernel_size=3, stride=2, ceil_mode=True, op='maximum'), + Fire(256, 48, 192, 192), + Fire(384, 48, 192, 192), + Fire(384, 64, 256, 256), + Fire(512, 64, 256, 256) + ) + else: + raise ValueError('Unsupported SqueezeNet version {version}:1_0 or 1_1 expected'.format(version=version)) + final_conv = nn.Conv(512, self.num_classes, kernel_size=1) + self.classifier = nn.Sequential( + nn.Dropout(p=0.5), + final_conv, + nn.Relu(), + nn.AdaptiveAvgPool2d((1, 1)) + ) + + def execute(self, x): + x = self.features(x) + x = self.classifier(x) + return jt.reshape(x, (x.shape[0], (- 1))) + +def _squeezenet(version, **kwargs): + model = SqueezeNet(version, **kwargs) + return model + +def squeezenet1_0(pretrained=False, **kwargs): + model = _squeezenet('1_0', **kwargs) + if pretrained: model.load("jittorhub://squeezenet1_0.pkl") + return model + +def squeezenet1_1(pretrained=False, **kwargs): + model = _squeezenet('1_1', **kwargs) + if pretrained: model.load("jittorhub://squeezenet1_1.pkl") + return model diff --git a/python/jittor/models/vgg.py b/python/jittor/models/vgg.py new file mode 100644 index 00000000..b106567e --- /dev/null +++ b/python/jittor/models/vgg.py @@ -0,0 +1,116 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +# This model is generated by pytorch converter. +import jittor as jt +from jittor import nn + +__all__ = [ + 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', + 'vgg19_bn', 'vgg19', +] + +class VGG(nn.Module): + + def __init__(self, features, num_classes=1000, init_weights=True): + super(VGG, self).__init__() + self.features = features + self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) + self.classifier = nn.Sequential( + nn.Linear(512 * 7 * 7, 4096), + nn.ReLU(), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(), + nn.Dropout(), + nn.Linear(4096, num_classes), + ) + + def execute(self, x): + x = self.features(x) + x = self.avgpool(x) + x = jt.reshape(x, [x.shape[0],-1]) + x = self.classifier(x) + return x + +def make_layers(cfg, batch_norm=False): + layers = [] + in_channels = 3 + for v in cfg: + if v == 'M': + layers += [nn.Pool(kernel_size=2, stride=2, op="maximum")] + else: + conv2d = nn.Conv(in_channels, v, kernel_size=3, padding=1) + if batch_norm: + layers += [conv2d, nn.BatchNorm(v), nn.ReLU()] + else: + layers += [conv2d, nn.ReLU()] + in_channels = v + return nn.Sequential(*layers) + + +cfgs = { + 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], + 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], + 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], +} + + +def _vgg(arch, cfg, batch_norm, **kwargs): + model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) + return model + + +def vgg11(pretrained=False, **kwargs): + model = _vgg('vgg11', 'A', False, **kwargs) + if pretrained: model.load("jittorhub://vgg11.pkl") + return model + + +def vgg11_bn(pretrained=False, **kwargs): + model = _vgg('vgg11_bn', 'A', True, **kwargs) + if pretrained: model.load("jittorhub://vgg11_bn.pkl") + return model + + +def vgg13(pretrained=False, **kwargs): + model = _vgg('vgg13', 'B', False, **kwargs) + if pretrained: model.load("jittorhub://vgg13.pkl") + return model + + +def vgg13_bn(pretrained=False, **kwargs): + model = _vgg('vgg13_bn', 'B', True, **kwargs) + if pretrained: model.load("jittorhub://vgg13_bn.pkl") + return model + + +def vgg16(pretrained=False, **kwargs): + model = _vgg('vgg16', 'D', False, **kwargs) + if pretrained: model.load("jittorhub://vgg16.pkl") + return model + + +def vgg16_bn(pretrained=False, **kwargs): + model = _vgg('vgg16_bn', 'D', True, **kwargs) + if pretrained: model.load("jittorhub://vgg16_bn.pkl") + return model + + +def vgg19(pretrained=False, **kwargs): + model = _vgg('vgg19', 'E', False, **kwargs) + if pretrained: model.load("jittorhub://vgg19.pkl") + return model + + +def vgg19_bn(pretrained=False, **kwargs): + model = _vgg('vgg19_bn', 'E', True, **kwargs) + if pretrained: model.load("jittorhub://vgg19_bn.pkl") + return model \ No newline at end of file diff --git a/python/jittor/nn.py b/python/jittor/nn.py new file mode 100644 index 00000000..623daeeb --- /dev/null +++ b/python/jittor/nn.py @@ -0,0 +1,3379 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Guoye Yang <498731903@qq.com> +# Wenyang Zhou <576825820@qq.com> +# Meng-Hao Guo +# Dun Liang . +# Zheng-Ning Liu +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +from abc import abstractmethod +import jittor as jt +from jittor import flatten, init, Module +import numpy as np +import collections +import math +from collections import OrderedDict +from jittor.pool import * +from jittor.optim import * +from jittor.misc import _pair, _triple +from jittor_utils import LOG +from functools import partial + + +def matmul_transpose(a, b): + ''' + returns a * b^T + ''' + assert a.shape[-1] == b.shape[-1], (a.shape, b.shape) + if len(a.shape) != 2: + aa = a.reshape((-1, a.shape[-1])) + cc = matmul_transpose(aa, b) + return cc.reshape(a.shape[:-1]+(-1,)) + assert len(a.shape) == 2 and len(b.shape) == 2 + + shape = list(a.shape)[:-1] + list(b.shape) + with jt.flag_scope(amp_reg = jt.flags.amp_reg | 36): + a = a.broadcast(shape, [len(shape)-2]) + b = b.broadcast(shape) + return (a*b).sum(len(shape)-1) + + +def bmm_transpose(a, b): + ''' + returns a * b^T + ''' + if jt.flags.use_cuda and jt.compile_extern.cublas_ops: + return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 1) + t = list(range(b.ndim)) + t[-1], t[-2] = t[-2], t[-1] + return bmm(a, b.transpose(t)) + + +def bmm(a, b): + ''' batch matrix multiply, +shape of input a is [batch, n, m], +shape of input b is [batch, m, k], +return shape is [batch, n, k] + +Example:: + + import jittor as jt + from jittor import nn + + batch, n, m, k = 100, 5, 6, 7 + + a = jt.random((batch, n, m)) + b = jt.random((batch, m, k)) + c = nn.bmm(a, b) + ''' + assert len(a.shape) > 2 and len(b.shape) > 2 + return matmul(a, b) + +def baddbmm(input, batch1, batch2, beta=1, alpha=1): + res = bmm(batch1, batch2) + if alpha != 1: res = res * alpha + if beta == 0: return res + return beta * input + res + +def matmul(a, b): + ''' matrix multiply, + +Example:: + + a = jt.random([3]) + b = jt.random([3]) + c = jt.matmul(a, b) + assert c.shape == [1] + + a = jt.random([3, 4]) + b = jt.random([4]) + c = jt.matmul(a, b) + assert c.shape == [3] + + a = jt.random([10, 3, 4]) + b = jt.random([4]) + c = jt.matmul(a, b) + assert c.shape == [10, 3] + + a = jt.random([10, 3, 4]) + b = jt.random([4, 5]) + c = jt.matmul(a, b) + assert c.shape == [10, 3, 5] + + a = jt.random([10, 3, 4]) + b = jt.random([10, 4, 5]) + c = jt.matmul(a, b) + assert c.shape == [10, 3, 5] + + a = jt.random([8, 1, 3, 4]) + b = jt.random([10, 4, 5]) + c = jt.matmul(a, b) + assert c.shape == [8, 10, 3, 5] + ''' + with jt.flag_scope(amp_reg = jt.flags.amp_reg | 36): + len_a = len(a.shape) + len_b = len(b.shape) + if len_b == 1: + # a: [n, m], b:[m], c:[n] + return (a*b).sum(-1) + if len_a == 1: + # a: [n], b:[n,k], c:[k] + return (a.broadcast(b, [-1]) * b).sum(0) + if len_a>=3 and len_a==len_b: + # bmm + # a: [..., n, m], b: [..., m, k], c:[..., n, k] + if jt.flags.use_cuda and jt.compile_extern.cublas_ops: + return jt.compile_extern.cublas_ops.cublas_batched_matmul(a, b, 0, 0) + shape = [] + len_c = max(len_a, len_b) + (n, m), (m_, k) = a.shape[-2:], b.shape[-2:] + assert m == m_, f"dimension not match, a.shape:{a.shape}, b.shape:{b.shape}" + # a: [..., n, m] + # b: [..., m, k] + # cc:[..., n, m, k] + # --> + # 012 + if len_b == 2 and len_a>2: + # TODO:ugly implementation for tuner + aa = a.reshape((-1, m)) + cc = matmul(aa, b) + # print(a.shape, b.shape, cc.shape) + return cc.reshape(a.shape[:-1] + [k]) + for i in range(len_c-2): + ai = len_a-(len_c-i) + bi = len_b-(len_c-i) + an = a.shape[ai] if ai>=0 else 1 + bn = b.shape[bi] if bi>=0 else 1 + if an!=1 and bn!=1: + assert an == bn, f"dimension not match, a.shape:{a.shape}, b.shape:{b.shape}" + cn = max(an, bn) + shape.append(cn) + shape.extend([n, m, k]) + a = a.broadcast(shape, [-1]) + b = b.broadcast(shape, [-3]) + return (a*b).sum(-2) +jt.Var.matmul = jt.Var.__matmul__ = matmul +jt.Var.__imatmul__ = lambda a,b: a.assign(matmul(a,b)) + +def get_init_var_rand(shape, dtype): + return jt.array(np.random.normal(0.0, 1.0, shape).astype(np.float32)) + +def relu(x): + r''' Applies the element-wise function: + + .. math:: + \text{ReLU6}(x) = \max(0,x) + + :param x: the input var + :type x: jt.Var + + Example: + >>> a = jt.randn(3) + >>> a + jt.Var([-0.38380373 1.1338731 6.128115 ], dtype=float32) + >>> nn.relu(a) + jt.Var([0. 1.1338731 6.128115 ], dtype=float32) + ''' + cond = x>0.0 + return jt.ternary_out_hint(cond, x, 0.0) + + +def leaky_relu(x, scale=0.01): + r''' Applies the element-wise function: + + .. math:: + \text{LeakyRELU}(x) = + \begin{cases} + x, & \text{ if } x \geq 0 \\ + \text{scale} \times x, & \text{ otherwise } + \end{cases} + + :param x: the input var + :type x: jt.Var + + :param scale: the :math:`\scale` value for the leaky relu formulation. Default: 0.01 + :param scale: float, optional + + Example: + >>> a = jt.randn(3) + >>> a + jt.Var([-0.38380373 1.1338731 6.128115 ], dtype=float32) + >>> nn.leaky_relu(a) + jt.Var([-3.8380371e-03 1.1338731e+00 6.1281152e+00], dtype=float32) + ''' + return jt.ternary(x>0, x, x*scale) + +def relu6(x): + r''' Applies the element-wise function: + + .. math:: + \text{ReLU6}(x) = \min(\max(0,x), 6) + + :param x: the input var + :type x: jt.Var + + Example: + >>> a = jt.randn(3) + >>> a + jt.Var([-0.38380373 1.1338731 6.128115 ], dtype=float32) + >>> nn.relu6(a) + jt.Var([0. 1.1338731 6. ], dtype=float32) + ''' + return jt.minimum(jt.maximum(x, 0.0), 6.0) + +def elu(x: jt.Var, alpha: float = 1.0) -> jt.Var: + r''' Applies the element-wise function: + + .. math:: + \text{ELU}(x) = \begin{cases} + x, & \text{ if } x > 0\\ + \alpha * (\exp(x) - 1), & \text{ if } x \leq 0 + \end{cases} + + :param x: the input var + :type x: jt.Var + + :param alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0 + :param alpha: float, optional + + Example: + >>> a = jt.randn(3) + >>> a + jt.Var([-0.38380373 -1.1338731 2.128115 ], dtype=float32) + >>> nn.elu(a) + jt.Var([-0.31873488 -0.6782155 2.128115 ], dtype=float32) + ''' + return jt.ternary(x>0,x,alpha*(x.exp()-1)) + +def sign(x: jt.Var) -> jt.Var: + ''' returns the signs of elements of x + + :param x: the input Var + :type x: jt.Var + + Example: + >>> a = jt.float32([0.99, 0, -0.99]) + >>> nn.sign(a) + jt.Var([ 1. 0. -1.], dtype=float32) + ''' + one = jt.ones(x.shape) + x = jt.ternary(x>0, one, x) + return jt.ternary(x<0, -one, x) + +def gelu(x): + r''' Applies the element-wise function: + + .. math:: + \text{GELU}(x) = x * \Phi(x) + + where :math:`\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution. + + :param x: the input var + :type x: jt.Var + + Example: + >>> a = jt.randn(3) + >>> a + jt.Var([-0.38380373 -1.1338731 2.128115 ], dtype=float32) + >>> nn.gelu(a) + jt.Var([-0.134547 0.9882567 6.128115 ], dtype=float32) + ''' + _sqrt2 = 1.4142135623730951 + erf = jt.erf(x/_sqrt2)+1 + r = erf*x*.5 + return r + +def silu(x): + r''' Applies the element-wise function: + + .. math:: + \text{SILU}(x) = x * Sigmoid(x) + + :param x: the input var + :type x: jt.Var + + Example: + >>> a = jt.randn(3) + >>> a + jt.Var([-0.38380373 -1.1338731 2.128115 ], dtype=float32) + >>> nn.silu(a) + jt.Var([-0.15552104 -0.27603802 1.9016962 ], dtype=float32) + ''' + return x * x.sigmoid() + +class ELU(Module): + r''' Applies the element-wise function: + + .. math:: + \text{ELU}(x) = \begin{cases} + x, & \text{ if } x > 0\\ + \alpha * (\exp(x) - 1), & \text{ if } x \leq 0 + \end{cases} + + :param x: the input var + :type x: jt.Var + + :param alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0 + :param alpha: float, optional + + Example: + >>> a = jt.randn(3) + >>> a + jt.Var([-0.38380373 -1.1338731 2.128115 ], dtype=float32) + >>> nn.elu(a) + jt.Var([-0.31873488 -0.6782155 2.128115 ], dtype=float32) + ''' + def __init__(self,alpha=1.0): + self.alpha=alpha + + def execute(self,x): + return elu(x,self.alpha) + +class PReLU(Module): + r''' Applies the element-wise function: + + .. math:: + \text{PReLU}(x) = + \begin{cases} + x, & \text{ if } x \geq 0 \\ + ax, & \text{ otherwise } + \end{cases} + + :param x: the input var + :type x: jt.Var + + :param num_parameters: number of :math:`a` to learn, can be either 1 or the number of channels at input. Default: 1 + :type num_parameters: int, optional + + :param init: the initial value of :math:`a`. Default: 0.25 + :param init: float, optional + + Example: + >>> a = jt.randn(3) + >>> prelu = nn.PReLU() + >>> prelu(a) + jt.Var([-0.09595093 1.1338731 6.128115 ], dtype=float32) + ''' + + def __init__(self, num_parameters=1, init_=0.25): + self.num_parameters = num_parameters + self.weight = init.constant((num_parameters,), "float32", init_) + + def execute(self, x): + if self.num_parameters != 1: + assert self.num_parameters == x.size(1), f"num_parameters does not match input channels in PReLU" + return jt.maximum(0, x) + self.weight.broadcast(x, [0,2,3]) * jt.minimum(0, x) + else: + return jt.maximum(0, x) + self.weight * jt.minimum(0, x) + +#TODO dims is 4 will cause slowly execution +def cross_entropy_loss(output, target, weight=None, ignore_index=None,reduction='mean'): + target_shape = target.shape + if len(output.shape) == 4: + c_dim = output.shape[1] + output = output.transpose((0, 2, 3, 1)) + output = output.reshape((-1, c_dim)) + + target = target.reshape((-1, )) + target_weight = ((target >= 0) & (target < output.shape[1])).float32() + if weight is not None: + target_weight = weight[target] + if ignore_index is not None: + target_weight = jt.ternary( + target==ignore_index, + jt.array(0).broadcast(target_weight).type_as(target_weight), + target_weight + ) + + target = target.broadcast(output, [1]) + target = target.index(1) == target + + output = output - output.max([1], keepdims=True) + logsum = output.exp().sum(1).log() + loss = (logsum - (output*target).sum(1)) * target_weight + if reduction == 'sum': + return loss.sum() + elif reduction == 'mean': + return loss.mean() / target_weight.mean() + else: + return loss.reshape(target_shape) + +def mse_loss(output, target, reduction="mean"): + return (output-target).sqr().reduce(reduction) + +def bce_loss(output, target, weight=None, size_average=True): + loss = - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))) + + if weight is not None: + loss *= weight + + if size_average: + return loss.mean() + else: + return loss.sum() + +def l1_loss(output, target): + return (output-target).abs().mean() + + +def smooth_l1_loss(y_true, y_pred,reduction="mean"): + """Implements Smooth-L1 loss. + y_true and y_pred are typically: [N, 4], but could be any shape. + + Args: + y_true - ground truth + y_pred - predictions + reduction - the mode of cal loss which must be in ['mean','sum','none'] + """ + diff = jt.abs(y_true - y_pred) + less_than_one = (diff<1.0).float32() + loss = (less_than_one * 0.5 * diff.sqr()) + (1 - less_than_one) * (diff - 0.5) + if reduction=="mean": + return loss.mean() + elif reduction=="sum": + return loss.sum() + elif reduction=="none": + return loss + else: + raise ValueError(f'not support {reduction}') + +def nll_loss(output,target,weight=None,ignore_index=-100,reduction='mean'): + assert output.ndim<=2 and output.ndim>0 and target.ndim==1 + n_classes = output.shape[-1] + assert weight is None or weight.numel()==n_classes + assert ignore_index<0 or ignore_index0: + weight[ignore_index]=0 + if output.ndim==2: + index = jt.index((output.shape[0],),dim=0) + loss = -output[index,target]*weight[target] + else: + loss = -output[target[0]]*weight[target[0]] + if reduction=="mean": + total_weight = weight[target].sum() if output.ndim==2 else weight[target[0]].sum() + return loss.sum()/total_weight + elif reduction=="sum": + return loss.sum() + elif reduction=="none": + return loss + else: + raise ValueError(f'not support {reduction}') + +class CrossEntropyLoss(Module): + def __init__(self, weight=None, ignore_index=None): + self.weight = weight + self.ignore_index = ignore_index + + def execute(self, output, target): + return cross_entropy_loss(output, target, self.weight, self.ignore_index) + +class MSELoss(Module): + def __init__(self, reduction='mean'): + self.reduction = reduction + def execute(self, output, target): + return mse_loss(output, target, self.reduction) + +class BCELoss(Module): + def __init__(self, weight=None, size_average=True): + self.weight = weight + self.size_average = size_average + def execute(self, output, target): + return bce_loss(output, target, self.weight, self.size_average) + +class L1Loss(Module): + def __init__(self): + pass + def execute(self, output, target): + return l1_loss(output, target) + +def binary_cross_entropy_with_logits(output, target, weight=None, pos_weight=None, size_average=True): + if not (target.shape == output.shape): + raise ValueError(f"Target size ({target.shape}) must be the same as output size ({output.shape})") + + max_val = jt.clamp(-output,min_v=0) + if pos_weight is not None: + log_weight = (pos_weight-1)*target + 1 + loss = (1-target)*output+(log_weight*(((-max_val).exp()+(-output - max_val).exp()).log()+max_val)) + else: + loss = (1-target)*output+max_val+((-max_val).exp()+(-output -max_val).exp()).log() + if weight is not None: + loss *=weight + + if size_average: + return loss.mean() + else: + return loss.sum() + +class BCEWithLogitsLoss(Module): + def __init__(self, weight=None, pos_weight=None, size_average=True): + self.pos_weight = pos_weight + self.weight = weight + self.size_average = size_average + + def execute(self, output, target): + return binary_cross_entropy_with_logits(output,target,self.weight,self.pos_weight,self.size_average) + +def softmax(x, dim=None, log=False): + import jittor.other.code_softmax as code_softmax + if code_softmax.can_softmax_v1(x, dim) and jt.compiler.is_cuda: + return code_softmax.softmax_v1(x, log) + if dim is None: dim = () + dtype, x = x.dtype, x._to_float() + if log: + a = x - jt.max(x, dim, keepdims=True) + ret = a - a.exp().sum(dim, keepdims=True).log() + else: + x = (x - jt.max(x, dim, keepdims=True)).exp() + ret = x / x.sum(dim, keepdims=True) + return ret.cast(dtype) +jt.Var.softmax = softmax + +def log_softmax(x,dim=None): + return softmax(x,dim=dim, log=True) +jt.Var.log_softmax = log_softmax + +def log_sigmoid(x): + return jt.log(jt.sigmoid(x)) +jt.Var.log_sigmoid = log_sigmoid + +def logsumexp(x, dim, keepdims=False, keepdim=False): + return x.exp().sum(dim, keepdim or keepdims).log() +jt.Var.logsumexp = logsumexp + +class Identity(Module): + def __init__(self, *args, **kwargs): + super(Identity, self).__init__() + + def execute(self, input): + return input + +def identity(input): return input + +class Dropout(Module): + def __init__(self, p=0.5, is_train=False): + assert p >= 0 and p <= 1, "dropout probability has to be between 0 and 1, but got {}".format(p) + self.p = p + self.is_train = is_train + #TODO: test model.train() to change self.is_train + def execute(self, input): + output = input + if self.p > 0 and self.is_train: + if self.p == 1: + noise = jt.zeros(input.shape) + output = output * noise + else: + noise = jt.random(input.shape) + noise = (noise > self.p).int() + output = output * noise / (1.0 - self.p) # div keep prob + return output + +def dropout(x,p=0.5,is_train=False): + return Dropout(p=p,is_train=is_train)(x) + +class Dropout2d(Module): + def __init__(self, p=0.5, is_train=False): + ''' + Randomly zero out entire channels, from "Efficient Object Localization Using Convolutional Networks" + input: + x: [N,C,H,W] or [N,C,L] + output: + y: same shape as x + ''' + assert p >= 0 and p <= 1, "dropout probability has to be between 0 and 1, but got {}".format(p) + self.p = p + self.is_train = is_train + #TODO: test model.train() to change self.is_train + def execute(self, input): + output = input + if (input.dim() != 4) and (input.dim() != 3): + raise RuntimeError(f'Expected 3D (unbatched) or 4D (batched) input to Dropout2d, but got input of size: {input.shape}') + shape = input.shape[:-2] + if self.p > 0 and self.is_train: + if self.p == 1: + output = jt.zeros(input.shape) + else: + noise = jt.random(shape) + noise = (noise > self.p).int() + output = output * noise.broadcast(input.shape, dims=[-2,-1]) / (1.0 - self.p) # div keep prob + return output + +def dropout2d(x,p=0.5,is_train=False): + return Dropout2d(p=p,is_train=is_train)(x) + +class DropPath(Module): + '''Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + ''' + def __init__(self, p=0.5, is_train=False): + ''' + :param p: Specifies the probability of each batch retention. Defaults to 0.5. + :type p: float dtype + :param is_train: Specify whether it is a training model. Defaults to False. + :type is_train: bool + ''' + self.p = p + self.is_train = is_train + #TODO: test model.train() to change self.is_train + def execute(self, x): + if self.p == 0. or not self.is_train: + return x + keep_prob = 1 - self.p + shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) + random_tensor = keep_prob + jt.rand(shape, dtype=x.dtype) + output = x.divide(keep_prob) * random_tensor.floor() + return output + +def droppath(x,p=0.5,is_train=False): + return DropPath(p=p,is_train=is_train)(x) + +class Linear(Module): + def __init__(self, in_features, out_features, bias=True): + self.in_features = in_features + self.out_features = out_features + self.weight = init.invariant_uniform((out_features, in_features), "float32") + bound = 1.0/math.sqrt(in_features) + self.bias = init.uniform((out_features,), "float32",-bound,bound) if bias else None + + def execute(self, x): + x = matmul_transpose(x, self.weight) + if self.bias is not None: + return x + self.bias + return x + +def linear(x, weight, bias=None): + ''' Returns x * weight^T + ''' + x = matmul_transpose(x, weight) + if bias is not None: + return x + bias + return x + +class BatchNorm(Module): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, is_train=True, sync=True): + self.sync = sync + self.num_features = num_features + self.is_train = is_train + self.eps = eps + self.momentum = momentum + self.affine = affine + self.weight = init.constant((num_features,), "float32", 1.0) if affine else 1.0 + self.bias = init.constant((num_features,), "float32", 0.0) if affine else 0.0 + self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad() + self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad() + + def execute(self, x): + dims = [0]+list(range(2,x.ndim)) + if self.is_train: + xmean = jt.mean(x, dims=dims) + x2mean = jt.mean(x*x, dims=dims) + if self.sync and jt.in_mpi: + xmean = xmean.mpi_all_reduce("mean") + x2mean = x2mean.mpi_all_reduce("mean") + + xvar = (x2mean-xmean*xmean).maximum(0.0) + w = self.weight / jt.sqrt(xvar+self.eps) + b = self.bias - xmean * w + norm_x = x * w.broadcast(x, dims) + b.broadcast(x, dims) + + self.running_mean.update(self.running_mean + + (xmean.reshape((-1,)) - self.running_mean) * self.momentum) + self.running_var.update(self.running_var + + (xvar.reshape((-1,))-self.running_var)*self.momentum) + return norm_x + else: + w = self.weight / jt.sqrt(self.running_var+self.eps) + b = self.bias - self.running_mean * w + norm_x = x * w.broadcast(x, dims) + b.broadcast(x, dims) + return norm_x + +BatchNorm3d = BatchNorm2d = BatchNorm1d = BatchNorm + +def batch_norm(x, running_mean, running_var, weight=1, bias=0, training=False, momentum=0.1, eps=1e-05): + dims = [0]+list(range(2,x.ndim)) + assert not training + w = weight / jt.sqrt(running_var+eps) + b = bias - running_mean * w + norm_x = x * w.broadcast(x, dims) + b.broadcast(x, dims) + return norm_x + + +class InstanceNorm(Module): + def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, is_train=True, sync=True): + self.sync = sync + self.num_features = num_features + self.is_train = is_train + self.eps = eps + self.momentum = momentum + + self.affine = affine + self.weight = init.constant((num_features,), "float32", 1.0) if affine else 1.0 + self.bias = init.constant((num_features,), "float32", 0.0) if affine else 0.0 + + def execute(self, x): + dims = list(range(2,x.ndim)) + xmean = jt.mean(x, dims=dims) + x2mean = jt.mean(x*x, dims=dims) + + xvar = (x2mean-xmean*xmean).maximum(0.0) + w = self.weight / jt.sqrt(xvar+self.eps) + b = self.bias - xmean * w + return x * w.broadcast(x, dims) + b.broadcast(x, dims) + +InstanceNorm3d = InstanceNorm2d = InstanceNorm1d = InstanceNorm + +def fp32_guard(func): + def wrapper(*args, **kw): + if jt.flags.amp_level == 0: + return func(*args, **kw) + new_args = [] + need_cast = False + dtype = None + for a in args: + if isinstance(a, jt.Var) and (a.dtype == "float16" or a.dtype == "bfloat16"): + dtype = a.dtype + new_args.append(a.float32()) + need_cast = True + else: + new_args.append(a) + with jt.flag_scope(amp_level=0): + a = func(*new_args, **kw) + if need_cast and isinstance(a, jt.Var) and a.dtype == "float32": + a = a.cast(dtype) + return a + return wrapper + +def instance_norm(x, + running_mean = None, + running_var = None, + weight = 1, + bias = 0, + momentum = 0.1, + eps = 1e-5): + dims = list(range(2,x.ndim)) + xmean = jt.mean(x, dims=dims) + x2mean = jt.mean(x*x, dims=dims) + + xvar = (x2mean-xmean*xmean).maximum(0.0) + w = weight / jt.sqrt(xvar+eps) + b = bias - xmean * w + return x * w.broadcast(x, dims) + b.broadcast(x, dims) + +class LayerNorm(Module): + def __init__(self, normalized_shape, eps: float = 1e-5, elementwise_affine: bool = True) -> None: + if isinstance(normalized_shape, int): + normalized_shape = (normalized_shape,) + self.normalized_shape = tuple(normalized_shape) + self.eps = eps + self.elementwise_affine = elementwise_affine + self.weight = init.constant(normalized_shape, "float32", 1.0) if elementwise_affine else 1.0 + self.bias = init.constant(normalized_shape, "float32", 0.0) if elementwise_affine else 0.0 + + @fp32_guard + def execute(self, x): + dims = [-i for i in range(len(self.normalized_shape), 0, -1)] + xmean = jt.mean(x, dims=dims, keepdims=1) + x2mean = jt.mean(x*x, dims=dims, keepdims=1) + + xvar = (x2mean-xmean*xmean).maximum(0.0) + w = self.weight / jt.sqrt(xvar+self.eps) + b = self.bias - xmean * w + return x * w + b + + +LayerNorm3d = LayerNorm2d = LayerNorm1d = LayerNorm + +@fp32_guard +def layer_norm(x, + normalized_shape, + weight = 1, + bias = 0, + eps: float = 1e-5, + elementwise_affine: bool = True): + dims = [-i for i in range(len(normalized_shape), 0, -1)] + xmean = jt.mean(x, dims=dims, keepdims=1) + x2mean = jt.mean(x*x, dims=dims, keepdims=1) + + xvar = (x2mean-xmean*xmean).maximum(0.0) + w = weight / jt.sqrt(xvar+eps) + b = bias - xmean * w + return x * w + b + +class GroupNorm(Module): + def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, is_train=True): + self.num_groups = num_groups + self.num_channels = num_channels + self.eps = eps + + self.affine = affine + self.weight = init.constant((num_channels,), "float32", 1.0) if affine else 1.0 + self.bias = init.constant((num_channels,), "float32", 0.0) if affine else 0.0 + + def execute(self, x): + N = x.shape[0] + C = self.num_channels + # output_shape = (N,-1) + # TODO: 3d group norm + # if x.ndim==4: + # output_shape = x.shape + output_shape = x.shape + assert C % self.num_groups == 0 + x = x.reshape((N, self.num_groups, C//self.num_groups, -1)) + xmean = jt.mean(x, dims=[2,3]).reshape((N, self.num_groups, 1)) + x2mean = jt.mean(x*x, dims=[2,3]).reshape((N, self.num_groups, 1)) + xvar = (x2mean-xmean*xmean).maximum(0.0) + + if self.affine: + w = self.weight.reshape((1, self.num_groups, -1)) + b = self.bias.reshape((1, self.num_groups, -1)) + else: + w = 1 + b = 0 + w = w / jt.sqrt(xvar+self.eps) + b = b - xmean * w + x = x * w.broadcast(x, [3]) + b.broadcast(x, [3]) + return x.reshape(output_shape) + +def group_norm(x, + num_groups, + weight = 1, + bias = 0, + eps=1e-05): + N = x.shape[0] + C = x.shape[1] + output_shape = (N,-1) + # TODO: 3d group norm + if x.ndim==4: + output_shape = x.shape + assert C % num_groups == 0 + x = x.reshape((N, num_groups, C//num_groups, -1)) + xmean = jt.mean(x, dims=[2,3]).reshape((N, num_groups, 1)) + x2mean = jt.mean(x*x, dims=[2,3]).reshape((N, num_groups, 1)) + xvar = (x2mean-xmean*xmean).maximum(0.0) + + if isinstance(weight, jt.Var): + weight = weight.reshape((1, num_groups, -1)) + if isinstance(bias, jt.Var): + bias = bias.reshape((1, num_groups, -1)) + weight = weight / jt.sqrt(xvar+eps) + bias = bias - xmean * weight + x = x * weight.broadcast(x, [3]) + bias.broadcast(x, [3]) + return x.reshape(output_shape) + + +Relu = jt.make_module(relu) +ReLU = Relu +Leaky_relu = jt.make_module(leaky_relu, 2) +LeakyReLU = Leaky_relu +ReLU6 = jt.make_module(relu6) +Softmax = jt.make_module(softmax, 2) +GELU = jt.make_module(gelu) +SiLU = jt.make_module(silu) + +class Flatten(Module): + ''' Flattens the contiguous range of dimensions in a Var. + + :param start_dim: the first dimension to be flattened. Defaults: 1. + :type start_dim: int + + :param end_dim: the last dimension to be flattened. Defaults: -1. + :type end_dim: int + ''' + def __init__(self, start_dim=1, end_dim=-1): + self.start_dim = start_dim + self.end_dim = end_dim + + def execute(self, x) -> jt.Var: + return x.flatten(self.start_dim, self.end_dim) + + +from jittor.depthwise_conv import DepthwiseConv + +class Conv(Module): + ''' Applies a 2D convolution over an input signal composed of several input planes. + + :param in_channels: Number of channels in the input feature map + :type in_channels: int + + :param out_channels: Number of channels in the output feature map + :type out_channels: int + + :param kernel_size: Size of the convolving kernel + :type kernel_size: int or tuple + + :param stride: Stride of the convolution. Default: 1 + :type stride: int or tuple, optional + + :param padding: Padding added to all four sides of the input. Default: 0 + :type padding: int or tuple, optional + + :param dilation: Spacing between kernel elements. Default: 1 + :type dilation: int or tuple, optional + + :param groups: Number of blocked connections from input channels to output channels. Default: 1 + :type groups: int, optional + + :param bias: If True, adds a learnable bias to the output. Default: True + :type bias: bool, optional + + Example: + + >>> conv = nn.Conv2d(24, 32, 3) + >>> conv = nn.Conv2d(24, 32, (3,3)) + >>> conv = nn.Conv2d(24, 32, 3, stride=2, padding=1) + >>> conv = nn.Conv2d(24, 32, 3, dilation=(3, 1)) + >>> input = jt.randn(4, 24, 100, 100) + >>> output = conv(input) + ''' + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): + if in_channels <= 0: + raise ValueError(f"in_channels must be greater than zero, got {in_channels}") + if out_channels <= 0: + raise ValueError(f"out_channels must be greater than zero, got {out_channels}") + if groups <= 0: + raise ValueError(f"groups must must be greater than zero, got {groups}") + assert in_channels % groups == 0, 'in_channels must be divisible by groups' + assert out_channels % groups == 0, 'out_channels must be divisible by groups' + if isinstance(kernel_size, tuple): + for size in kernel_size: + if size <= 0: + raise ValueError(f"kernel_size must be greater than zero, got {kernel_size}") + else: + if kernel_size <= 0: + raise ValueError(f"kernel_size must be greater than zero, got {kernel_size}") + if isinstance(stride, tuple): + for size in stride: + if size <= 0: + raise ValueError(f"stride must be greater than zero, got {stride}") + else: + if stride <= 0: + raise ValueError(f"stride must be greater than zero, got {stride}") + if isinstance(padding, tuple): + for size in padding: + if size < 0: + raise ValueError(f"padding must be nonnegative, got {padding}") + else: + if padding < 0: + raise ValueError(f"padding must be nonnegative, got {padding}") + if isinstance(dilation, tuple): + for size in dilation: + if size <= 0: + raise ValueError(f"dilation must be greater than zero, got {dilation}") + else: + if dilation <= 0: + raise ValueError(f"dilation must be greater than zero, got {dilation}") + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) + self.stride = stride if isinstance(stride, tuple) else (stride, stride) + self.padding = padding if isinstance(padding, tuple) else (padding, padding) + self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation) + self.groups = groups + self.is_depthwise_conv = self.groups == self.out_channels and self.groups == self.in_channels + if self.is_depthwise_conv and jt.flags.use_cuda and jt.compiler.is_cuda: + self.depthwise_conv = DepthwiseConv(stride, padding, dilation) + Kh, Kw = self.kernel_size + + # self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out") + self.weight = init.invariant_uniform([out_channels, in_channels//groups, Kh, Kw], dtype="float") + if bias: + fan=1 + for i in self.weight.shape[1:]: + fan *= i + bound = 1 / math.sqrt(fan) + self.bias = init.uniform([out_channels], dtype="float", low=-bound, high=bound) + else: + self.bias = None + + def execute(self, x): + if hasattr(self, 'depthwise_conv'): + y = self.depthwise_conv(x, self.weight) + if self.bias is not None: + b = self.bias.broadcast(y.shape, [0,2,3]) + y = y + b + return y + elif self.groups == 1: + N,C,H,W = x.shape + Kh, Kw = self.kernel_size + assert C==self.in_channels + oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1 + ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1 + assert oh>0 and ow>0 + with jt.flag_scope(amp_reg = jt.flags.amp_reg | 36): + xx = x.reindex([N,self.out_channels,C,oh,ow,Kh,Kw], [ + 'i0', # Nid + 'i2', # Cid + f'i3*{self.stride[0]}-{self.padding[0]}+i5*{self.dilation[0]}', # Hid+Khid + f'i4*{self.stride[1]}-{self.padding[1]}+i6*{self.dilation[1]}', # Wid+KWid + ]) + ww = self.weight.broadcast(xx.shape, [0,3,4]) + yy = xx*ww + y = yy.sum([2,5,6]) # Kc, Kh, Kw + if self.bias is not None: + b = self.bias.broadcast(y.shape, [0,2,3]) + y = y + b + return y + else: + N,C,H,W = x.shape + Kh, Kw = self.kernel_size + G = self.groups + CpG = C // G # channels per group + assert C==self.in_channels + oc = self.out_channels + oh = (H+self.padding[0]*2-Kh*self.dilation[0]+self.dilation[0]-1)//self.stride[0]+1 + ow = (W+self.padding[1]*2-Kw*self.dilation[1]+self.dilation[1]-1)//self.stride[1]+1 + assert oh>0 and ow>0 + xx = x.reindex([N,G,oc//G,CpG,oh,ow,Kh,Kw], [ + 'i0', # Nid + f'i1*{CpG}+i3', # Gid + f'i4*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}', # Hid+Khid + f'i5*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}', # Wid+KWid + ]) + # w: [oc, CpG, Kh, Kw] + ww = self.weight.reindex([N, G, oc//G, CpG, oh, ow, Kh, Kw], [ + f'i1*{oc//G}+i2', + 'i3', + 'i6', + 'i7' + ]) + ww.compile_options = xx.compile_options = {"G":G,"C":C} + yy = xx*ww + y = yy.reindex_reduce('add', [N, oc, oh, ow], [ + 'i0', + f'i1*{oc//G}+i2', + 'i4', + 'i5' + ]) + if self.bias is not None: + b = self.bias.broadcast(y.shape, [0,2,3]) + y = y + b + return y + +Conv2d = Conv + +class Conv1d(Module): + ''' Applies a 1D convolution over an input signal composed of several input planes. + + :param in_channels: Number of channels in the input feature map + :type in_channels: int + + :param out_channels: Number of channels in the output feature map + :type out_channels: int + + :param kernel_size: Size of the convolving kernel + :type kernel_size: int or tuple + + :param stride: Stride of the convolution. Default: 1 + :type stride: int or tuple, optional + + :param padding: Padding added to all four sides of the input. Default: 0 + :type padding: int or tuple, optional + + :param dilation: Spacing between kernel elements. Default: 1 + :type dilation: int or tuple, optional + + :param groups: Number of blocked connections from input channels to output channels. Default: 1 + :type groups: int, optional + + :param bias: If True, adds a learnable bias to the output. Default: True + :type bias: bool, optional + + Example: + + >>> conv = nn.Conv1d(24, 32, 3) + >>> conv = nn.Conv1d(24, 32, (3,3)) + >>> conv = nn.Conv1d(24, 32, 3, stride=2, padding=1) + >>> conv = nn.Conv1d(24, 32, 3, dilation=(3, 1)) + >>> input = jt.randn(4, 24, 100) + >>> output = conv(input) + ''' + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): + assert in_channels > 0, 'in_channels must be positive' + assert out_channels > 0, 'out_channels must be positive' + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = (kernel_size, 1) + self.stride = (stride, 1) + self.padding = (padding, 0) + self.dilation = (dilation, 1) + self.groups = groups + self.bias = bias + if groups <= 0: + raise ValueError("groups must be a positive integer") + assert in_channels % groups == 0, 'in_channels must be divisible by groups' + assert out_channels % groups == 0, 'out_channels must be divisible by groups' + # using list to escape module dfs + self._conv = [Conv(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, self.dilation, self.groups, self.bias)] + self.weight = self._conv[0].weight.squeeze(-1) + self.bias = self._conv[0].bias + + def execute(self, x): + if x.dim() != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + N,C,D = x.shape + assert C==self.in_channels + self._conv[0].weight = self.weight.unsqueeze(-1) + x = x.unsqueeze(-1) + x = self._conv[0](x) + y = x.squeeze(-1) + return y + +class Conv3d(Module): + ''' Applies a 3D convolution over an input signal composed of several input planes. + + :param in_channels: Number of channels in the input feature map + :type in_channels: int + + :param out_channels: Number of channels in the output feature map + :type out_channels: int + + :param kernel_size: Size of the convolving kernel + :type kernel_size: int or tuple + + :param stride: Stride of the convolution. Default: 1 + :type stride: int or tuple, optional + + :param padding: Padding added to all four sides of the input. Default: 0 + :type padding: int or tuple, optional + + :param dilation: Spacing between kernel elements. Default: 1 + :type dilation: int or tuple, optional + + :param groups: Number of blocked connections from input channels to output channels. Default: 1 + :type groups: int, optional + + :param bias: If True, adds a learnable bias to the output. Default: True + :type bias: bool, optional + + Example: + + >>> conv = nn.Conv3d(24, 32, 3) + >>> conv = nn.Conv3d(24, 32, (3,3)) + >>> conv = nn.Conv3d(24, 32, 3, stride=2, padding=1) + >>> conv = nn.Conv3d(24, 32, 3, dilation=(3, 1)) + >>> input = jt.randn(4, 24, 50, 50, 50) + >>> output = conv(input) + ''' + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size) + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.padding = padding if isinstance(padding, tuple) else (padding, padding, padding) + self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation) + self.groups = groups + if groups <= 0: + raise ValueError("groups must be a positive integer") + assert in_channels % groups == 0, 'in_channels must be divisible by groups' + assert out_channels % groups == 0, 'out_channels must be divisible by groups' + Kh, Kw, Kd = self.kernel_size + self.groups = groups + assert in_channels % groups == 0, 'in_channels must be divisible by groups' + assert out_channels % groups == 0, 'out_channels must be divisible by groups' + + self.weight = init.invariant_uniform([out_channels, in_channels//groups, Kh, Kw, Kd], dtype="float") + if bias: + fan=1 + for i in self.weight.shape[1:]: + fan *= i + bound = 1 / math.sqrt(fan) + self.bias = init.uniform([out_channels], dtype="float", low=-bound, high=bound) + else: + self.bias = None + + def execute(self, x): + return conv3d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class Conv1d_sp(Linear): + def __init__(self, inchannels, outchannels, kernel_size=1, bias=True): + assert inchannels > 0, 'in_channels must be positive' + assert outchannels > 0, 'out_channels must be positive' + super().__init__(inchannels, outchannels, bias=bias) + assert kernel_size == 1 + + def execute(self, x): + if x.dim() != 3: + raise ValueError("Input shape must be `(N, C, L)`!") + x = x.transpose(0, 2, 1) + x = super().execute(x) + x = x.transpose(0, 2, 1) + return x + +def conv2d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + ''' Applies a 2D convolution over an input signal composed of several input planes. + + :param x: the input image + :type x: jt.Var + + :param weight: the convolution kernel + :type weight: jt.Var + + :param bias: the bias after convolution + :type bias: jt,Var, optional + + :param stride: Stride of the convolution. Default: 1 + :type stride: int or tuple, optional + + :param padding: Padding added to all four sides of the input. Default: 0 + :type padding: int or tuple, optional + + :param dilation: Spacing between kernel elements. Default: 1 + :type dilation: int or tuple, optional + + :param groups: Number of blocked connections from input channels to output channels. Default: 1 + :type groups: int, optional + + Example: + + >>> x = jt.randn(4, 24, 100, 100) + >>> w = jt.randn(32, 24, 3, 3) + >>> y = nn.conv2d(x, w) + ''' + padding = _pair(padding) + stride = _pair(stride) + dilation = _pair(dilation) + out_channels = weight.shape[0] + if groups <= 0: + raise ValueError("groups must be a positive integer") + if groups == 1: + N,C,H,W = x.shape + Kh, Kw = weight.shape[-2:] + oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1 + ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1 + with jt.flag_scope(amp_reg = jt.flags.amp_reg | 36): + xx = x.reindex([N,out_channels,C,oh,ow,Kh,Kw], [ + 'i0', # Nid + 'i2', # Cid + f'i3*{stride[0]}-{padding[0]}+i5*{dilation[0]}', # Hid+Khid + f'i4*{stride[1]}-{padding[1]}+i6*{dilation[1]}', # Wid+KWid + ]) + ww = weight.broadcast(xx.shape, [0,3,4]) + yy = xx*ww + y = yy.sum([2,5,6]) # Kc, Kh, Kw + if bias is not None: + b = bias.broadcast(y.shape, [0,2,3]) + y = y + b + return y + else: + N,C,H,W = x.shape + Kh, Kw = weight.shape[-2:] + G = groups + CpG = C // G # channels per group + oc = out_channels + oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1 + ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1 + xx = x.reindex([N,G,oc//G,CpG,oh,ow,Kh,Kw], [ + 'i0', # Nid + f'i1*{CpG}+i3', # Gid + f'i4*{stride[0]}-{padding[0]}+i6*{dilation[0]}', # Hid+Khid + f'i5*{stride[1]}-{padding[1]}+i7*{dilation[1]}', # Wid+KWid + ]) + xx.compile_options = {"G":G} + # w: [oc, CpG, Kh, Kw] + ww = weight.reindex([N, G, oc//G, CpG, oh, ow, Kh, Kw], [ + f'i1*{oc//G}+i2', + 'i3', + 'i6', + 'i7' + ]) + yy = xx*ww + y = yy.reindex_reduce('add', [N, oc, oh, ow], [ + 'i0', + f'i1*{oc//G}+i2', + 'i4', + 'i5' + ]) + if bias is not None: + b = bias.broadcast(y.shape, [0,2,3]) + y = y + b + return y +conv = conv2d + +def conv3d(x, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): + ''' Applies a 3D convolution over an input signal composed of several input planes. + + :param x: the input volume + :type x: jt.Var + + :param weight: the convolution kernel + :type weight: jt.Var + + :param bias: the bias after convolution + :type bias: jt,Var, optional + + :param stride: Stride of the convolution. Default: 1 + :type stride: int or tuple, optional + + :param padding: Padding added to all four sides of the input. Default: 0 + :type padding: int or tuple, optional + + :param dilation: Spacing between kernel elements. Default: 1 + :type dilation: int or tuple, optional + + :param groups: Number of blocked connections from input channels to output channels. Default: 1 + :type groups: int, optional + + Example: + + >>> x = jt.randn(4, 24, 50, 50, 50) + >>> w = jt.randn(32, 24, 3, 3, 3) + >>> y = nn.conv2d(x, w) + ''' + padding = _triple(padding) + stride = _triple(stride) + dilation = _triple(dilation) + out_channels = weight.shape[0] + if groups <= 0: + raise ValueError("groups must be a positive integer") + if jt.flags.use_cuda and jt.cudnn: + y = jt.cudnn.ops.cudnn_conv3d(x, weight, *stride, *padding, *dilation, groups) + elif groups == 1: + N,C,D,H,W = x.shape + Kd, Kh, Kw = weight.shape[-3:] + od = (D+padding[0]*2-Kd*dilation[0]+dilation[0]-1)//stride[0]+1 + oh = (H+padding[1]*2-Kh*dilation[1]+dilation[1]-1)//stride[1]+1 + ow = (W+padding[2]*2-Kw*dilation[2]+dilation[2]-1)//stride[2]+1 + xx = x.reindex([N,out_channels,C,od,oh,ow,Kd,Kh,Kw], [ + 'i0', # Nid + 'i2', # Cid + f'i3*{stride[0]}-{padding[0]}+i6*{dilation[0]}', # Hid+Khid + f'i4*{stride[1]}-{padding[1]}+i7*{dilation[1]}', # Wid+KWid + f'i5*{stride[2]}-{padding[2]}+i8*{dilation[2]}', # Did+KDid + ]) + ww = weight.broadcast(xx.shape, [0,3,4,5]) + yy = xx*ww + y = yy.sum([2,6,7,8]) # Kc, Kh, Kw,Kd + else: + N,C,D,H,W = x.shape + Kd, Kh, Kw = weight.shape[-3:] + G = groups + CpG = C // G # channels per group + oc = out_channels + od = (D+padding[0]*2-Kd*dilation[0]+dilation[0]-1)//stride[0]+1 + oh = (H+padding[1]*2-Kh*dilation[1]+dilation[1]-1)//stride[1]+1 + ow = (W+padding[2]*2-Kw*dilation[2]+dilation[2]-1)//stride[2]+1 + xx = x.reindex([N,G,oc//G,CpG,od,oh,ow,Kd,Kh,Kw], [ + 'i0', # Nid + f'i1*{CpG}+i3', # Gid + f'i4*{stride[0]}-{padding[0]}+i7*{dilation[0]}', # Hid+Khid + f'i5*{stride[1]}-{padding[1]}+i8*{dilation[1]}', # Wid+KWid + f'i6*{stride[2]}-{padding[2]}+i9*{dilation[2]}', # Did+KDid + ]) + xx.compile_options = {"G":G} + # w: [oc, CpG, Kh, Kw, Kd] + ww = weight.reindex([N, G, oc//G, CpG, oh, ow, od, Kh, Kw, Kd], [ + f'i1*{oc//G}+i2', + 'i3', + 'i7', + 'i8', + 'i9' + ]) + yy = xx*ww + y = yy.reindex_reduce('add', [N, oc, od, oh, ow], [ + 'i0', + f'i1*{oc//G}+i2', + 'i4', + 'i5', + 'i6' + ]) + + if bias is not None: + b = bias.broadcast(y.shape, [0,2,3,4]) + y = y + b + return y + +class ConvTranspose(Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, \ + padding=0, output_padding=0, groups=1, bias=True, dilation=1): + self.in_channels = in_channels + self.out_channels = out_channels + + # added + self.dilation = dilation + self.groups = groups + + self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) + self.stride = stride if isinstance(stride, tuple) else (stride, stride) + self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation) + # added + self.padding = padding if isinstance(padding, tuple) else (padding, padding) + self.real_padding = (self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0], + self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1]) + self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding) + assert self.stride[0] > 0 and self.stride[1] > 0,"stride must be positive" + assert self.padding[0] >= 0 and self.padding[1] >= 0,"padding must be non-negative" + assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \ + self.output_padding[1] < max(self.stride[1], self.dilation[1]), \ + "output padding must be smaller than max(stride, dilation)" + assert in_channels % groups == 0, 'in_channels must be divisible by groups' + assert out_channels % groups == 0, 'out_channels must be divisible by groups' + + self.weight = init.invariant_uniform((in_channels, out_channels//groups) + self.kernel_size, dtype="float") + if bias: + fan=1 + for i in self.weight.shape[1:]: + fan *= i + bound = 1 / math.sqrt(fan) + self.bias = init.uniform([out_channels], dtype="float", low=-bound, high=bound) + else: + self.bias = None + + def execute(self, x): + if x.dim() != 4: + raise RuntimeError(f'Expected 4D (batched) input to conv_transpose2d, but got input of size: {x.shape}') + if self.groups == 1: + N,C,H,W = x.shape + i,o,h,w = self.weight.shape + assert C==i + stride_h, stride_w = self.stride + padding_h, padding_w = self.padding + dilation_h, dilation_w = self.dilation + + h_out = (H-1) * stride_h + self.output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h + w_out = (W-1) * stride_w + self.output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w + out_shape = (N, o, h_out, w_out) + shape = (N, i, o, H, W, h, w) + xx = x.broadcast(shape, (2, 5, 6)) # i,h,w + ww = self.weight.broadcast(shape, (0, 3, 4)) # N,H,W + y = (ww*xx).reindex_reduce("add", out_shape, [ + 'i0', # N + 'i2', # o + f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid + f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid + ]) + if self.bias is not None: + b = self.bias.broadcast(y.shape, [0,2,3]) + y = y + b + return y + else: + N,C,H,W = x.shape + Kh, Kw = self.kernel_size + i,o,h,w = self.weight.shape + oc = self.out_channels + G = self.groups + CpG = C // G # channels per group + assert C==self.in_channels + stride_h, stride_w = self.stride + padding_h, padding_w = self.padding + dilation_h, dilation_w = self.dilation + + oh = (H-1) * stride_h + self.output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h + ow = (W-1) * stride_w + self.output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w + out_shape = (N, oc, oh, ow) + shape = [N,G,oc//G,CpG,oh,ow,Kh,Kw] + xx = x.reindex(shape, [ + 'i0', + f'i1*{oc//G}+i2', + 'i4', + 'i5' + ]) + ww = self.weight.reindex(shape, [ + f'i1*{oc//G}+i2', + 'i3', + 'i6', + 'i7' + ]) + ww.compile_options = xx.compile_options = {"G":G,"C":C} + y = (ww*xx).reindex_reduce("add", out_shape, [ + 'i0', # Nid + f'i1*{CpG}+i3', # Gid + f'i4*{self.stride[0]}-{self.padding[0]}+i6*{self.dilation[0]}', # Hid+Khid + f'i5*{self.stride[1]}-{self.padding[1]}+i7*{self.dilation[1]}', # Wid+KWid + ]) + if self.bias is not None: + b = self.bias.broadcast(y.shape, [0,2,3]) + y = y + b + return y +ConvTranspose2d = ConvTranspose + +class ConvTranspose3d(Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, \ + padding=0, output_padding=0, groups=1, bias=True, dilation=1): + self.in_channels = in_channels + self.out_channels = out_channels + + # added + self.dilation = dilation + self.group = groups + assert groups==1, "Group conv not supported yet." + + self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size, kernel_size) + self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation) + # added + self.padding = padding if isinstance(padding, tuple) else (padding, padding, padding) + self.real_padding = ( + self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0], + self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1], + self.dilation[2] * (self.kernel_size[2] - 1) - self.padding[2]) + self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding, output_padding) + assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \ + self.output_padding[1] < max(self.stride[1], self.dilation[1]) and \ + self.output_padding[2] < max(self.stride[2], self.dilation[2]), \ + "output padding must be smaller than max(stride, dilation)" + + self.weight = init.invariant_uniform((in_channels, out_channels) + self.kernel_size, dtype="float") + if bias: + fan=1 + for i in self.weight.shape[1:]: + fan *= i + bound = 1 / math.sqrt(fan) + self.bias = init.uniform([out_channels], dtype="float", low=-bound, high=bound) + else: + self.bias = None + + def execute(self, x): + return conv_transpose3d(x, self.weight, self.bias, self.stride, self.padding, self.output_padding, self.group, self.dilation) + +def conv_transpose(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + if groups == 1: + x = input + if x.dim() != 4: + raise RuntimeError(f'Expected 4D input to conv_transpose, but got input of size: {x.shape}') + N,C,H,W = x.shape + i,o,h,w = weight.shape + assert C==i + stride = stride if isinstance(stride, tuple) else (stride, stride) + if stride[0] <= 0 or stride[1] <= 0: + raise RuntimeError("non-positive stride is not supported") + dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation) + # added + padding = padding if isinstance(padding, tuple) else (padding, padding) + output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding) + assert output_padding[0] < max(stride[0], dilation[0]) and \ + output_padding[1] < max(stride[1], dilation[1]), \ + "output padding must be smaller than max(stride, dilation)" + + stride_h, stride_w = stride + padding_h, padding_w = padding + dilation_h, dilation_w = dilation + + h_out = (H-1) * stride_h + output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h + w_out = (W-1) * stride_w + output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w + out_shape = (N, o, h_out, w_out) + shape = (N, i, o, H, W, h, w) + xx = x.broadcast(shape, (2, 5, 6)) # i,h,w + ww = weight.broadcast(shape, (0, 3, 4)) # N,H,W + y = (ww*xx).reindex_reduce("add", out_shape, [ + 'i0', # N + 'i2', # o + f'i3*{stride_h}-{padding_h}+i5*{dilation_h}', # Hid+Khid + f'i4*{stride_w}-{padding_w}+i6*{dilation_w}', # Wid+KWid + ]) + if isinstance(bias, jt.Var): + b = bias.broadcast(y.shape, [0,2,3]) + y = y + b + else: + assert not bias, "Bias should be none or jittor var" + return y + else: + if input.dim() != 4: + raise RuntimeError(f'Expected 4D input to conv_transpose, but got input of size: {input.shape}') + N,C,H,W = input.shape + i,o,h,w = weight.shape + G = groups + oc = o * G + CpG = C // G # channels per group + assert C % G == 0 + assert C==i, (C, i) + stride = stride if isinstance(stride, tuple) else (stride, stride) + if stride[0] <= 0 or stride[1] <= 0: + raise RuntimeError("non-positive stride is not supported") + dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation) + # added + padding = padding if isinstance(padding, tuple) else (padding, padding) + output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding) + assert output_padding[0] < max(stride[0], dilation[0]) and \ + output_padding[1] < max(stride[1], dilation[1]), \ + "output padding must be smaller than max(stride, dilation)" + + stride_h, stride_w = stride + padding_h, padding_w = padding + dilation_h, dilation_w = dilation + + oh = (H-1) * stride_h + output_padding[0] - 2*padding_h + 1 + (h-1)*dilation_h + ow = (W-1) * stride_w + output_padding[1] - 2*padding_w + 1 + (w-1)*dilation_w + out_shape = (N, oc, oh, ow) + shape = [N,G,oc//G,CpG,oh,ow,h,w] + xx = input.reindex(shape, [ + 'i0', + f'i1*{oc//G}+i2', + 'i4', + 'i5' + ]) + ww = weight.reindex(shape, [ + f'i1*{oc//G}+i2', + 'i3', + 'i6', + 'i7' + ]) + ww.compile_options = xx.compile_options = {"G":G,"C":C} + y = (ww*xx).reindex_reduce("add", out_shape, [ + 'i0', # Nid + f'i1*{CpG}+i3', # Gid + f'i4*{stride[0]}-{padding[0]}+i6*{dilation[0]}', # Hid+Khid + f'i5*{stride[1]}-{padding[1]}+i7*{dilation[1]}', # Wid+KWid + ]) + if bias is not None: + b = bias.broadcast(y.shape, [0,2,3]) + y = y + b + return y +conv_transpose2d = conv_transpose + +def conv_transpose3d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): + x = input + if x.dim() != 5: + raise RuntimeError(f'Expected 5D input to conv_transpose3d, but got input of size: {x.shape}') + N,C,D,H,W = x.shape + i,o,d,h,w = weight.shape + assert C==i + assert groups==1, "Group conv not supported yet." + stride = stride if isinstance(stride, tuple) else (stride, stride, stride) + if stride[0] <= 0 or stride[1] <= 0 or stride[2] <= 0: + raise RuntimeError("non-positive stride is not supported") + dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation, dilation) + # added + padding = padding if isinstance(padding, tuple) else (padding, padding, padding) + output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding, output_padding) + assert output_padding[0] < max(stride[0], dilation[0]) and \ + output_padding[1] < max(stride[1], dilation[1]) and \ + output_padding[2] < max(stride[2], dilation[2]), \ + "output padding must be smaller than max(stride, dilation)" + + stride_d, stride_h, stride_w = stride + padding_d, padding_h, padding_w = padding + dilation_d, dilation_h, dilation_w = dilation + + d_out = (D-1) * stride_d + output_padding[0] - 2*padding_d + 1 + (d-1)*dilation_d + h_out = (H-1) * stride_h + output_padding[1] - 2*padding_h + 1 + (h-1)*dilation_h + w_out = (W-1) * stride_w + output_padding[2] - 2*padding_w + 1 + (w-1)*dilation_w + out_shape = (N, o, d_out, h_out, w_out) + if jt.flags.use_cuda and jt.cudnn: + return jt.cudnn.ops.cudnn_conv3d_backward_x(weight, x, *out_shape[2:], *stride, *padding, *dilation, groups) + shape = (N, i, o, D, H, W, d, h, w) + xx = x.broadcast(shape, (2, 6, 7, 8)) # i,h,w + ww = weight.broadcast(shape, (0, 3, 4, 5)) # N,H,W + y = (ww*xx).reindex_reduce("add", out_shape, [ + 'i0', # N + 'i2', # o + f'i3*{stride_d}-{padding_d}+i6*{dilation_d}', # Did+Kdid + f'i4*{stride_h}-{padding_h}+i7*{dilation_h}', # Hid+Khid + f'i5*{stride_w}-{padding_w}+i8*{dilation_w}', # Wid+KWid + ]) + if isinstance(bias, jt.Var): + b = bias.broadcast(y.shape, [0,2,3,4]) + y = y + b + else: + assert not bias, "Bias should be none or jittor var" + return y + +conv_transpose2d = conv_transpose + +def pad(x,padding, mode='constant', value=0): + assert mode in ['constant','replicate','reflect','circular'],'only support constant,replicate,reflect,circular pad' + assert len(padding)%2==0 and len(padding)//2<=x.ndim + + padding = list(padding) + left = [0]*(x.ndim-len(padding)//2)+padding[::2][::-1] + right = [0]*(x.ndim-len(padding)//2)+padding[1::2][::-1] + + out_dims = [] + out_shape = [] + for i,n,l,r in zip(range(x.ndim),x.shape,left,right): + out_shape.append(n+l+r) + if mode == 'constant': + out_dims.append(f'i{i}-{l}') + elif mode == 'replicate': + out_dims.append(f"i{i}<{l} ? 0 : i{i} > {n+l-1} ? {n-1} : i{i}-{l}") + elif mode == 'reflect': + out_dims.append(f"i{i}<{l} ? {l}-i{i} : i{i} > {n+l-1} ? {2*(n-1)+l}-i{i} : i{i}-{l}") + elif mode == 'circular': + out_dims.append(f"i{i}<{l} ? {n-l}+i{i} : i{i} > {n+l-1} ? i{i}-{n+l} : i{i}-{l}") + + return x.reindex(out_shape,out_dims,overflow_value=value) + + +class ReflectionPad2d(Module): + def __init__(self, padding): + if padding < 0: + raise RuntimeError(f"padding must be > 0, but got {padding}") + self.padding = padding + if isinstance(self.padding, int): + self.pl = self.padding + self.pr = self.padding + self.pt = self.padding + self.pb = self.padding + elif isinstance(self.padding, tuple): + self.pl, self.pr, self.pt, self.pb = self.padding + else: + raise TypeError(f"ReflectionPad2d padding just support int or tuple, but found {type(padding)}") + if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0: + raise ValueError(f"padding must be non-negative") + + def execute(self, x): + n,c,h,w = x.shape + assert (self.pl < w and self.pr < w), f"padding_left and padding_right should be smaller than input width" + assert (self.pt < h and self.pb < h), f"padding_top and padding_bottom should be smaller than input height" + oh=h+self.pt+self.pb + ow=w+self.pl+self.pr + l = self.pl + r = self.pl + w - 1 + t = self.pt + b = self.pt + h - 1 + return x.reindex([n,c,oh,ow], ["i0","i1", + f"i2<{t} ? {t}-i2 : i2 > {b} ? {h-1+b}-i2 : i2-{t}", + f"i3<{l} ? {l}-i3 : i3 > {r} ? {w-1+r}-i3 : i3-{l}", + ]) + +class ZeroPad2d(Module): + def __init__(self, padding): + self.padding = padding + if isinstance(self.padding, int): + self.pl = self.padding + self.pr = self.padding + self.pt = self.padding + self.pb = self.padding + elif isinstance(self.padding, (tuple,list)): + self.pl, self.pr, self.pt, self.pb = self.padding + else: + raise TypeError(f"ZeroPad2d padding just support int or tuple, but found {type(padding)}") + if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0: + raise ValueError(f"padding must be non-negative") + + def execute(self, x): + if x.dim() != 4: + raise RuntimeError("Input shape must be `(N, C, H, W)`!") + n,c,h,w = x.shape + return x.reindex([n,c,h+self.pt+self.pb,w+self.pl+self.pr], ["i0","i1",f"i2-{self.pt}",f"i3-{self.pl}"]) + +class ConstantPad2d(Module): + def __init__(self, padding, value): + self.padding = padding + if isinstance(self.padding, int): + self.pl = self.padding + self.pr = self.padding + self.pt = self.padding + self.pb = self.padding + elif isinstance(self.padding, tuple): + self.pl, self.pr, self.pt, self.pb = self.padding + else: + raise TypeError(f"ConstantPad2d padding just support int or tuple, but found {type(padding)}") + self.value = value + if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0: + raise ValueError(f"padding must be non-negative") + + def execute(self, x): + assert len(x.shape) >= 2 + shape = x.shape + tar_shape = shape[0:-2] + [shape[-2]+self.pt+self.pb,shape[-1]+self.pl+self.pr] + tar_dims = [] + for i in range(len(shape)-2): + tar_dims.append(f"i{i}") + tar_dims.append(f"i{i+1}-{self.pt}") + tar_dims.append(f"i{i+2}-{self.pl}") + return x.reindex(tar_shape, tar_dims, overflow_value=self.value) + +class ReplicationPad2d(Module): + def __init__(self, padding): + if padding < 0: + raise RuntimeError(f"padding must be > 0, but got {padding}") + self.padding = padding + if isinstance(self.padding, int): + self.pl = self.padding + self.pr = self.padding + self.pt = self.padding + self.pb = self.padding + elif isinstance(self.padding, tuple): + self.pl, self.pr, self.pt, self.pb = self.padding + else: + raise TypeError(f"ReplicationPad2d padding just support int or tuple, but found {type(padding)}") + if self.pl < 0 or self.pr < 0 or self.pt < 0 or self.pb < 0: + raise ValueError(f"padding must be non-negative") + + def execute(self, x): + if x.dim() != 4: + raise RuntimeError("Input shape must be `(N, C, H, W)`!") + n,c,h,w = x.shape + oh=h+self.pt+self.pb + ow=w+self.pl+self.pr + l = self.pl + r = self.pl + w - 1 + t = self.pt + b = self.pt + h - 1 + return x.reindex([n,c,oh,ow], ["i0","i1", + f"i2<{t} ? 0 : i2 > {b} ? {h-1} : i2-{t}", + f"i3<{l} ? 0 : i3 > {r} ? {w-1} : i3-{l}" + ]) + +class Embedding(Module): + ''' A simple lookup table that stores embeddings of a fixed dictionary and size. + + :param num: size of the dictionary of embeddings + :type num: int + + :param dim: the size of each embedding vector + :type dim: int + + Example: + >>> embedding = nn.Embedding(10, 3) + >>> x = jt.int32([1, 2, 3, 3]) + >>> embedding(x) + jt.Var([[ 1.1128596 0.19169547 0.706642] + [ 1.2047412 1.9668795 0.9932192] + [ 0.14941819 0.57047683 -1.3217674] + [ 0.14941819 0.57047683 -1.3217674]], dtype=float32) + ''' + def __init__(self, num_embeddings, embedding_dim, padding_idx=None, dtype="float32"): + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.weight = jt.init.gauss([self.num_embeddings, self.embedding_dim], dtype) + if padding_idx is not None: + self.weight[padding_idx] = 0 + + def execute(self, x): + res = self.weight[x] + return res + +def embedding(input, weight): + return weight[input] + +class PixelShuffle(Module): + def __init__(self, upscale_factor): + assert upscale_factor > 0,f"upscale_factor must be greater than zero,got {upscale_factor}" + self.upscale_factor = upscale_factor + + def execute(self, x): + n,c,h,w = x.shape + r = self.upscale_factor + assert c%(r*r)==0, f"input channel needs to be divided by upscale_factor's square in PixelShuffle" + if r<=0: + raise RuntimeError(f"pixel_shuffle expects a positive upscale_factor, but got {r}") + return x.reindex([n,int(c/r**2),h*r,w*r], [ + "i0", + f"i1*{r*r}+i2%{r}*{r}+i3%{r}", + f"i2/{r}", + f"i3/{r}" + ]) + +class Tanh(Module): + def __init__(self): + super().__init__() + def execute(self, x) : + return x.tanh() + +class Sigmoid(Module): + def __init__(self): + super().__init__() + def execute(self, x) : + return x.sigmoid() + +def softplus(x,beta=1.0,threshold=20.0): + return 1 / beta * jt.log(1 + (beta * x).minimum(threshold).exp()) + \ + (x - threshold/beta).maximum(0.0) + +def hardtanh(x,min_val=-1,max_val=1): + return jt.clamp(x,min_v=min_val,max_v=max_val) + + +class Softplus(Module): + r''' + SoftPlus is a smooth approximation to the ReLU function and can be used to constrain the output of a machine to always be positive. + + Args: + + [in] beta (float): the beta value for the Softplus formulation. Default: 1. + + [in] threshold (float): values above this revert to a linear function. Default: 20. + ''' + def __init__(self, beta=1, threshold=20): + self.beta = beta + self.threshold = threshold + + def execute(self, x): + return softplus(x, self.beta, self.threshold) + +class Resize(Module): + def __init__(self, size, mode="nearest", align_corners=False): + super().__init__() + if isinstance(size,int): + if size <= 0: + raise ValueError(f"sizes must be positive, got {size}") + elif isinstance(size,tuple) or isinstance(size,list): + for item in size: + if item <= 0: + raise ValueError(f"sizes must be positive, got {item}") + else: + raise ValueError(f"size must be int or tuple") + self.size = size + self.mode = mode + self.align_corners = align_corners + def execute(self, x): + return resize(x, self.size, self.mode, self.align_corners) + + +def _bicubic(x, a, func): + # normal ver + if func == 1: + return (a+2)*(jt.abs(x)**3)-(a+3)*(x**2)+1 + if func == 2: + return a*(jt.abs(x)**3)-5*a*(x**2)+8*a*(jt.abs(x))-4*a + return 0 + + +def _interpolate(img, x, y, ids, mode): + if mode == "nearest": + return img.reindex([*ids, x.floor_int(), y.floor_int()]) + if mode == "bilinear": + fx, fy = x.floor_int(), y.floor_int() + cx, cy = fx + 1, fy + 1 + dx, dy = x - fx, y - fy + a = img.reindex_var([*ids, fx, fy]) + b = img.reindex_var([*ids, cx, fy]) + c = img.reindex_var([*ids, fx, cy]) + d = img.reindex_var([*ids, cx, cy]) + dnx, dny = 1 - dx, 1 - dy + ab = dx * b + dnx * a + cd = dx * d + dnx * c + o = ab * dny + cd * dy + return o + if mode=="bicubic": # ugly ver. + n,c,h,w = img.shape + fx, fy = x.floor_int(), y.floor_int() + dix, diy = x - fx, y - fy + ax, ay = _bicubic(dix+1,-0.75,2), _bicubic(diy+1,-0.75,2) + bx, by = _bicubic(dix,-0.75,1), _bicubic(diy,-0.75,1) + cx, cy = _bicubic(1-dix,-0.75,1), _bicubic(1-diy,-0.75,1) + dx, dy = _bicubic(2-dix,-0.75,2), _bicubic(2-diy,-0.75,2) + afx, afy = jt.maximum(jt.minimum(fx-1,h-1),0), jt.maximum(jt.minimum(fy-1,w-1),0) + bfx, bfy = jt.maximum(jt.minimum(fx,h-1),0), jt.maximum(jt.minimum(fy,w-1),0) + cfx, cfy = jt.maximum(jt.minimum(fx+1,h-1),0), jt.maximum(jt.minimum(fy+1,w-1),0) + dfx, dfy = jt.maximum(jt.minimum(fx+2,h-1),0), jt.maximum(jt.minimum(fy+2,w-1),0) + a = ax*(img.reindex_var([*ids,afx,afy])*ay+img.reindex_var([*ids,afx,bfy])*by+img.reindex_var([*ids,afx,cfy])*cy+img.reindex_var([*ids,afx,dfy])*dy) + b = bx*(img.reindex_var([*ids,bfx,afy])*ay+img.reindex_var([*ids,bfx,bfy])*by+img.reindex_var([*ids,bfx,cfy])*cy+img.reindex_var([*ids,bfx,dfy])*dy) + c = cx*(img.reindex_var([*ids,cfx,afy])*ay+img.reindex_var([*ids,cfx,bfy])*by+img.reindex_var([*ids,cfx,cfy])*cy+img.reindex_var([*ids,cfx,dfy])*dy) + d = dx*(img.reindex_var([*ids,dfx,afy])*ay+img.reindex_var([*ids,dfx,bfy])*by+img.reindex_var([*ids,dfx,cfy])*cy+img.reindex_var([*ids,dfx,dfy])*dy) + o = a + b + c + d + return o + raise (f"Not support interpolation mode: {mode}") + +# TODO: tf_mode to another function +def resize(img, size, mode="nearest", align_corners=False, tf_mode=False): + n, c, h, w = img.shape + H, W = size + nid, cid, hid, wid = jt.index((n, c, H, W)) + if align_corners: + x = hid * ((h - 1) / max(1, H - 1)) + y = wid * ((w - 1) / max(1, W - 1)) + elif mode == "bicubic": + x = (hid + 0.5) * (h / H) - 0.5 + y = (wid + 0.5) * (w / W) - 0.5 + elif mode == 'nearest': + x = hid * (h / H) + y = wid * (w / W) + elif mode == "area": + ''' + Area interpolation uses AdaptivePool2D to resize origin images. + ''' + stride = (h // H, w // W) + assert stride[0] > 0 and stride[1] > 0 + x, y = jt.meshgrid(jt.arange(0, H, 1), jt.arange(0, W, 1)) + startH = jt.floor(x*h/H).int32() + endH = jt.ceil((x+1)*h/H).int32() + maxH = int(jt.max(endH - startH).data) + startW = jt.floor(y*w/W).int32() + endW = jt.ceil((y+1)*w/W).int32() + maxW = int(jt.max(endW - startW).data) + pixel_count = (endH - startH) * (endW - startW) + adaptive_output = img.reindex([img.shape[0], img.shape[1], H, W, maxH, maxW], ["i0", "i1", "@e0(i2, i3) + i4", "@e2(i2, i3) + i5"], extras=[startH, endH, startW, endW], overflow_conditions=["i4 >= @e1(i2, i3) - @e0(i2, i3)", "i5 >= @e3(i2, i3) - @e2(i2, i3)"], overflow_value=0) + adaptive_output = adaptive_output.reduce("sum", [4,5]) / pixel_count[None, None, ...] + return adaptive_output + else: + if (tf_mode): + x = hid * (h / H) + if H > h: x = x.clamp(0, h - 1) + y = wid * (w / W) + if W > w: y = y.clamp(0, w - 1) + else: + x = hid * (h / H) + (h / H * 0.5 - 0.5) + if H > h: x = x.clamp(0, h - 1) + y = wid * (w / W) + (w / W * 0.5 - 0.5) + if W > w: y = y.clamp(0, w - 1) + return _interpolate(img, x, y, (nid, cid), mode) + +upsample = resize + + +def interpolate(X, size=None, scale_factor=None, mode='bilinear', align_corners=False, tf_mode=False): + if scale_factor is not None: + size = [int(X.shape[-2] * scale_factor), int(X.shape[-1] * scale_factor)] + if isinstance(size, int): + size = (size, size) + if scale_factor is not None and scale_factor > 1: + return upsample(X, size, mode, align_corners, tf_mode) + else: + return resize(X, size, mode, align_corners, tf_mode) + + +def grid_sample_v0(input, grid, mode='bilinear', padding_mode='zeros'): + r''' + Given an input and a flow-field grid, computes the output using input values and pixel locations from grid. + + grid specifies the sampling pixel locations normalized by the input spatial dimensions. Therefore, it should have most values in the range of [-1, 1]. For example, values x = -1, y = -1 is the left-top pixel of input, and values x = 1, y = 1 is the right-bottom pixel of input. + + Args: + + [in] input (var): the source input var, whose shape is (N, C, Hi, Wi) + + [in] grid (var): the pixel locations, whose shape is (N, Ho, Wo, 2) + + [in] mode (string): the interpolate way, default: bilinear. + + [in] padding_mode (string): the padding way, default: zeros. + + [out] output (var): the output var, whose shape is (N, C, Ho, Wo) + + Example: + + >>> x = jt.array([[[[1,2],[3,4]]]]) + >>> print(x) + [[[[1 2] + [3 4]]]] + + >>> grid = jt.array([[[[0.5, 0.5]]]]) + >>> print(x.shape, grid.shape) + [1,1,2,2,], [1,1,2,2,] + + >>> nn.grid_sample(x, grid) + [[[[3.25]]]] + ''' + assert padding_mode == 'zeros' + Ni, Ci, Hi, Wi = input.shape + No, Ho, Wo, D = grid.shape + assert D == 2 + assert Ni == No + assert len(input.shape) == 4 and len(grid.shape) + + nid, cid, hid, wid = jt.index((Ni, Ci, Ho, Wo)) + x = ((grid[:, :, :, 1].unsqueeze(1).repeat([1, Ci, 1, 1]) + 1) / 2) * (Hi - 1) + y = ((grid[:, :, :, 0].unsqueeze(1).repeat([1, Ci, 1, 1]) + 1) / 2) * (Wi - 1) + return _interpolate(input, x, y, (nid, cid), mode) + + +def linspace_from_neg_one(grid,num_steps,align_corners): + if num_steps <= 1: + return jt.array([],dtype=grid.dtype) + # TODO: use jt.index + ra = np.linspace(-1,1,num_steps) + if not align_corners: + ra = ra*(num_steps-1)/num_steps + return jt.array(ra,dtype=grid.dtype) + +def make_base_grid_4D(theta,N,C,H,W,align_corners): + base_grid = jt.zeros((N, H, W, 3), dtype=theta.dtype) + base_grid[...,0] = linspace_from_neg_one(theta, W, align_corners) + base_grid[...,1] = jt.unsqueeze(linspace_from_neg_one(theta, H, align_corners),-1) + base_grid[...,-1] = 1 + return base_grid + +def make_base_grid_5D(theta,N,C,D,H,W,align_corners): + base_grid = jt.zeros((N, D, H, W, 4), dtype=theta.dtype) + base_grid[...,0] = linspace_from_neg_one(theta, W, align_corners) + base_grid[...,1] = jt.unsqueeze(linspace_from_neg_one(theta, H, align_corners),-1) + base_grid[...,2] = jt.unsqueeze(jt.unsqueeze(linspace_from_neg_one(theta, D, align_corners),-1),-1) + base_grid[...,-1] = 1 + return base_grid + +def affine_grid_generator_4D(theta,N,C,H,W,align_corners): + base_grid = make_base_grid_4D(theta, N, C, H, W, align_corners) + grid = jt.nn.bmm(base_grid.reshape(N, H * W, 3),theta.transpose(0,2,1)) + return grid.reshape(N, H, W, 2) + +def affine_grid_generator_5D(theta,N,C,D,H,W,align_corners): + base_grid = make_base_grid_5D(theta, N, C, D, H, W, align_corners) + grid = jt.nn.bmm(base_grid.reshape(N, D * H * W, 4),theta.transpose(0,2,1)) + return grid.reshape(N, D, H, W, 3) + +def affine_grid(theta, size, align_corners=False): + assert str(theta.dtype) in ['float','float32','float64'] + assert min(size)>0 + assert len(size) in [4,5] + if len(size)== 4: + assert theta.ndim == 3 and theta.shape[-2] == 2 and theta.shape[-1] == 3 + return affine_grid_generator_4D(theta, size[0], size[1], size[2], size[3], align_corners) + elif len(size)==5: + assert theta.ndim == 3 and theta.shape[-2] == 3 and theta.shape[-1] == 4 + return affine_grid_generator_5D(theta, size[0], size[1], size[2], size[3], size[4], align_corners) + + +def grid_sampler_unnormalize(coord,size,align_corners): + if align_corners: + #unnormalize coord from [-1, 1] to [0, size - 1] + return ((coord + 1) / 2) * (size - 1) + else: + #unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + return ((coord + 1) * size - 1) / 2 + + +def clip_coordinates(x,clip_limit): + return jt.clamp(x,min_v=0,max_v=clip_limit-1) + +def reflect_coordinates(x,twice_low,twice_high): + if twice_low == twice_high: + return jt.zeros_like(x) + m = twice_low / 2 + span = (twice_high - twice_low) / 2 + x = (x - m).abs() + #`fmod` returns same sign as `in`, which is positive after the `fabs` above. + extra = x.mod(span) + flips = (x / span).floor_int() + result1 = extra+m + result2 = span-extra+m + con = flips%2==0 + not_con = flips%2!=0 + result1[not_con]=0.0 + result2[con]=0.0 + return result1+result2 + + +def grid_sampler_compute_source_index(coord,size,padding_mode,align_corners): + coord = grid_sampler_unnormalize(coord, size, align_corners) + if padding_mode == 'border': + #clip coordinates to image borders + coord = clip_coordinates(coord, size) + elif padding_mode == 'reflection': + #reflect coordinates by image borders + if align_corners: + coord = reflect_coordinates(coord, 0, 2*(size - 1)) + else: + coord = reflect_coordinates(coord, -1, 2*size - 1) + #clip coordinates to image borders + coord = clip_coordinates(coord, size) + return coord + + + +def grid_sampler_3d(X,grid,mode,padding_mode,align_corners): + N = X.shape[0] + C = X.shape[1] + inp_D = X.shape[2] + inp_H = X.shape[3] + inp_W = X.shape[4] + + D = grid.shape[1] + H = grid.shape[2] + W = grid.shape[3] + x = grid[:,:,:,:,0] + y = grid[:,:,:,:,1] + z = grid[:,:,:,:,2] + shape = [N,C,D,H,W] + cid = jt.index(shape, dim=1) + nid = jt.index(shape, dim=0) + + x = grid_sampler_compute_source_index(x,inp_W,padding_mode,align_corners) + y = grid_sampler_compute_source_index(y,inp_H,padding_mode,align_corners) + z = grid_sampler_compute_source_index(z,inp_D,padding_mode,align_corners) + xid = x.reindex(shape,['i0','i2','i3','i4']) + yid = y.reindex(shape,['i0','i2','i3','i4']) + zid = z.reindex(shape,['i0','i2','i3','i4']) + + if mode=='nearest': + return X.reindex([nid,cid,zid.round_int(),yid.round_int(),xid.round_int()]) + elif mode=='bilinear': + fx,fy,fz = xid.floor_int(),yid.floor_int(),zid.floor_int() + cx,cy,cz = fx+1,fy+1,fz+1 + dx,dy,dz = xid-fx,yid-fy,zid-fz + dnx,dny,dnz = cx-xid,cy-yid,cz-zid + a = X.reindex([nid,cid,fz,fy,fx]) + b = X.reindex([nid,cid,cz,fy,fx]) + c = X.reindex([nid,cid,fz,cy,fx]) + d = X.reindex([nid,cid,fz,fy,cx]) + e = X.reindex([nid,cid,fz,cy,cx]) + f = X.reindex([nid,cid,cz,fy,cx]) + g = X.reindex([nid,cid,cz,cy,fx]) + h = X.reindex([nid,cid,cz,cy,cx]) + o = a*dnx*dny*dnz+b*dnx*dny*dz+c*dnx*dy*dnz+d*dx*dny*dnz+e*dx*dy*dnz+f*dx*dny*dz+g*dnx*dy*dz+h*dx*dy*dz + return o + +def grid_sampler_2d(X,grid,mode,padding_mode,align_corners): + N = X.shape[0] + C = X.shape[1] + inp_H = X.shape[2] + inp_W = X.shape[3] + + H = grid.shape[1] + W = grid.shape[2] + x = grid[:,:,:,0] + y = grid[:,:,:,1] + shape = [N,C,H,W] + cid = jt.index(shape, dim=1) + nid = jt.index(shape, dim=0) + + x = grid_sampler_compute_source_index(x,inp_W,padding_mode,align_corners) + y = grid_sampler_compute_source_index(y,inp_H,padding_mode,align_corners) + xid = x.reindex(shape,['i0','i2','i3']) + yid = y.reindex(shape,['i0','i2','i3']) + + if mode=='nearest': + return X.reindex([nid,cid,yid.round_int(),xid.round_int()]) + elif mode=='bilinear': + #xid,yid = (xid+0.00001),(yid+0.00001) + fx,fy = (xid).floor_int(),(yid).floor_int() + cx,cy = fx+1,fy+1 + dx,dy = xid-fx,yid-fy + dnx,dny = cx-xid,cy-yid + + a = X.reindex([nid,cid,fy,fx],overflow_value=0.0) + b = X.reindex([nid,cid,cy,fx],overflow_value=0.0) + c = X.reindex([nid,cid,fy,cx],overflow_value=0.0) + d = X.reindex([nid,cid,cy,cx],overflow_value=0.0) + o = a*dnx*dny+b*dnx*dy+c*dx*dny+d*dx*dy + return o + + +def grid_sampler(X, grid, mode, padding_mode, align_corners): + assert X.dtype==grid.dtype + assert ((X.ndim==4 or X.ndim==5) and X.ndim==grid.ndim) + assert X.shape[0]==grid.shape[0] and grid.shape[-1]==X.ndim-2 + assert X.numel()>0 + if X.ndim == 4: + return grid_sampler_2d(X, grid, mode, padding_mode, align_corners) + else: + return grid_sampler_3d(X, grid, mode, padding_mode, align_corners) + + +def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=False): + assert mode in ['bilinear','nearest'] + assert padding_mode in ['zeros','border','reflection'] + return grid_sampler(input, grid, mode, padding_mode, align_corners) + + +class Upsample(Module): + def __init__(self, scale_factor=None, mode='nearest'): + if isinstance(scale_factor, tuple): + self.scale_factor = tuple(float(factor) for factor in scale_factor) + else: + self.scale_factor = float(scale_factor) if scale_factor else None + self.mode = mode + + def execute(self, x): + if self.scale_factor is None: + raise ValueError("scale_factor should be defined") + else: + return upsample(x, + size=( + int(x.shape[2]*self.scale_factor[0]), + int(x.shape[3]*self.scale_factor[1])), + mode=self.mode) + +class UpsamplingBilinear2d(Upsample): + def __init__(self, scale_factor=None): + Upsample.__init__(self, scale_factor, 'bilinear') + +class UpsamplingNearest2d(Upsample): + def __init__(self, scale_factor=None): + Upsample.__init__(self, scale_factor, 'nearest') + +class Sequential(Module): + def __init__(self, *args): + self.layers = collections.OrderedDict() + for mod in args: + if isinstance(mod, collections.OrderedDict): + for k, m in mod.items(): + self.add_module(k, m) + elif isinstance(mod,list): + for m in mod: + self.append(m) + else: + self.append(mod) + def __getitem__(self, idx): + if isinstance(idx, slice) or idx not in self.layers: + return list(self.layers.values())[idx] + + return self.layers[idx] + def __iter__(self): + return self.layers.values().__iter__() + def keys(self): + return self.layers.keys() + def values(self): + return self.layers.values() + def items(self): + return self.layers.items() + def execute(self, x): + for k, layer in self.layers.items(): + x = layer(x) + return x + def dfs(self, parents, k, callback, callback_leave, recurse=True): + n_children = len(self.layers) + ret = callback(parents, k, self, n_children) + if ret == False: + return + parents.append(self) + if recurse: + for k,v in self.layers.items(): + if isinstance(v, Module): + v.dfs(parents, k, callback, callback_leave) + parents.pop() + if callback_leave: + callback_leave(parents, k, self, n_children) + def append(self, mod): + assert callable(mod), f"Module <{type(mod)}> is not callable" + assert not isinstance(mod, type), f"Module is not a type" + self.layers[str(len(self.layers))]=mod + def add_module(self, name, mod): + assert callable(mod), f"Module <{type(mod)}> is not callable" + assert not isinstance(mod, type), f"Module is not a type" + self.layers[str(name)]=mod + + def __len__(self): + return len(self.layers) + + def named_children(self,): + return list(self.layers.items()) + + def __setattr__(self, key, value) -> None: + if isinstance(key, str) and key.isdigit(): + if int(key) is not jittor var" + self.params[len(self.params)] = var + def add_param(self, name, var): + assert isinstance(var, jt.Var), f"argument <{type(var)}> is not jittor var" + self.params[name]=var + def __setitem__(self, name, var): + self.add_param(name, var) + + def __len__(self): + return len(self.params) + +ParameterDict = ParameterList + +def Parameter(data, requires_grad=True): + ''' The `Parameter` interface isn't needed in Jittor, this interface +does nothings and it is just used for compatible. + +A Jittor Var is a Parameter +when it is a member of Module, if you don't want a Jittor +Var menber is treated as a Parameter, just name it startswith +underscore `_`. + ''' + LOG.w(Parameter.__doc__) + data = data.clone() + data.requires_grad = requires_grad + return data + +def backward(v, *args, **kw): + ''' The `backward` variable interface doesn't exist in Jittor. +please use `optimizer.backward(loss)` or +`optimizer.step(loss)` instead. +For example, if your code looks like this:: + + optimizer.zero_grad() + loss.backward() + optimizer.step() + +It can be changed to this:: + + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + +Or more concise:: + + optimizer.step(loss) + +The step function will automatically zero grad and backward. + ''' + LOG.f(backward.__doc__) + +jt.Var.backward = backward + +def unfold(X, kernel_size, dilation=1, padding=0, stride=1): + assert X.ndim == 4 + if not isinstance(kernel_size, tuple): + kernel_size = (kernel_size, kernel_size) + assert kernel_size[0] > 0 and kernel_size[1] > 0, "kernel size must be positive" + if not isinstance(dilation, tuple): + dilation = (dilation, dilation) + assert dilation[0] > 0 and dilation[1] > 0, "dilation must be positive" + if not isinstance(padding, tuple): + padding = (padding, padding) + assert padding[0] >= 0 and padding[1] >= 0, "padding must be non-negative" + if not isinstance(stride, tuple): + stride = (stride, stride) + assert stride[0] > 0 and stride[1] > 0, "stride must be positive" + n, c, h, w = X.shape + shape = X.shape + area = kernel_size[0] * kernel_size[1] + block_nums = [] + for i in range(2, 4): + block_nums.append( + (shape[i] + 2 * padding[i - 2] - dilation[i - 2] * (kernel_size[i - 2] - 1) - 1) // stride[i - 2] + 1) + if padding[0] != 0 or padding[1] != 0: + X = X.reindex([n, c, h + padding[0] * 2, w + padding[1] * 2], + ["i0", "i1", f"i2-{padding[0]}", f"i3-{padding[1]}"]) + output = X.reindex([n, c * area, block_nums[0] * block_nums[1]], ["i0", f"i1/{area}", + f"i2/{block_nums[1]}*{stride[0]}+(i1%{area})/{kernel_size[1]}*{dilation[0]}", + f"i2%{block_nums[1]}*{stride[1]}+(i1%{area})%{kernel_size[1]}*{dilation[1]}"]) + return output + + +def fold(X,output_size,kernel_size,dilation=1,padding=0,stride=1): + assert X.ndim==3 + assert output_size[0] > 0 and output_size[1] > 0, "output size must be positive." + if not isinstance(kernel_size,tuple): + kernel_size = (kernel_size,kernel_size) + assert kernel_size[0] > 0 and kernel_size[1] > 0, "kernel size must be positive" + if not isinstance(dilation,tuple): + dilation = (dilation,dilation) + assert dilation[0] > 0 and dilation[1] > 0, "dilation must be positive" + if not isinstance(padding,tuple): + padding = (padding,padding) + assert padding[0] >= 0 and padding[1] >= 0, "padding must be non-negative" + if not isinstance(stride,tuple): + stride = (stride,stride) + assert stride[0] > 0 and stride[1] > 0, "stride must be positive" + n,cl,num = X.shape + area = kernel_size[0] * kernel_size[1] + block_nums = [] + for i in range(2,4): + block_nums.append((output_size[i-2]+2*padding[i-2]-dilation[i-2]*(kernel_size[i-2]-1)-1) // stride[i-2]+1) + output = X.reindex_reduce("add",[n,cl // area,output_size[0]+2*padding[0],output_size[1]+2*padding[1]],["i0",f"i1/{area}",f"i2/{block_nums[1]}*{stride[0]}+(i1%{area})/{kernel_size[1]}*{dilation[0]}",f"i2%{block_nums[1]}*{stride[1]}+(i1%{area})%{kernel_size[1]}*{dilation[1]}"]) + return output[:,:,padding[0]:padding[0]+output_size[0],padding[1]:padding[1]+output_size[1]] + +ModuleList = Sequential + + +class LSTMCell(jt.Module): + ''' A long short-term memory (LSTM) cell. + + :param input_size: The number of expected features in the input + :type input_size: int + + :param hidden_size: The number of features in the hidden state + :type hidden_size: int + + :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True. + :type bias: bool, optional + + Example: + + >>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size) + >>> input = jt.randn(2, 3, 10) # (time_steps, batch, input_size) + >>> hx = jt.randn(3, 20) # (batch, hidden_size) + >>> cx = jt.randn(3, 20) + >>> output = [] + >>> for i in range(input.shape[0]): + hx, cx = rnn(input[i], (hx, cx)) + output.append(hx) + >>> output = jt.stack(output, dim=0) + ''' + def __init__(self, input_size, hidden_size, bias=True): + super().__init__() + + self.hidden_size = hidden_size + self.bias = bias + + k = math.sqrt(1 / hidden_size) + self.weight_ih = init.uniform((4 * hidden_size, input_size), 'float32', -k, k) + self.weight_hh = init.uniform((4 * hidden_size, hidden_size), 'float32', -k, k) + + if bias: + self.bias_ih = init.uniform((4 * hidden_size,), 'float32', -k, k) + self.bias_hh = init.uniform((4 * hidden_size,), 'float32', -k, k) + + def execute(self, input, hx = None): + if hx is None: + zeros = jt.zeros((input.shape[0], self.hidden_size), dtype=input.dtype) + h, c = zeros, zeros + else: + h, c = hx + + y = matmul_transpose(input, self.weight_ih) + matmul_transpose(h, self.weight_hh) + + if self.bias: + y = y + self.bias_ih + self.bias_hh + + i = y[:, :self.hidden_size].sigmoid() + f = y[:, self.hidden_size : 2 * self.hidden_size].sigmoid() + g = y[:, 2 * self.hidden_size : 3 * self.hidden_size].tanh() + o = y[:, 3 * self.hidden_size:].sigmoid() + + c = f * c + i * g + h = o * c.tanh() + + return h, c + + +class RNNCell(jt.Module): + ''' An Elman RNN cell with tanh or ReLU non-linearity. + + :param input_size: The number of expected features in the input + :type input_size: int + + :param hidden_size: The number of features in the hidden state + :type hidden_size: int + + :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True. + :type bias: bool, optional + + :param nonlinearity: The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh'. + :type nonlinearity: str, optional + + Example: + + >>> rnn = nn.RNNCell(10, 20) + >>> input = jt.randn((6, 3, 10)) + >>> hx = jt.randn((3, 20)) + >>> output = [] + >>> for i in range(6): + hx = rnn(input[i], hx) + output.append(hx) + ''' + def __init__(self, input_size, hidden_size, bias=True, nonlinearity = "tanh"): + super().__init__() + + self.hidden_size = hidden_size + self.bias = bias + self.nonlinearity = nonlinearity + + k = math.sqrt(1 / hidden_size) + self.weight_ih = init.uniform((hidden_size, input_size), 'float32', -k, k) + self.weight_hh = init.uniform((hidden_size, hidden_size), 'float32', -k, k) + + if bias: + self.bias_ih = init.uniform((hidden_size,), 'float32', -k, k) + self.bias_hh = init.uniform((hidden_size,), 'float32', -k, k) + + def execute(self, input, hx = None): + if hx is None: + hx = jt.zeros((input.shape[0], self.hidden_size), dtype=input.dtype) + + y = matmul_transpose(input, self.weight_ih)+matmul_transpose(hx, self.weight_hh) + + if self.bias: + y= y + self.bias_ih + self.bias_hh + + if self.nonlinearity == 'tanh': + y = y.tanh() + elif self.nonlinearity == 'relu': + y = relu(y) + else: + raise RuntimeError("Unknown nonlinearity: {}".format(self.nonlinearity)) + + return y + + +class GRUCell(jt.Module): + ''' A gated recurrent unit (GRU) cell. + + :param input_size: The number of expected features in the input + :type input_size: int + + :param hidden_size: The number of features in the hidden state + :type hidden_size: int + + :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True. + :type bias: bool, optional + + Example: + + >>> rnn = nn.GRUCell(10, 20) + >>> input = jt.randn((6, 3, 10)) + >>> hx = jt.randn((3, 20)) + >>> output = [] + >>> for i in range(6): + hx = rnn(input[i], hx) + output.append(hx) + ''' + def __init__(self, input_size, hidden_size, bias=True): + super().__init__() + + self.hidden_size = hidden_size + self.bias = bias + + k = math.sqrt(1 / hidden_size) + self.weight_ih = init.uniform((3*hidden_size, input_size), 'float32', -k, k) + self.weight_hh = init.uniform((3*hidden_size, hidden_size), 'float32', -k, k) + + if bias: + self.bias_ih = init.uniform((3*hidden_size,), 'float32', -k, k) + self.bias_hh = init.uniform((3*hidden_size,), 'float32', -k, k) + + def execute(self, input, hx = None): + if hx is None: + hx = jt.zeros((input.shape[0], self.hidden_size), dtype=input.dtype) + + gi = matmul_transpose(input, self.weight_ih) + gh = matmul_transpose(hx, self.weight_hh) + + if self.bias: + gi += self.bias_ih + gh += self.bias_hh + + i_r, i_i, i_n = gi.chunk(3, 1) + h_r, h_i, h_n = gh.chunk(3, 1) + + resetgate = jt.sigmoid(i_r + h_r) + inputgate = jt.sigmoid(i_i + h_i) + newgate = jt.tanh(i_n + resetgate * h_n) + hy = newgate + inputgate * (hx - newgate) + return hy + +class RNNBase(Module): + def __init__(self, mode: str, input_size: int, hidden_size: int, + num_layers: int = 1, bias: bool = True, batch_first: bool = False, + dropout: float = 0, bidirectional: bool = False, + proj_size: int = 0, nonlinearity: str = None) -> None: + super().__init__() + + self.mode = mode + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.dropout = dropout + self.bidirectional = bidirectional + self.proj_size = proj_size + self.nonlinearity = nonlinearity + + if mode == 'LSTM': + gate_size = 4 * hidden_size + elif mode == 'GRU': + gate_size = 3 * hidden_size + elif mode == 'RNN': + gate_size = hidden_size + else: + raise ValueError("Unrecognized RNN mode: " + mode) + + num_directions = 1 + bidirectional + k = math.sqrt(1 / hidden_size) + + def build_unit(name, in_channels, out_channels=None): + if out_channels is not None: + shape = (in_channels, out_channels) + else: + shape = (in_channels,) + setattr(self, name, init.uniform(shape, 'float32', -k, k)) + if self.bidirectional: + setattr(self, name + '_reverse', init.uniform(shape, 'float32', -k, k)) + + for layer in range(num_layers): + if layer == 0: + build_unit(f'weight_ih_l{layer}', gate_size, input_size) + else: + if proj_size > 0: + build_unit(f'weight_ih_l{layer}', gate_size, num_directions * proj_size) + else: + build_unit(f'weight_ih_l{layer}', gate_size, num_directions * hidden_size) + + if proj_size > 0: + build_unit(f'weight_hh_l{layer}', gate_size, proj_size) + build_unit(f'weight_hr_l{layer}', proj_size, hidden_size) + else: + build_unit(f'weight_hh_l{layer}', gate_size, hidden_size) + + if bias: + build_unit(f'bias_ih_l{layer}', gate_size) + build_unit(f'bias_hh_l{layer}', gate_size) + + def _cudnn_flatten_weights(self, cudnn_mode): + def copy_to_flatten_weight(param_name, offset_idx, num_gates): + def copy_to(param_name, offset_idx, idx): + cur_offset = self._cudnn_weight_offset[offset_idx] + param = getattr(self, param_name) + param = param[self.hidden_size * idx: self.hidden_size * (idx + 1)] + ft_weight[cur_offset:cur_offset + param.numel()] = param.flatten() + + if self.bias: + for idx in range(num_gates): + copy_to('weight' + param_name, offset_idx + idx * 2, idx) + copy_to('bias' + param_name, offset_idx + idx * 2 + 1, idx) + return num_gates * 2 + else: + for idx in range(num_gates): + copy_to('weight' + param_name, offset_idx + idx, idx) + return num_gates + + if jt.flags.use_cuda and jt.cudnn and jt.compiler.is_cuda: + if getattr(self, '_cudnn_weight_size', None) is None: + offset_array = jt.cudnn.cudnn_rnn_weight_offset( + cudnn_mode, + self.input_size, + self.hidden_size, + self.num_layers, + self.proj_size, + self.bias, + self.bidirectional + ) + self._cudnn_weight_size = offset_array[0] + self._cudnn_weight_offset = offset_array[1:] + + num_gates = { + "RNN": 1, "LSTM": 4, "GRU": 3 + }[self.mode] + ft_weight = jt.zeros(self._cudnn_weight_size, dtype=jt.float32) + + cnt = 0 + for layer in range(self.num_layers): + suffix = '' + cnt += copy_to_flatten_weight(f'_ih_l{layer}' + suffix, cnt, num_gates) + cnt += copy_to_flatten_weight(f'_hh_l{layer}' + suffix, cnt, num_gates) + if self.bidirectional: + suffix = '_reverse' + cnt += copy_to_flatten_weight(f'_ih_l{layer}' + suffix, cnt, num_gates) + cnt += copy_to_flatten_weight(f'_hh_l{layer}' + suffix, cnt, num_gates) + return ft_weight + else: + raise RuntimeError("Not Cudnn found") + + @abstractmethod + def call_rnn_cell(self, input, hidden, suffix): + pass + + def call_rnn_sequence(self, input, hidden, suffix): + if 'reverse' in suffix: + input = input[::-1] + + output = [] + for s in range(input.shape[0]): + out, hidden = self.call_rnn_cell(input[s], hidden, suffix) + output.append(out) + + if 'reverse' in suffix: + output = output[::-1] + output = jt.stack(output, dim=0) + + return output, hidden + + def _execute_cudnn_rnn(self, input, hx): + cudnn_mode = { + ('RNN', 'tanh'): 'tanh', + ('RNN', 'relu'): 'relu', + ('LSTM', None): 'lstm', + ('GRU', None): 'gru' + }[(self.mode, self.nonlinearity)] + ft_weight = self._cudnn_flatten_weights(cudnn_mode) + + if self.mode == 'LSTM': + ret = jt.cudnn.ops.cudnn_rnn(input, hx[0], hx[1], ft_weight, + cudnn_mode, self.input_size, self.hidden_size, self.num_layers, 0, + self.dropout, self.bias, self.bidirectional, self.is_training() + ) + return ret[0], (ret[1], ret[2]) + else: + ret = jt.cudnn.ops.cudnn_rnn(input, hx, ft_weight, + cudnn_mode, self.input_size, self.hidden_size, self.num_layers, 0, + self.dropout, self.bias, self.bidirectional, self.is_training() + ) + return ret[0], ret[1] + + def execute(self, input, hx=None): + if self.batch_first: + input = input.permute(1, 0, 2) + + num_directions = 2 if self.bidirectional else 1 + + if hx is None: + if self.mode in ['RNN', 'GRU']: + hx = jt.zeros((num_directions * self.num_layers, input.shape[1], self.hidden_size), dtype=input.dtype) + elif self.mode == 'LSTM': + hx = (jt.zeros((num_directions * self.num_layers, input.shape[1], self.hidden_size), dtype=input.dtype), + jt.zeros((num_directions * self.num_layers, input.shape[1], self.hidden_size), dtype=input.dtype)) + + if jt.flags.use_cuda and jt.cudnn and self.proj_size == 0 and jt.compiler.is_cuda: + return self._execute_cudnn_rnn(input, hx) + else: + hidden_n = [] + + for l in range(self.num_layers): + output = [] + + if isinstance(hx, tuple): + hidden = [h[l * num_directions] for h in hx] + else: + hidden = hx[l * num_directions] + + output, _hidden = self.call_rnn_sequence(input, hidden, f'l{l}') + hidden_n.append(_hidden) + + if self.bidirectional: + if isinstance(hx, tuple): + hidden = [h[l * num_directions + 1] for h in hx] + else: + hidden = hx[l * num_directions + 1] + + output_b, _hidden = self.call_rnn_sequence(input, hidden, f'l{l}_reverse') + output = jt.concat([output, output_b], dim=-1) + hidden_n.append(_hidden) + + if self.dropout > 0: + input = dropout(output, p=self.dropout) + else: + input = output + + if isinstance(hx, tuple): + hidden_n = tuple(jt.stack(hn, dim=0) for hn in zip(*hidden_n)) + else: + hidden_n = jt.stack(hidden_n, dim=0) + + return output, hidden_n + + +class RNN(RNNBase): + ''' Applies a multi-layer Elman RNN with tanh ReLU non-linearity to an input sequence. + + :param input_size: The number of expected features in the input. + :type input_size: int + + :param hidden_size: The number of features in the hidden state. + :type hidden_size: int + + :param num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two RNNs together to form a stacked RNN, with the second RNN taking in outputs of the first RNN and computing the final results. Default: 1 + :type num_layers: int, optinal + + :param nonlinearity: The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh' + :type nonlinearity: str, optional + + :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True. + :type bias: bool, optional + + :param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False + :type bias: bool, optional + + :param dropout: If non-zero, introduces a Dropout layer on the outputs of each RNN layer except the last layer, with dropout probability equal to dropout. Default: 0 + :type dropout: float, optional + + :param bidirectional: If True, becomes a bidirectional RNN. Default: False + :type bidirectional: bool, optional + + Example: + >>> rnn = nn.RNN(10, 20, 2) + >>> input = jt.randn(5, 3, 10) + >>> h0 = jt.randn(2, 3, 20) + >>> output, hn = rnn(input, h0) + ''' + def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1, + nonlinearity: str = 'tanh', bias: bool = True, batch_first: bool = False, + dropout: float = 0, bidirectional: bool = False) -> None: + super().__init__('RNN', input_size, hidden_size, num_layers=num_layers, + bias=bias, batch_first=batch_first, dropout=dropout, + bidirectional=bidirectional) + + if not nonlinearity in ['tanh', 'relu']: + raise ValueError('Unrecognized nonlinearity: ' + nonlinearity) + self.nonlinearity = nonlinearity + + def call_rnn_cell(self, input, hidden, suffix): + y = matmul_transpose(input, getattr(self, f'weight_ih_{suffix}')) + y = y + matmul_transpose(hidden, getattr(self, f'weight_hh_{suffix}')) + + if self.bias: + y = y + getattr(self, f'bias_ih_{suffix}') + getattr(self, f'bias_hh_{suffix}') + + if self.nonlinearity == 'tanh': + h = jt.tanh(y) + else: + h = jt.nn.relu(y) + + return h, h + + +class LSTM(RNNBase): + ''' Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence. + + :param input_size: The number of expected features in the input. + :type input_size: int + + :param hidden_size: The number of features in the hidden state. + :type hidden_size: int + + :param num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two LSTMs together to form a stacked LSTM, with the second LSTM taking in outputs of the first LSTM and computing the final results. Default: 1 + :type num_layers: int, optinal + + :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True. + :type bias: bool, optional + + :param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False + :type bias: bool, optional + + :param dropout: If non-zero, introduces a Dropout layer on the outputs of each LSTM layer except the last layer, with dropout probability equal to dropout. Default: 0 + :type dropout: float, optional + + :param bidirectional: If True, becomes a bidirectional LSTM. Default: False + :type bidirectional: bool, optional + + :param proj_size: If > 0, will use LSTM with projections of corresponding size. Default: 0 + :type proj_size: int, optional + + Example: + >>> rnn = nn.LSTM(10, 20, 2) + >>> input = jt.randn(5, 3, 10) + >>> h0 = jt.randn(2, 3, 20) + >>> c0 = jt.randn(2, 3, 20) + >>> output, (hn, cn) = rnn(input, (h0, c0)) + ''' + + def __init__(self, input_size, hidden_size, num_layers=1, bias=True, + batch_first=False, dropout=0, bidirectional=False, proj_size=0): + super().__init__('LSTM', input_size, hidden_size, num_layers=num_layers, + bias=bias, batch_first=batch_first, dropout=dropout, + bidirectional=bidirectional, proj_size=proj_size) + + def call_rnn_cell(self, input, hidden, suffix): + h, c = hidden + y = matmul_transpose(input, getattr(self, f'weight_ih_{suffix}')) + y = y + matmul_transpose(h, getattr(self, f'weight_hh_{suffix}')) + + if self.bias: + y = y + getattr(self, f'bias_ih_{suffix}') + getattr(self, f'bias_hh_{suffix}') + + i = y[:, :self.hidden_size].sigmoid() + f = y[:, self.hidden_size : 2 * self.hidden_size].sigmoid() + g = y[:, 2 * self.hidden_size : 3 * self.hidden_size].tanh() + o = y[:, 3 * self.hidden_size:].sigmoid() + c = f * c + i * g + h = o * c.tanh() + + if self.proj_size > 0: + h = matmul_transpose(h, getattr(self, f'weight_hr_{suffix}')) + + return h, (h, c) + + +class GRU(RNNBase): + ''' Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. + + :param input_size: The number of expected features in the input. + :type input_size: int + + :param hidden_size: The number of features in the hidden state. + :type hidden_size: int + + :param num_layers: Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two GRUs together to form a stacked GRU, with the second GRU taking in outputs of the first GRU and computing the final results. Default: 1 + :type num_layers: int, optinal + + :param bias: If False, then the layer does not use bias weights b_ih and b_hh. Default: True. + :type bias: bool, optional + + :param batch_first: If True, then the input and output tensors are provided as (batch, seq, feature). Default: False + :type bias: bool, optional + + :param dropout: If non-zero, introduces a Dropout layer on the outputs of each GRU layer except the last layer, with dropout probability equal to dropout. Default: 0 + :type dropout: float, optional + + :param bidirectional: If True, becomes a bidirectional GRU. Default: False + :type bidirectional: bool, optional + + Example: + >>> rnn = nn.GRU(10, 20, 2) + >>> input = jt.randn(5, 3, 10) + >>> h0 = jt.randn(2, 3, 20) + >>> output, hn = rnn(input, h0) + ''' + + def __init__(self, input_size: int, hidden_size: int, num_layers: int = 1, + bias: bool = True, batch_first: bool = False, dropout: float = 0, + bidirectional: bool = False) -> None: + super().__init__('GRU', input_size, hidden_size, num_layers=num_layers, + bias=bias, batch_first=batch_first, dropout=dropout, + bidirectional=bidirectional) + + def call_rnn_cell(self, input, hidden, suffix): + ih = matmul_transpose(input, getattr(self, f'weight_ih_{suffix}')) + hh = matmul_transpose(hidden, getattr(self, f'weight_hh_{suffix}')) + + if self.bias: + ih = ih + getattr(self, f'bias_ih_{suffix}') + hh = hh + getattr(self, f'bias_hh_{suffix}') + + hs = self.hidden_size + r = (ih[:, :hs] + hh[:, :hs]).sigmoid() + z = (ih[:, hs: 2 * hs] + hh[:, hs: 2 * hs]).sigmoid() + n = (ih[:, 2 * hs:] + r * hh[:, 2 * hs:]).tanh() + h = (1 - z) * n + z * hidden + + return h, h + +def bilinear(in1, in2, weight, bias): + w = weight.transpose((1,0,2)) + w = w.reshape((w.shape[0], -1)) + x = jt.matmul(in1, w) + x = x.reshape(x.shape[:-1]+[weight.shape[0], weight.shape[2]]) + y = in2.broadcast(x, (-2,)) + z = (x*y).sum(-1) + if bias is not None: + z += bias + return z + + +class Bilinear(Module): + ''' bilinear transformation $out = in1^T W in2 + bias$, Example:: + + m = nn.Bilinear(20, 30, 40) + input1 = jt.randn(128, 20) + input2 = jt.randn(128, 30) + output = m(input1, input2) + print(output.shape) + # [128, 40] + + ''' + def __init__(self, in1_features, in2_features, out_features, bias=True, dtype="float32"): + bound = 1 / math.sqrt(in1_features) + self.weight = jt.init.uniform([out_features, in1_features, in2_features], dtype, -bound, bound) + self.bias = bias + if bias: + self.bias = jt.init.uniform([out_features], dtype, -bound, bound) + + def execute(self, in1, in2): + return bilinear(in1, in2, self.weight, self.bias) + +#TODO: support FFT2D only now. +def _fft2(x, inverse=False): + assert(jt.flags.use_cuda==1) + assert(len(x.shape) == 4) + assert(x.shape[3] == 2) + y = jt.compile_extern.cufft_ops.cufft_fft(x, inverse) + if inverse: + y /= x.shape[1] * x.shape[2] + return y + +class ComplexNumber: + ''' Applys Complex number class. + + It's saved as jt.stack(real, imag, dim=-1) + + You can construct ComplexNumber with real part and imaginary part like ComplexNumber(real, imag) + or real part only with ComplexNumber(real) + or value after jt.stack with ComplexNumber(value, is_concat_value=True) + + add, sub, mul and truediv between ComplexNumber and ComplexNumber, jt.Var, int, float are implemented + + You can use 'shape', 'reshape' etc. as jt.Var + + Example: + >>> real = jt.array([[[1., -2., 3.]]]) + >>> imag = jt.array([[[0., 1., 6.]]]) + >>> a = ComplexNumber(real, imag) + >>> a + a + >>> a / a + >>> a.norm() # sqrt(real^2+imag^2) + >>> a.exp() # e^real(cos(imag)+isin(imag)) + >>> a.conj() # ComplexNumber(real, -imag) + >>> a.fft2() # cuda only now. len(real.shape) equals 3 + >>> a.ifft2() # cuda only now. len(real.shape) equals 3 + + >>> a = jt.array([[1,1],[1,-1]]) + >>> b = jt.array([[0,-1],[1,0]]) + >>> c = ComplexNumber(a,b) / jt.sqrt(3) + >>> c @ c.transpose().conj() + ComplexNumber(real=jt.Var([[0.99999994 0. ] + [0. 0.99999994]], dtype=float32), imag=jt.Var([[0. 0.] + [0. 0.]], dtype=float32)) + ''' + def __init__(self, real: jt.Var, imag: jt.Var=None, is_concat_value=False): + if is_concat_value: + assert real.shape[-1] == 2 + assert imag is None + self.value = real + elif imag is None: + self.value = jt.stack([real, jt.zeros_like(real)], dim=-1) + else: + assert real.shape == imag.shape + assert real.dtype == imag.dtype + self.value = jt.stack([real, imag], dim=-1) + + @property + def real(self): + return self.value[..., 0] + + @property + def imag(self): + return self.value[..., 1] + + @property + def shape(self): + return self.value.shape[:-1] + + def norm(self): + return jt.sqrt(jt.sqr(self.real) + jt.sqr(self.imag)) + + def stop_grad(self): + return ComplexNumber(self.value.stop_grad(), is_concat_value=True) + + def start_grad(self): + return ComplexNumber(self.value.start_grad(), is_concat_value=True) + + def detach(self): + return ComplexNumber(self.value.detach(), is_concat_value=True) + + def unsqueeze(self, dim=0): + return ComplexNumber(jt.unsqueeze(self.real, dim=dim), jt.unsqueeze(self.imag, dim=dim)) + + def squeeze(self, dim=0): + return ComplexNumber(jt.squeeze(self.real, dim=dim), jt.squeeze(self.imag, dim=dim)) + + def reshape(self, shape): + return ComplexNumber(jt.reshape(self.real, shape), jt.reshape(self.imag, shape)) + + def permute(self, *axes): + return ComplexNumber(jt.permute(self.real, *axes), jt.permute(self.imag, *axes)) + + def transpose(self, *axes): + return ComplexNumber(jt.transpose(self.real, *axes), jt.transpose(self.imag, *axes)) + + def broadcast(self, shape, dims): + return ComplexNumber(self.real.broadcast(shape, dims), self.imag.broadcast(shape, dims)) + + def sum(self, dims, keepdims: bool=False): + return ComplexNumber(self.real.sum(dims, keepdims=keepdims), self.imag.sum(dims, keepdims=keepdims)) + + def exp(self): + er = jt.exp(self.real) + return ComplexNumber(er * jt.cos(self.imag), er * jt.sin(self.imag)) + + def conj(self): + return ComplexNumber(self.real, -self.imag) + + def __add__(self, other): + if isinstance(other, ComplexNumber): + return ComplexNumber(self.real + other.real, self.imag + other.imag) + elif isinstance(other, (int, float)): + return ComplexNumber(self.real + other, self.imag) + else: + raise NotImplementedError + + def __radd__(self, other): + if isinstance(other, ComplexNumber): + return ComplexNumber(other.real + self.real, other.imag + self.imag) + elif isinstance(other, (int, float)): + return ComplexNumber(other + self.real, self.imag) + else: + raise NotImplementedError + + def __sub__(self, other): + if isinstance(other, ComplexNumber): + return ComplexNumber(self.real - other.real, self.imag - other.imag) + elif isinstance(other, (int, float)): + return ComplexNumber(self.real - other, self.imag) + else: + raise NotImplementedError + + def __rsub__(self, other): + if isinstance(other, ComplexNumber): + return ComplexNumber(other.real - self.real, other.imag - self.imag) + elif isinstance(other, (int, float)): + return ComplexNumber(other - self.real, self.imag) + else: + raise NotImplementedError + + def __mul__(self, other): + if isinstance(other, ComplexNumber): + return ComplexNumber(self.real * other.real - self.imag * other.imag, + self.real * other.imag + self.imag * other.real) + elif isinstance(other, (int, float)): + return ComplexNumber(self.value * other, is_concat_value=True) + else: + raise NotImplementedError + + def __rmul__(self, other): + if isinstance(other, ComplexNumber): + return ComplexNumber(other.real * self.real - other.imag * self.imag, + other.imag * self.real + other.real * self.imag) + elif isinstance(other, (int, float)): + return ComplexNumber(other * self.value, is_concat_value=True) + else: + raise NotImplementedError + + def __truediv__(self, other): + if isinstance(other, ComplexNumber): + norm = jt.sqr(other.real) + jt.sqr(other.imag) + return ComplexNumber((self.real * other.real + self.imag * other.imag) / norm, + (self.imag * other.real - self.real * other.imag) / norm) + elif isinstance(other, (int, float)): + return ComplexNumber(self.value / other, is_concat_value=True) + else: + raise NotImplementedError + + def __rtruediv__(self, other): + norm = jt.sqr(self.real) + jt.sqr(self.imag) + if isinstance(other, ComplexNumber): + return ComplexNumber((other.real * self.real + other.imag * self.imag) / norm, + (other.imag * self.real - other.real * self.imag) / norm) + elif isinstance(other, (int, float)): + return ComplexNumber(other * self.real / norm, - other * self.imag / norm) + else: + raise NotImplementedError + + def __matmul__(self, other): + if isinstance(other, ComplexNumber): + return ComplexNumber(self.real @ other.real - self.imag @ other.imag, + self.real @ other.imag + self.imag @ other.real) + else: + raise NotImplementedError + + def __imatmul__(self, other): + if isinstance(other, ComplexNumber): + return ComplexNumber(other.real @ self.real - other.imag @ self.imag, + other.imag @ self.real + other.real @ self.imag) + else: + raise NotImplementedError + + def __repr__(self) -> str: + return f'ComplexNumber(real={self.real.__repr__()}, imag={self.imag.__repr__()})' + + def fft2(self): + return ComplexNumber(_fft2(self.value, inverse=False), is_concat_value=True) + + def ifft2(self): + return ComplexNumber(_fft2(self.value, inverse=True), is_concat_value=True) + + +def one_hot(x: jt.Var, num_classes: int=-1) -> jt.Var: + ''' Returns the one_hot encoding of inputs. + + :param x: class values of any shape + :type x: jt.Var with bool or integer dtype + + :param num_classes: Total number of classes. If set to -1, the number of classes will be inferred as one greater than the largest class value in the input tensor. + :type num_classes: int, optional + + :return: a Var with one more dimension with 1 values at the index + of last dimension indicated by the input, and 0 everywhere else. + :rtype: jt.Var + + .. note:: + if the values in x are greater than num_class or less than 0, + the returned one_hot will be all zeros. + + Example: + >>> jt.nn.one_hot(jt.arange(5) % 3) + jt.Var([[1 0 0] + [0 1 0] + [0 0 1] + [1 0 0] + [0 1 0]], dtype=int32) + >>> jt.nn.one_hot(jt.arange(5) % 3, num_classes=5) + jt.Var([[1 0 0 0 0] + [0 1 0 0 0] + [0 0 1 0 0] + [1 0 0 0 0] + [0 1 0 0 0]], dtype=int32) + >>> jt.nn.one_hot(jt.arange(6).reshape(3,2) % 3) + jt.Var([[[1 0 0] + [0 1 0]] + + [[0 0 1] + [1 0 0]] + + [[0 1 0] + [0 0 1]]], dtype=int32) + ''' + + assert x.dtype in [jt.bool, jt.int8, jt.int16, jt.int32, jt.int64, jt.uint8, jt.uint16, jt.uint32, jt.uint64] + if num_classes == -1: + num_classes = x.max().item() + 1 + + N = len(x.shape) + indices = ["i"+str(i) for i in range(N)] + y = jt.ones_like(x).reindex( + x.shape + [num_classes], + indices, + extras=[x], + overflow_conditions=[f"i{N} != @e0({','.join(indices)})"], + overflow_value=0) + return y + + +class KLDivLoss(Module): + ''' Computes the Kullback-Leibler divergence loss. + ''' + + def __init__(self, reduction: str = 'mean', log_target: bool = False): + ''' + :param reduction: Specifies the reduction to apply to the output. Can be 'mean', 'sum', 'batchmean', or 'none'. Defaults to 'mean'. + :type reduction: str, optional + :param log_target: Specifies whether target is the log space. Defaults to False. + :type log_target: bool, optional + ''' + self.reduction = reduction + self.log_target = log_target + + def execute(self, input: jt.Var, target: jt.Var) -> jt.Var: + if not self.log_target: + loss_pointwise = target * (target.log() - input) + else: + loss_pointwise = target.exp() * (target - input) + + if self.reduction == "mean": + loss = loss_pointwise.mean() + elif self.reduction == "batchmean": + loss = loss_pointwise.sum() / input.size(0) + elif self.reduction == "sum": + loss = loss_pointwise.sum() + else: + loss = loss_pointwise + return loss + +class Mish(Module): + def __init__(self, inplace=False): + ''' +Applies the Mish function, element-wise. +reference: Mish - A Self Regularized Non-Monotonic Neural Activation Function. + ''' + pass + def execute(self, x): + return x * jt.tanh(jt.softplus(x)) + +def mish(x, inplace=False): + return x * jt.tanh(jt.nn.softplus(x)) + +def skip_init(module_cls, *args, **kw): + return module_cls(*args, **kw) diff --git a/python/jittor/notebook/60分钟快速入门Jittor/README.md b/python/jittor/notebook/60分钟快速入门Jittor/README.md new file mode 100644 index 00000000..9f6f5f9d --- /dev/null +++ b/python/jittor/notebook/60分钟快速入门Jittor/README.md @@ -0,0 +1,11 @@ +# 计图零基础入门教程(60分钟) + +``` +git clone https://github.com/Jittor/LearnJittorBasicIn60Min.git +cd LearnJittorBasicIn60Min +jupyter notebook +``` + +在线浏览地址: + +特别感谢教程作者:llt diff --git a/python/jittor/notebook/60分钟快速入门Jittor/mnist.png b/python/jittor/notebook/60分钟快速入门Jittor/mnist.png new file mode 100644 index 00000000..7fcc5659 Binary files /dev/null and b/python/jittor/notebook/60分钟快速入门Jittor/mnist.png differ diff --git a/python/jittor/notebook/ConditionGAN.src.md b/python/jittor/notebook/ConditionGAN.src.md new file mode 100644 index 00000000..e8b1300b --- /dev/null +++ b/python/jittor/notebook/ConditionGAN.src.md @@ -0,0 +1,170 @@ +# 使用Jittor实现Conditional GAN + +Generative Adversarial Nets(GAN)[1]提出了一种新的方法来训练生成模型。然而,GAN对于要生成的图片缺少控制。Conditional GAN(CGAN)[2]通过添加显式的条件或标签,来控制生成的图像。本教程讲解了CGAN的网络结构、损失函数设计、使用CGAN生成一串数字、从头训练CGAN、以及在mnist手写数字数据集上的训练结果。 + +## CGAN网络架构 + +通过在生成器generator和判别器discriminator中添加相同的额外信息y,GAN就可以扩展为一个conditional模型。y可以是任何形式的辅助信息,例如类别标签或者其他形式的数据。我们可以通过将y作为额外输入层,添加到生成器和判别器来完成条件控制。 + +在生成器generator中,除了y之外,还额外输入随机一维噪声z,为结果生成提供更多灵活性。 + +![](https://cg.cs.tsinghua.edu.cn/jittor/images/tutorial/2020-5-13-22-47-cgan/network.jpg) + +## 损失函数 + +### GAN的损失函数 + +在解释CGAN的损失函数之前,首先介绍GAN的损失函数。下面是GAN的损失函数设计。 + +![](https://cg.cs.tsinghua.edu.cn/jittor/images/tutorial/2020-5-13-22-47-cgan/gan-loss.png) + +对于判别器D,我们要训练最大化这个loss。如果D的输入是来自真实样本的数据x,则D的输出D(x)要尽可能地大,log(D(x))也会尽可能大。如果D的输入是来自G生成的假图片G(z),则D的输出D(G(z))应尽可能地小,从而log(1-D(G(z))会尽可能地大。这样可以达到max D的目的。 + +对于生成器G,我们要训练最小化这个loss。对于G生成的假图片G(z),我们希望尽可能地骗过D,让它觉得我们生成的图片就是真的图片,这样就达到了G“以假乱真”的目的。那么D的输出D(G(z))应尽可能地大,从而log(1-D(G(z))会尽可能地小。这样可以达到min G的目的。 + +D和G以这样的方式联合训练,最终达到G的生成能力越来越强,D的判别能力越来越强的目的。 + +### CGAN的损失函数 + +下面是CGAN的损失函数设计。 + +![](https://cg.cs.tsinghua.edu.cn/jittor/images/tutorial/2020-5-13-22-47-cgan/loss.png) + + +很明显,CGAN的loss跟GAN的loss的区别就是多了条件限定y。D(x/y)代表在条件y下,x为真的概率。D(G(z/y))表示在条件y下,G生成的图片被D判别为真的概率。 + +## Jittor代码数字生成 + +首先,我们导入需要的包,并且设置好所需的超参数: + +```python +import jittor as jt +from jittor import nn +import numpy as np +import pylab as pl + +%matplotlib inline + +# 隐空间向量长度 +latent_dim = 100 +# 类别数量 +n_classes = 10 +# 图片大小 +img_size = 32 +# 图片通道数量 +channels = 1 +# 图片张量的形状 +img_shape = (channels, img_size, img_size) +``` + +第一步,定义生成器G。该生成器输入两个一维向量y和noise,生成一张图片。 + +```python +class Generator(nn.Module): + def __init__(self): + super(Generator, self).__init__() + self.label_emb = nn.Embedding(n_classes, n_classes) + + def block(in_feat, out_feat, normalize=True): + layers = [nn.Linear(in_feat, out_feat)] + if normalize: + layers.append(nn.BatchNorm1d(out_feat, 0.8)) + layers.append(nn.LeakyReLU(0.2)) + return layers + self.model = nn.Sequential( + *block((latent_dim + n_classes), 128, normalize=False), + *block(128, 256), + *block(256, 512), + *block(512, 1024), + nn.Linear(1024, int(np.prod(img_shape))), + nn.Tanh()) + + def execute(self, noise, labels): + gen_input = jt.concat((self.label_emb(labels), noise), dim=1) + img = self.model(gen_input) + img = img.view((img.shape[0], *img_shape)) + return img +``` + +第二步,定义判别器D。D输入一张图片和对应的y,输出是真图片的概率。 + +```python +class Discriminator(nn.Module): + def __init__(self): + super(Discriminator, self).__init__() + self.label_embedding = nn.Embedding(n_classes, n_classes) + self.model = nn.Sequential( + nn.Linear((n_classes + int(np.prod(img_shape))), 512), + nn.LeakyReLU(0.2), + nn.Linear(512, 512), + nn.Dropout(0.4), + nn.LeakyReLU(0.2), + nn.Linear(512, 512), + nn.Dropout(0.4), + nn.LeakyReLU(0.2), + nn.Linear(512, 1)) + + def execute(self, img, labels): + d_in = jt.concat((img.view((img.shape[0], (- 1))), self.label_embedding(labels)), dim=1) + validity = self.model(d_in) + return validity +``` + +第三步,使用CGAN生成一串数字。 + +代码如下。您可以使用您训练好的模型来生成图片,也可以使用我们提供的预训练参数: 模型预训练参数下载:。 + +```python +# 下载提供的预训练参数 +!wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/generator_last.pkl +!wget https://cg.cs.tsinghua.edu.cn/jittor/assets/build/discriminator_last.pkl +``` + +生成自定义的数字: + +```python +# 定义模型 +generator = Generator() +discriminator = Discriminator() +generator.eval() +discriminator.eval() + +# 加载参数 +generator.load('./generator_last.pkl') +discriminator.load('./discriminator_last.pkl') + +# 定义一串数字 +number = "201962517" +n_row = len(number) +z = jt.array(np.random.normal(0, 1, (n_row, latent_dim))).float32().stop_grad() +labels = jt.array(np.array([int(number[num]) for num in range(n_row)])).float32().stop_grad() +gen_imgs = generator(z,labels) + +pl.imshow(gen_imgs.data.transpose((1,2,0,3))[0].reshape((gen_imgs.shape[2], -1))) +``` + +## 从头训练Condition GAN + +从头训练 Condition GAN 的完整代码在, 让我们把他下载下来看看! + +```python +!wget https://raw.githubusercontent.com/Jittor/gan-jittor/master/models/cgan/cgan.py +!python3.7 ./cgan.py --help + +# 选择合适的batch size,运行试试 +# 运行命令: !python3.7 ./cgan.py --batch_size 64 +``` + +## MNIST数据集训练结果 + +下面展示了Jittor版CGAN在MNIST数据集的训练结果。下面分别是训练0 epoch和90 epoches的结果。 + +![](https://cg.cs.tsinghua.edu.cn/jittor/images/tutorial/2020-5-13-22-47-cgan/0-epoch.png) + +![](https://cg.cs.tsinghua.edu.cn/jittor/images/tutorial/2020-5-13-22-47-cgan/90-epoch.png) + +## 参考文献 + +1. Goodfellow, Ian, et al. “Generative adversarial nets.” Advances in neural information processing systems. 2014. + +2. Mirza, Mehdi, and Simon Osindero. “Conditional generative adversarial nets.” arXiv preprint arXiv:1411.1784 (2014). \ No newline at end of file diff --git a/python/jittor/notebook/LSGAN.src.md b/python/jittor/notebook/LSGAN.src.md new file mode 100644 index 00000000..47b91f5f --- /dev/null +++ b/python/jittor/notebook/LSGAN.src.md @@ -0,0 +1,291 @@ +# 图像生成之LSGAN + +生成对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。GAN模型由生成器(Generator)和判别器(Discriminator)两个部分组成。在训练过程中,生成器的目标就是尽量生成真实的图片去欺骗判别器。而判别器的目标就是尽量把生成器生成的图片和真实的图片分别开来。这样,生成器和判别器构成了一个动态的“博弈过程”。许多相关的研究工作表明GAN能够产生效果非常真实的生成效果。 + +本教程使用Jittor框架实现了一种经典GAN模型LSGAN 。LSGAN将GAN的目标函数由交叉熵损失替换成最小二乘损失,以此拒绝了标准GAN生成的图片质量不高以及训练过程不稳定这两个缺陷。本教程通过LSGAN的实现介绍了Jittor数据加载、模型定义、模型训练的使用方法。 + +LSGAN论文: + +## 1.数据集准备 + +本教程使用两种数据集进行LSGAN的训练,分别是Jittor自带的数据集MNIST,和用户构建的数据集CelebA。 + +如果要使用CelebA数据集进行训练,可以通过以下链接下载CelebA数据集。 + +- CelebA 数据集: + +将下载的训练数据和验证数据分别存储在`./data/celebA_train/imgs/`和`./data/celebA_eval/imgs/`中 + +最终数据集的文件组织如下。 + +``` +# 文件组织 +根目录 +|----data + |----celebA_train + | |----imgs + |----celebA_eval + | |----imgs +``` + +## 2.模型定义 + +本教程使用LSGAN进行图像生成,其网络结构由生成器和别器。生成器网络输入一个`1024`维的向量,生成分辨率为`112*112`的图像;判别器网络输入`112*112`的图像,输出一个数字表示输入图像为真实图像的可信程度。 + +下面分别定义生成器和判别器 + +```python +import jittor as jt +from jittor import nn, Module, init +from jittor.dataset.mnist import MNIST +from jittor.dataset.dataset import ImageFolder +import jittor.transform as transform +import os +import numpy as np +import matplotlib.pyplot as plt + +# 如果有CUDA,则通过use_cuda设置在GPU上进行训练 +if jt.has_cuda: + jt.flags.use_cuda = 1 + +class generator(Module): + def __init__(self, dim=3): + super(generator, self).__init__() + self.fc = nn.Linear(1024, 7*7*256) + self.fc_bn = nn.BatchNorm(256) + self.deconv1 = nn.ConvTranspose(256, 256, 3, 2, 1, 1) + self.deconv1_bn = nn.BatchNorm(256) + self.deconv2 = nn.ConvTranspose(256, 256, 3, 1, 1) + self.deconv2_bn = nn.BatchNorm(256) + self.deconv3 = nn.ConvTranspose(256, 256, 3, 2, 1, 1) + self.deconv3_bn = nn.BatchNorm(256) + self.deconv4 = nn.ConvTranspose(256, 256, 3, 1, 1) + self.deconv4_bn = nn.BatchNorm(256) + self.deconv5 = nn.ConvTranspose(256, 128, 3, 2, 1, 1) + self.deconv5_bn = nn.BatchNorm(128) + self.deconv6 = nn.ConvTranspose(128, 64, 3, 2, 1, 1) + self.deconv6_bn = nn.BatchNorm(64) + self.deconv7 = nn.ConvTranspose(64 , dim, 3, 1, 1) + self.relu = nn.ReLU() + self.tanh = nn.Tanh() + + def execute(self, input): + x = self.fc(input).reshape((input.shape[0], 256, 7, 7)) + x = self.relu(self.fc_bn(x)) + x = self.relu(self.deconv1_bn(self.deconv1(x))) + x = self.relu(self.deconv2_bn(self.deconv2(x))) + x = self.relu(self.deconv3_bn(self.deconv3(x))) + x = self.relu(self.deconv4_bn(self.deconv4(x))) + x = self.relu(self.deconv5_bn(self.deconv5(x))) + x = self.relu(self.deconv6_bn(self.deconv6(x))) + x = self.tanh(self.deconv7(x)) + return x + + +class discriminator(nn.Module): + def __init__(self, dim=3): + super(discriminator, self).__init__() + self.conv1 = nn.Conv(dim, 64, 5, 2, 2) + self.conv2 = nn.Conv(64, 128, 5, 2, 2) + self.conv2_bn = nn.BatchNorm(128) + self.conv3 = nn.Conv(128, 256, 5, 2, 2) + self.conv3_bn = nn.BatchNorm(256) + self.conv4 = nn.Conv(256, 512, 5, 2, 2) + self.conv4_bn = nn.BatchNorm(512) + self.fc = nn.Linear(512*7*7, 1) + self.leaky_relu = nn.Leaky_relu() + + def execute(self, input): + x = self.leaky_relu(self.conv1(input), 0.2) + x = self.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2) + x = self.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2) + x = self.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2) + x = x.reshape((x.shape[0], 512*7*7)) + x = self.fc(x) + return x +``` + +损失函数采用最小二乘损失函数。具体实现如下,`x`为生成器的输出值,`b`表示该图像是否希望被判别为真。 + +```python +def ls_loss(x, b): + mini_batch = x.shape[0] + y_real_ = jt.ones((mini_batch,)) + y_fake_ = jt.zeros((mini_batch,)) + if b: + return (x-y_real_).sqr().mean() + else: + return (x-y_fake_).sqr().mean() +``` + +## 3.模型训练 + +参数设定如下: + +```python +# 使用 MNIST 或者 CelebA数据集进行训练 +task = "MNIST" +# task = "CelebA" +# 批大小 +batch_size = 128 +# 学习率 +lr = 0.0002 +# 训练轮数 +train_epoch = 20 if task=="MNIST" else 50 +# 训练图像标准大小 +img_size = 112 +# Adam优化器参数 +betas = (0.5,0.999) +# 数据集图像通道数,MNIST为1,CelebA为3 +dim = 1 if task=="MNIST" else 3 +# 结果图片存储路径 +save_path = "./results_img" +``` + +分别声明生成器和判别器,并使用Adam作为优化器。 + +```python +G = generator (dim) +D = discriminator (dim) +G_optim = nn.Adam(G.parameters(), lr, betas=betas) +D_optim = nn.Adam(D.parameters(), lr, betas=betas) +``` + +jittor自带有MNIST数据集。使用`jittor.transform`可以进行数据归一化及数据增强,这里本教程通过`transform`将图片归一化到指定区间,并resize到标准大小`112*112`。。通过`set_attrs`函数可以修改数据集的相关参数,如`batch_size`、`shuffle`及`transform`等。 + +如果使用自己构建CelebA数据集进行训练,可以通过通用数据加载器`jittor.dataset.dataset.ImageFolder`,输入数据集路径即可构建用户数据集。 + +构建数据集代码如下: + +```python +if task=="MNIST": + transform = transform.Compose([ + transform.Resize(size=img_size), + transform.Gray(), + transform.ImageNormalize(mean=[0.5], std=[0.5]), + ]) + train_loader = MNIST(train=True, transform=transform).set_attrs(batch_size=batch_size, shuffle=True) + eval_loader = MNIST(train=False, transform = transform).set_attrs(batch_size=batch_size, shuffle=True) +elif task=="CelebA": + transform = transform.Compose([ + transform.Resize(size=img_size), + transform.ImageNormalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + ]) + train_dir = './data/celebA_train' + train_loader = ImageFolder(train_dir).set_attrs(transform=transform, batch_size=batch_size, shuffle=True) + eval_dir = './data/celebA_eval' + eval_loader = ImageFolder(eval_dir).set_attrs(transform=transform, batch_size=batch_size, shuffle=True) +``` + +训练和验证代码如下: + +```python +def train(epoch): + for batch_idx, (x_, target) in enumerate(train_loader): + mini_batch = x_.shape[0] + # train discriminator + D_result = D(x_) + D_real_loss = ls_loss(D_result, True) + z_ = jt.init.gauss((mini_batch, 1024), 'float') + G_result = G(z_) + D_result_ = D(G_result) + D_fake_loss = ls_loss(D_result_, False) + D_train_loss = D_real_loss + D_fake_loss + D_train_loss.sync() + D_optim.step(D_train_loss) + + # train generator + z_ = jt.init.gauss((mini_batch, 1024), 'float') + G_result = G(z_) + D_result = D(G_result) + G_train_loss = ls_loss(D_result, True) + G_train_loss.sync() + G_optim.step(G_train_loss) + if (batch_idx%100==0): + print("train: batch_idx",batch_idx,"epoch",epoch) + print(' D training loss =', D_train_loss.data.mean()) + print(' G training loss =', G_train_loss.data.mean()) + +def validate(epoch): + D_losses = [] + G_losses = [] + G.eval() + D.eval() + for batch_idx, (x_, target) in enumerate(eval_loader): + mini_batch = x_.shape[0] + + # calculation discriminator loss + D_result = D(x_) + D_real_loss = ls_loss(D_result, True) + z_ = jt.init.gauss((mini_batch, 1024), 'float') + G_result = G(z_) + D_result_ = D(G_result) + D_fake_loss = ls_loss(D_result_, False) + D_train_loss = D_real_loss + D_fake_loss + D_losses.append(D_train_loss.data.mean()) + + # calculation generator loss + z_ = jt.init.gauss((mini_batch, 1024), 'float') + G_result = G(z_) + D_result = D(G_result) + G_train_loss = ls_loss(D_result, True) + G_losses.append(G_train_loss.data.mean()) + G.train() + D.train() + print("validate: epoch",epoch) + print(' D validate loss =', np.array(D_losses).mean()) + print(' G validate loss =', np.array(G_losses).mean()) +``` + +使用每个epoch的生成器通过固定向量生成图片,将图片显示并存储在`./results_img/`中 + +```python +if not os.path.exists(save_path): + os.mkdir(save_path) +fixed_z_ = jt.init.gauss((5 * 5, 1024), 'float') +def save_result(num_epoch, G , path = 'result.png'): + """Use the current generator to generate 5*5 pictures and store them. + + Args: + num_epoch(int): current epoch + G(generator): current generator + path(string): storage path of result image + """ + + z_ = fixed_z_ + G.eval() + test_images = G(z_) + G.train() + size_figure_grid = 5 + fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5)) + for i in range(size_figure_grid): + for j in range(size_figure_grid): + ax[i, j].get_xaxis().set_visible(False) + ax[i, j].get_yaxis().set_visible(False) + + for k in range(5*5): + i = k // 5 + j = k % 5 + ax[i, j].cla() + if task=="MNIST": + ax[i, j].imshow((test_images[k, 0].data+1)/2, cmap='gray') + else: + ax[i, j].imshow((test_images[k].data.transpose(1, 2, 0)+1)/2) + + label = 'Epoch {0}'.format(num_epoch) + fig.text(0.5, 0.04, label, ha='center') + plt.savefig(path) + plt.show() +``` + +现在,让我们训练一番试试! + +```python +for epoch in range(train_epoch): + print ('number of epochs', epoch) + train(epoch) + validate(epoch) + result_img_path = './results_img/' + task + str(epoch) + '.png' + save_result(epoch, G, path=result_img_path) +``` + diff --git a/python/jittor/notebook/__main__.py b/python/jittor/notebook/__main__.py new file mode 100644 index 00000000..0196b413 --- /dev/null +++ b/python/jittor/notebook/__main__.py @@ -0,0 +1,9 @@ +from .md_to_ipynb import dirname, notebook_dir +import os +import sys +import shutil +from distutils.dir_util import copy_tree + +copy_tree(dirname, notebook_dir) +os.chdir(notebook_dir) +os.system(f"{sys.executable} -m jupyter notebook") \ No newline at end of file diff --git a/python/jittor/notebook/basics.src.md b/python/jittor/notebook/basics.src.md new file mode 100644 index 00000000..4036cb31 --- /dev/null +++ b/python/jittor/notebook/basics.src.md @@ -0,0 +1,57 @@ +# Basics: Op, Var + +# 基本概念:Op, Var + +To train your model with jittor, there are only two main concept you need to know: + +要使用jittor训练模型,您需要了解两个主要概念: + +* Var: basic data type of jittor +* Var:Jittor的基本数据类型 +* Operations: Jittor'op is simular with numpy +* Operations:Jittor的算子与numpy类似 + +## Var +First, let's get started with Var. Var is the basic data type of jittor. Computation process in Jittor is asynchronous for optimization. If you want to access the data, `Var.data` can be used for synchronous data accessing. + +首先,让我们开始使用Var。Var是jittor的基本数据类型,为了运算更加高效Jittor中的计算过程是异步的。 如果要访问数据,可以使用`Var.data`进行同步数据访问。 + +``` +import jittor as jt +a = jt.float32([1,2,3]) +print (a) +print (a.data) +# Output: float32[3,] +# Output: [ 1. 2. 3.] +``` +## Op +Jittor'op is simular with numpy. Let's try some operations. We create Var `a` and `b` via operation `jt.float32`, and add them. Printing those variables shows they have the same shape and dtype. + + Jittor的算子与numpy类似。 让我们尝试一些操作, 我们通过操作jt.float32创建Var `a`和`b`,并将它们相加。 输出这些变量相关信息,可以看出它们具有相同的形状和类型。 + +``` +import jittor as jt +a = jt.float32([1,2,3]) +b = jt.float32([4,5,6]) +c = a+b +print(a,b,c) +``` + +Beside that, All the operators we used `jt.xxx(Var, ...)` have alias `Var.xxx(...)`. For example: + +除此之外,我们使用的所有算子`jt.xxx(Var,...)`都具有别名`Var.xxx(...)`。 例如: + +``` +c.max() # alias of jt.max(a) +c.add(a) # alias of jt.add(c, a) +c.min(keepdims=True) # alias of jt.min(c, keepdims=True) +``` + +if you want to know all the operation which Jittor supports. try `help(jt.ops)`. All the operation you found in `jt.ops.xxx`, can be used via alias `jt.xxx`. + +如果您想知道Jittor支持的所有操作,可以运行`help(jt.ops)`。 您在`jt.ops.xxx`中找到的所有操作都可以通过别名`jt.xxx`。 + +``` +help(jt.ops) +``` + diff --git a/python/jittor/notebook/custom_op.src.md b/python/jittor/notebook/custom_op.src.md new file mode 100644 index 00000000..662bc119 --- /dev/null +++ b/python/jittor/notebook/custom_op.src.md @@ -0,0 +1,109 @@ +# Custom Op: write your operator with C++ and CUDA and JIT compile it + +# 自定义算子:使用C ++和CUDA编写您的算子,并其进行即时编译 + +> NOTE: This tutorial is still working in progress + +In this tutorial, we will show: + +1. how to write your operator with C++ and CUDA and JIT compile it +2. execute your custom operation + +If you want to implement a very simple op with few lines of code, please use code op, please see `help(jt.code)`. +custom_op is used for implement a complicated op. The capabilities of custom_op and built-in operations are exactly the same. + +> 注意:本教程仍在持续更新中 + +在本教程中,我们将展示: + +1. 如何用C ++和CUDA编写您的算子并对其进行即时编译 +2. 运行您的自定义算子 + +如果您想用几行代码来实现一个非常简单的算子,请使用code运算,请参阅`help(jt.code)`. +custom_op用于实现复杂的算子。 custom_op和内置运算的功能完全相同。 + +```python +import jittor as jt + +header =""" +#pragma once +#include "op.h" + +namespace jittor { + +struct CustomOp : Op { + Var* output; + CustomOp(NanoVector shape, NanoString dtype=ns_float32); + + const char* name() const override { return "custom"; } + DECLARE_jit_run; +}; + +} // jittor +""" + +src = """ +#include "var.h" +#include "custom_op.h" + +namespace jittor { +#ifndef JIT +CustomOp::CustomOp(NanoVector shape, NanoString dtype) { + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 1); + output = create_output(shape, dtype); +} + +void CustomOp::jit_prepare(JK& jk) { + add_jit_define(jk, "T", output->dtype()); +} + +#else // JIT +#ifdef JIT_cpu +void CustomOp::jit_run() { + index_t num = output->num; + auto* __restrict__ x = output->ptr(); + for (index_t i=0; inum; + auto* __restrict__ x = output->ptr(); + int blockSize = 256; + int numBlocks = (num + blockSize - 1) / blockSize; + kernel<<>>(num, x); +} +#endif // JIT_cpu +#endif // JIT + +} // jittor +""" + +my_op = jt.compile_custom_op(header, src, "custom", warp=False) +``` + +Let's check the result of this op. + +让我们查看一下这个运算的结果。 + +```python +# run cpu version +jt.flags.use_cuda = 0 +a = my_op([3,4,5], 'float').fetch_sync() +assert (a.flatten() == range(3*4*5)).all() + +if jt.compiler.has_cuda: + # run cuda version + jt.flags.use_cuda = 1 + a = my_op([3,4,5], 'float').fetch_sync() + assert (-a.flatten() == range(3*4*5)).all() +``` \ No newline at end of file diff --git a/python/jittor/notebook/example.src.md b/python/jittor/notebook/example.src.md new file mode 100644 index 00000000..bef27c40 --- /dev/null +++ b/python/jittor/notebook/example.src.md @@ -0,0 +1,68 @@ +# Example: Model definition and training + +# 示例:模型定义与训练 + +The following example shows how to model a two-layer neural network step by step and train from scratch In a few lines of Python code. + +以下示例展示了如何逐步搭建两层神经网络模型,并通过几行Python代码从头开始进行模型训练。 + +``` +import jittor as jt +import numpy as np +from jittor import nn, Module, init +``` + +The following code defines our model, which is a two-layer neural network. The size of hidden layer is 10. and the activation function is relu. + +以下代码定义了我们的模型,该模型是一个两层神经网络。 隐藏层的大小为10,激活函数为relu。 + +``` +### model define + +class Model(Module): + def __init__(self): + self.layer1 = nn.Linear(1, 10) + self.relu = nn.ReLU() + self.layer2 = nn.Linear(10, 1) + def execute (self,x) : + x = self.layer1(x) + x = self.relu(x) + x = self.layer2(x) + return x +``` + +At last, this model is trained from scratch. A simple gradient descent is used, and the loss function is L2 distance. The training process is asynchronous for efficiency. jittor calculates the gradients and applies graph- and operator-level optimizations via **unify IR graph** and **jit analyzer**. +In this example, multiple optimizations can be used, including: **operator fusion**, the activation function and loss function can be fused into the first and second linear layers; Three meta-operators in matrix multiplication could also be fused. **Parallelism**, it can improve performance of compute-intensive operations on modern multi-core CPUs and GPUs. The operator fusion is a graph-level optimization, and parallelism can be achieved in both graph-level and operator-level. + +最后,从头开始训练该模型。 优化器使用简单的梯度下降,损失函数为L2距离。 为提高效率训练过程是异步的。 jittor通过**统一计算图**和**即时分析器**计算梯度,并进行计算图级和算子级的优化。 + +在该示例中,Jittor使用了多个优化,包括:**算子融合**,激活函数和损失函数可以融合到第一和第二全连接层中; 矩阵乘法中的三元算子也可以融合。 **并行化**,它可以提高现代多核CPU和GPU上计算密集型运算的性能。 算子融合是一种计算图级优化,而并行化则同时作用于图形级和算子级的优化。 + +``` +np.random.seed(0) +jt.set_seed(3) +n = 1000 +batch_size = 50 +base_lr = 0.05 +# we need to stop grad of global value to prevent memory leak +lr = jt.float32(base_lr).name("lr").stop_grad() + +def get_data(n): + for i in range(n): + x = np.random.rand(batch_size, 1) + y = x*x + yield jt.float32(x), jt.float32(y) + +model = Model() +learning_rate = 0.1 +optim = nn.SGD (model.parameters(), learning_rate) + +for i,(x,y) in enumerate(get_data(n)): + pred_y = model(x) + loss = jt.sqr(pred_y - y) + loss_mean = loss.mean() + optim.step (loss_mean) + print(f"step {i}, loss = {loss_mean.data.sum()}") + +assert loss_mean.data < 0.005 +``` diff --git a/python/jittor/notebook/figs/mop.svg b/python/jittor/notebook/figs/mop.svg new file mode 100644 index 00000000..06e73304 --- /dev/null +++ b/python/jittor/notebook/figs/mop.svg @@ -0,0 +1,936 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + image/svg+xml + + + + + + + + DL Models + + Common DL Operators + + Conv + + Conv + + Norm + + + + + Pool + + + + + + Meta-Operators + + + + + Reindex + Reindex Reduce + Element-Wise + + Broadcast + + Pad + + Slice + + + Backward + + Backward + + + + + + + + + Reduce + + Product + + Sum + + + + + + + + + Unary + + Binary + + Ternary + + + + + + + + + diff --git a/python/jittor/notebook/md_to_ipynb.py b/python/jittor/notebook/md_to_ipynb.py new file mode 100644 index 00000000..175b50f8 --- /dev/null +++ b/python/jittor/notebook/md_to_ipynb.py @@ -0,0 +1,67 @@ +#!python3 +import os, json +import jittor_utils as jit_utils +notebook_dir = os.path.join(jit_utils.home(), ".cache","jittor","notebook") +if not os.path.isdir(notebook_dir): + os.mkdir(notebook_dir) +dirname = os.path.dirname(__file__) +all_md = [] +for r, _, f in os.walk(dirname): + for fname in f: + if not fname.endswith(".md"): continue + all_md.append(os.path.join(r, fname)) +for mdname in all_md: + with open(os.path.join(dirname, mdname), "r", encoding="utf-8") as f: + src = f.read() + blocks = [] + for i, b in enumerate(src.split("```")): + b = b.strip() + is_markdown_block = i%2==0 + if not is_markdown_block and not b.startswith("python"): + is_markdown_block = True + b = "```\n"+b+"\n```" + if is_markdown_block: + # in a markdown block + if len(blocks)%2==0: + # prev code block + blocks.append(b) + else: + # prev markdown block + blocks[-1] += "\n\n" + b + else: + # in a code block + if b.startswith("python"): + b = b[6:].strip() + # prev markdown block + assert len(blocks)%2==1 + blocks.append(b) + cells = [] + for i, b in enumerate(blocks): + b = b.strip() + if len(b)==0: continue + b = b.split("\n") + for j in range(len(b)-1): + b[j] += '\n' + cell = { + "source": b, + "metadata": {}, + } + if i%2==0: + cell["cell_type"] = "markdown" + else: + cell["cell_type"] = "code" + cell["outputs"] = [] + cell["execution_count"] = 0 + cells.append(cell) + ipynb = { + "cells":cells, + "nbformat": 4, + "nbformat_minor": 2, + "metadata": { + }, + } + ipynb_name = os.path.basename(mdname[:-2])+"ipynb" + ipynb_name = os.path.join(notebook_dir, ipynb_name) + print(mdname, len(src), len(blocks), len(cells), "--->", ipynb_name) + with open(ipynb_name, "w", encoding='utf8') as f: + f.write(json.dumps(ipynb)) \ No newline at end of file diff --git a/python/jittor/notebook/meta_op.src.md b/python/jittor/notebook/meta_op.src.md new file mode 100644 index 00000000..d7b861be --- /dev/null +++ b/python/jittor/notebook/meta_op.src.md @@ -0,0 +1,256 @@ +# Meta-operator: Implement your own convolution with Meta-operator + +# 元算子:通过元算子实现自己的卷积层 + +Meta-operator is a key concept of jittor, The hierarchical architecture of meta-operators is shown below. + +The meta-operators are consist of reindex, reindex-reduce and element-wise operators. Reindex and reindex-reduce operators are both unary operators. The reindex operator is a one-to-many mapping between its input and output. And the reindex-reduce operator is a many-to-one mapping. Broadcast, pad and slice operators are common reindex operators. And reduce, product and sum are common reindex-reduce operators. Element-wise operator is the third component of meta-operators. Compared to the first two, element-wise operators may contain multiple inputs. But all the input and output shapes of element-wise operators must be the same. And they are one-to-one mapped. For example, the addition of two variables is a binary element-wise operator. + +元算子是jittor的关键概念,元算子的层次结构如下所示。 + +元算子由重索引算子,重索引化简算子和元素级算子组成。重索引算子,重索引化简算子都是一元算子。 重索引算子是其输入和输出之间的一对多映射。重索引简化算子是多对一映射。广播,填补, 切分算子是常见的重新索引算子。 而化简,累乘,累加算子是常见的索引化简算子。 元素级算子是元算子的第三部分,与前两个相比,元素算级子可能包含多个输入。 但是元素级算子的所有输入和输出形状必须相同,它们是一对一映射的。 例如,两个变量的加法是一个二进制的逐元素算子。 + +> ![](./figs/mop.svg) +> The hierarchical architecture of meta-operators. The meta-operators are consist of reindex, reindex-reduce and element-wise operators. Reindex and reindex-reduce are each other's backward operators. The backward operators of element-wise operators are itself. Those meta-operators are fused into common DL operations, and these DL operators further constitute the model. +> +> 元算子的层级结构。元算子包含三类算子,重索引算子,重索引化简算子,元素级算子。元算 +> 子的反向传播算子还是元算子。元算子可以组成常用的深度学习算子。而这些深度学习算子又 +> 可以进一步组成深度学习模型。 + +In the previous [example](example.ipynb), we have demonstrated how to implement matrix multiplication via three meta-operators: + +在第一个[示例](example.ipynb)中,我们演示了如何通过三个元算子实现矩阵乘法: + +``` +def matmul(a, b): + (n, m), k = a.shape, b.shape[-1] + a = a.broadcast([n,m,k], dims=[2]) + b = b.broadcast([n,m,k], dims=[0]) + return (a*b).sum(dim=1) +``` + +In this tutorial, we will show how to implement your own convolution with meta-operator. + +First, let's implement a naive Python convolution: + +在本教程中,我们将展示如何使用元算子实现自己的卷积。 + +首先,让我们实现一个朴素的Python卷积: + +``` +import numpy as np +import os +def conv_naive(x, w): + N,H,W,C = x.shape + + Kh, Kw, _C, Kc = w.shape + assert C==_C, (x.shape, w.shape) + y = np.zeros([N,H-Kh+1,W-Kw+1,Kc]) + for i0 in range(N): + for i1 in range(H-Kh+1): # dimension error + for i2 in range(W-Kw+1): + for i3 in range(Kh): + for i4 in range(Kw): + for i5 in range(C): + for i6 in range(Kc): + if i1-i3<0 or i2-i4<0 or i1-i3>=H or i2-i4>=W: continue + y[i0, i1, i2, i6] += x[i0, i1 + i3, i2 + i4, i5] * w[i3,i4,i5,i6] + return y +``` + +Then, let's download a cat image, and run `conv_naive` with a simple horizontal filte. + +然后,让我们下载一个猫的图像,并使用`conv_naive`实现一个简单的水平滤波器。 + +``` +# %matplotlib inline +import pylab as pl +img_path="/tmp/cat.jpg" +if not os.path.isfile(img_path): + !wget -O - 'https://upload.wikimedia.org/wikipedia/commons/thumb/4/4f/Felis_silvestris_catus_lying_on_rice_straw.jpg/220px-Felis_silvestris_catus_lying_on_rice_straw.jpg' > $img_path +img = pl.imread(img_path) +pl.subplot(121) +pl.imshow(img) +kernel = np.array([ + [-1, -1, -1], + [0, 0, 0], + [1, 1, 1], +]) +pl.subplot(122) +x = img[np.newaxis,:,:,:1].astype("float32") +w = kernel[:,:,np.newaxis,np.newaxis].astype("float32") +y = conv_naive(x, w) +print (x.shape, y.shape) # shape exists confusion +pl.imshow(y[0,:,:,0]) +``` +It looks good, our `naive_conv` works well. Let's replace our naive implementation with jittor. + +看起来不错,我们的`naive_conv`运作良好。现在让我们用jittor替换我们的朴素实现。 + +``` +import jittor as jt + +def conv(x, w): + N,H,W,C = x.shape + Kh, Kw, _C, Kc = w.shape + assert C==_C + xx = x.reindex([N,H-Kh+1,W-Kw+1,Kh,Kw,C,Kc], [ + 'i0', # Nid + 'i1+i3', # Hid+Khid + 'i2+i4', # Wid+KWid + 'i5', # Cid| + ]) + ww = w.broadcast_var(xx) + yy = xx*ww + y = yy.sum([3,4,5]) # Kh, Kw, c + return y + +# Let's disable tuner. This will cause jittor not to use mkl for convolution +jt.flags.enable_tuner = 0 + +jx = jt.array(x) +jw = jt.array(w) +jy = conv(jx, jw).fetch_sync() +print (jx.shape, jy.shape) +pl.imshow(jy[0,:,:,0]) +``` + +They looks the same. How about the performance? + +他们的结果看起来一样。那么它们的性能如何? + +``` +%time y = conv_naive(x, w) +%time jy = conv(jx, jw).fetch_sync() +``` + +The jittor implementation is much faster. So why this two implementation are equivalent in math, and why jittor's implementation is faster? We will explain step by step: + +First, let's take a look at the help document of `jt.reindex`. + +可以看出jittor的实现要快得多。 那么,为什么这两个实现在数学上等效,而jittor的实现运行速度更快? 我们将逐步进行解释: + +首先,让我们看一下`jt.reindex`的帮助文档。 + +``` +help(jt.reindex) +``` + +Following the document, we can expand the reindex operation for better understanding: + +遵循该文档,我们可以扩展重索引操作以便更好地理解: + +``` +py +xx = x.reindex([N,H-Kh+1,W-Kw+1,Kh,Kw,C,Kc], [ + 'i0', # Nid + 'i1+i3', # Hid+Khid + 'i2+i4', # Wid+KWid + 'i5', # Cid +]) +ww = w.broadcast_var(xx) +yy = xx*ww +y = yy.sum([3,4,5]) # Kh, Kw, C +``` + +**After expansion:** + +扩展后: + +``` +py +shape = [N,H-Kh+1,W-Kw+1,Kh,Kw,C,Kc] +# expansion of x.reindex +xx = np.zeros(shape, x.dtype) +for i0 in range(shape[0]): + for i1 in range(shape[1]): + for i2 in range(shape[2]): + for i3 in range(shape[3]): + for i4 in range(shape[4]): + for i5 in range(shape[5]): + for i6 in range(shape[6]): + if is_overflow(i0,i1,i2,i3,i4,i5,i6): + xx[i0,i1,...,in] = 0 + else: + xx[i0,i1,i2,i3,i4,i5,i6] = x[i0,i1+i3,i2+i4,i5] + +# expansion of w.broadcast_var(xx) +ww = np.zeros(shape, x.dtype) +for i0 in range(shape[0]): + for i1 in range(shape[1]): + for i2 in range(shape[2]): + for i3 in range(shape[3]): + for i4 in range(shape[4]): + for i5 in range(shape[5]): + for i6 in range(shape[6]): + ww[i0,i1,i2,i3,i4,i5,i6] = w[i3,i4,i5,i6] +# expansion of xx*ww +yy = np.zeros(shape, x.dtype) +for i0 in range(shape[0]): + for i1 in range(shape[1]): + for i2 in range(shape[2]): + for i3 in range(shape[3]): + for i4 in range(shape[4]): + for i5 in range(shape[5]): + for i6 in range(shape[6]): + yy[i0,i1,i2,i3,i4,i5,i6] = xx[i0,i1,i2,i3,i4,i5,i6] * ww[i0,i1,i2,i3,i4,i5,i6] +# expansion of yy.sum([3,4,5]) +shape2 = [N,H-Kh+1,W-Kw+1,Kc] +y = np.zeros(shape2, x.dtype) +for i0 in range(shape[0]): + for i1 in range(shape[1]): + for i2 in range(shape[2]): + for i3 in range(shape[3]): + for i4 in range(shape[4]): + for i5 in range(shape[5]): + for i6 in range(shape[6]): + y[i0,i1,i2,i6] += yy[i0,i1,i2,i3,i4,i5,i6] +``` + +**After loop fusion:** + +循环融合后: + +``` +py +shape2 = [N,H-Kh+1,W-Kw+1,Kc] +y = np.zeros(shape2, x.dtype) +for i0 in range(shape[0]): + for i1 in range(shape[1]): + for i2 in range(shape[2]): + for i3 in range(shape[3]): + for i4 in range(shape[4]): + for i5 in range(shape[5]): + for i6 in range(shape[6]): + if not is_overflow(i0,i1,i2,i3,i4,i5,i6): + y[i0,i1,i2,i6] += x[i0,i1+i3,i2+i4,i5] * w[i3,i4,i5,i6] +``` + +This is the trick of meta-operator, It can fused multiple operator into a complicated operation, including many variation of convolution (e.g. group conv, seperate conv,...). + +jittor will try to optimize the fused operator as fast as possible. Let's try some optimizations(compile the shapes as constants into the kernel), and show the underlying c++ kernel. + +这是就元算子的优化技巧,它可以将多个算子融合为一个复杂的融合算子,包括许多卷积的变化(例如group conv,separate conv等)。 + +jittor会尝试将融合算子优化得尽可能快。 让我们尝试一些优化(将形状作为常量编译到内核中),并编译到底层的c++内核代码中。 + + +``` +jt.flags.compile_options={"compile_shapes":1} +with jt.profile_scope() as report: + jy = conv(jx, jw).fetch_sync() +jt.flags.compile_options={} + +print(f"Time: {float(report[1][4])/1e6}ms") + +with open(report[1][1], 'r') as f: + print(f.read()) +``` + +Even faster than the previous implementation! From the output we can look at the function definition of func0. This is the main code of our convolution kernel, which is generated Just-in-time. Because the compiler knows the shapes of the kernel and more optimizations are used. + +比之前的实现还要更快! 从输出中我们可以看一看`func0`的函数定义,这是我们卷积内核的主要代码,该内核代码是即时生成的。因为编译器知道内核的形状,所以使用了更多的优化方法。 + +在这个教程中,Jittor简单演示了元算子的使用,并不是正真的性能测试,所以使用了比较小的数据规模进行测试,如果需要性能测试,请打开`jt.flags.enable_tuner = 1`,会启动使用专门的硬件库加速。 + +In this tutorial, Jittor simply demonstrated the use of meta-operators, which is not a performance test. If you need a performance test, `jt.flags.enable_tuner = 1` will try to use the dedicated hardware library. diff --git a/python/jittor/notebook/profiler.src.md b/python/jittor/notebook/profiler.src.md new file mode 100644 index 00000000..8b1c799f --- /dev/null +++ b/python/jittor/notebook/profiler.src.md @@ -0,0 +1,20 @@ +# Profiler: Profiling your model + +# 性能分析器:分析您的模型 + +> NOTE: This tutorial is still working in progress + +In this tutorial, we will show: +1. how to profiling your model and check the elapsed time of each operation +2. profiling the cache hit rate + +> 注意:本教程仍在持续更新中 + +在本教程中,我们将展示: + +1. 如何分析模型并检查每个运算的耗时 +2. 分析缓存命中率 + +```python +import jittor as jt +``` \ No newline at end of file diff --git a/python/jittor/optim.py b/python/jittor/optim.py new file mode 100644 index 00000000..14c1ef01 --- /dev/null +++ b/python/jittor/optim.py @@ -0,0 +1,631 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Guoye Yang <498731903@qq.com> +# Wenyang Zhou <576825820@qq.com> +# Meng-Hao Guo +# Dun Liang . +# +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +import numpy as np + +class Optimizer(object): + """ Basic class of Optimizer. + + Example:: + + optimizer = nn.SGD(model.parameters(), lr) + optimizer.step(loss) + """ + def __init__(self, params, lr, param_sync_iter=10000): + self.param_groups = [] + self.lr = lr + self.param_sync_iter = param_sync_iter + + assert len(params) > 0, "Length of parameters should not be zero" + if not isinstance(params[0], dict): + params = [{'params': params}] + for pg in params: + assert isinstance(pg, dict) + self.param_groups.append(pg) + self.n_step = 0 + # __zero_grad is a value for fast determ the grad is zero or not + # so we can omit 0+x + self.__zero_grad = True + self._grad_map = {} + + def add_param_group(self, group): + self.param_groups.append(group) + + def clip_grad_norm(self, max_norm:float, norm_type:int=2): + r"""Clips gradient norm of this optimizer. + The norm is computed over all gradients together. + + Args: + max_norm (float or int): max norm of the gradients + norm_type (int): 1-norm or 2-norm + + Example:: + + a = jt.ones(2) + opt = jt.optim.SGD([a], 0.1) + + loss = a*a + opt.zero_grad() + opt.backward(loss) + + print(opt.param_groups[0]['grads'][0].norm()) # output: 2.83 + opt.clip_grad_norm(0.01, 2) + print(opt.param_groups[0]['grads'][0].norm()) # output: 0.01 + + opt.step() + + """ + if self.__zero_grad: return + grads = [] + for pg in self.param_groups: + for p, g in zip(pg["params"], pg["grads"]): + if p.is_stop_grad(): continue + grads.append(g.flatten()) + if len(grads) == 0: return + total_norm = jt.norm(jt.concat(grads), norm_type) + clip_coef = jt.minimum(max_norm / (total_norm + 1e-6), 1.0) + for pg in self.param_groups: + for p, g in zip(pg["params"], pg["grads"]): + if p.is_stop_grad(): continue + g.update(g*clip_coef) + + @property + def defaults(self): + exclude = set(("defaults", "pre_step", "step")) + return { k:v for k, v in self.__dict__.items() + if k[0] != '_' and k not in exclude and not callable(v) } + + def state_dict(self): + state = {"defaults": self.defaults} + return state + + def load_state_dict(self, state): + + def dfs(x): + if isinstance(x, list): + for i in range(len(x)): + x[i] = dfs(x[i]) + elif isinstance(x, dict): + for k in x: + x[k] = dfs(x[k]) + elif isinstance(x, np.ndarray): + return jt.array(x).stop_grad() + elif isinstance(x, jt.Var): + return x.stop_grad() + return x + + exclude = set(("param_groups", "params")) + for k, v in state["defaults"].items(): + if k not in exclude: + setattr(self, k, dfs(v)) + param_groups = dfs(state["defaults"].get('param_groups', None)) + if param_groups is not None: + exclude = set(("params",)) + for i in range(len(param_groups)): + for k, v in param_groups[i].items(): + if k not in exclude: + self.param_groups[i][k] = v + + + + def zero_grad(self): + self.__zero_grad = True + + def backward(self, loss, retain_graph=False): + ''' + optimize.backward(loss) is used for accumulate multiple step, + it can be used as following: + + Origin source code :: + + n_iter = 10000 + batch_size = 100 + ... + for i in range(n_iter): + ... + loss = calc_loss() + optimizer.step(loss) + + Accumulation version :: + + n_iter = 10000 + batch_size = 100 + accumulation_steps = 10 + n_iter *= accumulation_steps + batch_size //= accumulation_steps + ... + for i in range(n_iter): + ... + loss = calc_loss() + # if loss is a mean across batch, we need to divide accumulation_steps + optimizer.backward(loss / accumulation_steps) + if (i+1) % accumulation_steps == 0: + optimizer.step() + + + ''' + # clean prev grads + params = [] + params_has_grad = [] + for pg in self.param_groups: + for p in pg['params']: + params.append(p) + if not p.is_stop_grad(): + params_has_grad.append(p) + + # sync prev params + jt.sync(params_has_grad) + + # get gradient + grads = jt.grad(loss, params_has_grad, retain_graph) + + # sync grads and model if in mpi + if jt.in_mpi: + dep = [] + def add_dep(v): + nonlocal dep + v._add_dependency(dep) + dep = [v] + + for g in grads: + g.assign(g.mpi_all_reduce("mean")) + add_dep(g._input(0)) + if self.n_step % self.param_sync_iter == 0: + for p in params: + p.assign(p.mpi_broadcast()) + add_dep(p) + self.n_step += 1 + + # set up grads in param_groups + pid = 0 + for pg in self.param_groups: + if "grads" not in pg: + pg["grads"] = [ jt.zeros_like(p).stop_grad().stop_fuse() for p in pg['params'] ] + pg_grads = pg["grads"] + for i, p in enumerate(pg['params']): + if not p.is_stop_grad(): + # accumulate grad and stop grad of grad + g = grads[pid].stop_grad() + if not self.__zero_grad: + g = g + pg_grads[i] + pg_grads[i].update(g) + pid += 1 + self.__zero_grad = False + + def pre_step(self, loss, retain_graph=False): + """ something should be done before step, such as calc gradients, mpi sync, and so on. + + Example:: + + class MyOptimizer(Optimizer): + def step(self, loss): + self.pre_step(loss) + ... + self.post_step() + """ + if loss is not None: + self.backward(loss, retain_graph) + jt.flags.node_order = 1 + + def post_step(self): + """ something should be done before step, such as zero grad, and so on. + + Example:: + + class MyOptimizer(Optimizer): + def step(self, loss): + self.pre_step(loss) + ... + self.post_step() + """ + jt.flags.node_order = 0 + self.zero_grad() + + + def step(self, loss=None, retain_graph=False): + self.pre_step(loss, retain_graph) + for pg in self.param_groups: + lr = pg.get("lr", self.lr) + for p, g in zip(pg["params"], pg["grads"]): + if p.is_stop_grad(): continue + p.update(p - g * lr) + self.post_step() + + def _build_grad_map(self): + _grad_map = {} + for pg in self.param_groups: + for p, g in zip(pg["params"], pg["grads"]): + _grad_map[id(p)] = g + self._grad_map = _grad_map + + def find_grad(self, v:jt.Var) -> jt.Var: + if id(v) not in self._grad_map: + self._build_grad_map() + if id(v) not in self._grad_map: + raise RuntimeError("This variable is not managed by this optimizer") + return self._grad_map[id(v)] + +def opt_grad(v:jt.Var, opt:Optimizer): + ''' Get grad of certain variable in optimizer, Example:: + + + model = Model() + optimizer = SGD(model.parameters()) + ... + optimizer.backward(loss) + + for p in model.parameters(): + grad = p.opt_grad(optimizer) + ''' + return opt.find_grad(v) + +jt.Var.opt_grad = opt_grad + +class SGD(Optimizer): + """ SGD Optimizer. + + Example:: + + optimizer = nn.SGD(model.parameters(), lr, momentum=0.9) + optimizer.step(loss) + """ + def __init__(self, params, lr, momentum=0, weight_decay=0, dampening=0, nesterov=False): + super().__init__(params, lr) + self.momentum = momentum + self.weight_decay = weight_decay + self.dampening = dampening + self.nesterov = nesterov + + # initialize required arguments + for pg in self.param_groups: + values = pg["values"] = [] + for p in pg["params"]: + values.append(jt.zeros(p.shape, p.dtype).stop_grad()) + + def add_param_group(self, group): + values = group["values"] = [] + for p in group["params"]: + values.append(jt.zeros(p.shape, p.dtype).stop_grad()) + self.param_groups.append(group) + + def step(self, loss=None, retain_graph=False): + self.pre_step(loss, retain_graph=False) + jt.flags.node_order = 1 + for pg in self.param_groups: + # get arguments from each param_groups + lr = pg.get("lr", self.lr) + momentum = pg.get("momentum", self.momentum) + weight_decay = pg.get("weight_decay", self.weight_decay) + dampening = pg.get("dampening", self.dampening) + nesterov = pg.get("nesterov", self.nesterov) + + # optimize main body + for p, g, v in zip(pg["params"], pg["grads"], pg["values"]): + if p.is_stop_grad(): continue + dp = p * weight_decay + g + v.update(momentum * v + dp * (1 - dampening)) + if nesterov: + p.update(p - (dp + momentum * v) * lr) + else: + p.update(p - v * lr) + self.post_step() + +class RMSprop(Optimizer): + """ RMSprop Optimizer. + Args: + params(list): parameters of model. + lr(float): learning rate. + eps(float): term added to the denominator to avoid division by zero, default 1e-8. + alpha(float): smoothing constant, default 0.99. + + Example: + optimizer = nn.RMSprop(model.parameters(), lr) + optimizer.step(loss) + """ + def __init__(self, params, lr=1e-2, eps=1e-8, alpha=0.99): + super().__init__(params, lr) + self.eps = eps + self.alpha = alpha + + # initialize required arguments for each param_groups + for pg in self.param_groups: + values = pg["values"] = [] + for p in pg["params"]: + values.append(jt.zeros(p.shape, p.dtype).stop_grad()) + + def add_param_group(self, group): + values = group["values"] = [] + for p in group["params"]: + values.append(jt.zeros(p.shape, p.dtype).stop_grad()) + self.param_groups.append(group) + + def step(self, loss=None, retain_graph=False): + self.pre_step(loss, retain_graph) + for pg in self.param_groups: + # get arguments from each param_groups + lr = pg.get("lr", self.lr) + eps = pg.get("eps", self.eps) + alpha = pg.get("alpha", self.alpha) + for p, g, v in zip(pg["params"], pg["grads"], pg["values"]): + if p.is_stop_grad(): continue + v.update(alpha * v + (1-alpha) * g * g) + p.update(p - lr * g / (jt.sqrt(v) + eps)) + self.post_step() + +class Adam(Optimizer): + """ Adam Optimizer. + + Example:: + + optimizer = nn.Adam(model.parameters(), lr, eps=1e-8, betas=(0.9, 0.999)) + optimizer.step(loss) + """ + def __init__(self, params, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0): + super().__init__(params, lr) + self.eps = eps + self.betas = betas + self.weight_decay = weight_decay + # assert weight_decay==0, "weight_decay is not supported yet" + + # initialize required arguments for each param_groups + for pg in self.param_groups: + values = pg["values"] = [] + m = pg["m"] = [] + for p in pg["params"]: + values.append(jt.zeros(p.shape, p.dtype).stop_grad()) + m.append(jt.zeros(p.shape, p.dtype).stop_grad()) + + def add_param_group(self, group): + values = group["values"] = [] + m = group["m"] = [] + for p in group["params"]: + values.append(jt.zeros(p.shape, p.dtype).stop_grad()) + m.append(jt.zeros(p.shape, p.dtype).stop_grad()) + self.param_groups.append(group) + + def step(self, loss=None, retain_graph=False): + self.pre_step(loss, retain_graph) + n = float(self.n_step) + jt.flags.node_order = 1 + for pg in self.param_groups: + # get arguments from each param_groups + lr = pg.get("lr", self.lr) + eps = pg.get("eps", self.eps) + weight_decay = pg.get("weight_decay", self.weight_decay) + b0, b1 = pg.get("betas", self.betas) + for p, g, v, m in zip(pg["params"], pg["grads"], pg["values"], pg["m"]): + if p.is_stop_grad(): continue + g = p * weight_decay + g + m.update(b0 * m + (1-b0) * g) + v.update(b1 * v + (1-b1) * g * g) + step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n) + p.update(p - m * step_size / (jt.sqrt(v) + eps)) + self.post_step() + + +class AdamW(Optimizer): + """ AdamW Optimizer. + + Example:: + + optimizer = nn.AdamW(model.parameters(), lr, eps=1e-8, betas=(0.9, 0.999)) + optimizer.step(loss) + """ + def __init__(self, params, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0): + super().__init__(params, lr) + self.eps = eps + self.betas = betas + self.weight_decay = weight_decay + # assert weight_decay==0, "weight_decay is not supported yet" + + # initialize required arguments for each param_groups + for pg in self.param_groups: + values = pg["values"] = [] + m = pg["m"] = [] + for p in pg["params"]: + values.append(jt.zeros(p.shape, p.dtype).stop_grad()) + m.append(jt.zeros(p.shape, p.dtype).stop_grad()) + + def add_param_group(self, group): + values = group["values"] = [] + m = group["m"] = [] + for p in group["params"]: + values.append(jt.zeros(p.shape, p.dtype).stop_grad()) + m.append(jt.zeros(p.shape, p.dtype).stop_grad()) + self.param_groups.append(group) + + def step(self, loss=None, retain_graph=False): + self.pre_step(loss, retain_graph) + n = float(self.n_step) + for pg in self.param_groups: + # get arguments from each param_groups + lr = pg.get("lr", self.lr) + eps = pg.get("eps", self.eps) + weight_decay = pg.get("weight_decay", self.weight_decay) + b0, b1 = pg.get("betas", self.betas) + for p, g, v, m in zip(pg["params"], pg["grads"], pg["values"], pg["m"]): + if p.is_stop_grad(): continue + p.update(p * (1 - lr * weight_decay)) + bias_correction1 = 1 - b0 ** n + bias_correction2 = 1 - b1 ** n + m.update(b0 * m + (1-b0) * g) #exp_avg + v.update(b1 * v + (1-b1) * g * g) #exp_avg_sq + denom = jt.sqrt(v) / jt.sqrt(bias_correction2) + eps + step_size = lr / bias_correction1 + p.update(p - step_size * m / denom) + self.post_step() + + +class Adan(Optimizer): + """ Adan Optimizer. + Adan was proposed in + Adan: Adaptive Nesterov Momentum Algorithm for + Faster Optimizing Deep Models[J].arXiv preprint arXiv:2208.06677, 2022. + https://arxiv.org/abs/2208.06677 + Adan is an efficient optimizer for most DNN frameworks: + - About 2x fewer computational load than SOTAs + - Robust to training setting and batch size + - Easy to Plug-and-play + + Arguments: + params (iterable): iterable of parameters to optimize or + dicts defining parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float, flot], optional): coefficients used for + first- and second-order moments. (default: (0.98, 0.92, 0.99)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): decoupled weight decay + (L2 penalty) (default: 0) + max_grad_norm (float, optional): value used to clip + global grad norm (default: 0.0 no clip) + """ + def __init__(self, params, lr=1e-3, betas=(0.98, 0.92, 0.99), + eps=1e-8, weight_decay=0.0, max_grad_norm=0.0): + super().__init__(params, lr) + self.betas = betas + self.eps = eps + self.weight_decay = weight_decay + self.max_grad_norm = max_grad_norm + + for pg in self.param_groups: + pg["m"] = [] + pg["v"] = [] + pg["d"] = [] + pg["pre_grad"] = [] + for p in pg["params"]: + pg["m"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + pg["v"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + pg["d"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + pg["pre_grad"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + + + def add_param_group(self, group): + group["m"] = [] + group["v"] = [] + group["d"] = [] + group["pre_grad"] = [] + for p in group["params"]: + group["m"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + group["v"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + group["d"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + group["pre_grad"].append(jt.zeros(p.shape, p.dtype).stop_grad()) + self.param_groups.append(group) + + def step(self, loss=None, retain_graph=False): + self.pre_step(loss, retain_graph) + n = float(self.n_step) + for pg in self.param_groups: + lr = pg.get("lr", self.lr) + betas = pg.get("betas", self.betas) + eps = pg.get("eps", self.eps) + weight_decay = pg.get("weight_decay", self.weight_decay) + max_grad_norm = pg.get("max_grad_norm", self.max_grad_norm) + if max_grad_norm>0: self.clip_grad_norm(max_grad_norm) + beta1, beta2, beta3 = betas + + bias_correction1 = 1 - beta1 ** n + bias_correction2 = 1 - beta2 ** n + bias_correction3_sqrt = jt.sqrt(1 - beta3 ** n) + + + step_size_diff = lr * beta2 * bias_correction3_sqrt / bias_correction2 + step_size = lr * bias_correction3_sqrt / bias_correction1 + eps_bias_sqrt = eps * bias_correction3_sqrt + + for p, g, m, v, d, pre_g in zip(pg["params"], + pg["grads"], + pg["m"], + pg["v"], + pg["d"], + pg["pre_grad"]): + if p.is_stop_grad(): continue + + if self.n_step>0: + pre_g.update(g - pre_g) # Update pre_g as grad_diff + + + m.update(beta1 * m + (1 - beta1) * g) + d.update(beta2 * d + (1 - beta2) * pre_g) # Use pre_g as grad_diff + + pre_g.update(jt.multiply(pre_g, beta2) + g) # Update pre_g as update (g + beta2 * grad_diff) + + v.update(beta3 * v + (1 - beta3) * pre_g * pre_g) # Use pre_g as update + + p.update(p - (step_size * m + step_size_diff * d) / (jt.sqrt(v) + eps_bias_sqrt)) + p.update(p / (1 + lr * weight_decay)) + + pre_g.update(g) # Update pre_g for the next iteration + self.post_step() + + +class LRScheduler: + def __init__(self,optimizer, last_epoch=-1): + assert isinstance(optimizer,Optimizer) + self.optimizer = optimizer + + if last_epoch==-1: + for gp in optimizer.param_groups: + gp.setdefault('initial_lr',gp.get('lr',optimizer.lr)) + else: + for gp in optimizer.param_groups: + assert 'initial_lr' in gp + + self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) + self.last_epoch = last_epoch + self.optimizer._step_count = 0 + self._step_count = 0 + self.step() + + def get_lr(self): + raise NotImplementedError + + def get_last_lr(self): + return self._last_lr + + def step(self,epoch=None): + self._step_count += 1 + + if epoch is None: + self.last_epoch += 1 + values = self.get_lr() + else: + self.last_epoch = epoch + values = self.get_lr() + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group['lr'] = lr + + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + + +class LambdaLR(LRScheduler): + + def __init__(self, optimizer, lr_lambda, last_epoch=-1): + if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): + self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) + else: + if len(lr_lambda) != len(optimizer.param_groups): + raise ValueError("Expected {} lr_lambdas, but got {}".format(len(optimizer.param_groups), len(lr_lambda))) + + self.lr_lambdas = list(lr_lambda) + + super(LambdaLR, self).__init__(optimizer, last_epoch) + + + + def get_lr(self): + return [base_lr * lmbda(self.last_epoch) + for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] diff --git a/python/jittor/other/code_softmax.py b/python/jittor/other/code_softmax.py new file mode 100644 index 00000000..1acf73d8 --- /dev/null +++ b/python/jittor/other/code_softmax.py @@ -0,0 +1,149 @@ +import jittor as jt +from jittor import nn + +def can_softmax_v1(a, dim): + if not jt.flags.use_cuda: + return False + if dim != -1 and dim != len(a.shape)-1: + return False + if a.shape[len(a.shape)-1] > 10000: + return False + return True + +def softmax_v1(a, log=False): + assert can_softmax_v1(a, -1) + length = a.shape[-1] + # tnum = 1024 + tnum = 500 if length % 500 == 0 else 512 + tnum = 125 if length % 125 == 0 else 128 + # tnum = 125 + # tnum = 1000 if length % 1000 == 0 else 1024 + # tnum = 250 + per_thread = (length-1) // tnum + 1 + ILP = 1 + for ilp in [8,4,2]: + if length % tnum == 0 and per_thread % ilp == 0: + ILP = ilp + per_thread //= ILP + break + for_loop = f""" + #pragma unroll + for (int i=0; i<{per_thread}; i++) + """ + if length % tnum != 0: + for_loop += f"if ((i*{tnum}+threadIdx.x)*{ILP} < len)\n" + + class CodeSoftmax(jt.Function): + def execute(self, x): + self.save_vars = jt.code(x.shape, x.dtype, [x], cuda_header=f''' +#include <{jt.compile_extern.cub_home}cub/cub.cuh> +#include +''', cuda_src=f''' +__global__ void kernel(in0_type* x, out0_type* y, int len) {{ + typedef cub::BlockReduce BlockReduce; + constexpr int need_log = {int(log)}; + __shared__ typename BlockReduce::TempStorage temp_storage; + + int id = blockIdx.x * len; + in0_type v[{per_thread}][{ILP}]; + {for_loop} + vload(v[i], &x[id+(i*{tnum}+threadIdx.x)*{ILP}]); + // v[i] = x[id+i*{tnum}+threadIdx.x]; + float v1 = -1e30; + {for_loop} + #pragma unroll + for (int j=0; j<{ILP}; j++) {{ + v1 = ::max(v1, float(v[i][j])); + }} + __shared__ float vmax; + auto tmp = BlockReduce(temp_storage).Reduce(v1, cub::Max()); + if (threadIdx.x == 0) + vmax = tmp; + __syncthreads(); + + v1 = 0; + {for_loop} + #pragma unroll + for (int j=0; j<{ILP}; j++) {{ + if (need_log) {{ + v[i][j] = float(v[i][j]) - vmax; + v1 += expf(float(v[i][j])); + }} else {{ + v[i][j] = expf(float(v[i][j]) - vmax); + v1 += float(v[i][j]); + }} + }} + + tmp = BlockReduce(temp_storage).Sum(v1); + __shared__ float vsum; + if (threadIdx.x == 0) + vsum = tmp; + __syncthreads(); + + {for_loop} + #pragma unroll + for (int j=0; j<{ILP}; j++) {{ + if (need_log) + v[i][j] = v[i][j] - @expand_op(log,@in0_type,vsum); + else + v[i][j] = float(v[i][j])/vsum; + }} + {for_loop} + vload(&y[id+(i*{tnum}+threadIdx.x)*{ILP}], v[i]); +}} +int len = in0->shape[in0->shape.size()-1]; +int bnum = in0->numel() / len; +cudaGetLastError(); +kernel<<>>(in0_p, out0_p, len); +CHECK(0 == cudaGetLastError()); +''') + return self.save_vars + + def grad(self, grad_x): + x = self.save_vars + return jt.code(x.shape, x.dtype, [x, grad_x], cuda_header=f''' +#include <{jt.compile_extern.cub_home}cub/cub.cuh> +#include +''', + cuda_src=f""" +__global__ void kernel(in0_type* x, in1_type* y, out0_type* z, int len) {{ + int id = blockIdx.x * len; + in0_type vx[{per_thread}][{ILP}]; + in0_type vy[{per_thread}][{ILP}]; + {for_loop} {{ + vload(vx[i], &x[id+(i*{tnum}+threadIdx.x)*{ILP}]); + vload(vy[i], &y[id+(i*{tnum}+threadIdx.x)*{ILP}]); + }} + float v1 = 0; + {for_loop} + #pragma unroll + for (int j=0; j<{ILP}; j++) + v1 += {"float(vy[i][j]);" if log else "float(vx[i][j]*vy[i][j]);"} + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + auto tmp = BlockReduce(temp_storage).Sum(v1); + __shared__ float reduce_var; + if (threadIdx.x == 0) + reduce_var = tmp; + __syncthreads(); + + {for_loop} + #pragma unroll + for (int j=0; j<{ILP}; j++) + vx[i][j] = { + "vy[i][j] - in0_type(expf(vx[i][j]) * reduce_var);" if log + else "vx[i][j] * (vy[i][j] - in0_type(reduce_var));" + } + + {for_loop} + vload(&z[id+(i*{tnum}+threadIdx.x)*{ILP}], + vx[i]); +}} +int len = in0->shape[in0->shape.size()-1]; +int bnum = in0->numel() / len; +cudaGetLastError(); +kernel<<>>(in0_p, in1_p, out0_p, len); +CHECK(0 == cudaGetLastError()); +""") + return CodeSoftmax()(a) \ No newline at end of file diff --git a/python/jittor/pool.py b/python/jittor/pool.py new file mode 100644 index 00000000..7e9e808a --- /dev/null +++ b/python/jittor/pool.py @@ -0,0 +1,681 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Wenyang Zhou <576825820@qq.com> +# Meng-Hao Guo +# Dun Liang . +# +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +from jittor import init, Module +import numpy as np +import math + +pool_use_code_op = True + +class Pool(Module): + def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=True, op="maximum"): + assert dilation == None + assert return_indices == None or op == "maximum" + self.return_indices = return_indices + self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) + self.op = op + stride = stride if stride else kernel_size + self.stride = stride if isinstance(stride, tuple) else (stride, stride) + self.padding = padding if isinstance(padding, tuple) else (padding, padding) + self.ceil_mode = ceil_mode + self.count_include_pad = count_include_pad and padding != 0 + for item in self.kernel_size: + if item <= 0: + raise RuntimeError(f"kernel_size must be greater than zero, but got {item}") + for item in self.stride: + if item <= 0: + raise RuntimeError(f"stride must be greater than zero, but got {item}") + for item in self.padding: + if item < 0: + raise RuntimeError(f"padding must be non-negative, but got {item}") + + def execute(self, x): + N,C,H,W = x.shape + if H <= self.kernel_size[0] or W <= self.kernel_size[1]: + raise RuntimeError(f"size of var should be larger than kernel_size") + if self.ceil_mode == False: + h = (H+self.padding[0]*2-self.kernel_size[0])//self.stride[0]+1 + w = (W+self.padding[1]*2-self.kernel_size[1])//self.stride[1]+1 + use_code_op = self.op in ['maximum', 'minimum'] + # some second order avg_pool is require, so we don't use code op here + else: + h = (H+self.padding[0]*2-self.kernel_size[0] + self.stride[0] - 1)//self.stride[0]+1 + w = (W+self.padding[1]*2-self.kernel_size[1] + self.stride[1] - 1)//self.stride[1]+1 + use_code_op = self.op in ['maximum', 'minimum', 'mean'] + + if use_code_op and pool_use_code_op: + if self.op == 'mean': + if self.count_include_pad: + count = f"int count = {self.kernel_size[0]*self.kernel_size[1]};" + else: + count = "int count = (k2_ - k2) * (k3_ - k3);" + count += "float32 rcount = 1.0f / count;" + else: + count = "" + forward_body = f''' + int k3 = i3*{self.stride[1]}-{self.padding[1]}; + int k2 = i2*{self.stride[0]}-{self.padding[0]}; + int k3_ = min(k3 + {self.kernel_size[1]}, in0_shape3); + int k2_ = min(k2 + {self.kernel_size[0]}, in0_shape2); + k3 = max(0, k3); + k2 = max(0, k2); + {count} + ''' + if not self.return_indices: + forward_body += f''' + @out(i0, i1, i2, i3) = @expand_op(init_{self.op}, @out_type); + for (int p = k2; p < k2_; ++p) + for (int q = k3; q < k3_; ++q) + @out(i0, i1, i2, i3) = @expand_op({self.op}, @out_type, @out(i0, i1, i2, i3), @out_type, @in0(i0, i1, p, q), @in0_type); + ''' + else: + forward_body += f''' + auto out_value = @expand_op(init_{self.op}, @out_type); + int out_index = -1; + for (int p = k2; p < k2_; ++p) + for (int q = k3; q < k3_; ++q) + if (out_value < @in0(i0, i1, p, q)) {{ + out_value = @in0(i0, i1, p, q); + out_index = p * in0_shape3 + q; + }} + @out(i0, i1, i2, i3) = out_value; + @out1(i0, i1, i2, i3) = out_index; + ''' + backward_body = f''' + int k3 = i3*{self.stride[1]}-{self.padding[1]}; + int k2 = i2*{self.stride[0]}-{self.padding[0]}; + int k3_ = min(k3 + {self.kernel_size[1]}, in0_shape3); + int k2_ = min(k2 + {self.kernel_size[0]}, in0_shape2); + k3 = max(0, k3); + k2 = max(0, k2); + {count} + int bo=1; + for (int p = k2; p < k2_ && bo; ++p) + for (int q = k3; q < k3_ && bo; ++q) {{ + {"atomicAdd(&@out(i0,i1,p,q), @dout(i0,i1,i2,i3)/count);" + if self.op == "mean" else + f"""if (@pout(i0,i1,i2,i3) == @in0(i0,i1,p,q)) {{ + atomicAdd(&@out(i0,i1,p,q), @dout(i0,i1,i2,i3)), + bo=0; + }}"""} + }} + ''' + if self.return_indices: + return_shapes = [[N,C,h,w]] * 2 + return_dtypes = [x.dtype, 'int32'] + else: + return_shapes = [N,C,h,w] + return_dtypes = x.dtype + out = jt.code(return_shapes, return_dtypes, [x], + cuda_header=""" + #include + """, + cuda_src=f''' + __global__ static void kernel1(@ARGS_DEF) {{ + @PRECALC + int p3 = threadIdx.x; + int s3 = blockDim.x; + int p2 = threadIdx.y + blockIdx.x * blockDim.y; + int s2 = blockDim.y * gridDim.x; + int i1 = blockIdx.y; + int i0 = blockIdx.z; + for (int i3 = p3; i3 < out_shape3; i3 += s3) + for (int i2 = p2; i2 < out_shape2; i2 += s2) + {{ {forward_body} }} + }} + int tx = std::min(1024, out_shape3); + int ty = std::min(1024 / tx, out_shape2); + int bx = (out_shape2 - 1) / ty + 1; + int by = out_shape1; + int bz = out_shape0; + dim3 s1(bx, by, bz); + dim3 s2(tx, ty); + kernel1<<>>(@ARGS); + ''', + cuda_grad_src=[f''' + __global__ static void kernel3(@ARGS_DEF) {{ + @PRECALC + int p3 = threadIdx.x; + int s3 = blockDim.x; + int p2 = threadIdx.y + blockIdx.x * blockDim.y; + int s2 = blockDim.y * gridDim.x; + int i1 = blockIdx.y; + int i0 = blockIdx.z; + for (int i3 = p3; i3 < pout_shape3; i3 += s3) + for (int i2 = p2; i2 < pout_shape2; i2 += s2) + {{ {backward_body} }} + }} + cudaMemsetAsync(out_p, 0, out->size); + int tx = std::min(1024, pout_shape3); + int ty = std::min(1024 / tx, pout_shape2); + int bx = (pout_shape2 - 1) / ty + 1; + int by = pout_shape1; + int bz = pout_shape0; + dim3 s1_(bx, by, bz); + dim3 s2_(tx, ty); + kernel3<<>>(@ARGS); + '''], + cpu_header='', + cpu_src=f''' + using namespace std; + for (int i0=0; i0size); + #define atomicAdd(a,b) (*a) += b + + for (int i0=0; i0 + """, + cuda_src=f''' + __global__ static void kernel1(@ARGS_DEF) {{ + @PRECALC + int p4 = threadIdx.x; + int s4 = blockDim.x; + int p3 = threadIdx.y; + int s3 = blockDim.y; + int p2 = threadIdx.z + blockIdx.x * blockDim.z; + int s2 = blockDim.z * gridDim.x; + int i1 = blockIdx.y; + int i0 = blockIdx.z; + for (int i4 = p4; i4 < out_shape4; i4 += s4) + for (int i3 = p3; i3 < out_shape3; i3 += s3) + for (int i2 = p2; i2 < out_shape2; i2 += s2) + {{ {forward_body} }} + }} + int tx = std::min(1024, out_shape4); + int ty = std::min(1024 / tx, out_shape3); + int tz = std::min(1024 / tx / ty, out_shape2); + int bx = (out_shape2 - 1) / tz + 1; + int by = out_shape1; + int bz = out_shape0; + dim3 s1(bx, by, bz); + dim3 s2(tx, ty, tz); + kernel1<<>>(@ARGS); + ''', + cuda_grad_src=[f''' + __global__ static void kernel3(@ARGS_DEF) {{ + @PRECALC + int p4 = threadIdx.x; + int s4 = blockDim.x; + int p3 = threadIdx.y; + int s3 = blockDim.y; + int p2 = threadIdx.z + blockIdx.x * blockDim.z; + int s2 = blockDim.z * gridDim.x; + int i1 = blockIdx.y; + int i0 = blockIdx.z; + for (int i4 = p4; i4 < out_shape4; i4 += s4) + for (int i3 = p3; i3 < out_shape3; i3 += s3) + for (int i2 = p2; i2 < out_shape2; i2 += s2) + {{ {backward_body} }} + }} + cudaMemsetAsync(out_p, 0, out->size); + int tx = std::min(1024, pout_shape4); + int ty = std::min(1024 / tx, pout_shape3); + int tz = std::min(1024 / tx / ty, pout_shape2); + int bx = (pout_shape2 - 1) / tz + 1; + int by = pout_shape1; + int bz = pout_shape0; + dim3 s1(bx, by, bz); + dim3 s2(tx, ty, tz); + kernel3<<>>(@ARGS); + '''], + cpu_header='', + cpu_src=f''' + using namespace std; + for (int i0=0; i0size); + #define atomicAdd(a,b) (*a) += b + + for (int i0=0; i0>> import jittor as jt + >>> from jittor import nn + + >>> pool = nn.MaxPool2d(2, stride=2, return_indices=True) + >>> unpool = nn.MaxUnpool2d(2, stride=2) + >>> input = jt.array([[[[ 1., 2, 3, 4,0], + [ 5, 6, 7, 8,0], + [ 9, 10, 11, 12,0], + [13, 14, 15, 16,0], + [0, 0, 0, 0, 0]]]]) + >>> output, indices = pool(input) + >>> unpool(output, indices, output_size=input.shape) + jt.array([[[[ 0., 0., 0., 0., 0.], + [ 0., 6., 0., 8., 0.], + [ 0., 0., 0., 0., 0.], + [ 0., 14., 0., 16., 0.], + [ 0., 0., 0., 0., 0.]]]]) + ''' + def __init__(self, kernel_size, stride=None): + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + if isinstance(stride, int): + stride = (stride, stride) + if stride is None: stride = kernel_size + self.kernel_size = kernel_size + self.stride = stride + if self.kernel_size[0] <= 0 or self.kernel_size[1] <= 0: + raise RuntimeError(f"kernel_size must be greater than zero, but got {kernel_size}") + if self.stride[0] <= 0 or self.stride[1] <= 0: + raise RuntimeError(f"stride must be greater than zero, but got {stride}") + + def execute(self, x, id, output_size=None): + b, c, ph, pw = x.shape + kh, kw = self.kernel_size + sh, sw = self.stride + if output_size: + h, w = output_size[-2:] + else: + h, w = ph * sh, pw * sw + if self.stride == self.kernel_size: + x = x.reindex(shape=[b, c, h, w], + indexes=['i0', 'i1', f'i2/{kh}', f'i3/{kw}'], + extras=[id], + overflow_conditions=[ + f'(i2*yshape3+i3) != @e0(i0,i1,i2/{kh},i3/{kw})'], + overflow_value=0) + else: + x = x.reindex_reduce( + op="add", + shape=[b, c, h, w], + indexes=['i0', 'i1', + f'@e0(i0,i1,i2,i3)/xshape3', + f'@e0(i0,i1,i2,i3)%xshape3'], + extras=[id], + ) + return x + +class MaxUnpool3d(Module): + ''' MaxUnpool3d is the invert version of MaxPool3d with indices. + It takes the output index of MaxPool3d as input. + The element will be zero if it is not the max pooled value. + ''' + def __init__(self, kernel_size, stride=None): + if stride is None: stride = kernel_size + kernel_size = _triple(kernel_size) + stride = _triple(stride) + self.kernel_size = kernel_size + self.stride = stride + if self.kernel_size[0] <= 0 or self.kernel_size[1] <= 0 or self.kernel_size[2] <= 0: + raise RuntimeError(f"kernel_size must be greater than zero, but got {kernel_size}") + if self.stride[0] <= 0 or self.stride[1] <= 0 or self.stride[2] <= 0: + raise RuntimeError(f"stride must be greater than zero, but got {stride}") + + def execute(self, x, id, output_size=None): + b, c, pd, ph, pw = x.shape + kd, kh, kw = self.kernel_size + sd, sh, sw = self.stride + if output_size: + d, h, w = output_size[-3:] + else: + d, h, w = pd * sd, ph * sh, pw * sw + if self.stride == self.kernel_size: + x = x.reindex(shape=[b, c, d, h, w], + indexes=['i0', 'i1', f'i2/{kd}', f'i3/{kh}', f'i4/{kw}'], + extras=[id], + overflow_conditions=[ + f'(i2*yshape3*yshape4+i3*yshape4+i4) != @e0(i0,i1,i2/{kd},i3/{kh},i4/{kw})'], + overflow_value=0) + else: + x = x.reindex_reduce( + op="add", + shape=[b, c, d, h, w], + indexes=['i0', 'i1', + f'@e0(i0,i1,i2,i3,i4)/(xshape4*xshape3)', + f'@e0(i0,i1,i2,i3,i4)/xshape4%xshape3', + f'@e0(i0,i1,i2,i3,i4)%xshape4'], + extras=[id], + ) + return x diff --git a/python/jittor/pyjt_compiler.py b/python/jittor/pyjt_compiler.py new file mode 100644 index 00000000..c003d864 --- /dev/null +++ b/python/jittor/pyjt_compiler.py @@ -0,0 +1,938 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import re +import os +from jittor_utils import LOG, run_cmd, simple_timer +import json +from collections import OrderedDict +import glob + +def parse_attrs(s): + '''parse @attrs(..., x=y) syntax''' + attrs = {} + if s is None: return attrs + for a in s.split(','): + a = a.strip() + if len(a)==0: continue + if '=' in a: + k, v = a.split('=') + attrs[k] = v + else: + attrs[a] = 1 + return attrs + + +pytype_map = { + "const char*": ["PyUnicode_AsUTF8", "PyUnicode_FromString", "PyUnicode_CheckExact"], + "int": ["PyLong_AsLong", "PyLong_FromLong", "PyLong_CheckExact"], + "int64": ["PyLong_AsLongLong", "PyLong_FromLongLong", "PyLong_CheckExact"], + "uint": ["PyLong_AsUnsignedLong", "PyLong_FromUnsignedLong", "PyLong_CheckExact"], + "uint8": ["PyLong_AsUnsignedLong", "PyLong_FromUnsignedLong", "PyLong_CheckExact"], + "uint16": ["PyLong_AsUnsignedLong", "PyLong_FromUnsignedLong", "PyLong_CheckExact"], + "uint64": ["PyLong_AsUnsignedLongLong", "PyLong_FromUnsignedLongLong", "PyLong_CheckExact"], + "void": ["...", "GET_PY_NONE", "..."], + "PyObject*": ["","",""], +} +def get_pytype_map(T, i): + assert T != "" + if T in pytype_map: + return pytype_map[T][i] + return ["from_py_object", "to_py_object", "is_type"][i]+"<"+T+">" + +binary_number_slots = { + "__add__": "nb_add", + "__sub__": "nb_subtract", + "__mul__": "nb_multiply", + "__mod__": "nb_remainder", + "__divmod__": "nb_divmod", + "__pow__": "nb_power", + "__lshift__": "nb_lshift", + "__rshift__": "nb_rshift", + "__and__": "nb_and", + "__xor__": "nb_xor", + "__or__": "nb_or", + "__floordiv__": "nb_floor_divide", + "__truediv__": "nb_true_divide", + "__matmul__": "nb_matrix_multiply", +} + +for k,v in list(binary_number_slots.items()): + # __add__: nb_add ----> __iadd: nb_inplace_add + binary_number_slots["__i"+k[2:]] = "nb_inplace"+v[2:] + +unary_number_slots = { + "__neg__": "nb_negative", + "__abs__": "nb_absolute", +} + +def split_args(s): + # split args xxx,xxx, xx, xx + s = s.strip() + if s=="": return [] + prev = -1 + presum = 0 + args = [] + for i in range(len(s)): + if s[i]=='<': + presum += 1 + elif s[i]=='>': + presum -= 1 + if presum==0 and s[i]==',': + args.append(s[prev+1:i]) + prev = i + args.append(s[prev+1:]) + return args + +def get_def_code(df, scope_name, pyname, self_as_arg0=False): + is_fast_call = not pyname.startswith("__") + no_need_convert = pyname == "__getitem__" + args = df["args"] + # n==1 && PyXXX__CheckExact(args[0]) && ... + max_args = len(args) + min_args = max_args + for tid, a in enumerate(args): + if a[2] != "": + min_args = tid + break + arg_names = [ f"args[{i}]" for i in range(len(args))] + if self_as_arg0: + max_args -= 1 + min_args -= 1 + arg_names = ["self"] + arg_names[:-1] + kw_args_id = [] + for aid, arg in enumerate(args): + if "VarHolder*" != arg[0] and is_fast_call: + kw_args_id.append(aid) + func_quick_check_runable = "" + func_quick_check_size = f"n<={max_args} && n>={min_args}" + if len(kw_args_id): + func_quick_check_size = f"n+(kw?Py_SIZE(kw):0)<={max_args} && n+(kw?Py_SIZE(kw):0)>={min_args}" + fill_with_default = "" + func_args_convert = "" + func_call = df["func_name"]+"(" + pytypes = [ get_pytype_map(a[0],0) for a in args ] + holder_dec_array = [] + holder_set_array = [] + for tid, tpc in enumerate(pytypes): + check = get_pytype_map(args[tid][0],2) + default_arg = args[tid][2] + jtp = args[tid][0] + holder_dec = "" + holder_set = "" + if jtp == "VarHolder*": + holder_dec = f"unique_ptr arg{tid}_holder" + holder_set = f", arg{tid}_holder" + if jtp == "VarSlices": + holder_dec = f"vector> arg{tid}_holder" + holder_set = f", arg{tid}_holder" + holder_dec_array.append(holder_dec) + holder_set_array.append(holder_set) + if len(default_arg): + func_args_convert += f""" + {holder_dec}; + {jtp} arg{tid}; + if (n>{tid-self_as_arg0}) {{ + CHECK(({check}({arg_names[tid]}))); + arg{tid} = {tpc}({arg_names[tid]}{holder_set}); + arg_filled |= 1ull << {tid}; + }} + """ + fill_with_default += f""" + if (!(arg_filled & (1ull<<{tid}))) {{ + arg{tid} = {default_arg}; + }} + """ + else: + func_quick_check_runable += f" && {check}({arg_names[tid]})" + func_args_convert += f""" + {holder_dec}; + {jtp} arg{tid} = {tpc}({arg_names[tid]}{holder_set}); + """ + if tid: func_call += "," + if args[tid][3].endswith("&&"): + func_call += f"move(arg{tid})" + else: + func_call += f"arg{tid}" + if pyname == "__richcmp__": + for rname in [ "__lt__", "__le__", "__gt__", + "__ge__", "__eq__", "__ne__"]: + if rname in df["attrs"]: + func_quick_check_runable += " && op==Py_"+rname[2:-2].upper() + # fill args with keyword arguments + fill_with_kw = "" + if is_fast_call and len(kw_args_id): + fill_with_kw = f""" + if (kw) {{ + auto kw_n = Py_SIZE(kw); + for (int i=0; ixxx if is class def + if df["is_scope_def"]: + if df["is_static"]: + func_call = f"{scope_name}::" + func_call + else: + func_call = f"(GET_RAW_PTR({scope_name},self))->" + func_call + if pyname == "__init__": + # XXX->xxx(...) ---> new XXX xxx(...) + assert "->" in func_call, func_call + func_call = "new " + func_call.replace("->", " ") + if no_need_convert: + func_quick_check_runable = "" + func_args_convert = "" + fill_with_kw = fill_with_default = "" + return ( + func_quick_check_size + func_quick_check_runable, + func_args_convert, + fill_with_kw+fill_with_default, + func_call, + has_return + ) + +hash_to_key_map = {} + +def get_hash(s): + mask = (1<<32)-1 + v=0 + mul = 1 + for c in s: + v += mul * ord(c) + mul *= 55 + v &= mask + mul &= mask + if v in hash_to_key_map: + assert hash_to_key_map[v] == s, \ + f"hash conflict {hash_to_key_map[v]} {s} {hash_to_key_map}" + hash_to_key_map[v] = s + return v + +def get_hash_condition(s): + if s == "keepdims": + return f"khash == {get_hash(s)}u || khash == {get_hash('keepdim')}u" + return f"khash == {get_hash(s)}u" + +reg = re.compile( + '(/\\*(.*?)\\*/\\s*)?(//\\s*@pyjt\\(([^\\n]*)\\)\\s*)' + # ^^^^^^^^^^^^^^^^^ ^^^^ ^^^^ + # doc string $1 pyjt args $3 + + + '(//\\s*@attrs\\(([^\\n]*)\\)\\s*)?' + # ^^^^^ ^^^^^^^ + # attrs args $5 +, re.DOTALL) + +def generate_error_code_from_func_header(func_head, target_scope_name, name, dfs, basename, h, class_info): + # func_head is a string like: + # (PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* + lib_name = os.path.basename(h).split("_")[0] + # TODO: fix/add var help + if target_scope_name == "Var": target_scope_name = None + if target_scope_name: + if target_scope_name == "flags": + help_name = "flags" + else: + help_name = ""+target_scope_name+'.'+name + else: + help_name = name + if lib_name in ["mpi", "nccl", "cudnn", "curand" "cufft", "cublas", "mkl"]: + help_name = lib_name+'.'+help_name + help_cmd = f"help(jt.{help_name})" + + LOG.vvv("gen err from func_head", func_head) + args = func_head[1:].split(")")[0].split(",") + error_code = f" << \"Wrong inputs arguments, Please refer to examples({help_cmd}).\"" + error_code += r' << "\n\nTypes of your inputs are:\n"' + for arg in args: + arg = arg.strip() + if arg.startswith("PyObject* "): + t, n = arg.split(' ') + if n == "args" or n == "_args": + error_code += f" << PyTupleArgPrinter{{{n}, \"args\"}} " + elif n == "kw": + error_code += f" << PyKwArgPrinter{{{n}}} " + else: + error_code += f" << PyArgPrinter{{{n}, \"{n}\"}} " + elif arg.startswith("PyObject** "): + t, n = arg.split(' ') + error_code += f" << PyFastCallArgPrinter{{{n}, n, kw}} " + break + else: + LOG.vvv("Unhandled arg", arg) + LOG.vvv("gen err from func_head", func_head, " -> ", error_code) + return error_code + +def compile_src(src, h, basename): + res = list(reg.finditer(src, re.S)) + if len(res)==0: return + class_ranges = None + class_name = None + class_info = None + submodule_name = None + submodule_ranges = None + submodule_info = None + defs = [] + LOG.vv(("find in", h)) + for x in res: + LOG.vvv((x, x.groups())) + g = x.groups() + doc = g[1] + pyjt = g[3] + attrs = g[5] + esplit = lambda x: [] if x==None else \ + [ a.strip() for a in x.split(",") if len(a.strip()) ] + attrs = parse_attrs(attrs) + pynames = esplit(pyjt) + end = x.end() + def find_bc(i): + while src[i] not in "({;": + i += 1 + j = i+1 + if src[i]==';': + return i, j + presum = 1 + while True: + if src[j] in "({[": + presum += 1 + elif src[j] in ")}]": + presum -= 1 + if presum==0: + s = src[i]+src[j] + assert s in ("()","{}","()"), "braces not match "+s + return i, j + j += 1 + # // @pyjt(DType) + # struct DType { + # ^ --> a + # ..... + # } <--- b + # or + # // @pyjt(hash) + # inline uint hash(const char* input) + # ^ --> a ^ --> b + a, b = find_bc(end) + is_property = 0 + if src[a] == ';': + # This case + # class XXX { + # // @pyjt(property) + # T property; + # } + is_property = 1 + if src[a] == '{': + assert len(pynames)==1 + if "submodule" in attrs: + assert submodule_ranges==None + submodule_ranges = (a, b) + submodule_name = src[end:a-1].strip().split()[-1] + submodule_info = { + "pynames": pynames, + "attrs": attrs + } + continue + assert class_ranges==None + class_ranges = (a, b) + class_name = src[end:a-1].strip().split()[-1] + class_info = { + "pynames": pynames, + "attrs": attrs + } + continue + is_scope_def = False + is_static = False + scope_name = "" + if class_ranges != None: + if class_ranges[0] < a and a < class_ranges[1]: + is_scope_def = True + scope_name = class_name + if submodule_ranges != None: + if submodule_ranges[0] < a and a < submodule_ranges[1]: + is_scope_def = True + scope_name = submodule_name + is_static = True + dec = src[end:b+1].strip() + arr = src[end:a].strip().split() + func_name = arr[-1] + + is_constructor = False + if is_scope_def and func_name==class_name: + is_constructor = True + + args = [] + for arg in split_args(src[a+1:b]): + if arg=="": continue + default = "" + if "=" in arg: + arg, default = arg.split('=') + default = default + arg = arg.strip() + name = arg.split(' ')[-1] + tp = arg[:-len(name)] + tp = tp.strip() + prev_tp = tp + # const string& ----> string + if tp.startswith("const") and tp.endswith("&"): + tp = tp[5:-1].strip() + # T&& -> T + if tp.endswith("&&"): + tp = tp[:-2].strip() + # ArrayArgs& -> ArrayArgs + if tp.endswith("&"): + tp = tp[:-1].strip() + args.append((tp, name.strip(), default.strip(), prev_tp)) + return_t = "" + for a in arr[:-1]: + if a in ["", "inline", "constexpr"]: continue + if a == "static": + is_static = True + continue + if return_t != "": return_t += " " + return_t += a + + if is_scope_def and class_info and "submodule" in class_info["attrs"]: + is_static = True + + for pid, pyname in enumerate(pynames): + for rname in [ "__lt__", "__le__", "__gt__", + "__ge__", "__eq__", "__ne__"]: + if pyname.endswith(rname): + attrs[rname] = 1 + pynames[pid] = pyname.replace(rname, "__richcmp__") + + def_info = { + "is_scope_def": is_scope_def, + "is_constructor": is_constructor, + "is_static": is_static, + "is_property": is_property, + "func_name": func_name, + "args": args, # [(type,name,defaut), ...] + "return_t": return_t, # return type + "dec": dec, # full string of xxx(A a, B b) + "pynames": pynames, # names in @pyjt(...) + "attrs": attrs, # attrs in @attrs(...) + "doc": doc, + "scope_name": scope_name, + } + if is_property: + # This case + # class XXX { + # // @pyjt(property) + # T property; + # } + assert is_scope_def and not is_static + def_info["is_property"] = 1 + def_info["pynames"] = ["__get__"+n for n in pynames] + assert return_t != "void" + defs.append(dict(def_info)) + def_info["pynames"] = ["__set__"+n for n in pynames] + assert len(args) == 0 + def_info["args"] = [(def_info["return_t"], func_name, "", "")] + def_info["return_t"] = "void" + defs.append(dict(def_info)) + continue + else: + defs.append(def_info) + LOG.vvv(lambda: json.dumps(def_info, indent=4)) + # deal with defs + if len(defs) == 0: return + # include_name = h[4:] # remove "src/" prefix + include_name = h + code = [] + class_defs_code = [] + class_getsets_code = [] + class_gets = OrderedDict() + class_sets = OrderedDict() + class_slots_code = [] + submodule_defs_code = [] + def_targets = OrderedDict() + has_attr_dict = class_name in ["VarHolder"] + for df in defs: + for name in df["pynames"]: + if df["is_scope_def"] and '.' not in name: + if df["scope_name"] == class_name: + name = class_info["pynames"][0] + '.' + name + else: + name = submodule_info["pynames"][0] + '.' + name + if name not in def_targets: + def_targets[name] = [] + def_targets[name].append(df) + for name in def_targets: + dfs = def_targets[name] + target_scope_name = None + LOG.vv(name) + if "." in name: + target_scope_name, name = name.split(".") + # array for each df: + arr_func_quick_check_runable = [] + arr_func_args_convert = [] + arr_fill_with_default = [] + arr_func_call = [] + arr_has_return = [] + self_as_arg0 = False + for df in dfs: + self_as_arg0 = class_info and \ + target_scope_name == class_info["pynames"][0] and \ + df["scope_name"] == submodule_name \ + and not name.startswith("__") + res = get_def_code(df, df["scope_name"], name, bool(self_as_arg0)) + arr_func_quick_check_runable.append(res[0]) + arr_func_args_convert.append(res[1]) + arr_fill_with_default.append(res[2]) + arr_func_call.append(res[3]) + arr_has_return.append(res[4]) + + slot_name = None + func_cast = "" + func_fill = "" + before_return = "" + if name == "__init__": + slot_name = "tp_init" + func_head = "(PyObject* self, PyObject* _args, PyObject* kw) -> int" + func_fill = """ + int64 n = Py_SIZE(_args); + auto args = (PyObject**)&PyTuple_GET_ITEM(_args, 0); + (void)n, (void)args; + // TODO: support kw + CHECK(kw==0); + """ + if has_attr_dict: + func_fill += f"((PyObject**)(((char*)self) + sizeof(PyObject) + sizeof({class_name})))[0] = PyDict_New(); " + + elif name == "__repr__": + slot_name = "tp_repr" + func_head = "(PyObject* self) -> PyObject*" + func_fill = "int64 n = 0; (void)n;" + + elif name.startswith("__get__"): + slot_name = "tp_gets" + name = name[len("__get__"):] + func_head = "(PyObject* self, void*) -> PyObject*" + func_fill = "int64 n = 0; (void)n;" + + elif name.startswith("__set__"): + slot_name = "tp_sets" + name = name[len("__set__"):] + func_head = "(PyObject* self, PyObject* arg, void*) -> int" + func_fill = """ + int64 n=1; + PyObject** args = &arg; + (void)n, (void)args; + """ + + elif name == "__call__": + slot_name = "tp_call" + func_head = "(PyObject* self, PyObject* _args, PyObject* kw) -> PyObject*" + func_fill = """ + int64 n = Py_SIZE(_args); + auto args = (PyObject**)&PyTuple_GET_ITEM(_args, 0); + (void)n, (void)args; + // TODO: support kw + CHECK(kw==0); + """ + + elif name == "__dealloc__": + slot_name = "tp_dealloc" + func_head = "(PyObject* self) -> void" + func_fill = "int64 n = 0" + before_return = "Py_TYPE(self)->tp_free((PyObject *) self);" + if has_attr_dict: + before_return = f"Py_XDECREF(((PyObject**)(((char*)self) + sizeof(PyObject) + sizeof({class_name})))[0]);" + before_return + + elif name in binary_number_slots: + slot_name = "tp_as_number->"+binary_number_slots[name] + func_head = "(PyObject* self, PyObject* b) -> PyObject*" + if name.endswith("pow__"): + func_head = "(PyObject* self, PyObject* b, PyObject*) -> PyObject*" + func_fill = """ + int64 n = 2; + PyObject* args[] = {self, b}; + (void)n, (void)args; + """ + + elif name in unary_number_slots: + slot_name = "tp_as_number->"+unary_number_slots[name] + func_head = "(PyObject* self) -> PyObject*" + func_fill = """ + int64 n = 1; + PyObject* args[] = {self}; + (void)n, (void)args; + """ + + elif name == "__str__": + slot_name = "tp_str" + func_head = "(PyObject* self) -> PyObject*" + func_fill = """ + int64 n = 0; + PyObject* args[] = {self}; + (void)n, (void)args; + """ + + elif name == "__richcmp__": + slot_name = "tp_richcompare" + func_head = "(PyObject* self, PyObject* b, int op) -> PyObject*" + func_fill = """ + int64 n = 2; + PyObject* args[] = {self, b}; + (void)n, (void)args; + """ + + elif name == "__len__": + slot_name = "tp_as_sequence->sq_length" + func_head = "(PyObject* self) -> Py_ssize_t" + func_fill = """ + int64 n = 0; + (void)n; + """ + + elif name == "__map_len__": + slot_name = "tp_as_mapping->mp_length" + func_head = "(PyObject* self) -> Py_ssize_t" + func_fill = """ + int64 n = 0; + (void)n; + """ + + elif name == "__getitem__": + slot_name = "tp_as_sequence->sq_item" + func_head = "(PyObject* self, Py_ssize_t arg0) -> PyObject*" + func_fill = f""" + int64 n = 1; + (void)n; + if (arg0 >= GET_RAW_PTR({dfs[0]["scope_name"]},self)->size()) {{ + PyErr_SetString(PyExc_IndexError, ""); + return (PyObject*)nullptr; + }} + """ + + elif name == "__map_getitem__": + slot_name = "tp_as_mapping->mp_subscript" + func_head = "(PyObject* self, PyObject* arg0) -> PyObject*" + func_fill = f""" + int64 n = 1; + PyObject* args[] = {{arg0}}; + (void)n; + """ + + elif name.startswith("__"): + LOG.f(f"Not support slot {name}") + continue + + else: + func_head = "(PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject*" + func_cast = f"(PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))" + # if not return, return py_none + arr_has_return = [ True for _ in arr_has_return ] + + arr_func_return = [] + doc_all = "" + decs = "The function declarations are:\n" + for did, has_return in enumerate(arr_has_return): + df = dfs[did] + func_call = arr_func_call[did] + if df["doc"] and not (did > 0 and df["doc"] == dfs[did - 1]["doc"]): + doc_all += "Document:\n" + doc_all += df["doc"]+'\n' + doc_all += "Declaration:\n" + doc_all += df["dec"]+'\n\n' + decs += " " + df["dec"]+'\n' + if has_return: + assert "-> int" not in func_head + if "-> PyObject*" in func_head: + if "return_self" in df["attrs"]: + arr_func_return.append( + f"return (({func_call}), Py_INCREF(self), self)") + else: + arr_func_return.append( + f"return {get_pytype_map(df['return_t'],1)}(({func_call}))") + func_return_failed = "return nullptr" + else: + arr_func_return.append( + f"return ({func_call});") + func_return_failed = "return -1" + else: + if "-> int" in func_head: + arr_func_return.append(f"return ({func_call},0)") + func_return_failed = "return -1" + else: + assert "-> void" in func_head, func_head + arr_func_return.append(f"{func_call};{before_return}return") + func_return_failed = "return" + # generate error msg when not a valid call + error_log_code = generate_error_code_from_func_header(func_head, target_scope_name, name, dfs, basename ,h, class_info) + func = f""" + {func_cast}[]{func_head} {{ + try {{ + {func_fill}; + uint64 arg_filled=0; + (void)arg_filled; + {"".join([f''' + if ({arr_func_quick_check_runable[did]}) {{ + {arr_func_args_convert[did]}; + {arr_fill_with_default[did]}; + {arr_func_return[did]}; + }} + ''' + for did in range(len(arr_func_return)) + ])} + LOGf << "Not a valid call."; + }} catch (const std::exception& e) {{ + if (!PyErr_Occurred()) {{ + std::stringstream ss; + if (check_async_executor_error(e, ss)) {{ + PyErr_Format(PyExc_RuntimeError, + "%s", + ss.str().c_str() + ); + }} else {{ + ss {error_log_code}; + PyErr_Format(PyExc_RuntimeError, + "%s\\n%s\\nFailed reason:%s", + ss.str().c_str(), + R""({decs})"", + e.what() + ); + }} + }} + }} + {func_return_failed}; + }} + """ + + if slot_name: + if slot_name=="tp_gets": + class_gets[name] = { + "func": func, + "doc": doc_all + } + continue + if slot_name=="tp_sets": + class_sets[name] = { + "func": func, + "doc": "" + } + continue + class_slots_code.append(f""" + tp.{slot_name} = {func}; + """) + continue + need_static = "" + if df["is_scope_def"] and df["is_static"] and \ + df["scope_name"] == class_name and \ + "submodule" not in class_info["attrs"]: + need_static = " | METH_STATIC" + func = (f""" + {{ R""({name})"", + {func}, + METH_FASTCALL | METH_KEYWORDS{need_static}, + R""({doc_all})"" + }}""") + if df["is_scope_def"]: + if df["scope_name"] == class_name or \ + (class_info and \ + target_scope_name == class_info["pynames"][0]): + class_defs_code.append(func) + else: + submodule_defs_code.append(func) + else: + code.append(func) + prop_names = list(set(class_gets.keys()).union(class_sets.keys())) + prop_names = sorted(prop_names) + for prop_name in prop_names: + get_func = "NULL" + set_func = "NULL" + doc = "" + if prop_name in class_gets: + get_func = class_gets[prop_name]["func"] + if class_gets[prop_name]["doc"]: + doc += class_gets[prop_name]["doc"] + if prop_name in class_sets: + set_func = class_sets[prop_name]["func"] + if class_sets[prop_name]["doc"]: + doc += class_sets[prop_name]["doc"] + class_getsets_code.append(f""" + {{"{prop_name}", {get_func}, {set_func}, R""({doc})""}} + """) + code.append("{0,0,0,0}") + class_defs_code.append("{0,0,0,0}") + class_getsets_code.append("{0,0,0,0}") + submodule_defs_code.append("{0,0,0,0}") + core_name = "jittor_core" + if class_info and "attrs" in class_info and "core_name" in class_info["attrs"]: + core_name = class_info["attrs"]["core_name"] + if submodule_info and "attrs" in submodule_info and "core_name" in submodule_info["attrs"]: + core_name = submodule_info["attrs"]["core_name"] + has_map = class_name in ["VarHolder", "NanoVector"] + has_seq = class_name in ["VarHolder", "NanoVector"] + # add extra include to avoid compile error + src_code = "" + if include_name.endswith("var_slices.h"): + src_code += '#include "var_holder.h"\n' + src_code += f""" + #include "utils/seh.h" + #include "pyjt/py_converter.h" + #include "pyjt/py_arg_printer.h" + #include "common.h" + #include "{include_name}" + + namespace jittor {{ + + { + "" if class_name is None else + f"PyHeapTypeObject Pyjt{class_name};" if "heaptype" in class_info["attrs"] else + f"PyTypeObject Pyjt{class_name};" + } + + void pyjt_def_{basename}(PyObject* m) {{ + static PyMethodDef defs[] = {{ + {",".join(code)} + }}; + ASSERT(PyModule_AddFunctions(m, defs)==0); + { + f''' + static PyMethodDef class_defs[] = {{ + {",".join(class_defs_code)} + }}; + static PyGetSetDef class_getsets[] = {{ + {",".join(class_getsets_code)} + }}; + + static PyNumberMethods number_methods = {{0}}; + {f"auto& htp =Pyjt{class_name}; auto& tp = htp.ht_type;" + if "heaptype" in class_info["attrs"] else + f"auto& tp = Pyjt{class_name};"} + tp.tp_as_number = &number_methods; + + {f"static PyMappingMethods class_map_defs = {{0}};" if has_map else ""} + {f"tp.tp_as_mapping = &class_map_defs;" if has_map else ""} + + {f"static PySequenceMethods class_seq_defs = {{0}};" if has_seq else ""} + {f"tp.tp_as_sequence = &class_seq_defs;" if has_seq else ""} + + tp.tp_name = "{core_name}.{class_info["pynames"][0]}"; + tp.tp_basicsize = GET_OBJ_SIZE({class_name}); + {f"tp.tp_dictoffset = tp.tp_basicsize; tp.tp_basicsize += sizeof(PyObject*); " if has_attr_dict else ""} + tp.tp_new = PyType_GenericNew; + tp.tp_flags = Py_TPFLAGS_DEFAULT; + {"tp.tp_flags |= Py_TPFLAGS_HEAPTYPE; htp.ht_name = htp.ht_qualname = to_py_object(tp.tp_name);" + if "heaptype" in class_info["attrs"] else ""} + tp.tp_methods = &class_defs[0]; + tp.tp_getset = &class_getsets[0]; + {"".join(class_slots_code)}; + ASSERT(0==PyType_Ready(&tp)) << (PyErr_Print(), 0); + Py_INCREF(&tp); + ASSERT(0==PyModule_AddObject(m, "{class_info["pynames"][0]}", (PyObject*)&tp)); + ''' if class_name is not None else "" + } + {f''' + + // sub module def + static PyMethodDef submodule_defs[] = {{ + {",".join(submodule_defs_code)} + }}; + auto sub = PyImport_AddModule("{core_name}.{submodule_info["pynames"][0]}"); + ASSERT(PyModule_AddFunctions(sub, submodule_defs)==0); + ASSERT(sub); + ASSERT(0==PyModule_AddObject(m, "{submodule_info["pynames"][0]}", sub)); + ''' if submodule_name is not None else "" + } + + }} + + }} + """ + return src_code + +def compile_single(head_file_name, src_file_name, src=None): + basename = os.path.basename(head_file_name).split(".")[0] + if src==None: + with open(head_file_name, 'r', encoding='utf8') as f: + src = f.read() + code = compile_src(src, head_file_name, basename) + if not code: return False + LOG.vvv("write to", src_file_name) + LOG.vvvv(code) + with open(src_file_name, 'w', encoding='utf8') as f: + f.write(code) + return True + +def compile(cache_path, jittor_path): + headers1 = glob.glob(jittor_path+"/src/**/*.h", recursive=True) + headers2 = glob.glob(cache_path+"/gen/**/*.h", recursive=True) + headers = headers1 + headers2 + basenames = [] + pyjt_names = [] + for h in headers: + with open(h, 'r', encoding='utf8') as f: + src = f.read() + + bh = os.path.basename(h) + # jit_op_maker.h merge compile with var_holder.h + if bh == "var_holder.h": continue + if bh == "jit_op_maker.h": + with open(os.path.join(jittor_path, "src", "var_holder.h"), "r", encoding='utf8') as f: + src = f.read() + src + basename = bh.split(".")[0] + fname = "pyjt_"+basename+".cc" + fname = os.path.join(cache_path, "gen", fname) + check = compile_single(h, fname, src) + + if not check: continue + + basenames.append(basename) + pyjt_names.append(fname) + + code = f""" + #include "pyjt/py_converter.h" + #include "common.h" + + namespace jittor {{ + + { " ".join([f"extern void pyjt_def_{n}(PyObject* m);" for n in basenames])} + + void pyjt_def_all(PyObject* m) {{ + { " ".join([f"pyjt_def_{n}(m);" for n in basenames])} + }} + + }} + """ + fname = os.path.join(cache_path, "gen", "pyjt_all.cc") + LOG.vvv(("write to", fname)) + LOG.vvvv(code) + with open(fname, "w", encoding='utf8') as f: + f.write(code) + pyjt_names.append(fname) + return pyjt_names diff --git a/python/jittor/script/Dockerfile_cuda11 b/python/jittor/script/Dockerfile_cuda11 new file mode 100644 index 00000000..8e395da4 --- /dev/null +++ b/python/jittor/script/Dockerfile_cuda11 @@ -0,0 +1,50 @@ +# docker build commands +ARG FROM_IMAGE=nvidia/cuda:11.1.1-cudnn8-devel-ubuntu20.04 + +FROM ${FROM_IMAGE} + +RUN rm /etc/apt/sources.list.d/cuda.list + +RUN apt update && apt install ca-certificates -y + +# change tsinghua mirror +RUN echo \ +"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-updates main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-backports main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-security main restricted universe multiverse" > /etc/apt/sources.list + +RUN apt update && apt install wget \ + python3 python3-dev python3-pip \ + g++ build-essential -y + +WORKDIR /usr/src/jittor + +ENV PYTHONIOENCODING utf8 +ENV DEBIAN_FRONTEND noninteractive + +# change tsinghua mirror +RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple + +RUN pip3 install \ + numpy \ + tqdm \ + pillow \ + astunparse \ + notebook + +RUN pip3 install matplotlib + +RUN apt install openmpi-bin openmpi-common libopenmpi-dev -y + +RUN pip3 install jittor --timeout 100 && python3 -m jittor.test.test_example + +RUN pip3 uninstall jittor -y + +COPY . . + +RUN pip3 install . --timeout 100 + +RUN python3 -m jittor.test.test_example + +CMD python3 -m jittor.notebook --allow-root --ip=0.0.0.0 \ No newline at end of file diff --git a/python/jittor/script/build_aarch64_mkl.sh b/python/jittor/script/build_aarch64_mkl.sh new file mode 100644 index 00000000..3061ea34 --- /dev/null +++ b/python/jittor/script/build_aarch64_mkl.sh @@ -0,0 +1,31 @@ +# wget https://github.com/oneapi-src/oneDNN/archive/refs/tags/v2.2.zip +# extract zip +# cd to root folder + +mkdir -p build +cd build +make clean +export CC=aarch64-linux-gnu-gcc-8 +export CXX=aarch64-linux-gnu-g++-8 +cmake .. \ + -DCMAKE_SYSTEM_NAME=Linux \ + -DCMAKE_SYSTEM_PROCESSOR=AARCH64 \ + -DCMAKE_LIBRARY_PATH=/usr/aarch64-linux-gnu/lib \ + -DCMAKE_BUILD_TYPE=Release + # -DCMAKE_SHARED_LINKER_FLAGS=' -lm ' \ +make -j8 + +name=dnnl_lnx_2.2.0_cpu_gomp_aarch64 +mkdir -p $name +cp -r ../include ./$name/ +mkdir -p ./$name/lib +cp ./src/libmkldnn.so ./$name/lib/libmkldnn.so +cp -r ../examples ./$name/ +cp ./include/oneapi/dnnl/* ./$name/include/oneapi/dnnl/ + +tar -acvf $name.tgz ./$name/ + +rsync -avPu $name.tgz jittor-web:Documents/jittor-blog/assets/ +ssh jittor-web Documents/jittor-blog.git/hooks/post-update +echo "https://cg.cs.tsinghua.edu.cn/jittor/assets/$name.tgz" +md5sum $name.tgz \ No newline at end of file diff --git a/python/jittor/script/converter_server.sh b/python/jittor/script/converter_server.sh new file mode 100644 index 00000000..7768d32e --- /dev/null +++ b/python/jittor/script/converter_server.sh @@ -0,0 +1,14 @@ +cat > /tmp/converter_server.dockerfile <<\EOF +FROM jittor/jittor + +RUN python3.7 -m pip install flask +RUN apt update && apt install git -y +EOF + +docker build --tag jittor/converter_server -f /tmp/converter_server.dockerfile . + +# docker run --rm -it -m 16g --cpus=8 -p 0.0.0.0:5000:5000 jittor/converter_server bash -c "python3.7 -m pip install -U git+https://github.com/Jittor/jittor.git && python3.7 -m jittor.utils.converter_server" +while true; do + timeout --foreground 24h docker run --rm -it -m 16g --cpus=8 -p 0.0.0.0:58187:5000 -v /etc/letsencrypt/:/https jittor/converter_server bash -c "python3.7 -m pip install -U jittor && python3.7 -m jittor.test.test_core && FLASK_APP=/usr/local/lib/python3.7/dist-packages/jittor/utils/converter_server python3.7 -m flask run --cert=/https/live/randonl.me/fullchain.pem --key=/https/live/randonl.me/privkey.pem --host=0.0.0.0" + sleep 10 +done \ No newline at end of file diff --git a/python/jittor/script/inference_perf.py b/python/jittor/script/inference_perf.py new file mode 100644 index 00000000..be9a03fc --- /dev/null +++ b/python/jittor/script/inference_perf.py @@ -0,0 +1,123 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Wenyang Zhou <576825820@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import numpy as np +import jittor as jt +import torch +import time +import jittor.models as jtmodels +import torchvision.models as tcmodels +import os + +jt.flags.use_cuda = 1 +torch.backends.cudnn.deterministic = False +torch.backends.cudnn.benchmark = True +jt.cudnn.set_algorithm_cache_size(10000) + +threshold = 1e-3 + +models = [ + # 'squeezenet1_0', + 'squeezenet1_1', + 'alexnet', + # 'resnet18', + # 'resnet34', + 'resnet50', + # 'resnet101', + 'resnet152', + 'resnext50_32x4d', + 'resnext101_32x8d', + 'vgg11', + # 'vgg11_bn', + # 'vgg13', + # 'vgg13_bn', + # 'vgg16', + # 'vgg16_bn', + # 'vgg19', + # 'vgg19_bn', + 'wide_resnet50_2', + 'wide_resnet101_2', +] + +def to_cuda(x): + if jt.has_cuda: + return x.cuda() + return x + +def test_allmodels(bs=1): + # Define numpy input image + test_img = np.random.random((bs,3,224,224)).astype('float32') + # Define pytorch & jittor input image + pytorch_test_img = to_cuda(torch.Tensor(test_img)) + jittor_test_img = jt.array(test_img) + for model in models: + if model == "inception_v3": + test_img = np.random.random((bs,3,300,300)).astype('float32') + pytorch_test_img = to_cuda(torch.Tensor(test_img)) + jittor_test_img = jt.array(test_img) + + jittor_test_img.stop_grad() + pytorch_test_img.requires_grad = False + + # Define pytorch & jittor model + pytorch_model = to_cuda(tcmodels.__dict__[model]()) + jittor_model = jtmodels.__dict__[model]() + # Set eval to avoid dropout layer + pytorch_model.eval() + jittor_model.eval() + # Jittor loads pytorch parameters to ensure forward alignment + jittor_model.load_parameters(pytorch_model.state_dict()) + + total = 512 + warmup = max(2, total // bs // 8) + rerun = max(2, total // bs) + + print("=" * 20 + model + "=" * 20) + + # Jittor warms up + for i in range(warmup): + jittor_result = jittor_model(jittor_test_img) + jt.sync_all(True) + # Test jittor and once forward time + sta = time.time() + for i in range(rerun): + jittor_result = jittor_model(jittor_test_img) + jittor_result.sync() + jt.sync_all(True) + end = time.time() + print(f"- Jittor {model} forward average time cost: {round((time.time() - sta) / rerun,5)}, Batch Size: {bs}, FPS: {round(bs * rerun / (end - sta),2)}") + + # pytorch warmup + for i in range(warmup): + pytorch_result = pytorch_model(pytorch_test_img) + # Test pytorch and once forward time + torch.cuda.synchronize() + sta = time.time() + for i in range(rerun): + pytorch_result = pytorch_model(pytorch_test_img) + torch.cuda.synchronize() + end = time.time() + print(f"- Pytorch {model} forward average time cost: {round((end - sta) / rerun,5)}, Batch Size: {bs}, FPS: {round(bs * rerun / (end - sta),2)}") + + # Judge pytorch & jittor forward relative error. If the differece is lower than threshold, this test passes. + x = pytorch_result.detach().cpu().numpy() + 1 + y = jittor_result.numpy() + 1 + relative_error = abs(x - y) / abs(y) + diff = relative_error.mean() + assert diff < threshold, f"[*] {model} forward fails..., Relative Error: {diff}" + print(f"[*] {model} forword passes with Relative Error {diff}") + torch.cuda.empty_cache() + jt.clean() + jt.gc() + + +with torch.no_grad(): + for bs in [1,2,4,8,16,32,64,128]: + # for bs in [128]: + test_allmodels(bs) \ No newline at end of file diff --git a/python/jittor/script/install.sh b/python/jittor/script/install.sh new file mode 100755 index 00000000..d919e397 --- /dev/null +++ b/python/jittor/script/install.sh @@ -0,0 +1,70 @@ +#!/bin/bash +# Single line install script +# wget -O - https://raw.githubusercontent.com/Jittor/jittor/master/script/install.sh | with_clang=1 with_cuda=1 bash +set -ex + +if [ "$is_docker" = "1" ]; then +tee /etc/apt/sources.list < /tmp/llvm.sh +sudo bash /tmp/llvm.sh 8 +sudo apt-get install libc++-8-dev libc++abi-8-dev -y +sudo apt-get install libomp-8-dev -y +export cc_path="clang-8" +fi + +if [ "$with_gcc" = "1" ]; then +sudo apt install g++ build-essential libomp-dev -y +export cc_path="g++" +fi + +if [ "$with_icc" = "1" ]; then +export cc_path="icc" +fi + +# Step 2: Install Python and dependency + +if [ "$py_version" = "" ]; then +py_version="3.7" +fi +sudo add-apt-repository ppa:deadsnakes/ppa -y +sudo apt-get update +sudo apt install python$py_version python$py_version-dev -y +# python3.8 need this +# sudo apt install python3.8-distutils +wget -O - https://bootstrap.pypa.io/get-pip.py | sudo -H python$py_version + +# Step 3: Run jittor + +sudo python$py_version -m pip install git+https://github.com/Jittor/jittor.git + +if [ "$with_cuda" = "1" ]; then +export nvcc_path="/usr/local/cuda/bin/nvcc" +fi + +# run a simple test +python$py_version -m jittor.test.test_example + +if [ "$with_cuda" = "1" ]; then +python$py_version -m jittor.test.test_cuda +fi + +set +x +echo "jittor test is passed. Please export the following enviroments value" +echo "---------------------------------------" +echo "export cc_path=$cc_path" +echo "export nvcc_path=$nvcc_path" +echo "---------------------------------------" \ No newline at end of file diff --git a/python/jittor/script/install_llvm.sh b/python/jittor/script/install_llvm.sh new file mode 100644 index 00000000..c367c893 --- /dev/null +++ b/python/jittor/script/install_llvm.sh @@ -0,0 +1,62 @@ +#!/bin/bash +################################################################################ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +################################################################################ +# +# This script will install the llvm toolchain on the different +# Debian and Ubuntu versions + +set -eux + +# read optional command line argument +LLVM_VERSION=8 +if [ "$#" -eq 1 ]; then + LLVM_VERSION=$1 +fi + +DISTRO=$(lsb_release -is) +VERSION=$(lsb_release -sr) +DIST_VERSION="${DISTRO}_${VERSION}" + +if [[ $EUID -ne 0 ]]; then + echo "This script must be run as root!" + exit 1 +fi + +declare -A LLVM_VERSION_PATTERNS +LLVM_VERSION_PATTERNS[8]="-8" +LLVM_VERSION_PATTERNS[9]="-9" +LLVM_VERSION_PATTERNS[10]="" + +if [ ! ${LLVM_VERSION_PATTERNS[$LLVM_VERSION]+_} ]; then + echo "This script does not support LLVM version $LLVM_VERSION" + exit 3 +fi + +LLVM_VERSION_STRING=${LLVM_VERSION_PATTERNS[$LLVM_VERSION]} + +# find the right repository name for the distro and version +case "$DIST_VERSION" in + Debian_9* ) REPO_NAME="deb http://apt.llvm.org/stretch/ llvm-toolchain-stretch$LLVM_VERSION_STRING main" ;; + Debian_10* ) REPO_NAME="deb http://apt.llvm.org/buster/ llvm-toolchain-buster$LLVM_VERSION_STRING main" ;; + Debian_unstable ) REPO_NAME="deb http://apt.llvm.org/unstable/ llvm-toolchain$LLVM_VERSION_STRING main" ;; + Debian_testing ) REPO_NAME="deb http://apt.llvm.org/unstable/ llvm-toolchain$LLVM_VERSION_STRING main" ;; + Ubuntu_16.04 ) REPO_NAME="deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial$LLVM_VERSION_STRING main" ;; + Ubuntu_18.04 ) REPO_NAME="deb http://apt.llvm.org/bionic/ llvm-toolchain-bionic$LLVM_VERSION_STRING main" ;; + Ubuntu_18.10 ) REPO_NAME="deb http://apt.llvm.org/cosmic/ llvm-toolchain-cosmic$LLVM_VERSION_STRING main" ;; + Ubuntu_19.04 ) REPO_NAME="deb http://apt.llvm.org/disco/ llvm-toolchain-disco$LLVM_VERSION_STRING main" ;; + * ) + echo "Distribution '$DISTRO' in version '$VERSION' is not supported by this script (${DIST_VERSION})." + exit 2 +esac + + +cat /etc/apt/sources.list +# install everything +wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | apt-key add - +add-apt-repository "${REPO_NAME}" +cat /etc/apt/sources.list +apt-get update +apt-get install -y clang-$LLVM_VERSION lldb-$LLVM_VERSION lld-$LLVM_VERSION clangd-$LLVM_VERSION diff --git a/python/jittor/script/install_mkl.sh b/python/jittor/script/install_mkl.sh new file mode 100755 index 00000000..8ea47e77 --- /dev/null +++ b/python/jittor/script/install_mkl.sh @@ -0,0 +1,23 @@ +#!/bin/bash +set -xe +if [ "$cache_path" = "" ]; then + bpath=$(dirname "${BASH_SOURCE[0]}") + cd $bpath + cd ../extern/mkl +else + cd $cache_path +fi +filename="mkldnn_lnx_1.0.2_cpu_gomp.tgz" +dirname="mkldnn_lnx_1.0.2_cpu_gomp" +if [ ! -f $filename ]; then + wget https://github.com/intel/mkl-dnn/releases/download/v1.0.2/$filename +fi +if [ ! -d $dirname ]; then + tar zxvf $filename +fi + +if [ ! -f $dirname/examples/test ]; then + echo "compile mkldnn example and test" + cd $dirname/examples + g++ -std=c++14 cpu_cnn_inference_f32.cpp -Ofast -lmkldnn -I ../include -L ../lib -o test && LD_LIBRARY_PATH=../lib/ ./test +fi \ No newline at end of file diff --git a/python/jittor/script/make_doc.py b/python/jittor/script/make_doc.py new file mode 100644 index 00000000..2c2ffcc4 --- /dev/null +++ b/python/jittor/script/make_doc.py @@ -0,0 +1,34 @@ +import os + +def fix_config(in_name, out_name, src_path, out_path): + data = open(in_name, 'r', encoding='utf8').readlines() + out = [] + for d in data: + if d.startswith('INPUT ='): + d = f'INPUT ={src_path}\n' + elif d.startswith('OUTPUT_DIRECTORY ='): + d = f'OUTPUT_DIRECTORY ={out_path}\n' + out.append(d) + f = open(out_name, 'w', encoding='utf8') + f.writelines(out) + +jt_path = os.getcwd() +cache_path = f"{os.environ['HOME']}/.cache/jittor" + +os.system(f"rm -rf {cache_path}/docxygen/jittor") +os.system(f"mkdir -p {cache_path}/docxygen/jittor") +os.chdir(f"{cache_path}/docxygen") +# copy jittor src code +os.system(f"cp -r {jt_path}/src {cache_path}/docxygen/jittor") +os.system(f"cp -r {jt_path}/python {cache_path}/docxygen/jittor") +os.system(f"cp -r {jt_path}/notebook {cache_path}/docxygen/jittor") +os.system(f"cp {jt_path}/README.src.md {cache_path}/docxygen/jittor") +#download doxygen & config file +if not os.path.exists('doxygen-1.8.17'): + os.system("wget -O doxygen.tar.gz https://cloud.tsinghua.edu.cn/f/dfa8f16ab00c4fa6b158/?dl=1") + os.system("wget -O Doxyfile https://cloud.tsinghua.edu.cn/f/caf3c3aa518248d5ad73/?dl=1") + os.system("tar -xzvf doxygen.tar.gz") +#run docxygen +fix_config(f'{cache_path}/docxygen/Doxyfile', f'{cache_path}/docxygen/doxygen-1.8.17/bin/Doxyfile', f'{cache_path}/docxygen/jittor', f'{cache_path}/docxygen') +os.chdir(f"{cache_path}/docxygen/doxygen-1.8.17/bin") +os.system(f'./doxygen Doxyfile') diff --git a/python/jittor/script/tmpi b/python/jittor/script/tmpi new file mode 100755 index 00000000..28b8e521 --- /dev/null +++ b/python/jittor/script/tmpi @@ -0,0 +1,117 @@ +#!/bin/bash + +# Copyright 2013 Benedikt Morbach +# Distributed under the terms of the GNU General Public License v2 + +# runs multiple MPI processes as a grid in a new tmux window and multiplexes keyboard input to all of them + +additional_vars=( LD_LIBRARY_PATH LD_PRELOAD ) +export "${additional_vars[@]}" + +usage() { + echo 'tmpi: Run multiple MPI processes as a grid in a new tmux window and multiplex keyboard input to all of them.' + echo '' + echo 'Usage:' + echo ' tmpi [number] [command]' + echo '' + echo 'You need to pass at least two arguments.' + echo 'The first argument is the number of processes to use, every argument after that is the commandline to run.' + echo 'If you call this script from outside tmux and your command contains important whitespace then you need to appy two levels of quoting to preserve it.' + echo '' + echo 'LD_LIBRARY_PATH and LD_PRELOAD are passed through, so you can run it like this:' + echo 'LD_LIBRARY_PATH="${PWD}/.libs:${LD_LIBRARY_PATH}" tmpi 16 gdb -q bin/.libs/example' + echo '' + echo 'The new window is set to remain on exit and has to be closed manually. ("C-b + k" by default)' +} + +check_tools() { + tools=( tmux mpirun ) + + for tool in "${tools[@]}"; do + if ! which ${tool}; then + echo "You need to install ${tool} to run this script." + fi + done +} + +if [[ ${#} -lt 2 ]]; then + usage + + exit 1 +fi + +if [[ -z ${TMUX} ]]; then + # it seems we aren't in a tmux session. + # start a new one so that our window doesn't end up in some other session and we have to search it. + # actually start a new server with '-L' to ensure that our environment carries over. + socket=$(mktemp --dry-run tmpi.XXXX) + exec tmux -L ${socket} new-session "${0} ${*}" +fi + +if [[ ${1} == runmpi ]] ; then + # we are being started as one of many processes by mpirun. + shift + + # start the processes in the order of their rank. + # this avoids races, as we have to push the variables in tmux' environment. + # it has the nice side-effect that the panes are also ordered by rank. + while [[ $(cat /tmp/tmpi.lock) -ne ${OMPI_COMM_WORLD_RANK} ]] ; do + sleep 0.02 + done + + # get all the variables that mpirun starts us with so that we can pass them through. + mpi_vars=( $( env | grep -e MPI -e OPAL -e PMIX -e PYTHON -e debug -e PATH | cut -d '=' -f1 ) ) + mpi_vars+=( "${additional_vars[@]}" ) + + # add the variables to tmux' session environment. + # we can't just export them because the process will be started as a child of tmux, not us. + for var in "${mpi_vars[@]}"; do + tmux set-environment -t ${session} "${var}" "${!var}" + done + + x=( $(tmux split-window -P -F '#{pane_pid} #{pane_id}' -t ${window} "${*}") ) + pid=${x[0]} + pane=${x[1]} + + for var in "${mpi_vars[@]}"; do + tmux set-environment -t ${session} -u "${var}" + done + + # kill the dummy pane that opened the new window + [[ ${OMPI_COMM_WORLD_RANK} -eq 0 ]] && tmux kill-pane -t ${dummy} &> /dev/null + + # set the window to tiled mode. + # have to do this after every new pane is spawned because otherwise the splits get + # smaller and smaller until tmux refuses to open new panes, despite plenty of space being left. + tmux select-layout -t ${pane} tiled &> /dev/null + + # let the next process start + echo $((${OMPI_COMM_WORLD_RANK}+1)) > /tmp/tmpi.lock + + # don't exit here as mpirun needs to be kept alive and it would also exit. + while [[ -d /proc/${pid} ]]; do + sleep 1 + done +else + # we are the parent and set everything up before we start ourselves a bunch of times via mpirun. + processes=${1} + self=${0} + shift + + # create an empty new dummy window which we sill later split up for the mpi processes. + x=( $(tmux new-window ${session} -P -F '#{pane_id} #{window_id} #{session_id}') ) + export dummy=${x[0]} + export window=${x[1]} + export session=${x[2]} + + # syncronize input to all panes. + tmux set-window-option -t ${window} synchronize-panes on &> /dev/null + tmux set-window-option -t ${window} remain-on-exit on &> /dev/null + + # always start with rank 0. + echo 0 > /tmp/tmpi.lock + + # re-execute ourself to spawn of the processes. + echo mpirun ${HOSTS_ARGS} ${MPI_ARGS} -np ${processes} ${self} runmpi "${@}" + mpirun ${HOSTS_ARGS} ${MPI_ARGS} -np ${processes} ${self} runmpi "${@}" +fi diff --git a/python/jittor/script/update.sh b/python/jittor/script/update.sh new file mode 100755 index 00000000..ac2a9428 --- /dev/null +++ b/python/jittor/script/update.sh @@ -0,0 +1,8 @@ +#!/bin/bash +bpath=$(dirname "${BASH_SOURCE[0]}") +cd $bpath +cd .. +pwd +git fetch --all +git reset --hard origin/master +python3.7 -c "import jittor" \ No newline at end of file diff --git a/python/jittor/sparse.py b/python/jittor/sparse.py new file mode 100644 index 00000000..795622be --- /dev/null +++ b/python/jittor/sparse.py @@ -0,0 +1,54 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Dun Liang . +# Xiangli Li <190569238@qq.com> +# +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** + +import jittor as jt +import numpy as np + +class SparseVar: + def __init__(self,indices,values,shape): + assert isinstance(indices,jt.Var) and isinstance(values,jt.Var) and isinstance(shape,jt.NanoVector) + self.indices = indices + self.values = values + self.shape = shape + self.ndim = len(shape) + + def _indices(self): + return self.indices + + def _values(self): + return self.values + + def t(self): + indices = list(self.indices.split(1,dim=0)) + indices[-1],indices[-2] = indices[-2],indices[-1] + indices = jt.concat(indices,dim=0) + shape = list(self.shape) + shape[-1],shape[-2] = shape[-2],shape[-1] + shape = jt.NanoVector(shape) + return SparseVar(indices,self.values,shape) + + def to_dense(self): + ret = jt.zeros(self.shape,self.values.dtype) + indices = tuple(self.indices.split(1,dim=0)) + ret[indices]=self.values + return ret + +def sparse_array(indices,values,shape): + return SparseVar(indices,values,shape) + +def spmm(spase_x,y): + assert isinstance(spase_x,SparseVar) and isinstance(y,jt.Var) + assert spase_x.ndim==2 and y.ndim==2 and spase_x.shape[-1]==y.shape[0] + + # TODO + x = spase_x.to_dense() + return jt.matmul(x,y) + \ No newline at end of file diff --git a/python/jittor/src/async_queue.h b/python/jittor/src/async_queue.h new file mode 100644 index 00000000..6b510cc9 --- /dev/null +++ b/python/jittor/src/async_queue.h @@ -0,0 +1,77 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#include +#include +#include +#include + +namespace jittor { + +struct AsyncQueue { + AsyncQueue() : stop(false), prevTaskCompleted(true) { + worker = std::thread([this]() { this->workerThread(); }); + } + + ~AsyncQueue() { + { + std::unique_lock lock(mutex); + stop = true; + condition.notify_all(); + } + worker.join(); + } + + template + void enqueue(F&& f, Args&&... args) { + { + std::unique_lock lock(mutex); + tasks.push(std::bind(std::forward(f), std::forward(args)...)); + } + condition.notify_one(); + } + + void waitAllTasksComplete() { + std::unique_lock lock(mutex); + condition.wait(lock, [this]() { return tasks.empty() && prevTaskCompleted; }); + } + +private: + void workerThread() { + while (true) { + std::function task; + { + std::unique_lock lock(mutex); + condition.wait(lock, [this]() { return stop || (!tasks.empty() && prevTaskCompleted); }); + if (stop && tasks.empty()) { + return; + } + prevTaskCompleted = false; + task = std::move(tasks.front()); + tasks.pop(); + } + task(); + { + std::lock_guard lock(mutex); + prevTaskCompleted = true; + } + condition.notify_one(); // 完成一个任务后通知等待的线程 + } + } + + std::queue> tasks; + std::thread worker; + std::mutex mutex; + std::condition_variable condition; + bool stop; + bool prevTaskCompleted; +}; + +} // jittor diff --git a/python/jittor/src/common.h b/python/jittor/src/common.h new file mode 100644 index 00000000..58c6a37a --- /dev/null +++ b/python/jittor/src/common.h @@ -0,0 +1,45 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include +#include "utils/log.h" +#include "../extern/acl/aclnn/aclnn.h" + +#define JIT_TEST(name) extern void jit_test_ ## name () +void expect_error(std::function func); + +#define VAR_MEMBER_NAME_AND_OFFSET(name, op) { #name , offsetof(struct op, name) } +#define GET_VAR_MEMBER(op, offset) (*((Var**)(((char*)(op))+(offset)))) + +#ifdef __clang__ +#pragma clang diagnostic ignored "-Winvalid-offsetof" +#pragma clang diagnostic ignored "-Wtautological-compare" +#else +#ifdef __GNUC__ +#pragma GCC diagnostic ignored "-Winvalid-offsetof" +#pragma GCC diagnostic ignored "-Wsign-compare" +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#pragma GCC diagnostic ignored "-Wdiv-by-zero" +#endif +#endif + +#ifdef _WIN32 +#ifndef __restrict__ +#define __restrict__ __restrict +#endif +#endif + +#ifdef _MSC_VER +#define __builtin_popcount __popcnt +#endif + +#ifdef HAS_CUDA +#define _HAS_CUDA 1 +#else +#define _HAS_CUDA 0 +#endif diff --git a/python/jittor/src/core.h b/python/jittor/src/core.h new file mode 100644 index 00000000..76839c32 --- /dev/null +++ b/python/jittor/src/core.h @@ -0,0 +1,38 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once + +#include "var.h" +#include "op.h" +#include "var_holder.h" + +namespace jittor { + +// @pyjt(number_of_hold_vars) +inline static uint64 get_number_of_hold_vars() { + return hold_vars.size(); +} + +// @pyjt(number_of_lived_vars) +inline static int64 get_number_of_lived_vars() { + return Var::number_of_lived_vars; +} + +// @pyjt(number_of_lived_ops) +inline static int64 get_number_of_lived_ops() { + return Op::number_of_lived_ops; +} + +// @pyjt(print_trace) +inline static void __print_trace() { + print_trace(); +} + +// @pyjt(grad) +vector _grad(VarHolder* loss, const vector& targets, bool retain_graph=true); + +} // jittor diff --git a/python/jittor/src/event_queue.cc b/python/jittor/src/event_queue.cc new file mode 100644 index 00000000..0219560a --- /dev/null +++ b/python/jittor/src/event_queue.cc @@ -0,0 +1,60 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "event_queue.h" + +namespace jittor { + +#ifdef HAS_CUDA +EventQueue event_queue; + +void EventQueue::Worker::start() { + Worker* self = &event_queue.worker; + while (1) { + Func todo; + { + std::unique_lock l(self->mtx); + event_queue.cv.notify_one(); + self->cv.wait(l); + todo = self->todo; + } + if (!todo) break; + todo(); + } +} + + +void EventQueue::Worker::stop() { + LOGv << "stoping event queue worker..."; + event_queue.worker.run(nullptr); + event_queue.worker.thread.join(); + LOGv << "stopped event queue worker."; +} + +EXTERN_LIB vector cleanup_callback; + +EventQueue::Worker::Worker() : thread(EventQueue::Worker::start) { + cleanup_callback.push_back(&EventQueue::Worker::stop); +} + +void EventQueue::worker_caller() { + int status = OK; + try { + event_queue.func(); + } catch (const std::exception& e) { + LOGe << "Catch error:\n" >> e.what(); + status = ERROR; + } + { + std::lock_guard l(event_queue.mtx); + event_queue.run_sync_done = status; + } +} + +#endif + + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/event_queue.h b/python/jittor/src/event_queue.h new file mode 100644 index 00000000..0561c8ff --- /dev/null +++ b/python/jittor/src/event_queue.h @@ -0,0 +1,95 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include +#include +#include "common.h" + +namespace jittor { + +#ifdef HAS_CUDA +struct EventQueue { + static constexpr int RUNNING = 0; + static constexpr int OK = 1; + static constexpr int ERROR = 2; + typedef void(*Func)(); + + list tasks; + std::condition_variable cv; + std::mutex mtx; + Func func; + volatile int run_sync_done; + + struct Worker { + Func todo; + std::condition_variable cv; + std::mutex mtx; + std::thread thread; + + static void start(); + static void stop(); + + Worker(); + + inline void run(Func func) { + { + std::lock_guard l(mtx); + todo = func; + } + cv.notify_one(); + } + } worker; + + inline void flush() { + list ts; + { + std::lock_guard g(mtx); + ts = move(tasks); + } + for (auto func : ts) + func(); + } + + static void worker_caller(); + + inline int run_sync(Func func) { + // send work to worker and do something by self + std::unique_lock l(mtx); + this->func = func; + run_sync_done = RUNNING; + // send func to worker + worker.run(worker_caller); + while (1) { + // check self work or worker's status + cv.wait(l); + list ts = move(tasks); + l.unlock(); + // do self works + for (auto func : ts) + func(); + l.lock(); + // worker is finished + if (int ret = run_sync_done) + return ret; + } + } + + inline void push(Func func) { + { + std::lock_guard g(mtx); + tasks.push_back(func); + } + cv.notify_one(); + } +}; + +EXTERN_LIB EventQueue event_queue; + +#endif + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/executor.cc b/python/jittor/src/executor.cc new file mode 100644 index 00000000..4b082ef5 --- /dev/null +++ b/python/jittor/src/executor.cc @@ -0,0 +1,741 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang . +// Guoye Yang <498731903@qq.com> +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#include +#ifdef HAS_CUDA +#include +#include "helper_cuda.h" +#include "mem/allocator/cuda_dual_allocator.h" +#include "event_queue.h" +#endif +#include "misc/cuda_flags.h" +#include "executor.h" +#include "var.h" +#include "op.h" +#include "mem/allocator.h" +#include "graph.h" +#include "fused_op.h" +#include "fuser.h" +#include "profiler/profiler_guard.h" +#include "parallel_compiler.h" +#include "memory_profiler.h" +#include "misc/nan_checker.h" +#include "memory_profiler.h" +#include "utils/seh.h" +#include "utils/cache_compile.h" +#include "var_holder.h" +#include "mem/swap.h" +#include "mem/mem_info.h" + +namespace jittor { + +Executor exe; +EXTERN_LIB MemoryProfiler memory_profiler; +DECLARE_FLAG(int, profile_memory_enable); +DEFINE_FLAG(int, gopt_disable, 0, "Disable graph optimizer."); +DEFINE_FLAG(int, use_threading, 0, "Allow to use python threading with jittor."); + +DEFINE_FLAG(int, exec_called, 0, "exec sync called"); + +// from fetch_op.cc +EXTERN_LIB list fetcher_to_free; +// from cuda_managed_allocator +#ifdef HAS_CUDA +DECLARE_FLAG(int, use_cuda_managed_allocator); +#endif + +void load_fused_op(FusedOp& fused_op, vector& fuse_ops, vector& ops, int ll, int rr, int64 tt) { + fused_op.ops.clear(); + fused_op.edges.clear(); + auto ntt = ++tflag_count; + for (int i=ll; icustom_data = fid1; + op->tflag = ntt; + fused_op.ops.push_back(op); + } + LOGvvv << "Prepare fused_op" << fused_op.ops; + fused_op.update_ops(); + for (Op* op : fused_op.ops) { + uint fid1 = op->custom_data; + int iid = 0; + for (auto ve : op->_inputs) { + // this is a control dependency edge, dont used + if (ve.back->index<0) continue; + auto v = ve.node->var(); + iid++; + int iop_id; + int iv_id; + if (v->_inputs.size() && v->input()->tflag == ntt) { + auto e = v->_inputs.front(); + iop_id = e.node->custom_data; + iv_id = e.back->index; + } else { + iv_id = v->custom_data >> 2; + // add iv_id, prevent iv_id jit key overflow + iop_id = fused_op.ops.size() + iv_id; + } + fused_op.edges.emplace_back(iop_id, iv_id, fid1, iid-1); + } + // TODO: can we remove this? + // uint oid = 0; + // for (Var* v : op->outputs()) { + // oid++; + // if (v->tflag != tt) { + // // this var node not belong to current execution + // // this will happend in multiple outputs fuseable op + // // v->custom_data = 0 represents this var cannot be fused + // v->custom_data = 0; + // continue; + // } + // // for (auto o : v->outputs_with_index()) { + // // Op* op2 = o.op; + // // uint iid = o.index; + // // if (op2->tflag != ntt) continue; + // // uint fid2 = op2->custom_data; + // // fused_op.edges.emplace_back(fid1, oid-1, fid2, iid); + // // } + // } + } +} + +static inline void propergate_needed_flags(FusedOp& fused_op) { + auto& ops = fused_op.ops; + for (int i=ops.size()-1; i>=0; i--) { + bool has_need = 0; + auto op = ops[i]; + for (auto o : op->outputs()) + if (o->flags.get(NodeFlags::_needed_by_backward) && + !(o->custom_data&1)) { + has_need = 1; + } + if (has_need) + for (auto i : op->inputs()) { + i->flags.set(NodeFlags::_needed_by_backward); + } + } +} + +void check_op_async_error(Op* op, bool is_fused_op, const std::exception& e, jittor::Log& logf) { + vector stack; + if (is_fused_op) { + FusedOp& fused_op = *((FusedOp*)op); + logf >> "[OP TYPE]:" << "fused_op:("; + for (auto& op : fused_op.ops) + logf << op->name_ex() >> ","; + logf >> ")\n"; + logf >> "[Input]:"; + for (auto& vi : fused_op.vars) + if (vi.type == 0) logf << vi.var->dtype() >> vi.var->shape >> vi.var->name >> ","; + logf << "\n[Output]:"; + Var* ov = nullptr; + for (auto& vi : fused_op.vars) + if (vi.type == 2) { + logf << vi.var->dtype() >> vi.var->shape >> vi.var->name >> ","; + ov = vi.var; + } + if (ov) + stack = get_node_trace(ov); + } else { + logf >> "[OP TYPE]:" << op->name_ex(); + logf << "\n[Input]:"; + for (auto v : op->inputs()) + logf << v->dtype() >> v->shape >> v->name >> ","; + logf << "\n[Output]:"; + Var* ov = nullptr; + for (auto v : op->outputs()) { + logf << v->dtype() >> v->shape >> v->name >> ","; + ov = v; + } + if (ov) + stack = get_node_trace(ov); + } + logf << "\n[Async Backtrace]:"; + if (stack.size()) { + logf << "---"; + for (auto& s : stack) { + logf << "\n " << s.file_path >> ":" >> s.lineno; + if (s.module_type.size()) logf << '<' >> s.module_type >> '>'; + if (s.module_name.size() && s.module_name.find(":") == string::npos) + logf << '[' >> s.module_name >> ']'; + } + } else + logf << "not found, please set env JT_SYNC=1, trace_py_var=3"; + logf << "\n[Reason]:" << e.what(); + jittor::LogFatalVoidify() && logf; +} + +static void top_weak_sync(vector& vars) { + auto t = ++tflag_count; + int64 max_id=0; + for (auto v : vars) { + if (v->is_finished()) continue; + max_id = std::max(v->id, max_id); + v->tflag = t; + } + while (true) { + if (sync_ptr == hold_vars.begin()) + break; + auto next_ptr = std::prev(sync_ptr); + auto v = (*next_ptr)->var; + if (v->id > max_id) break; + sync_ptr = next_ptr; + if (v->tflag == t) continue; + if (v->_outputs.size()) continue; + if (v->is_finished()) continue; + vars.push_back(v); + } +} + +void Executor::run_sync(vector vars, bool device_sync, bool weak_sync) { + exec_called ++; + if (weak_sync && !use_threading) + top_weak_sync(vars); + auto allocator = get_allocator(); + auto temp_allocator = get_allocator(true); + this->allocator = allocator; + this->temp_allocator = temp_allocator; + // bfs find all ops need to run + int op_num = 0; + vector bfs_q; + bfs_q.reserve(vars.size()); + int start_var_num = 0; + while (1) { + op_num = 0; + start_var_num = 0; + bfs_q.clear(); + // get all nodes need to be executed + int need_opt = 0; + auto t = ++tflag_count; + int64 max_id = 0; + for (Var* v : vars) + if (!v->is_finished() && v->tflag != t) { + v->tflag = t; + start_var_num++; + bfs_q.push_back(v); + max_id = std::max(max_id, v->id); + } + for (int i=0; iis_var(); + for (auto i : node->_inputs) + if (i.node->tflag != t && !i.node->is_finished()) { + i.node->tflag = t; + need_opt += i.node->flags.get(NodeFlags::_has_gopt); + bfs_q.push_back(i.node); + } + // this var has been fetched + if (weak_sync || node->flags.get(NodeFlags::_fetch)) { + for (auto& n : node->_outputs) { + // if not in queue and is fetch op + if (n.node->tflag != t && + n.node->pending_liveness && + !n.node->is_finished() && + (n.node->id <= max_id || + n.node->flags.get(NodeFlags::_fetch))) { + n.node->tflag = t; + need_opt += n.node->flags.get(NodeFlags::_has_gopt); + bfs_q.push_back(n.node); + } + } + } + } + if (!need_opt || gopt_disable) break; + for (Node* n : bfs_q) { + if (n->flags.get(NodeFlags::_has_gopt)) { + n->op()->graph_optimize(); + n->flags.set(NodeFlags::_has_gopt, 0); + } + } + } + auto tt = tflag_count; + vector ops; + vector all_vars; + ops.reserve(op_num); + all_vars.reserve(bfs_q.size() - op_num); + for (Node* node : bfs_q) + if (!node->is_var()) { + node->custom_data = ops.size(); + ops.push_back(node->op()); + } else { + // set can't fuse flag to false + node->custom_data = all_vars.size(); + all_vars.push_back(node->var()); + } + int var_num = all_vars.size(); + + // father: father of union-find set + vector father(op_num); + for (int i=0; i int { + int j=i; + while (father[j] != j) j = father[j]; + while (i != j) { + int tmp = father[i]; + father[i] = j; + i = tmp; + } + return j; + }; + vector var_fused(var_num); + + if (V_ON(100)) { + for (uint i=0; itype()==OpType::reduce) st="reduce"; + if (op->type()==OpType::broadcast) st="broadcast"; + if (op->type()==OpType::element) st="element"; + + LOGvvv << "id:" << ops[i]->custom_data << " type:" << + st << " addr:" << op; + for (Var* v : op->inputs()) { + Op* next_op = v->input(); + // continue if is boundary + if (!next_op || next_op->tflag != tt) { + LOGvvv << "input:" << v; + continue; + } + LOGvvv << "input:" << next_op->custom_data << " addr:" << next_op; + } + LOGvvv << ""; + } + } + + count_fuse(tt, start_var_num, ops, all_vars, father, var_fused); + // var_fused represents: + // 0: can fused + // 1: cannot fused + // 2: weak shared(may turn into 1 or 3 by shared operator cutting) + // 3: strong shared(force shared) + vector roots, next(op_num, -1); + vector deps(op_num, 0); + roots.reserve(op_num); + for (int i=0; i queue; + queue.reserve(roots.size()); + + // ** toplogical_sort external ** + // output: + // queue: toplogical order of fused op + { + // queue.clear(); + #ifndef JT_bfs_executor + std::priority_queue> p_queue; + #endif + for (int root : roots) { + for (int i=root; i>=0; i=next[i]) { + Op* op = ops[i]; + for (Var* v : op->inputs()) { + if (v->tflag != tt) continue; + Op* opi = v->input(); + // if those two ops are not fused + if (father[opi->custom_data] != root) { + deps[root]++; + } + } + } + #ifdef JT_bfs_executor + if (deps[root] == 0) + queue.push_back(root); + #else + if (deps[root] == 0) + p_queue.emplace(-ops[root]->order(), root); + #endif + } + #ifdef JT_bfs_executor + for (uint s=0; s=0; i=next[i]) { + Op* op = ops[i]; + for (Var* v : op->outputs()) + { + if (v->tflag == tt) + for (Op* op2 : v->outputs()) + { + if (op2->tflag != tt) continue; + int op2_id = father[op2->custom_data]; + // continue if those two ops are fused + if (op2_id == op_id) continue; + deps[op2_id]--; + #ifdef JT_bfs_executor + if (deps[op2_id] == 0) + queue.push_back(op2_id); + #else + if (deps[op2_id] == 0) + p_queue.emplace(-op2->order(), op2_id); + #endif + } + } + } + } + ASSERTop(queue.size(),==,roots.size()); + } + + // ** toplogical_sort internal ** + // output: + // fuse_ops: fused op id [000|1111|22|3333] + // range: split index ^ ^ ^ ^ ^ + vector fuse_ops; + fuse_ops.reserve(op_num*2); + vector range(queue.size()); + { + vector subgraph; + subgraph.reserve(16); + vector sharegraph; + sharegraph.reserve(16); + vector sharegraph_q; + sharegraph_q.reserve(16); + vector shared_id(op_num, -1); + + // for fused op in reversed order + for (uint rid=0; rid=0; i=next[i], total++) { + Op* op = ops[i]; + for (Var* v : op->inputs()) { + if (v->tflag != tt) continue; + Op* opi = v->input(); + // if those two ops are fused + int opid = opi->custom_data; + auto fopid = father[opid]; + if (fopid == root) + deps[i]++; + else if (shared_id[opid] != root) { + auto& vf = var_fused[v->custom_data]; + // var_fused = 1 cannot share input op + // TODO: check this input op's output var all can be shared + if (vf == 1) + continue; + // if weak share, turn into strong share + if (vf == 2) vf = 3; + // new shared op + deps[opid] = 0; + shared_id[opid] = root; + sharegraph.push_back(opid); + } + } + if (deps[i] == 0) + queue.push_back(i); + } + // find all share graph + uint sn = sharegraph.size(); + for (uint i=0; iinputs()) { + if (v->tflag != tt) continue; + int vi = v->custom_data; + if (var_fused[vi] == 1) + continue; + // if weak share, cut off + if (var_fused[vi] == 2) { + if (sharegraph.size() - sn < 32) + var_fused[vi] = 3; + else { + var_fused[vi] = 1; + continue; + } + } + Op* opi = v->input(); + int opid = opi->custom_data; + int& dep = deps[opid]; + if (shared_id[opid] != root) { + shared_id[opid] = root; + dep = 1; + sharegraph.push_back(opid); + } else + dep ++; + } + } + sharegraph_q.clear(); + for (uint i=0; iinputs()) { + if (v->tflag != tt) continue; + int vi = v->custom_data; + if (var_fused[vi] == 1) + continue; + Op* opi = v->input(); + int opid = opi->custom_data; + int& dep = deps[opid]; + dep --; + if (dep == 0) + sharegraph_q.push_back(opid); + } + } + LOGvvvv << "sharegraph_q" << sharegraph_q; + ASSERTop(sharegraph.size(),==,sharegraph_q.size()); + // topsort fused op internal + for (uint s=0; soutputs()) + if (v->tflag == tt) + for (Op* op2 : v->outputs()) { + if (op2->tflag != tt) continue; + int op2_id = op2->custom_data; + // continue if those two ops are not fused + if (father[op2_id] != root) continue; + deps[op2_id]--; + if (deps[op2_id] == 0) + queue.push_back(op2_id); + } + } + ASSERTop(queue.size(),==,(uint)total); + LOGvvvv << "topsort internal" << queue; + for (int i=(int)sharegraph_q.size()-1; i>=0; i--) + fuse_ops.push_back(sharegraph_q[i]); + for (uint i=0; icustom_data = var_fused[i]==1; + } + FusedOp fused_op; + + // compile all ops, prevent compiling during running + parallel_compile_all_ops(queue, range, fused_op, fuse_ops, ops, tt); + + // running + SetupFreeBuffer setup_free_buffer; + vector outputs_bk; + #ifdef HAS_CUDA + int sync_times = 0; + #endif + auto& jkl = get_jk(); + for (uint rid=0; ridtype() != OpType::other) { + op = &fused_op; + is_fused_op = true; + int ll = (ridinputs()) { + var->tflag = swap_timestamp; + } + for (auto* var : op->inputs()) { + check_and_swap_out(var, allocator); + } + for (auto* var : op->outputs()) { + alloc_with_swap(var, allocator, true); + var->tflag = swap_timestamp; + } + } else { + for (auto* var : op->outputs()) { + var->alloc(allocator); + } + } + if (PREDICT_BRANCH_NOT_TAKEN(profile_memory_enable)) + memory_profiler.check(); + LOGvvv << "Run" << op << "inputs:" << op->inputs() << "outputs:" << op->outputs(); + op->do_prepare(jkl); + bool is_cuda = op->flags.get(NodeFlags::_cuda); + #ifdef HAS_CUDA + if (!is_cuda) { + if (last_is_cuda) { + // if prev op in gpu and this op in cpu + // cuda sync + checkCudaErrors(cudaDeviceSynchronize()); + sync_times++; + } + for (Var* v : op->inputs()) { + if (v->allocator->is_cuda()) + migrate_to_cpu(v, allocator); + } + if (!use_cuda_managed_allocator) { + for (auto* var : op->outputs()) + if (var->allocator->is_cuda()) + migrate_to_cpu(var, allocator); + } + } else { + for (Var* v : op->inputs()) { + if (!v->allocator->is_cuda()) + migrate_to_gpu(v, allocator); + } + for (Var* v : op->outputs()) { + if (!v->allocator->is_cuda()) + migrate_to_gpu(v, allocator); + } + } + #endif + #ifdef NODE_MEMCHECK + if (is_fused_op) { + for (auto& vi : fused_op.vars) + if (vi.type == 0) + ASSERT(vi.var->mem_ptr) << vi.var; + } else { + for (auto* v : op->inputs()) + ASSERT(v->mem_ptr) << v; + } + #endif + last_is_cuda = is_cuda; + // _JT_SEH_START2; + op->do_run_after_prepare(jkl); + // _JT_SEH_END2; + #ifdef HAS_CUDA + // migrate to gpu + if (PREDICT_BRANCH_NOT_TAKEN((!is_cuda && use_cuda && !use_cuda_managed_allocator))) { + for (Var* v : op->outputs()) { + migrate_to_gpu(v, allocator); + } + } + #endif + // record trace data + if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var>=2)) { + trace_data.record_execution(op, is_fused_op, jkl); + #ifdef HAS_CUDA + if (use_cuda) + checkCudaErrors(cudaDeviceSynchronize()); + #endif + } + #ifdef JT_CHECK_NAN + for (Var* var : op->outputs()) + check_nan(var, op); + #endif + #ifdef JT_SYNC + #ifdef HAS_CUDA + checkCudaErrors(cudaGetLastError()); + checkCudaErrors(cudaDeviceSynchronize()); + #endif + #endif + LOGvvv << "Finished Op(" >> op->name() << rid >> + "/" >> queue.size() >> ") output:" << op->outputs(); + if (is_fused_op) { + propergate_needed_flags(fused_op); + for (Var* var : op->outputs()) + var->finish_pending_liveness(); + continue; + } + // release liveness when op is finished + // outputs may change during free, we need to backup it; + outputs_bk.clear(); + for (Var* var : op->outputs()) { + /* only free not need_free output var. + For example o1, o2 = op1(i1) + o2 is not used, so its f:b:p liveness == 0 + when o1 is freed, op2 will be freed, o2 will be freed too. + so no need to free o2 again. + */ + if (!var->need_free()) + outputs_bk.push_back(var); + else { + // TODO: will this cause bug? + var->flags.set(NodeFlags::_finished); + } + } + op->finish_pending_liveness(); + for (Var* var : outputs_bk) + var->finish_pending_liveness(); + } catch (const std::exception& e) { + // log memory info + display_memory_info(__FILELINE__, false, true); + // log jit_key and file location + op->do_prepare(jkl); + string jit_src_path = Op::get_filename_from_jit_key(jkl.to_cstring(), ".cc"); + jittor::Log logf(__FILELINE__, 'f', 0); + logf << "\nExecute fused operator(" >> rid >> '/' >> queue.size() >> ")" + << "failed."; + if (jit_compiler::file_exist(jit_src_path)) + logf << "\n[JIT Source]:" << jit_src_path << "\n"; + check_op_async_error(op, is_fused_op, e, logf); + } + } + LOGvv << "All" << op_num << "ops finished, return vars:" << vars; + for (Var* v : vars) ASSERT(v->mem_ptr || v->flags.get(NodeFlags::_is_swapped) || !v->backward_liveness) << v; + // clean fetcher free buffer + fetcher_to_free.clear(); + #ifdef HAS_CUDA + if (device_sync && use_cuda) { + last_is_cuda = false; + sync_times++; + try { + // CHECK(EventQueue::OK == event_queue.run_sync([]() { + checkCudaErrors(cudaDeviceSynchronize()); + // })); + // TODO: run_sync cause hang, tmp fix it + } catch (const std::exception& e) { + // log memory info + display_memory_info(__FILELINE__, false, true); + throw; + } + event_queue.flush(); + } + LOGvv << "cudaDeviceSynchronize times:" << sync_times << "/" < allocation_map; +unordered_map size_map; + +extern "C" void* jittor_cuda_malloc(void*, size_t size, int device_id) { + size_t allocation; + void* ptr=exe.allocator->alloc(size, allocation); + allocation_map[ptr]=allocation; + size_map[ptr]=size; + return ptr; +} + +extern "C" void jittor_cuda_free(void*, void* ptr, int device_id) { + exe.allocator->free(ptr, size_map[ptr], allocation_map[ptr]); +} + +extern "C" void* get_jittor_cuda_malloc() { + return (void*)jittor_cuda_malloc; +} + +extern "C" void* get_jittor_cuda_free() { + return (void*)jittor_cuda_free; +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/executor.h b/python/jittor/src/executor.h new file mode 100644 index 00000000..9195ebe0 --- /dev/null +++ b/python/jittor/src/executor.h @@ -0,0 +1,35 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang . +// Guoye Yang <498731903@qq.com> +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include "mem/allocator.h" +#ifdef HAS_CUDA +#include +#include "helper_cuda.h" +#endif + +namespace jittor { + +struct Executor { + Allocator* allocator; + Allocator* temp_allocator; + bool last_is_cuda = false; + void run_sync(vector vars, bool device_sync, bool weak_sync=true); + + inline Allocation alloc_temp(size_t size) { + return Allocation(temp_allocator, size); + } +}; + +EXTERN_LIB Executor exe; + +void load_fused_op(FusedOp& fused_op, vector& fuse_ops, vector& ops, int ll, int rr, int64 tt); + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/fused_op.cc b/python/jittor/src/fused_op.cc new file mode 100644 index 00000000..5220d5ad --- /dev/null +++ b/python/jittor/src/fused_op.cc @@ -0,0 +1,274 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "fused_op.h" +#include "var.h" +#include "op_compiler.h" +#include "profiler/profiler.h" +#include "misc/fast_shared_ptr.h" +#include "misc/cuda_flags.h" + +namespace jittor { + +#ifndef JIT + +string_view_map jit_fused_ops; + +std::ostream& operator<<(std::ostream& os, const VarInfo& vi) { + return os << vi.var << " type:" << vi.type; +} + +int FusedOp::get_loop_option(const string& key, const int& _default) { + auto iter = loop_options->find(key); + return iter == loop_options->end() ? _default : iter->second; +} + +loop_options_t& FusedOp::get_loop_options_tuned() { + loop_options_tuned = *loop_options_origin; + loop_options = &loop_options_tuned; + return loop_options_tuned; +} + +void FusedOp::update_jit_key() { + JK& jk = get_jk(); + jk.clear(); + do_jit_prepare(jk); +} + +void FusedOp::update_ops() { + loop_options_merged.clear(); + loop_options_tuned.clear(); + loop_options = loop_options_origin = nullptr; + + _inputs.clear(); + _outputs.clear(); + vars.clear(); + for (Op* op : ops) { + for (Var* o : op->outputs()) { + if (o->loop_options) { + if (loop_options_origin == nullptr) + loop_options_origin = &o->loop_options.data(); + else if (loop_options_origin != &o->loop_options.data()) { + // merge loop options + for (auto& kv : o->loop_options.data()) + loop_options_merged[kv.first] = kv.second; + } + } + // bit0 represents can fuse or not + if (o->custom_data&1) + // this var can not fuse + _outputs.emplace_back((Node*)o, 0); + } + } + + if (loop_options_origin) { + if (loop_options_merged.size()) { + // merge loop_options_origin into loop_options_merged + for (auto& kv : *loop_options_origin) + loop_options_merged.emplace(kv); + } + } else { + loop_options_origin = &loop_options_merged; + } + loop_options = loop_options_origin; + + ASSERT(outputs().size()); + LOGvvvv << "set fused output" << outputs(); + + // var.custom_data + // meaning of custom_data&1(input): 1: cannot fuse, 0 can fuse + // meaning of custom_data&2: visited or not + // meaning of custom_data>>2: index of vars + + // op.custom_data: opid + for (uint i=0; icustom_data = i; + for (Var* i : opi->inputs()) { + i->custom_data &= 1; + } + for (Var* o : opi->outputs()) { + o->custom_data &= 1; + } + } + for (Op* opi : ops) { + for (Var* i : opi->inputs()) { + auto &c = i->custom_data; + // if not visited + if (!(c&2)) { + c += 2 + vars.size()*4; + vars.push_back({i, 0}); + _inputs.emplace_back((Node*)i); + } + } + for (Var* o : opi->outputs()) { + auto &c = o->custom_data; + // if not visited + if (!(c&2)) { + c += 2 + vars.size()*4; + // intermediate(can fuse) or output + vars.push_back({o, int((c&1)+1)}); + } + } + } + LOGvvvv << "Var info" << vars; +} + + +FusedOp::FusedOp() { + Op::number_of_lived_ops--; +} + +FusedOp::FusedOp(const FusedOp& other) { + Op::number_of_lived_ops--; + ops = other.ops; + edges = other.edges; + vars = other.vars; + loop_options_merged = other.loop_options_merged; + loop_options_tuned = other.loop_options_tuned; + loop_options = other.loop_options; + loop_options_origin = other.loop_options_origin; + context = other.context; +} + +FusedOp::~FusedOp() { + _inputs.clear(); + _outputs.clear(); + Op::number_of_lived_ops++; +} + +void FusedOp::infer_shape() { + for (Op* op : ops) { + op->init(); + } +} + +void FusedOp::statistics(uint64_t& in, uint64_t& out, uint64_t& compute) { + in = out = compute = 0; + for (auto& vi : vars) { + compute = std::max(compute, (uint64_t)vi.var->num); + if (vi.type == 0) in += vi.var->size; + if (vi.type == 2) out += vi.var->size; + } +} + +void FusedOp::do_jit_prepare(JK& jk) { + jk.clear(); + for (uint i=0; iname(); + op->jit_prepare(jk); + } + jk << "«JIT:1"; + if (!use_cuda) { + // only cpu + jk << "«JIT_cpu:1"; + this->flags.set(NodeFlags::_cuda, 0); + this->flags.set(NodeFlags::_cpu, 1); + } else { + jk << "«JIT_cuda:1"; + this->flags.set(NodeFlags::_cpu, 0); + this->flags.set(NodeFlags::_cuda, 1); + } + jk << "«graph:"; + for (auto& t : edges) { + uint i,j,k,l; + std::tie(i,j,k,l) = t; + jk << JK::hex2(i) << JK::hex1(j) << JK::hex2(k) << JK::hex1(l) << ','; + } + jk << "«var_info:" << JK::val; + bool use_int64_t = false; + for (auto& vi : vars) { + jk << JK::hex1(vi.type) << JK::hex1(vi.var->shape.size()); + if (vi.type != 1 && vi.var->num >= std::numeric_limits::max()) + use_int64_t = true; + } + if (use_int64_t) + jk << "«index_t:int64"; + else + jk << "«index_t:int32"; + if (loop_options->size()) { + if (get_loop_option("compile_shapes")) { + jk << "«shapes:"; + for (auto& vi : vars) { + jk << '['; + for (auto a : vi.var->shape) + jk << a << ','; + jk << "],"; + } + } + jk << "«choices:"; + for (auto& kv : *loop_options) { + if (kv.first.size() && kv.first[0] != '_') + jk << kv.first << ':' << kv.second << ','; + } + } + jk.finilize(); +} + +void FusedOp::do_prepare(JK& jk) { + do_jit_prepare(jk); +} + +void FusedOp::do_run_after_prepare(JK& jk) { + const char* jit_key = jk.to_cstring(); + auto iter = jit_fused_ops.find(string_view(jit_key, jk.size)); + if (iter != jit_fused_ops.end()) { + LOGvvv << "Jit fused op key found:" << jit_key << "jit op entry:" << (void*)iter->second; + context = iter->second; + iter->second->vrm.fop = this; + Profiler::record_and_run(iter->second->entry, this, jit_key); + return; + } + LOGvv << "Jit op key not found:" << jit_key; + // compile JIT op + context = new FusedOpContext(); + context->setup(this); + string prev_jit_key = jit_key; + context->entry = OpCompiler::do_compile(this); + string new_jit_key = get_jit_key(jk); + jit_fused_ops[new_jit_key] = jit_fused_ops[prev_jit_key] = context; + jit_key_mapper[prev_jit_key] = new_jit_key; + LOGvv << "Get jit op entry:" << (void*)(context->entry); + Profiler::record_and_run(context->entry, this, new_jit_key.c_str()); +} + +void FusedOpContext::setup(FusedOp* fop) { + node_id.clear(); + vrm.fop = fop; + for (int i=0; iops.size(); i++) + node_id[fop->ops[i]] = i; + for (int i=0; ivars.size(); i++) + node_id[fop->vars[i].var] = i; +} + +int FusedOp::get_node_id(Node* node) { + ASSERT(context); + return context->node_id.at(node); +} + +int FusedOp::has(Node* node) { + ASSERT(context); + return context->node_id.count(node); +} + +void FusedOp::do_run() { + JK& jk = get_jk(); + do_prepare(jk); + do_run_after_prepare(jk); +} + +#else // JIT +void FusedOp::jit_run() { + for (uint i=0; iinputs() << ops[i]->outputs(); + ops[i]->do_run(); + } +} +#endif // JIT + +} diff --git a/python/jittor/src/fused_op.h b/python/jittor/src/fused_op.h new file mode 100644 index 00000000..3611fc04 --- /dev/null +++ b/python/jittor/src/fused_op.h @@ -0,0 +1,62 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" +#include "opt/var_relay.h" + +namespace jittor { + +struct VarInfo { + Var* var; + // 0: input, 1: intermediate, 2: output + int type; +}; +std::ostream& operator<<(std::ostream& os, const VarInfo& vi); + +struct FusedOpContext { + VarRelayManager vrm; + jit_op_entry_t entry; + unordered_map node_id; + void setup(FusedOp* fop); +}; + +EXTERN_LIB string_view_map jit_fused_ops; + +struct FusedOp final : Op { + vector ops; + // edges: [[i,j,k,l], ...] represents opi.output(j) == opk.input(l) + vector> edges; + vector vars; + loop_options_t loop_options_merged, loop_options_tuned; + loop_options_t* loop_options, * loop_options_origin; + loop_options_t& get_loop_options_tuned(); + FusedOpContext* context; + + int get_node_id(Node* node); + int has(Node* node); + void update_ops(); + FusedOp(); + FusedOp(const FusedOp& other); + ~FusedOp(); + + int get_loop_option(const string& key, const int& _default=0); + void add_loop_option_candidate(const string& key, int x); + void update_jit_key(); + + const char* name() const override { return "fused"; } + void statistics(uint64_t& in, uint64_t& out, uint64_t& compute) override; + void infer_shape() override; + void do_jit_prepare(JK& jk) override; + void do_prepare(JK& jk) override; + void do_run_after_prepare(JK& jk) override; + void do_run() override; +#ifdef JIT + void jit_run(); +#endif +}; + +} \ No newline at end of file diff --git a/python/jittor/src/fuser.h b/python/jittor/src/fuser.h new file mode 100644 index 00000000..d1aa3ca9 --- /dev/null +++ b/python/jittor/src/fuser.h @@ -0,0 +1,17 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guowei Yang <471184555@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +namespace jittor { + +void count_fuse(int64_t tt, int start_var_num, const vector& ops, const vector& vars, vector &father, vector &var_fused); + +} // jittor diff --git a/python/jittor/src/grad.cc b/python/jittor/src/grad.cc new file mode 100644 index 00000000..fb6d6c46 --- /dev/null +++ b/python/jittor/src/grad.cc @@ -0,0 +1,287 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "pybind/py_var_tracer.h" +#include "grad.h" +#include "var.h" +#include "op.h" +#include "graph.h" +#include "ops/op_register.h" +#include "var_holder.h" + +namespace jittor { + +#define PREVENT_LARGE_FUSED_OP 16 + +DECLARE_FLAG(int, auto_mixed_precision_level); + +static auto make_binary = get_op_info("binary") + .get_constructor(); +static auto make_unary = get_op_info("unary") + .get_constructor(); +static auto make_number = get_op_info("number") + .get_constructor(); + +#ifdef _WIN32 +template struct StackIniter { + T* a; + int n; + inline StackIniter(T* a, int n) :a(a), n(n) { + for (int i=0; i __init_##a(a, n); +#else +#define STACK_ALLOC2(T, a, n) T a[n] +#endif + +struct AmpGradGuard { + int amp_reg_bk; + AmpGradGuard(Op* op) { + amp_reg_bk = amp_reg; + amp_reg |= (op->flags.flags >> NodeFlags::_prefer_32); + } + + ~AmpGradGuard() { + amp_reg = amp_reg_bk; + } +}; + +VarPtr make_grad(Op* op, Var* out, Var* dout, Var* x, int x_index) { + if (dout == nullptr) return nullptr; + if (x_index<0) return nullptr; + LOGvvvv << "Make grad op:" >> op->name() << "inputs:" >> op->inputs() + << "out:" >> out << "dout:" >> dout << "x:" >> x << "xid:" >> x_index; + AmpGradGuard agg(op); + auto dx = op->grad(out, dout, x, x_index); + if (x->loop_options) + dx->loop_options = x->loop_options; + return dx; +} + +inline static void assign_attrs(Var* a, Var* b) { + if (b->flags.get(NodeFlags::_stop_fuse)) + a->flags.set(NodeFlags::_stop_fuse); +} + +map grad_breaks; + +void warn_grad_break(int i, Var* v) { + if (grad_breaks.count(v->name.c_str())) return; + grad_breaks[v->name.c_str()] = 1; + LOGw << "grads[">>i>>"] '">> v->name>>"' doesn't have gradient. It will be set to zero:" << v; +} + +vector grad(Var* loss, vector targets, bool retain_graph) { + LOGvv << "loss:" >> loss << "targets:" >> targets; + CHECK(loss->is_float()) << "Loss should be float"; + for (Var* var : targets) + CHECK(var->is_float()) << "Targets of grad should be float"; + // successors of targets + vector ts(targets.begin(), targets.end()); + // bfs visit find all successors of targets + LOGvv << "Size of successors:" << ts.size(); + bfs_forward(ts, [](Node*){ return true; }); + vector gnodes; + gnodes.reserve(ts.size()); + auto nt = tflag_count; + if (loss->tflag == nt) + gnodes.push_back(loss); + bfs_backward(gnodes, [&](Node* node) { + if (node->tflag != nt) + return false; + if (node->is_stop_grad()) + return false; + return true; + }); + LOGvv << "Size of grad nodes:" << gnodes.size(); + + vector sorted; + toplogical_sort_backward(gnodes, sorted, [](Node*){}); + nt = tflag_count; + vector gvars; + gvars.reserve(sorted.size()); + for (Node* node : sorted) + if (node->is_var()) { + Var* v = node->var(); + v->custom_data = gvars.size(); + gvars.push_back(v); + } + LOGvv << "Size of grad vars:" << gvars.size(); + + vector grads(gvars.size()); + vector results(targets.size()); + vector target_id(targets.size()); + for (int i=0; itflag == nt) ? + var->custom_data : -1; + } + + if (grads.size()) { + grads[0] = make_number(1.f, loss); + assign_attrs(grads[0].ptr, loss); + } + + vector> id_buffer; + id_buffer.reserve(sorted.size()+10); + + // backup id in custum data + for (int i=1; ioutputs_with_index()) { + Op* op = it.op; + auto index = it.index; + if (op->tflag != nt) continue; + id_buffer.emplace_back(op, index); + + // backward together + if (op->flags.get(NodeFlags::_grads)) { + // dont backward next time + op->tflag = 0; + for (Var* out : op->outputs()) { + id_buffer.emplace_back( + out, + out->tflag == nt ? out->custom_data : -1); + } + for (Var* in : op->inputs()) { + id_buffer.emplace_back( + in, + in->tflag == nt ? in->custom_data : -1); + } + } else { + // single var backward + for (Var* out : op->outputs()) { + id_buffer.emplace_back( + out, + out->tflag == nt ? out->custom_data : -1); + } + } + } + // end of var output + id_buffer.emplace_back(nullptr, 0); + } + + // real backward construction from prev backuped ids + int j=0; + for (int i=1; ioutputs_with_index())" + while (id_buffer[j].first) { + Op* op = id_buffer[j].first->op(); + auto index = id_buffer[j].second; + j++; + auto n_o = op->outputs().size(); + + if (op->flags.get(NodeFlags::_grads)) { + // backward together + auto n_i = op->inputs().size(); + STACK_ALLOC(Var*, douts, n_o); + STACK_ALLOC2(VarPtr, dins, n_i); + // dump "for (Var* out : op->outputs())" + for (int i=0; i=0) { + douts[i] = grads[id]; + } else + douts[i] = nullptr; + } + trace_grad_op = op; + { + AmpGradGuard agg(op); + op->grads(douts, dins); + } + // dump "for (Var* in : op->inputs())" + for (int i=0; i=0) { + auto& din = dins[i]; + auto& grad = grads[id]; + if (din && grad) { + grad = make_binary(grad, din, ns_add); + } else + grad = move(din); + } + } + } else { + // single var backward + // dump "for (Var* out : op->outputs())" + for (int i=0; ivar(); + if (id<0) continue; + Var* dout = grads[id]; + trace_grad_op = op; + VarPtr dvar = make_grad(op, out, dout, var, index); + if (dvar && dvar->num>=0 && var->num>0) + // var->num == 0 represents a any match var + ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size()) + << "dvar" << dvar << "var" << var; + if (!grad) + grad = move(dvar); + else if (dvar) { + grad = make_binary(grad, dvar, ns_add); + #ifdef PREVENT_LARGE_FUSED_OP + gsum ++; + if (gsum>=PREVENT_LARGE_FUSED_OP) { + // TODO: this is a dirty fix for + // stopping fuse lots of op together, + // try to find a better solution + grad->flags.set(NodeFlags::_stop_fuse); + } + #endif + assign_attrs(grad.ptr, var); + } + } + } + } + if (auto_mixed_precision_level == 3 && grad->ns != var->ns) { + grad = make_unary(grad, var->ns); + } + } + trace_grad_op = nullptr; + // set zero grad + for (size_t i=0; i=0) + grad = move(grads[id]); + if (!grad) { + // TODO: better warning message + warn_grad_break(i, var); + grad = make_number(0.f, var); + assign_attrs(grad.ptr, var); + } + } + if (!retain_graph) { + auto t = ++tflag_count; + for (auto& vh : hold_vars) + if (vh->var->tflag != t) { + vh->var->tflag = t; + } + SetupFreeBuffer setup_free_buffer; + for (int i=int(gvars.size())-1; i>=0; i--) + if (gvars[i]->tflag != t && gvars[i]->backward_liveness) + gvars[i]->set_stop_grad(); + for (int i=0; iset_stop_grad(); + } + return results; +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/grad.h b/python/jittor/src/grad.h new file mode 100644 index 00000000..e3ba9e62 --- /dev/null +++ b/python/jittor/src/grad.h @@ -0,0 +1,21 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "ops/tape_op.h" +#include "common.h" + +namespace jittor { + +vector grad(Var* loss, vector targets, bool retain_graph=true); + +// @pyjt(tape_together) +void tape_together( + const vector& taped_inputs, + const vector& taped_outputs, + GradCallback&& grad_callback +); + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/graph.cc b/python/jittor/src/graph.cc new file mode 100644 index 00000000..2de5051e --- /dev/null +++ b/python/jittor/src/graph.cc @@ -0,0 +1,180 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#include "graph.h" +#include "var_holder.h" +#include "var.h" + +namespace jittor { + +DEFINE_FLAG(int, check_graph, 0, "Unify graph sanity check."); + + +template +string ss_convert(T x) { + std::stringstream ss; + ss << x; + return ss.str(); +} + +void do_graph_check() { + vector queue; + unordered_map visited; + for (auto& vh : hold_vars) { + if (0==visited[vh->var]++) + queue.push_back(vh->var); + } + LOGvv << "Check hold_vars size" << queue.size(); + int vhsize = queue.size(); + for (auto* node : queue) { + // ASSERTop(node->forward_liveness,>,0); + ASSERTop(node->backward_liveness,>,0); + } + for (uint i=0; iinputs()) { + if (visited.count(i)) continue; + visited[i] = 0; + queue.push_back(i); + } + } + LOGvv << "Check all var size" << queue.size(); + for (int i=0; i<(int)queue.size(); i++) { + auto* node = queue[i]; + LOGvvvv << "Check node" << i << node; + int f=0, b=0, p=0; + if (iinputs()) { + if (i->is_stop_grad()) continue; + if (!i->forward_liveness) continue; + f ++; + } + for (auto* o : node->outputs()) { + if (o->backward_liveness) + b ++; + if (o->pending_liveness && !o->is_finished()) + p++; + } + // if (f>0 && b>0 && !node->is_finished()) p++; + if (f!=node->forward_liveness || b!=node->backward_liveness || p!=node->pending_liveness) { + LOGf << "ERROR" << node << '\n' + << f << b << p << i << '\n' + << node->inputs() << '\n' + << node->outputs(); + continue; + } + } + for (auto& kv : lived_nodes) { + if (!kv.second) continue; + auto* node = (Node*) kv.first; + if (!visited.count(node) && node->tflag != -1) { + if (node->is_var() && node->_inputs.size()) + continue; + LOGf << "ERROR dnode" << (void*)node << kv.second << node; + } + } +} + +DumpGraphs dump_all_graphs() { + DumpGraphs graphs; + vector queue; + auto t = ++tflag_count; + for (auto& vh : hold_vars) + if (vh->var->tflag != t) { + vh->var->tflag = t; + queue.push_back(vh->var); + graphs.hold_vars.emplace_back(ss_convert(vh->var)); + } + bfs_both(queue, [](Node*){return true;}); + std::sort(queue.begin(), queue.end(), + [](Node* a, Node* b) { return a->id < b->id;}); + for (uint i=0; icustom_data = i; + for (Node* node : queue) { + graphs.nodes_info.emplace_back(ss_convert(node)); + + graphs.inputs.emplace_back(); + auto& inputs = graphs.inputs.back(); + inputs.reserve(node->_inputs.size()); + for (auto i : node->_inputs) + inputs.push_back(i.node->custom_data); + + graphs.outputs.emplace_back(); + auto& outputs = graphs.outputs.back(); + outputs.reserve(node->_outputs.size()); + for (auto o : node->_outputs) + outputs.push_back(o.node->custom_data); + } + return graphs; +} + +void clean_graph() { + vector queue; + auto t = ++tflag_count; + for (auto& vh : hold_vars) + if (vh->var->tflag != t) { + vh->var->tflag = t; + queue.push_back(vh->var); + } + bfs_both(queue, [](Node*){return true;}); + t = ++tflag_count; + for (auto& vh : hold_vars) + vh->var->tflag = t; + SetupFreeBuffer setup_free_buffer; + for (auto node : queue) { + if (node->tflag != t) { + node->set_stop_grad(); + } + } +} + +void check_circle(Node* s) { + vector q = {s}; + vector fa = {-1}; + unordered_set visited = {s}; + for (int i=0; ioutputs()) { + if (o == s) { + LOGe << "Found circle:"; + int j=i; + vector nodes{o}; + while (j) { + nodes.push_back(q[j]); + j = fa[j]; + } + for (int i=0; iinputs()) { + if (ii == in) break; + in_id ++; + } + for (auto oo : n->outputs()) { + if (oo == out) break; + out_id ++; + } + LOGe << n << "in:" >> in_id >> '/' >> n->inputs().size() << "out:" >> out_id >> '/' >> n->outputs().size(); + } + LOGf << "found circle"; + } + if (!visited.count(o)) { + visited.emplace(o); + q.push_back(o); + fa.push_back(i); + } + } + } + +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/graph.h b/python/jittor/src/graph.h new file mode 100644 index 00000000..10d2627d --- /dev/null +++ b/python/jittor/src/graph.h @@ -0,0 +1,155 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "node.h" + +namespace jittor { + +DECLARE_FLAG(int, check_graph); + +// this struct is used for debug and visualization +// @pyjt(DumpGraphs) +struct DumpGraphs { + // @pyjt(hold_vars) + vector hold_vars; + // @pyjt(nodes_info) + vector nodes_info; + // @pyjt(inputs) + vector> inputs; + // @pyjt(outputs) + vector> outputs; +}; + +// @pyjt(graph_check) +void do_graph_check(); +inline void graph_check() { if (check_graph) do_graph_check(); }; +// @pyjt(dump_all_graphs) +DumpGraphs dump_all_graphs(); +/** + * Clean graph, try to reduce memory usage. + * This operation will stop grad for all previous nodes. + * Backpropegation for previous nodes will be unavailable. + * This operation offen used between train and eval. + */ +// @pyjt(clean_graph) +void clean_graph(); + +template +void bfs_backward(vector& queue, Func&& func) { + auto t = ++tflag_count; + size_t i=0; + for (Node* node : queue) node->tflag = t; + while (i < queue.size()) { + Node* node = queue[i++]; + for (auto i : node->_inputs) + if (i.node->tflag != t && func(i.node)) { + i.node->tflag = t; + queue.push_back(i.node); + } + } +} + +template +void bfs_backward(vector& seed, vector& queue, Func&& func) { + for (Node* node : seed) + if (func(node)) queue.push_back(node); + bfs_backward(queue, func); +} + +template +void bfs_forward(vector& queue, Func&& func) { + auto t = ++tflag_count; + size_t i=0; + for (Node* node : queue) node->tflag = t; + while (i < queue.size()) { + Node* node = queue[i++]; + for (auto o : node->_outputs) + if (o.node->tflag != t && func(o.node)) { + o.node->tflag = t; + queue.push_back(o.node); + } + } +} + +template +void bfs_both(vector& queue, Func&& func) { + auto t = ++tflag_count; + size_t i=0; + for (Node* node : queue) node->tflag = t; + while (i < queue.size()) { + Node* node = queue[i++]; + for (auto o : node->_outputs) + if (o.node->tflag != t && func(o.node)) { + o.node->tflag = t; + queue.push_back(o.node); + } + for (auto i : node->_inputs) + if (i.node->tflag != t && func(i.node)) { + i.node->tflag = t; + queue.push_back(i.node); + } + } +} + +template +void toplogical_sort_forward(vector& nodes, vector& sorted, Func&& func) { + auto t = ++tflag_count; + sorted.reserve(nodes.size()); + for (auto node : nodes) node->tflag = t; + for (auto node : nodes) { + auto& deps = node->custom_data; + deps = 0; + for (auto i : node->_inputs) + if (i.node->tflag == t) + deps++; + if (deps == 0) sorted.push_back(node); + } + size_t i=0; + while (i < sorted.size()) { + Node* node = sorted[i++]; + for (auto o : node->_outputs) + if (o.node->tflag == t) { + o.node->custom_data--; + if (o.node->custom_data == 0) + sorted.push_back(o.node); + } + func(node); + } + ASSERTop(nodes.size(),==,sorted.size()); +} + + +template +void toplogical_sort_backward(vector& nodes, vector& sorted, Func&& func) { + auto t = ++tflag_count; + sorted.reserve(nodes.size()); + for (auto node : nodes) node->tflag = t; + for (auto node : nodes) { + auto& deps = node->custom_data; + deps = 0; + for (auto o : node->_outputs) + if (o.node->tflag == t) + deps++; + if (deps == 0) sorted.push_back(node); + } + size_t i=0; + while (i < sorted.size()) { + Node* node = sorted[i++]; + for (auto i : node->_inputs) + if (i.node->tflag == t) { + i.node->custom_data--; + if (i.node->custom_data == 0) + sorted.push_back(i.node); + } + func(node); + } + ASSERTop(nodes.size(),==,sorted.size()); +} + +void check_circle(Node* s); + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/init.cc b/python/jittor/src/init.cc new file mode 100644 index 00000000..51c00bdc --- /dev/null +++ b/python/jittor/src/init.cc @@ -0,0 +1,110 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#ifdef HAS_CUDA +#include +#include "helper_cuda.h" +#include "misc/cuda_flags.h" +#endif +#include + +#include "init.h" +#include "ops/op_register.h" +#include "var.h" +#include "op.h" +#include "executor.h" + +namespace jittor { + +DEFINE_FLAG(vector, cuda_archs, {}, "Cuda arch"); +DEFINE_FLAG(int, use_tensorcore, 0, "use tensor core"); + +unique_ptr eng; + +vector callbacks; +int current_seed; +int64 current_offset; + +// fron fetch_op.cc +EXTERN_LIB list fetcher; +EXTERN_LIB list fetcher_to_free; +EXTERN_LIB vector cleanup_callback; +EXTERN_LIB bool exited; + +void cleanup() { + exited = true; + fetcher_to_free.clear(); + fetcher.clear(); + for (auto cb : cleanup_callback) + cb(); + cleanup_callback.clear(); +} + +static void init_cuda_devices() { +#ifdef IS_CUDA + if (cuda_archs.size()) return; + int count=0; + cudaGetDeviceCount(&count); + for (int i=0; i. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include "common.h" + +namespace jittor { + +typedef void (*set_seed_callback)(int); + +void init(); + +/** +Sets the seed of jittor random number generator. Also see @jittor.set_global_seed. + +---------------- + +* [in] seed: a python number. + + */ +// @pyjt(set_seed, seed) +void set_seed(int seed); + +/** +Returns the seed of jittor random number generator. + */ +// @pyjt(get_seed) +int get_seed(); + +void add_set_seed_callback(set_seed_callback callback); + +extern +std::default_random_engine* get_random_engine(); + +// things need to be clean before python exit +// @pyjt(cleanup) +void cleanup(); + +// @pyjt(jt_init_subprocess) +void jt_init_subprocess(); + +} // jittor diff --git a/python/jittor/src/jit_compiler.cc b/python/jittor/src/jit_compiler.cc new file mode 100644 index 00000000..cf4c1128 --- /dev/null +++ b/python/jittor/src/jit_compiler.cc @@ -0,0 +1,277 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#include +#ifdef _WIN32 +#include +#else +#include +#endif +#include + +#include "jit_compiler.h" +#include "op.h" +#include "utils/cache_compile.h" +#include "utils/flags.h" +#include "fused_op.h" +#include "utils/str_utils.h" +JPU(header) + +namespace jittor { + +DEFINE_FLAG(string, jittor_path, "", "Source path of jittor"); +DEFINE_FLAG(string, cc_path, "", "Path of C++ compiler"); +DEFINE_FLAG(string, cc_type, "", "Type of C++ compiler(clang, icc, g++)"); +DEFINE_FLAG(string, cc_flags, "", "Flags of C++ compiler"); +DEFINE_FLAG(string, nvcc_path, "", "Path of CUDA C++ compiler"); +DEFINE_FLAG(string, nvcc_flags, "", "Flags of CUDA C++ compiler"); +DEFINE_FLAG(string, python_path, "", "Path of python interpreter"); +DEFINE_FLAG(string, cache_path, "", "Cache path of jittor"); +DEFINE_FLAG(int, rewrite_op, 1, "Rewrite source file of jit operator or not"); + +vector shsplit(const string& s) { + auto s1 = split(s, " "); + vector s2; + int count = 0; + for (auto& s : s1) { + int nc = 0; + for (auto& c : s) + nc += c=='"' || c=='\''; + if (count&1) { + count += nc; + s2.back() += " "; + s2.back() += s; + } else { + count = nc; + s2.push_back(s); + } + } + return s2; +} + +string fix_cl_flags(const string& cmd, bool is_cuda) { +#ifdef _MSC_VER + auto flags = shsplit(cmd); + vector output, output2; + + for (auto& f : flags) { + if (startswith(f, "-link")) + continue; + else if (startswith(f, "-l")) + output2.push_back(f.substr(2)+".lib"); + else if (startswith(f, "-LIB")) + output2.push_back(f); + else if (startswith(f, "-LD")) + output.push_back(f); + else if (startswith(f, "-L")) + output2.push_back("-LIBPATH:"+f.substr(2)); + else if (f.find(".lib") != string::npos) + output2.push_back(f); + else if (startswith(f, "-DEF:")) + output2.push_back(f); + else if (startswith(f, "-W") || startswith(f,"-f")) + continue; + else if (startswith(f,"-std=")) + output.push_back("-std:"+f.substr(5)); + else if (startswith(f,"-include")) + output.push_back("-FI"); + else if (startswith(f,"-shared")) + output.push_back("-LD"); + else + output.push_back(f); + } + string cmdx = ""; + for (auto& s : output) { + cmdx += s; + cmdx += " "; + } + cmdx += "-link "; + for (auto& s : output2) { + cmdx += s; + cmdx += " "; + } + return cmdx; +#else + auto flags = shsplit(cmd); + vector output; + #ifdef __APPLE__ + vector libpaths; + #endif + + for (auto& f : flags) { + if (startswith(f, "-l") && + (f.find("cpython") != string::npos || + f.find("lib") != string::npos)) { + #ifdef __APPLE__ + auto fname = f.substr(2) + ".so"; + int i; + for (i=libpaths.size()-1; i>=0; i--) { + auto full = libpaths[i] + '/' + fname; + string full2; + for (auto c : full) + if (c != '\"') full2 += c; + if (jit_compiler::file_exist(full2)) { + output.push_back(full2); + break; + } + } + if (i<0) output.push_back(f); + #else + output.push_back("-l:"+f.substr(2)+".so"); + #endif + } + else if (startswith(f, "-L")) { + if (is_cuda) + output.push_back(f+" -Xlinker -rpath="+f.substr(2)); + else + output.push_back(f+" -Wl,-rpath,"+f.substr(2)); + #ifdef __APPLE__ + libpaths.push_back(f.substr(2)); + #endif + } else + output.push_back(f); + } + string cmdx = ""; + for (auto& s : output) { + cmdx += s; + cmdx += " "; + } + return cmdx; +#endif +} + +namespace jit_compiler { + +std::mutex dl_open_mutex; + +jit_op_entry_t load_jit_lib( + string name, string symbol_name="jit_entry", const string& extra_flags="") { + std::lock_guard lock(dl_open_mutex); + const char* msg = ""; + LOGvv << "Opening jit lib:" << name; + #ifdef _WIN32 + void* handle = (void*)LoadLibraryExA(_to_winstr(name).c_str(), nullptr, + LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | + LOAD_LIBRARY_SEARCH_USER_DIRS); + #elif defined(__linux__) + auto flags = RTLD_LAZY | RTLD_DEEPBIND | RTLD_LOCAL; + if (extra_flags.find("GLOBAL_VAR") != string::npos) + flags = RTLD_LAZY | RTLD_DEEPBIND | RTLD_GLOBAL; + void* handle = dlopen(name.c_str(), flags); + msg = dlerror(); + #else + auto flags = RTLD_LAZY | RTLD_LOCAL; + if (extra_flags.find("GLOBAL_VAR") != string::npos) + flags = RTLD_LAZY | RTLD_GLOBAL; + void *handle = dlopen(name.c_str(), flags); + msg = dlerror(); + #endif + + CHECK(handle) << "Cannot open library" << name << ":" << msg; + + #ifdef _WIN32 + auto jit_entry = (jit_op_entry_t)GetProcAddress((HINSTANCE)handle, symbol_name.c_str()); + #else + //dlerror(); + auto jit_entry = (jit_op_entry_t)dlsym(handle, symbol_name.c_str()); + msg = dlerror(); + #endif + CHECK(jit_entry) << "Loading symbol" << symbol_name << "from" << name << "failed:" << msg; + + return jit_entry; +} + +void run_cmd(string cmd, string cwd="") { + if (cwd.size()) cmd = "cd '"+cwd + "' && " + cmd; + LOGvvv << "Run cmd:" << cmd; + system_with_check(cmd.c_str()); +} + +static string get_symbol_name(const string& jit_key) { + int i=0; + while (i=0 && jit_key[i]<=127) i++; + string op_name = i ? jit_key.substr(0, i) : "fused"; + op_name = Op::file_name_to_class_name(op_name); + // _ZN7jittorXyyyyyy7jit_runEv + // jittor::yyyyyy::jit_run + #ifdef _MSC_VER + op_name = "?jit_run@"+op_name+"Op@jittor@@QEAAXXZ"; + #else + op_name = "_ZN6jittor"+S(op_name.size()+2)+op_name+"Op7jit_runEv"; + #endif + return op_name; +} + +jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_cuda_op, const string& extra_flags) { + LOGvv << "Compile op" << jit_key; + // compiler do not allowed filename too long + CHECK(cc_path.size()); + string jit_src_path; + if (is_cuda_op && extra_flags.find("-dc") != string::npos) + jit_src_path = Op::get_filename_from_jit_key(jit_key, ".cu"); + else + jit_src_path = Op::get_filename_from_jit_key(jit_key, ".cc"); + string* src2 = (string*)&src; + string* extra_flags2 = (string*)&extra_flags; + JPU(op_compiler(jit_src_path, *src2, is_cuda_op, *extra_flags2)); + #ifdef _WIN32 + string jit_lib_path = Op::get_filename_from_jit_key(jit_key, ".dll"); + string jit_src_path2 = _to_winstr(jit_src_path); + #else + string jit_lib_path = Op::get_filename_from_jit_key(jit_key, ".so"); + string& jit_src_path2 = jit_src_path; + #endif + string other_src; + LOGvvv << "Generate" << jit_src_path >> "\n" >> src; + if (rewrite_op || !file_exist(jit_src_path2)) + write(jit_src_path2, src); + string cmd; + + auto symbol_name = get_symbol_name(jit_key); +#ifndef _MSC_VER + if (is_cuda_op) { + cmd = "\"" + nvcc_path + "\"" + + " \"" + jit_src_path + "\"" + other_src + + fix_cl_flags(nvcc_flags + extra_flags, is_cuda_op) + + " -o \"" + jit_lib_path + "\""; + if (cmd.find("-dc") != string::npos) { + cmd = python_path+" "+jittor_path+"/utils/dlink_compiler.py " + cmd; + } + } else { + cmd = "\"" + cc_path + "\"" + + " \"" + jit_src_path + "\"" + other_src + + fix_cl_flags(cc_flags + extra_flags, is_cuda_op) + + " -o \"" + jit_lib_path + "\""; +#ifdef __linux__ + cmd = python_path+" "+jittor_path+"/utils/asm_tuner.py " + "--cc_path=" + cmd; +#endif + } +#else // Windows _MSC_VER + if (is_cuda_op) { + cmd = "\"" + nvcc_path + "\"" + + " \"" + jit_src_path + "\"" + other_src + + nvcc_flags + extra_flags + + " -o \"" + jit_lib_path + "\"" + + " -Xlinker -EXPORT:\"" + + symbol_name + "\""; + } else { + cmd = "\"" + cc_path + "\"" + + " \"" + jit_src_path + "\"" + other_src + + " -Fe: \"" + jit_lib_path + "\" " + + fix_cl_flags(cc_flags + extra_flags, is_cuda_op) + " -EXPORT:\"" + + symbol_name + "\""; + } +#endif + cache_compile(cmd, cache_path, jittor_path); + auto jit_entry = load_jit_lib(jit_lib_path, symbol_name, extra_flags); + return jit_entry; +} + +} // jit_compiler +} // jittor \ No newline at end of file diff --git a/python/jittor/src/jit_compiler.h b/python/jittor/src/jit_compiler.h new file mode 100644 index 00000000..00ead363 --- /dev/null +++ b/python/jittor/src/jit_compiler.h @@ -0,0 +1,20 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "common.h" +#include "op_compiler.h" + +namespace jittor { +namespace jit_compiler { + +jit_op_entry_t compile( + const string& jit_key, + const string& src, + const bool is_cuda_op = false, + const string& extra_flags=""); + +} // jit_compiler +} // jittor \ No newline at end of file diff --git a/python/jittor/src/jit_key.cc b/python/jittor/src/jit_key.cc new file mode 100644 index 00000000..ab5f9ebd --- /dev/null +++ b/python/jittor/src/jit_key.cc @@ -0,0 +1,112 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#ifndef _WIN32 +#include +#include +#endif +#include +#include "jit_key.h" +#include "utils/str_utils.h" + +namespace jittor { + +#ifndef _WIN32 +EXTERN_LIB thread_local size_t protected_page; + +static size_t get_buffer_end_page(size_t buffer_end) { + // get the last complete page in buffer + // 4k align : + // | | | | | + // buffer: xxxxxxxxxxxxxxxxxxxxxxxx + // ^ buffer_end_page + size_t buffer_end_page = buffer_end - buffer_end % getpagesize(); + if (buffer_end_page + getpagesize()-1 > buffer_end) + buffer_end_page -= getpagesize(); + return buffer_end_page; +} +#endif + +JitKey::JitKey() { +#ifndef _WIN32 + auto buffer_end_page = get_buffer_end_page((size_t)&buffer[buffer_size-1]); + LOGvv << "protect page" << (void*)buffer_end_page; + ASSERT(0==mprotect((void*)buffer_end_page, getpagesize(), PROT_NONE)); + protected_page = buffer_end_page; +#endif +} + +JitKey::~JitKey() { +#ifndef _WIN32 + auto buffer_end_page = get_buffer_end_page((size_t)&buffer[buffer_size-1]); + LOGvv << "un-protect page" << (void*)buffer_end_page; + mprotect((void*)buffer_end_page, getpagesize(), PROT_READ|PROT_WRITE|PROT_EXEC); + protected_page = 0; +#endif +} + +static void hex_to_dec(string& s) { + // check s is hex or not, if yes, convert to dec + if (!s.size()) return; + unsigned int x; + std::stringstream ss; + ss << std::hex << s; + ss >> x; + s = S(x); +} + +static void convert_itof(string& s) { + uint64 x; + std::stringstream ss; + // itof(0x...) + // ^ ^ + // 7 + ASSERT(s.size()>=8); + ss << std::hex << s.substr(7, s.size()-7-1); + ASSERT(ss >> x); + ss.str(""); ss.clear(); + ss << std::hexfloat << itof(x); + s = ss.str(); + // 0x0p+0 ---> 0x0p0 + if (s.find("p+") != string::npos) + s.erase(s.find("p+")+1, 1); + if (s=="inf") s = "(1.0/0)"; + if (s=="-inf") s = "(-1.0/0)"; + if (s=="nan" || s=="-nan") s = "(0.0/0)"; +} + +vector> parse_jit_keys(const string& s) { + vector> jit_keys; + auto sp = split(s, JitKey::key); + for (auto& ss : sp) { + if (!ss.size()) continue; + string key, val; + char state=0; + for (auto c : ss) { + if (state == 0 && + (c==JK::val || c==JK::hex_val)) { + state = c; + continue; + } + if (state == 0) key += c; + else val += c; + } + if (state == JK::hex_val) + hex_to_dec(val); + if (startswith(val, "itof")) + convert_itof(val); + jit_keys.emplace_back(move(key), move(val)); + } + return jit_keys; +} + +thread_local JitKey jk; + +JK& get_jk() { + return jk; +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/jit_key.h b/python/jittor/src/jit_key.h new file mode 100644 index 00000000..7dd27292 --- /dev/null +++ b/python/jittor/src/jit_key.h @@ -0,0 +1,308 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include "common.h" +#include "misc/nano_string.h" +#include "misc/nano_vector.h" + +namespace jittor { + +struct JitKey { + static constexpr size_t buffer_size = 2*1024*1024; + static constexpr const char + *key = "«", + val = ':', + hex_val = '='; + int64 size=0; + uint64 flags=0; + char buffer[buffer_size]; + + JitKey(); + ~JitKey(); + + inline void clear() {size = flags = 0;} + inline void finilize() { buffer[size] = 0; } + inline bool empty() { return !size; } + inline const char* to_cstring() { + return &buffer[0]; + } + inline string to_string() { + return string(&buffer[0], size); + } + + struct hex { + uint64 data; + explicit hex(uint64 data) : data(data) {} + }; + + struct hex1 { + uint data; + explicit hex1(uint data) : data(data) {} + }; + + struct shex1 { + int data; + explicit shex1(int data) : data(data) {} + }; + + struct hex2 { + uint data; + explicit hex2(uint data) : data(data) {} + }; + + struct Oxhex { + uint64 data; + explicit Oxhex(uint64 data) : data(data) {} + }; + + struct Oxhex1 { + uint data; + explicit Oxhex1(uint data) : data(data) {} + }; + + struct Oxhex2 { + uint data; + explicit Oxhex2(uint data) : data(data) {} + }; + + struct dec1 { + uint data; + explicit dec1(uint data) : data(data) {} + }; + + struct dec2 { + uint data; + explicit dec2(uint data) : data(data) {} + }; +}; + +struct __jk_int128 { + int64 a,b; +}; +struct __jk_int256 { + int64 a,b,c,d; +}; + +typedef JitKey JK; +EXTERN_LIB JK& get_jk(); + +inline void jk_put_str_with_len(JK& jk, const char* a, int n) { + char* xx = &jk.buffer[jk.size]; + int i=0; + while (i+32<=n) { + ((__jk_int256*)(xx+i))[0] = ((const __jk_int256*)(a+i))[0]; + i+=32; + } + while (i+16<=n) { + ((__jk_int128*)(xx+i))[0] = ((const __jk_int128*)(a+i))[0]; + i+=16; + } + while (i+8<=n) { + ((long long*)(xx+i))[0] = ((const long long*)(a+i))[0]; + i+=8; + } + while (i+4<=n) { + ((int*)(xx+i))[0] = ((const int*)(a+i))[0]; + i+=4; + } + while (i+2<=n) { + ((int16_t*)(xx+i))[0] = ((const int16_t*)(a+i))[0]; + i+=2; + } + while (i+1<=n) { + ((char*)(xx+i))[0] = ((const char*)(a+i))[0]; + i+=1; + } + jk.size += n; +} + +inline JK& operator<<(JK& jk, const char* s) { + jk_put_str_with_len(jk, s, strlen(s)); + return jk; +} + +inline JK& operator<<(JK& jk, const string& s) { + auto a = (__jk_int256*)(jk.buffer+jk.size); + auto b = (__jk_int256*)(&s[0]); + auto len = s.size(); + uint64 i=0; + for (; i+32<=len; i+=32) + a[i/32] = b[i/32]; + + for (; i>4) << JK::hex1(h.data); +} + +inline JK& operator<<(JK& jk, const JK::hex& h) { + auto a = h.data; + uint nbits = 64 - lzcnt(a); + nbits = a ? nbits-1 : 0; + int i=nbits/4; + for (; i>=0; i--) + jk << JK::hex1(a >> (i*4)); + return jk; +} + +inline JK& operator<<(JK& jk, const JK::Oxhex& h) { + return jk << "0x" << JK::hex(h.data); +} + +inline JK& operator<<(JK& jk, const JK::Oxhex1& h) { + return jk << "0x" << JK::hex1(h.data); +} + +inline JK& operator<<(JK& jk, const JK::Oxhex2& h) { + return jk << "0x" << JK::hex2(h.data); +} + +inline JK& operator<<(JK& jk, const JK::dec2& h) { + uint8 a = h.data % 10; + uint8 b = h.data / 10; + if (b) jk << (char)(b+'0'); + return jk << (char)(a+'0'); +} + +inline JK& operator<<(JK& jk, const JK::dec1& h) { + uint8 a = h.data % 10; + return jk << (char)(a+'0'); +} + +inline std::ostream& operator<<(std::ostream& os, const JK::dec2& h) { + uint8 a = h.data % 10; + uint8 b = h.data / 10; + if (b) os << (char)(b+'0'); + return os << (char)(a+'0'); +} + +inline std::ostream& operator<<(std::ostream& os, const JK::dec1& h) { + uint8 a = h.data % 10; + return os << (char)(a+'0'); +} + +inline JK& operator<<(JK& jk, int c) { + if (c<0) { + c = -c; + jk << '-'; + } + return jk << JK::hex(c); +} + +inline JK& operator<<(JK& jk, uint c) { + return jk << JK::hex(c); +} + +inline JK& operator<<(JK& jk, int64 c) { + if (c<0) { + c = -c; + jk << '-'; + } + return jk << JK::hex(c); +} + +#ifdef __linux__ +inline JK& operator<<(JK& jk, int64_t c) { + return jk << (int64)c; +} +#endif + +inline JK& operator<<(JK& jk, uint64 c) { + return jk << JK::hex(c); +} + +#ifdef __GNUC__ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif +static inline uint64 ftoi(float64 a) { return *(uint64*)&a; } +static inline float64 itof(uint64 a) { return *(float64*)&a; } +#ifdef __GNUC__ +#pragma GCC diagnostic pop +#endif + +inline JK& operator<<(JK& jk, const NanoString& ns) { + auto a = (__jk_int128*)(jk.buffer+jk.size); + auto b = (__jk_int128*)(ns.to_cstring()); + auto len = ns.len(); + a[0] = b[0]; + jk.size += len; + return jk; +} + +vector> parse_jit_keys(const string& s); + +template +void add_jit_define(JK& jk, const Ta& key, const Tb& val) { + jk << JK::key << key << JK::val << val; +} + +template +void add_jit_define(JK& jk, const Ta& key, const Tb& i, const Tc& val) { + jk << JK::key << key << i << JK::val << val; +} + + +template +void add_jit_define(JK& jk, const Ta& key, const JK::hex& val) { + jk << JK::key << key << JK::hex_val << val; +} + +template +void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex& val) { + jk << JK::key << key << i << JK::hex_val << val; +} + +template +void add_jit_define(JK& jk, const Ta& key, const JK::hex1& val) { + jk << JK::key << key << JK::hex_val << val; +} + +template +void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex1& val) { + jk << JK::key << key << i << JK::hex_val << val; +} + +template +void add_jit_define(JK& jk, const Ta& key, const JK::hex2& val) { + jk << JK::key << key << JK::hex_val << val; +} + +template +void add_jit_define(JK& jk, const Ta& key, const Tb& i, const JK::hex2& val) { + jk << JK::key << key << i << JK::hex_val << val; +} + +#define _CS(x) x + +inline JK& operator<<(JK& jk, float64 f) { + return jk << "itof(0x" << JK::hex(ftoi(f)) << ')'; +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/lock.cc b/python/jittor/src/lock.cc new file mode 100644 index 00000000..df8361cb --- /dev/null +++ b/python/jittor/src/lock.cc @@ -0,0 +1,79 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Wenyang Zhou <576825820@qq.com> +// Dun Liang +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#ifdef _WIN32 +#include +#include +#include +#include +#define getpid _getpid +#define open _open +#else +#include +#endif +#include +#include + +#include "lock.h" + +namespace jittor { + +static int lock_fd = -1; +int _has_lock = 0; + +DEFINE_FLAG(bool, disable_lock, 0, "Disable file lock"); + +void set_lock_path(string path) { + lock_fd = open(_to_winstr(path).c_str(), O_RDWR); + ASSERT(lock_fd >= 0); + LOGv << "OPEN LOCK path:" << path << "Pid:" << getpid(); +} + +void lock() { + if (disable_lock) return; + ASSERT(lock_fd >= 0); +#ifdef _WIN32 + OVERLAPPED offset = {0, 0, 0, 0, NULL}; + auto hfile = (HANDLE)_get_osfhandle(lock_fd); + ASSERT(LockFileEx(hfile, 2, 0, -0x10000, 0, &offset)); +#else + struct flock lock = { + .l_type = F_WRLCK, + .l_whence = SEEK_SET, + .l_start = 0, + .l_len = 0 + }; + ASSERT(fcntl(lock_fd, F_SETLKW, &lock) == 0); +#endif + _has_lock = 1; + LOGvv << "LOCK Pid:" << getpid(); +} + +void unlock() { + if (disable_lock) return; + ASSERT(lock_fd >= 0); +#ifdef _WIN32 + OVERLAPPED offset = {0, 0, 0, 0, NULL}; + auto hfile = (HANDLE)_get_osfhandle(lock_fd); + ASSERT(UnlockFileEx(hfile, 0, -0x10000, 0, &offset)); +#else + struct flock lock = { + .l_type = F_UNLCK, + .l_whence = SEEK_SET, + .l_start = 0, + .l_len = 0 + }; + ASSERT(fcntl(lock_fd, F_SETLKW, &lock) == 0); +#endif + _has_lock = 0; + LOGvv << "UNLOCK Pid:" << getpid(); +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/lock.h b/python/jittor/src/lock.h new file mode 100644 index 00000000..9a4428ce --- /dev/null +++ b/python/jittor/src/lock.h @@ -0,0 +1,37 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Wenyang Zhou <576825820@qq.com> +// Dun Liang +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +namespace jittor { + +// @pyjt(set_lock_path) +void set_lock_path(string path); + +void lock(); + +void unlock(); + +EXTERN_LIB int _has_lock; + +struct lock_guard { + int has_lock = 0; + inline lock_guard() { + if (_has_lock) return; + has_lock = 1; + lock(); + } + inline ~lock_guard() { + if (!has_lock) return; + unlock(); + } +}; + +} // jittor diff --git a/python/jittor/src/mem/allocator.cc b/python/jittor/src/mem/allocator.cc new file mode 100644 index 00000000..160c76cf --- /dev/null +++ b/python/jittor/src/mem/allocator.cc @@ -0,0 +1,175 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "misc/cuda_flags.h" + +#include "mem/allocator/aligned_allocator.h" +#ifdef HAS_CUDA +#include "mem/allocator/cuda_managed_allocator.h" +#include "mem/allocator/cuda_device_allocator.h" +#include "mem/allocator/cuda_host_allocator.h" +#include "mem/allocator/cuda_dual_allocator.h" +#endif +#include "mem/allocator/stat_allocator.h" +#include "mem/allocator/sfrl_allocator.h" +#include "mem/allocator/nfef_allocator.h" +#include "mem/allocator/temp_allocator.h" +#include "mem/swap.h" +#include "var.h" + +namespace jittor { + + +struct pair_hash { + template + std::size_t operator() (const std::pair &pair) const { + return std::hash()(pair.first) ^ std::hash()(pair.second); + } +}; + + +std::unordered_map< + pair, + unique_ptr, + pair_hash> allocators; + +template +Allocator* setup_allocator(Allocator* underlying) { + pair key{typeid(T).name(), underlying}; + auto iter = allocators.find(key); + if (iter != allocators.end()) return iter->second.get(); + auto a = std::make_unique(); + auto* p = a.get(); + a->setup(underlying); + allocators[key] = move(a); + return p; +} + +Allocator* cpu_allocator = setup_allocator(&aligned_allocator); + +DEFINE_FLAG_WITH_SETTER(int, use_cuda_host_allocator, 1, "use cuda host allocator for cpu memory globally"); + +void setter_use_cuda_host_allocator(int value) { + #ifdef HAS_CUDA + auto use_cuda_bk = use_cuda; + use_cuda = 0; + use_cuda_host_allocator = value; + cpu_allocator = get_allocator(); + use_cuda = use_cuda_bk; + #endif +} + +extern int64 sfrl_large_block_size_device; + +Allocator* get_allocator(bool temp_allocator) { + Allocator* allocator = nullptr; + if (use_cuda && sfrl_large_block_size_device >= (1ll<<40)) { + // if super large block is used, don't use + // temp allocator + temp_allocator = false; + } +#ifdef HAS_CUDA + if (use_cuda && !allocator) { + if (use_cuda_managed_allocator) { + LOGvv << "Using cuda_managed_allocator"; + allocator = &cuda_managed_allocator; + } else { + LOGvv << "Using cuda_device_allocator"; + allocator = &cuda_device_allocator; + } + } else + if (use_cuda_host_allocator) + allocator = &cuda_host_allocator; +#endif + if (!allocator) { + LOGvv << "Using aligned_allocator"; + allocator = &aligned_allocator; + } + if (use_stat_allocator==1) { + LOGvv << "Using stat_allocator"; + allocator = setup_allocator(allocator); + } + if (use_nfef_allocator) { + LOGvv << "Using use_nfef_allocator"; + allocator = setup_allocator(allocator); + return allocator; + } + if (temp_allocator && use_temp_allocator) { + LOGvv << "Using temp_allocator"; + allocator = setup_allocator(allocator); + } else if (use_sfrl_allocator) { + LOGvv << "Using sfrl_allocator"; + allocator = setup_allocator(allocator); + } + if (use_stat_allocator==2) { + LOGvv << "Using stat_allocator at last"; + allocator = setup_allocator(allocator); + } + return allocator; +} + +void gc_all() { + for (auto& kv : allocators) kv.second->gc(); +} + +void migrate_to_cpu(Var* var, Allocator* allocator) { + #ifdef HAS_CUDA + if (!use_cuda_managed_allocator) + allocator = cpu_allocator; + #endif + if (save_mem) { + if (swap_timestamp != var->tflag) { + swap_timestamp = ++tflag_count; + var->tflag = swap_timestamp; + } + move_with_swap(var, cpu_allocator, true); + return; + } + #ifdef HAS_CUDA + if (var->allocator == &delay_free) { + var->allocator = allocator; + delay_free.migrate_to_cpu( + var->mem_ptr, var->allocation, var->size, var->allocator + ); + } else + if (!use_cuda_managed_allocator) { + if (!var->allocator->is_cuda()) return; + // must be a device allocator + Allocation a(allocator, var->size); + checkCudaErrors(cudaMemcpy(a.ptr, var->mem_ptr, var->size, cudaMemcpyDeviceToHost)); + var->allocator->free(var->mem_ptr, var->size, var->allocation); + var->mem_ptr = a.ptr; + var->allocation = a.allocation; + var->allocator = a.allocator; + a.ptr = nullptr; + } + #endif +} + + +void migrate_to_gpu(Var* var, Allocator* allocator) { + #ifdef HAS_CUDA + // only happend when not using use_cuda_managed_allocator + if (save_mem) { + if (swap_timestamp != var->tflag) { + swap_timestamp = ++tflag_count; + var->tflag = swap_timestamp; + } + move_with_swap(var, allocator, true); + return; + } + Allocation a(allocator, var->size); + checkCudaErrors(cudaMemcpy(a.ptr, var->mem_ptr, var->size, cudaMemcpyHostToDevice)); + var->allocator->free(var->mem_ptr, var->size, var->allocation); + var->mem_ptr = a.ptr; + var->allocation = a.allocation; + var->allocator = a.allocator; + a.ptr = nullptr; + #endif +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/mem/allocator.h b/python/jittor/src/mem/allocator.h new file mode 100644 index 00000000..0b204b80 --- /dev/null +++ b/python/jittor/src/mem/allocator.h @@ -0,0 +1,59 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +namespace jittor { + +struct Allocator { + enum Flag { + _cuda=1, + _aligned=2 + }; + int64 used_memory=0, unused_memory=0; + inline virtual uint64 flags() const { return 0; }; + inline bool is_cuda() const { return flags() & _cuda; } + inline bool is_aligned() const { return flags() & _aligned; } + virtual const char* name() const = 0; + virtual void* alloc(size_t size, size_t& allocation) = 0; + virtual void free(void* mem_ptr, size_t size, const size_t& allocation) = 0; + inline virtual void gc() {}; + inline virtual bool share_with(size_t size, size_t allocation) { return false; }; + inline virtual ~Allocator() {} +}; + +struct AlignedAllocator; +EXTERN_LIB AlignedAllocator aligned_allocator; + +struct Allocation { + void* ptr; + size_t allocation, size; + Allocator* allocator; + inline Allocation() = default; + inline Allocation(void* ptr, size_t allocation, size_t size, Allocator* allocator) + : ptr(ptr), allocation(allocation), size(size), allocator(allocator) {} + inline Allocation(Allocation&& o) + : ptr(o.ptr), allocation(o.allocation), size(o.size), allocator(o.allocator) + { o.ptr = nullptr; } + inline Allocation(unique_ptr&& p) + { ptr = p.release(); allocator = (Allocator*)&aligned_allocator; } + inline Allocation(Allocator* at, size_t size) + : size(size), allocator(at) + { allocator = at; ptr = at->alloc(size, allocation); } + inline ~Allocation() + { if (ptr) allocator->free(ptr, size, allocation); } +}; + +EXTERN_LIB Allocator* cpu_allocator; +Allocator* get_allocator(bool temp_allocator=false); +// @pyjt(gc) +void gc_all(); + +void migrate_to_cpu(Var* var, Allocator* allocator); +void migrate_to_gpu(Var* var, Allocator* allocator); + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/mem/allocator/aligned_allocator.cc b/python/jittor/src/mem/allocator/aligned_allocator.cc new file mode 100644 index 00000000..810c9b1d --- /dev/null +++ b/python/jittor/src/mem/allocator/aligned_allocator.cc @@ -0,0 +1,42 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "mem/allocator/aligned_allocator.h" +#include "var.h" + +namespace jittor { + +AlignedAllocator aligned_allocator; + +const char* AlignedAllocator::name() const {return "aligned";} + +void* AlignedAllocator::alloc(size_t size, size_t& allocation) { + #ifndef _WIN32 + #ifdef __APPLE__ + size += 32-size%32; + // low version of mac don't have aligned_alloc + return new char[size]; + #else + return aligned_alloc(alignment, size); + #endif + #else + return _aligned_malloc(size, alignment); + #endif +} + +void AlignedAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) { + #ifdef _WIN32 + _aligned_free(mem_ptr); + #else + #ifdef __APPLE__ + delete[] (char*)mem_ptr; + #else + ::free(mem_ptr); + #endif + #endif +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/mem/allocator/aligned_allocator.h b/python/jittor/src/mem/allocator/aligned_allocator.h new file mode 100644 index 00000000..c6c7561c --- /dev/null +++ b/python/jittor/src/mem/allocator/aligned_allocator.h @@ -0,0 +1,21 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "mem/allocator.h" + +namespace jittor { + +struct AlignedAllocator : Allocator { + uint64 flags() const override { return _aligned; } + const char* name() const override; + void* alloc(size_t size, size_t& allocation) override; + void free(void* mem_ptr, size_t size, const size_t& allocation) override; +}; + +EXTERN_LIB AlignedAllocator aligned_allocator; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/mem/allocator/cuda_device_allocator.cc b/python/jittor/src/mem/allocator/cuda_device_allocator.cc new file mode 100644 index 00000000..195c2468 --- /dev/null +++ b/python/jittor/src/mem/allocator/cuda_device_allocator.cc @@ -0,0 +1,44 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#ifdef HAS_CUDA +#include +#include "mem/mem_info.h" +#include "helper_cuda.h" +#include "mem/allocator/cuda_device_allocator.h" + +namespace jittor { + +CudaDeviceAllocator cuda_device_allocator; +EXTERN_LIB bool no_cuda_error_when_free; + +const char* CudaDeviceAllocator::name() const {return "cuda_device";} + +void* CudaDeviceAllocator::alloc(size_t size, size_t& allocation) { + if (size==0) return (void*)0x10; + void* ptr; + try { + checkCudaErrors(cudaMalloc(&ptr, size)); + return ptr; + } catch (...) { + // clean the last error + cudaGetLastError(); + } + display_memory_info(__FILELINE__); + LOGf << "Unable to alloc cuda device memory for size" << size; + checkCudaErrors(cudaMallocManaged(&ptr, size)); + return ptr; +} + +void CudaDeviceAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) { + if (size==0) return; + if (no_cuda_error_when_free) return; + checkCudaErrors(cudaFree(mem_ptr)); +} + +} // jittor + +#endif diff --git a/python/jittor/src/mem/allocator/cuda_device_allocator.h b/python/jittor/src/mem/allocator/cuda_device_allocator.h new file mode 100644 index 00000000..0dca72ca --- /dev/null +++ b/python/jittor/src/mem/allocator/cuda_device_allocator.h @@ -0,0 +1,24 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#ifdef HAS_CUDA +#include "mem/allocator.h" + +namespace jittor { + +struct CudaDeviceAllocator : Allocator { + uint64 flags() const override { return _cuda; } + const char* name() const override; + void* alloc(size_t size, size_t& allocation) override; + void free(void* mem_ptr, size_t size, const size_t& allocation) override; +}; + +EXTERN_LIB CudaDeviceAllocator cuda_device_allocator; + +} + +#endif diff --git a/python/jittor/src/mem/allocator/cuda_dual_allocator.cc b/python/jittor/src/mem/allocator/cuda_dual_allocator.cc new file mode 100644 index 00000000..b942621b --- /dev/null +++ b/python/jittor/src/mem/allocator/cuda_dual_allocator.cc @@ -0,0 +1,38 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#ifdef HAS_CUDA +#include "misc/cuda_flags.h" +#include "mem/allocator/cuda_dual_allocator.h" +#include "mem/allocator/cuda_host_allocator.h" +#include "mem/allocator/cuda_device_allocator.h" +#include "event_queue.h" + +namespace jittor { + +SFRLAllocator cuda_dual_host_allocator(&cuda_host_allocator, 0.3, 1<<28); +SFRLAllocator cuda_dual_device_allocator(&cuda_device_allocator, 0.3, 1<<28); +CudaDualAllocator cuda_dual_allocator; +DelayFree delay_free; + +namespace cuda_dual_local { + +list allocations; + +static void free_caller() { + allocations.pop_front(); +} + +} + +void to_free_allocation(CUDA_HOST_FUNC_ARGS) { + using namespace cuda_dual_local; + event_queue.push(free_caller); +} + +} + +#endif diff --git a/python/jittor/src/mem/allocator/cuda_dual_allocator.h b/python/jittor/src/mem/allocator/cuda_dual_allocator.h new file mode 100644 index 00000000..0cc94a31 --- /dev/null +++ b/python/jittor/src/mem/allocator/cuda_dual_allocator.h @@ -0,0 +1,122 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#ifdef HAS_CUDA +#include +#include +#include +#include +#include "helper_cuda.h" +#include "misc/cuda_flags.h" +#include "var.h" +#include "mem/allocator.h" +#include "mem/allocator/sfrl_allocator.h" + +namespace jittor { + +struct DualAllocation { + size_t ref_cnt; + void* host_ptr, * device_ptr; + size_t host_allocation, device_allocation; +}; + +EXTERN_LIB SFRLAllocator cuda_dual_host_allocator; +EXTERN_LIB SFRLAllocator cuda_dual_device_allocator; +EXTERN_LIB bool no_cuda_error_when_free; + +struct CudaDualAllocator : Allocator { + //for recycle block_id + static const size_t ID_LIMIT = 1 << 16; + int n_free_ids; + int free_ids[ID_LIMIT]; + DualAllocation allocations[ID_LIMIT]; + + CudaDualAllocator() { + n_free_ids = ID_LIMIT; + for (int i=0; i allocations; + +} + +void to_free_allocation(CUDA_HOST_FUNC_ARGS); + +struct DelayFree final : Allocator { + inline uint64 flags() const override { return _cuda; }; + const char* name() const override { return "delay_free"; }; + void* alloc(size_t size, size_t& allocation) override { + LOGf << "Should not call this"; + return nullptr; + } + bool share_with(size_t size, size_t allocation) override { + return cuda_dual_allocator.share_with(size, allocation); + }; + void free(void* mem_ptr, size_t size, const size_t& allocation) override { + using namespace cuda_dual_local; + if (no_cuda_error_when_free) return; + allocations.emplace_back(mem_ptr, allocation, size, &cuda_dual_allocator); + peekCudaErrors(_cudaLaunchHostFunc(0, &to_free_allocation, 0)); + } + + void migrate_to_cpu(void*& mem_ptr, size_t& allocation, size_t size, Allocator* allocator) { + auto da = cuda_dual_allocator.get_dual_allocation(allocation); + auto pre_allocation = allocation; + auto offset = (int64)mem_ptr - (int64)da.device_ptr; + + mem_ptr = allocator->alloc(size, allocation); + + checkCudaErrors(cudaMemcpy(mem_ptr, + (void*)((int64)da.device_ptr+offset), size, cudaMemcpyDeviceToHost)); + // std::memcpy(mem_ptr, (void*)((int64)da.host_ptr+offset), size); + free(da.device_ptr, size, pre_allocation); + } +}; + +EXTERN_LIB DelayFree delay_free; + +} + +#endif diff --git a/python/jittor/src/mem/allocator/cuda_host_allocator.cc b/python/jittor/src/mem/allocator/cuda_host_allocator.cc new file mode 100644 index 00000000..0d1eaccf --- /dev/null +++ b/python/jittor/src/mem/allocator/cuda_host_allocator.cc @@ -0,0 +1,34 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#ifdef HAS_CUDA +#include +#include "helper_cuda.h" +#include "mem/allocator/cuda_host_allocator.h" + +namespace jittor { + +CudaHostAllocator cuda_host_allocator; +EXTERN_LIB bool no_cuda_error_when_free; + +const char* CudaHostAllocator::name() const {return "cuda_host";} + +void* CudaHostAllocator::alloc(size_t size, size_t& allocation) { + if (size==0) return (void*)0x10; + void* ptr; + checkCudaErrors(cudaMallocHost(&ptr, size)); + return ptr; +} + +void CudaHostAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) { + if (size==0) return; + if (no_cuda_error_when_free) return; + checkCudaErrors(cudaFreeHost(mem_ptr)); +} + +} // jittor + +#endif diff --git a/python/jittor/src/mem/allocator/cuda_host_allocator.h b/python/jittor/src/mem/allocator/cuda_host_allocator.h new file mode 100644 index 00000000..edf1f74b --- /dev/null +++ b/python/jittor/src/mem/allocator/cuda_host_allocator.h @@ -0,0 +1,24 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#ifdef HAS_CUDA +#include "mem/allocator.h" + +namespace jittor { + +struct CudaHostAllocator : Allocator { + inline uint64 flags() const override { return 0; } + const char* name() const override; + void* alloc(size_t size, size_t& allocation) override; + void free(void* mem_ptr, size_t size, const size_t& allocation) override; +}; + +EXTERN_LIB CudaHostAllocator cuda_host_allocator; + +} + +#endif diff --git a/python/jittor/src/mem/allocator/cuda_managed_allocator.cc b/python/jittor/src/mem/allocator/cuda_managed_allocator.cc new file mode 100644 index 00000000..0be42c68 --- /dev/null +++ b/python/jittor/src/mem/allocator/cuda_managed_allocator.cc @@ -0,0 +1,35 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#ifdef HAS_CUDA +#include +#include "helper_cuda.h" +#include "mem/allocator/cuda_managed_allocator.h" + +namespace jittor { + +CudaManagedAllocator cuda_managed_allocator; +DEFINE_FLAG(int, use_cuda_managed_allocator, 0, "Enable cuda_managed_allocator"); +EXTERN_LIB bool no_cuda_error_when_free; + +const char* CudaManagedAllocator::name() const {return "cuda_managed";} + +void* CudaManagedAllocator::alloc(size_t size, size_t& allocation) { + if (size==0) return (void*)0x10; + void* ptr; + checkCudaErrors(cudaMallocManaged(&ptr, size)); + return ptr; +} + +void CudaManagedAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) { + if (size==0) return; + if (no_cuda_error_when_free) return; + checkCudaErrors(cudaFree(mem_ptr)); +} + +} // jittor + +#endif diff --git a/python/jittor/src/mem/allocator/cuda_managed_allocator.h b/python/jittor/src/mem/allocator/cuda_managed_allocator.h new file mode 100644 index 00000000..93e8d4d6 --- /dev/null +++ b/python/jittor/src/mem/allocator/cuda_managed_allocator.h @@ -0,0 +1,25 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#ifdef HAS_CUDA +#include "mem/allocator.h" + +namespace jittor { + +struct CudaManagedAllocator : Allocator { + uint64 flags() const override { return _cuda; } + const char* name() const override; + void* alloc(size_t size, size_t& allocation) override; + void free(void* mem_ptr, size_t size, const size_t& allocation) override; +}; + +EXTERN_LIB CudaManagedAllocator cuda_managed_allocator; +DECLARE_FLAG(int, use_cuda_managed_allocator); + +} + +#endif diff --git a/python/jittor/src/mem/allocator/foreign_allocator.cc b/python/jittor/src/mem/allocator/foreign_allocator.cc new file mode 100644 index 00000000..f35cd050 --- /dev/null +++ b/python/jittor/src/mem/allocator/foreign_allocator.cc @@ -0,0 +1,49 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "mem/allocator/foreign_allocator.h" +#include "var.h" + +namespace jittor { + +struct ForeignAllocation { + std::function del_func; + int64 cnt; + ForeignAllocation(std::function&& del_func) + : del_func(std::move(del_func)), cnt(1) {} +}; + +ForeignAllocator foreign_allocator; + +const char* ForeignAllocator::name() const {return "foreign";} + +void* ForeignAllocator::alloc(size_t size, size_t& allocation) { + return nullptr; +} + +void ForeignAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) { + auto a = (ForeignAllocation*)allocation; + a->cnt--; + if (!a->cnt) { + a->del_func(); + delete a; + } +} + +void make_foreign_allocation(Allocation& a, void* ptr, size_t size, std::function&& del_func) { + auto fa = new ForeignAllocation(std::move(del_func)); + a.allocator = &foreign_allocator; + a.allocation = (size_t)fa; + a.ptr = ptr; + a.size = size; +} + +bool ForeignAllocator::share_with(size_t size, size_t allocation) { + ((ForeignAllocation*)allocation)->cnt++; + return true; +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/mem/allocator/foreign_allocator.h b/python/jittor/src/mem/allocator/foreign_allocator.h new file mode 100644 index 00000000..3d56cdc3 --- /dev/null +++ b/python/jittor/src/mem/allocator/foreign_allocator.h @@ -0,0 +1,24 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "mem/allocator.h" + +namespace jittor { + +struct ForeignAllocator : Allocator { + uint64 flags() const override { return _aligned; } + const char* name() const override; + void* alloc(size_t size, size_t& allocation) override; + void free(void* mem_ptr, size_t size, const size_t& allocation) override; + bool share_with(size_t size, size_t allocation) override; +}; + +void make_foreign_allocation(Allocation& a, void* ptr, size_t size, std::function&& del_func); + +EXTERN_LIB ForeignAllocator foreign_allocator; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/mem/allocator/nfef_allocator.cc b/python/jittor/src/mem/allocator/nfef_allocator.cc new file mode 100644 index 00000000..cd5b910f --- /dev/null +++ b/python/jittor/src/mem/allocator/nfef_allocator.cc @@ -0,0 +1,33 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "mem/allocator/nfef_allocator.h" +#include "var.h" + +namespace jittor { + +DEFINE_FLAG(int, use_nfef_allocator, 0, "Enable never free exact fit allocator"); + +void NFEFAllocator::setup(Allocator* underlying) { + this->underlying = underlying; +} + +const char* NFEFAllocator::name() const {return "nfef";} + +void* NFEFAllocator::alloc(size_t size, size_t& allocation) { + auto iter = freed.find(size); + if (iter == freed.end() || iter->second.empty()) + return underlying->alloc(size, allocation); + auto ptr = iter->second.front(); + iter->second.pop_front(); + return ptr; +} + +void NFEFAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) { + freed[size].push_front(mem_ptr); +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/mem/allocator/nfef_allocator.h b/python/jittor/src/mem/allocator/nfef_allocator.h new file mode 100644 index 00000000..66f519a9 --- /dev/null +++ b/python/jittor/src/mem/allocator/nfef_allocator.h @@ -0,0 +1,28 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include +#include "mem/allocator.h" + +namespace jittor { + +// Never free exact fit allocator +struct NFEFAllocator : Allocator { + Allocator* underlying; + std::unordered_map> freed; + + void setup(Allocator* underlying); + uint64 flags() const override { return underlying->flags(); } + const char* name() const override; + void* alloc(size_t size, size_t& allocation) override; + void free(void* mem_ptr, size_t size, const size_t& allocation) override; +}; + +DECLARE_FLAG(int, use_nfef_allocator); + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/mem/allocator/sfrl_allocator.cc b/python/jittor/src/mem/allocator/sfrl_allocator.cc new file mode 100644 index 00000000..add3aef9 --- /dev/null +++ b/python/jittor/src/mem/allocator/sfrl_allocator.cc @@ -0,0 +1,320 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** + +#include +#include "mem/allocator/sfrl_allocator.h" +#include "misc/cuda_flags.h" + +namespace jittor { + +DEFINE_FLAG(int, use_sfrl_allocator, 1, "Enable sfrl allocator"); +DEFINE_FLAG(int64, sfrl_large_block_size_device, 20971520, "sfrl_large_block_size, larger will reduce memory shard, only affect device"); +constexpr int64 sfrl_large_block_size_cpu=20971520; + +std::vector CachingBlockPool::block_ids; + //start from 1 +size_t CachingBlockPool::tot_block_id = 0; +std::unique_ptr CachingBlockPool::occupied_id_mapper( + new CachingBlock*[CachingBlockPool::ID_LIMIT]); + +//CachingBlock +CachingBlock::CachingBlock(size_t size, size_t origin_size) : + size(size), origin_size(origin_size), id(0), share_times(0), memory_ptr(nullptr), blocks(nullptr), prev(nullptr), next(nullptr), occupied(false) {} + +CachingBlock::CachingBlock(size_t size, size_t origin_size, CachingBlockPool* blocks, void* memory_ptr) : + size(size), origin_size(origin_size), id(0), share_times(0), memory_ptr(memory_ptr), blocks(blocks), prev(nullptr), next(nullptr), occupied(false) {} + +//CachingBlockPool +CachingBlockPool::CachingBlockPool() { + +} + +CachingBlockPool::~CachingBlockPool() { + for (auto it = blocks.begin(); it != blocks.end(); ++it) { + delete it->second; + } +} + +pair CachingBlockPool::get_key(CachingBlock* block) { + return std::make_pair((size_t)block->size, (size_t)(block->origin_size * ID_LIMIT + block->id)); +} + +void CachingBlockPool::insert(CachingBlock* block) { + size_t id; + if (!block_ids.empty()) { + id = block_ids.back(); + block_ids.pop_back(); + } else { + ASSERT(tot_block_id < ID_LIMIT - 1) << "block id limit extended."; + id = ++tot_block_id; + } + block->id = id; + blocks[get_key(block)] = block; +} + +void CachingBlockPool::erase(CachingBlock* block) { + block_ids.push_back(block->id); + blocks.erase(get_key(block)); +} + +size_t CachingBlockPool::insert_occupied(CachingBlock* block) { + size_t id; + if (!block_ids.empty()) { + id = block_ids.back(); + block_ids.pop_back(); + } else { + ASSERT(tot_block_id < ID_LIMIT - 1) << "block id limit extended."; + id = ++tot_block_id; + } + block->id = id; + occupied_id_mapper[id] = block; + return id; +} + +CachingBlock* CachingBlockPool::erase_occupied(size_t allocation) { + ASSERT(occupied_id_mapper[allocation] != nullptr) << "allocation not found"; + block_ids.push_back(allocation); + CachingBlock* block = occupied_id_mapper[allocation]; + occupied_id_mapper[allocation] = nullptr; + return block; +} + +CachingBlock* CachingBlockPool::get_occupied(size_t allocation) { + ASSERT(occupied_id_mapper[allocation] != nullptr) << "allocation not found"; + CachingBlock* block = occupied_id_mapper[allocation]; + return block; +} + +CachingBlock* CachingBlockPool::pop_block(size_t size) { + auto temp = CachingBlock(size, 0); + auto it = blocks.lower_bound(get_key(&temp)); + CachingBlock* block = nullptr; + if (it != blocks.end()) { + block = it->second; + block_ids.push_back(block->id); + blocks.erase(it); + } + return block; +} + +list SFRLAllocator::sfrl_allocators; +//SFRLAllocator +SFRLAllocator::~SFRLAllocator() { + sfrl_allocators.erase(iter); + for (auto it = occupied_blocks.begin(); it != occupied_blocks.end(); ++it) { + delete it->second; + } +} + +const char* SFRLAllocator::name() const {return "sfrl";} + +size_t SFRLAllocator::align_size(size_t size) { + return (size + ALIGN_SIZE - 1) / ALIGN_SIZE * ALIGN_SIZE; +} + +void SFRLAllocator::setup(Allocator* underlying) { + this->underlying = underlying; +} + +size_t SFRLAllocator::allocation_size(size_t size) { + // #ifdef HAS_CUDA + // if (is_cuda() && size >= SMALL_BLOCK_SIZE) { + // // just take all free mem + // size_t gpu_free = 0, _gpu_total = 0; + // cudaMemGetInfo(&gpu_free, &_gpu_total); + // // left 512MB + // size_t left = 1<<29; + // if (gpu_free >= left) { + // gpu_free = (gpu_free - left) / LARGE_ALIGN_SIZE * LARGE_ALIGN_SIZE; + // if (gpu_free >= size) + // return gpu_free; + // } + // } + // #endif + if (size <= SMALL_BLOCK_SIZE) + return SMALL_BLOCK_SIZE; + int64 large_block_size = is_cuda() ? sfrl_large_block_size_device : sfrl_large_block_size_cpu; + int64 align_size = (size + LARGE_ALIGN_SIZE - 1) / LARGE_ALIGN_SIZE * LARGE_ALIGN_SIZE; + if (size <= large_block_size) { + #ifdef HAS_CUDA + if (is_cuda()) { + // just take all free mem + int64 gpu_free = 0, _gpu_total = 0; + cudaMemGetInfo((size_t*)&gpu_free, (size_t*)&_gpu_total); + // left 512MB + int64 left = 1<<29; + gpu_free = (gpu_free - left) / LARGE_ALIGN_SIZE * LARGE_ALIGN_SIZE; + gpu_free = std::min(gpu_free, large_block_size); + if (gpu_free >= align_size) + return gpu_free; + else + return align_size; + } + #endif + return large_block_size; + } else + return align_size; +} + +bool SFRLAllocator::should_split(CachingBlock* block, size_t size) { + size_t rest = block->size - size; + if (block->blocks == &small_blocks) { + return rest >= ALIGN_SIZE; + } else { + return rest > SMALL_BLOCK_SIZE; + } +} + +size_t CachingBlockPool::free_all_cached_blocks(Allocator* underlying, long long free_size) { + auto it = blocks.begin(); + size_t freed_memory = 0; + while (it != blocks.end()) { + if (free_size != -1 && freed_memory >= free_size) + break; + CachingBlock* block = it->second; + if (!block->prev && !block->next) { + underlying->free((void*)block->memory_ptr, block->size, 0); + freed_memory += block->size; + auto cur = it; + ++it; + block_ids.push_back(cur->second->id); + blocks.erase(cur); + delete block; + } else { + ++it; + } + } + return freed_memory; +} + +void SFRLAllocator::try_merge_two_blocks(CachingBlock* dst, CachingBlock* src, CachingBlockPool& blocks) { + if (!src || src->occupied) { + return; + } + if (dst->prev == src) { + dst->memory_ptr = src->memory_ptr; + dst->prev = src->prev; + if (dst->prev) { + dst->prev->next = dst; + } + } else { + dst->next = src->next; + if (dst->next) { + dst->next->prev = dst; + } + } + dst->size += src->size; + blocks.erase(src); + delete src; +} + +CachingBlockPool* SFRLAllocator::get_blocks(size_t size) { + if (size <= SMALL_BLOCK_SIZE) + return &small_blocks; + else + return &large_blocks; +} + +void SFRLAllocator::free_all_sfrl_allocators() { + for (auto i : sfrl_allocators) { + if (float(i->unused_memory) > i->free_ratio * float(i->unused_memory + i->used_memory) && i->unused_memory > i->min_free_size) { + i->unused_memory -= i->large_blocks.free_all_cached_blocks(i->underlying, i->unused_memory - i->min_free_size); + i->unused_memory -= i->small_blocks.free_all_cached_blocks(i->underlying, i->unused_memory - i->min_free_size); + } + } +} + +inline void SFRLAllocator::try_free_this_allocators() { + if (float(unused_memory) > free_ratio * float(unused_memory + used_memory)) { + unused_memory -= large_blocks.free_all_cached_blocks(underlying); + unused_memory -= small_blocks.free_all_cached_blocks(underlying); + } +} + +std::mutex sfrl_allocator_mutex; + +void* SFRLAllocator::alloc(size_t size, size_t& allocation) { + std::unique_lock lock(sfrl_allocator_mutex); + #ifdef IS_ACL + // output of acl op need additional 32 bytes + size = align_size(size+32); + #else + size = align_size(size); + #endif + CachingBlockPool* blocks = get_blocks(size); + //search cached block + CachingBlock* block = blocks->pop_block(size); + //alloc from GPU + if (block == nullptr) { + free_all_sfrl_allocators(); + size_t alloc_size = allocation_size(size); + void* ptr = nullptr; + try { + ptr = underlying->alloc(alloc_size, allocation); + } catch (...) { + unused_memory -= large_blocks.free_all_cached_blocks(underlying); + unused_memory -= small_blocks.free_all_cached_blocks(underlying); + gc_all(); + ptr = underlying->alloc(alloc_size, allocation); + } + block = new CachingBlock(alloc_size, alloc_size, blocks, ptr); + } else { + unused_memory -= block->size; + } + if (should_split(block, size)) { + CachingBlock* rest = new CachingBlock(block->size - size, block->origin_size, block->blocks, static_cast(block->memory_ptr) + size); + block->size = size; + if (block->next) { + block->next->prev = rest; + } + rest->next = block->next; + rest->prev = block; + block->next = rest; + blocks->insert(rest); + unused_memory += rest->size; + } + block->occupied = true; + allocation = blocks->insert_occupied(block); + used_memory += block->size; + return block->memory_ptr; +} + +void SFRLAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) { + std::unique_lock lock(sfrl_allocator_mutex); + auto* block = CachingBlockPool::occupied_id_mapper[allocation]; + auto* blocks = block->blocks; + if (block->share_times == 0) { + blocks->erase_occupied(allocation); + used_memory -= block->size; + unused_memory += block->size; + block->occupied = false; + auto& block_list = *block->blocks; + try_merge_two_blocks(block, block->prev, block_list); + try_merge_two_blocks(block, block->next, block_list); + block_list.insert(block); + } else { + --block->share_times; + } +} + +void SFRLAllocator::gc() { + unused_memory -= small_blocks.free_all_cached_blocks(underlying); + unused_memory -= large_blocks.free_all_cached_blocks(underlying); +} + +bool SFRLAllocator::share_with(size_t size, size_t allocation) { + std::unique_lock lock(sfrl_allocator_mutex); + auto* block = CachingBlockPool::occupied_id_mapper[allocation]; + ++block->share_times; + return true; +} + +} // jittor + diff --git a/python/jittor/src/mem/allocator/sfrl_allocator.h b/python/jittor/src/mem/allocator/sfrl_allocator.h new file mode 100644 index 00000000..2bd9a170 --- /dev/null +++ b/python/jittor/src/mem/allocator/sfrl_allocator.h @@ -0,0 +1,99 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "mem/allocator.h" + +namespace jittor { +struct CachingBlockPool; + +struct CachingBlock { + size_t size; + // origin size before split + size_t origin_size; + size_t id; + size_t share_times; + void* memory_ptr; + CachingBlockPool* blocks; + CachingBlock* prev; + CachingBlock* next; + bool occupied; + + CachingBlock(size_t size, size_t origin_size); + CachingBlock(size_t size, size_t origin_size, CachingBlockPool* blocks, void* memory_ptr); +}; + +struct CachingBlockPool { + std::map, CachingBlock*> blocks; + //for recycle block_id + static std::vector block_ids; + //start from 1 + static size_t tot_block_id; + static std::unique_ptr occupied_id_mapper; + static const size_t ID_LIMIT = 1 << 18; + + pair get_key(CachingBlock* block); + + CachingBlockPool(); + ~CachingBlockPool(); + // return a block whose size >= input size and delete it from pool, return nullptr if no block is found. + CachingBlock* pop_block(size_t size); + // insert a block, id of this block will be obtanined in this function. + void insert(CachingBlock* block); + // delete a block from pool and recycle id. + void erase(CachingBlock* block); + // insert a block, id of this block will be obtanined and returned in this function. + size_t insert_occupied(CachingBlock* block); + // delete and return a block from pool and recycle id. + CachingBlock* erase_occupied(size_t allocation); + // return a block from pool + CachingBlock* get_occupied(size_t allocation); + // free all unsplit unoccupied blocks and recycle id. + size_t free_all_cached_blocks(Allocator* underlying, long long free_size = -1); +}; + +// Segregate fit range list allocator +struct SFRLAllocator : Allocator { + CachingBlockPool small_blocks, large_blocks; + std::map occupied_blocks; + Allocator* underlying; + + static const size_t ALIGN_SIZE = 512; + static const size_t SMALL_BLOCK_SIZE = 1048576; + static const size_t LARGE_ALIGN_SIZE = 2097152; + float free_ratio, min_free_size; + static list sfrl_allocators; + list::iterator iter; + CachingBlockPool* get_blocks(size_t size); + size_t align_size(size_t size); + size_t allocation_size(size_t size); + bool should_split(CachingBlock* block, size_t size); + void try_merge_two_blocks(CachingBlock* b1, CachingBlock* b2, CachingBlockPool& blocks); + + inline SFRLAllocator(float free_ratio = 1, float min_free_size=0) : free_ratio(free_ratio), min_free_size(min_free_size) { sfrl_allocators.push_front(this); iter = sfrl_allocators.begin(); } + inline SFRLAllocator(Allocator* underlying, float free_ratio = 1, float min_free_size=0) : SFRLAllocator(free_ratio, min_free_size) { + setup(underlying); + } + ~SFRLAllocator(); + // free all unused memory of all sfrl allocators. + static void free_all_sfrl_allocators(); + void try_free_this_allocators(); + void setup(Allocator* underlying); + uint64 flags() const override { return underlying->flags(); } + const char* name() const override; + void* alloc(size_t size, size_t& allocation) override; + void free(void* mem_ptr, size_t size, const size_t& allocation) override; + void gc() override; + virtual bool share_with(size_t size, size_t allocation) override; +}; + +DECLARE_FLAG(int, use_sfrl_allocator); + +}//jittor + diff --git a/python/jittor/src/mem/allocator/stat_allocator.cc b/python/jittor/src/mem/allocator/stat_allocator.cc new file mode 100644 index 00000000..a8d68ed9 --- /dev/null +++ b/python/jittor/src/mem/allocator/stat_allocator.cc @@ -0,0 +1,46 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "mem/allocator/stat_allocator.h" +#include "var.h" + +namespace jittor { + +DEFINE_FLAG_WITH_SETTER(int, use_stat_allocator, 0, "Enable stat allocator"); +DEFINE_FLAG(size_t, stat_allocator_total_alloc_call, 0, "Number of alloc function call"); +DEFINE_FLAG(size_t, stat_allocator_total_alloc_byte, 0, "Total alloc byte"); +DEFINE_FLAG(size_t, stat_allocator_total_free_call, 0, "Number of alloc function call"); +DEFINE_FLAG(size_t, stat_allocator_total_free_byte, 0, "Total alloc byte"); + +void setter_use_stat_allocator(int value) { + // if enabled, clean prev records + if (!use_stat_allocator && value) { + stat_allocator_total_alloc_call = 0; + stat_allocator_total_alloc_byte = 0; + stat_allocator_total_free_call = 0; + stat_allocator_total_free_byte = 0; + } +} + +void StatAllocator::setup(Allocator* underlying) { + this->underlying = underlying; +} + +const char* StatAllocator::name() const {return "stat";} + +void* StatAllocator::alloc(size_t size, size_t& allocation) { + stat_allocator_total_alloc_call++; + stat_allocator_total_alloc_byte += size; + return underlying->alloc(size, allocation); +} + +void StatAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) { + stat_allocator_total_free_call++; + stat_allocator_total_free_byte += size; + underlying->free(mem_ptr, size, allocation); +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/mem/allocator/stat_allocator.h b/python/jittor/src/mem/allocator/stat_allocator.h new file mode 100644 index 00000000..548590c2 --- /dev/null +++ b/python/jittor/src/mem/allocator/stat_allocator.h @@ -0,0 +1,28 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "mem/allocator.h" + +namespace jittor { + +struct StatAllocator : Allocator { + Allocator* underlying; + + void setup(Allocator* underlying); + uint64 flags() const override { return underlying->flags(); } + const char* name() const override; + void* alloc(size_t size, size_t& allocation) override; + void free(void* mem_ptr, size_t size, const size_t& allocation) override; +}; + +DECLARE_FLAG(int, use_stat_allocator); +DECLARE_FLAG(size_t, stat_allocator_total_alloc_call); +DECLARE_FLAG(size_t, stat_allocator_total_alloc_byte); +DECLARE_FLAG(size_t, stat_allocator_total_free_call); +DECLARE_FLAG(size_t, stat_allocator_total_free_byte); + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/mem/allocator/temp_allocator.cc b/python/jittor/src/mem/allocator/temp_allocator.cc new file mode 100644 index 00000000..7dae4625 --- /dev/null +++ b/python/jittor/src/mem/allocator/temp_allocator.cc @@ -0,0 +1,121 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** + +#include "mem/allocator/temp_allocator.h" + +namespace jittor { + +DEFINE_FLAG(int, use_temp_allocator, 1, "Enable temp allocator"); +vector TempAllocator::temp_allocators; + +TempAllocator::~TempAllocator() { + while (!cached_blocks.empty()) { + auto it = cached_blocks.begin(); + TempCachingBlock* block = it->second; + cached_blocks.erase(it); + delete block; + } +} + +const char* TempAllocator::name() const {return "temp";} + +void TempAllocator::setup(Allocator* underlying) { + this->underlying = underlying; +} + +size_t TempAllocator::align_size(size_t size) { + return (size + ALIGN_SIZE - 1) / ALIGN_SIZE * ALIGN_SIZE; +} + +unsigned long long TempAllocator::get_key(TempCachingBlock* block) { + return ((unsigned long long)block->size) * ID_LIMIT + block->id; +} + +void* TempAllocator::alloc(size_t size, size_t& allocation) { + size = align_size(size); + + auto temp = TempCachingBlock(size); + auto it = cached_blocks.lower_bound(get_key(&temp)); + TempCachingBlock* block = nullptr; + if (it != cached_blocks.end()) { + block = it->second; + cached_blocks.erase(it); + unused_memory -= block->size; + } else { + void* ptr = underlying->alloc(size, allocation); + block = new TempCachingBlock(size, ptr); + size_t id; + if (!block_ids.empty()) { + id = block_ids.back(); + block_ids.pop_back(); + } else { + ASSERT(tot_block_id < ID_LIMIT - 1) << "block id limit extended."; + id = ++tot_block_id; + } + block->id = id; + } + + used_memory += block->size; + occupied_id_mapper[block->id] = block; + allocation = block->id; + return block->memory_ptr; +} + +void TempAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) { + size = align_size(size); + ASSERT(occupied_id_mapper[allocation] != nullptr) << "allocation not found"; + TempCachingBlock* block = occupied_id_mapper[allocation]; + occupied_id_mapper[allocation] = nullptr; + used_memory -= block->size; + unused_memory += block->size; + bool can_add = true; + if (cached_blocks.size() > cache_blocks_limit-1) { + ASSERT(cached_blocks.size() == cache_blocks_limit); + auto it = cached_blocks.lower_bound(get_key(block)); + if (it == cached_blocks.begin()) { + can_add = false; + underlying->free((void*)block->memory_ptr, block->size, 0); + unused_memory -= block->size; + block_ids.push_back(block->id); + delete block; + } else { + --it; + TempCachingBlock* block = it->second; + underlying->free((void*)block->memory_ptr, block->size, 0); + unused_memory -= block->size; + block_ids.push_back(block->id); + cached_blocks.erase(it); + delete block; + } + } + if (can_add) { + cached_blocks[get_key(block)] = block; + } +} + +void TempAllocator::gc() { + while (!cached_blocks.empty()) { + auto it = cached_blocks.begin(); + TempCachingBlock* block = it->second; + underlying->free((void*)block->memory_ptr, block->size, 0); + unused_memory -= block->size; + block_ids.push_back(block->id); + cached_blocks.erase(it); + delete block; + } +} + +bool TempAllocator::share_with(size_t size, size_t allocation) { + ASSERT(false); + return true; +} + +} // jittor + diff --git a/python/jittor/src/mem/allocator/temp_allocator.h b/python/jittor/src/mem/allocator/temp_allocator.h new file mode 100644 index 00000000..08dc0994 --- /dev/null +++ b/python/jittor/src/mem/allocator/temp_allocator.h @@ -0,0 +1,59 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "mem/allocator.h" + +namespace jittor { + +struct TempCachingBlock { + size_t size; + size_t id; + void* memory_ptr; + + TempCachingBlock(size_t size):size(size),id(0) {} + TempCachingBlock(size_t size, void* memory_ptr):size(size),id(0), memory_ptr(memory_ptr) {} +}; + +struct TempAllocator : Allocator { + static const size_t ALIGN_SIZE = 512; + static const size_t ID_LIMIT = 1 << 18; + static vector temp_allocators; + Allocator* underlying; + size_t cache_blocks_limit, used_memory, unused_memory; + std::map cached_blocks; + std::vector block_ids; + size_t tot_block_id; + std::unique_ptr occupied_id_mapper; + + + inline TempAllocator(size_t cache_blocks_limit=2) : cache_blocks_limit(cache_blocks_limit), used_memory(0), unused_memory(0), tot_block_id(0), occupied_id_mapper(new TempCachingBlock*[ID_LIMIT]) { + temp_allocators.push_back(this); + } + inline TempAllocator(Allocator* underlying, size_t cache_blocks_limit=2) : TempAllocator(cache_blocks_limit) { + setup(underlying); + } + ~TempAllocator(); + + size_t align_size(size_t size); + unsigned long long get_key(TempCachingBlock* block); + // free all unused memory of all sfrl allocators. + void setup(Allocator* underlying); + uint64 flags() const override { return underlying->flags(); } + const char* name() const override; + void* alloc(size_t size, size_t& allocation) override; + void free(void* mem_ptr, size_t size, const size_t& allocation) override; + void gc() override; + virtual bool share_with(size_t size, size_t allocation) override; +}; + +DECLARE_FLAG(int, use_temp_allocator); + +}//jittor + diff --git a/python/jittor/src/mem/mem_info.cc b/python/jittor/src/mem/mem_info.cc new file mode 100644 index 00000000..5de25836 --- /dev/null +++ b/python/jittor/src/mem/mem_info.cc @@ -0,0 +1,322 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#if defined(__linux__) +#include +#elif defined(__APPLE__) +#include +#include +#include +#include +#elif defined(_WIN32) +#include +#endif +#ifndef _WIN32 +#include +#endif + +#include "var.h" +#include "op.h" +#include "var_holder.h" +#include "graph.h" +#include "misc/cuda_flags.h" +#include "mem/allocator/sfrl_allocator.h" +#include "mem/allocator/stat_allocator.h" +#include "mem/allocator/temp_allocator.h" +#include "mem/mem_info.h" +#include "mem/swap.h" +#include "executor.h" + +namespace jittor { + +struct FloatOutput { + double value; + string scale; + int base; + string suffix; + int p=4; +}; + +std::ostream& operator<<(std::ostream& os, const FloatOutput& o) { + int w = 8; + os << std::setw(w-2-o.suffix.size()); + os << std::setprecision(o.p); + uint i=0; + double k = o.value; + for (; i+1* _grad_backup_ptr = nullptr; + +void display_memory_info(const char* fileline, bool dump_var, bool red_color) { + int p = 3; + Log log(fileline, red_color?'e':'i', 0); + log << "\n=== display_memory_info ===\n"; + log << "total_cpu_ram:" << + FloatOutput{(double)mem_info.total_cpu_ram, " KMG", 1024, "B"}; + log << "total_device_ram:" << + FloatOutput{(double)mem_info.total_cuda_ram, " KMG", 1024, "B"} >> "\n"; + log << "hold_vars:" << hold_vars.size() + << "lived_vars:" << Var::number_of_lived_vars + << "lived_ops:" << Op::number_of_lived_ops >> '\n'; + if (_grad_backup_ptr) + log << "jtorch_grad_vars:" << _grad_backup_ptr->size() >> '\n'; + + // get the oldest var + if (trace_py_var) { + vector queue; + auto t = ++tflag_count; + for (auto& vh : hold_vars) + if (vh->var->tflag != t) { + vh->var->tflag = t; + queue.push_back(vh->var); + } + bfs_both(queue, [](Node*){return true;}); + static unordered_map cnt; + auto cnt_bk = cnt; + map stat; + for (auto* node : queue) { + auto &x = cnt[node->id]; + x++; + if (x == 3 && node->is_var()) { + LOGe << node; + } + stat[x]++; + } + for (auto x : cnt_bk) { + if (x.second == cnt[x.first]) { + cnt.erase(x.first); + } + } + LOGe << "appear time -> node cnt:" << stat; + if (lived_nodes_id.size()) { + LOGe << "lived_nodes cnt:" << lived_nodes_id.size(); + Node* not_found=nullptr; + int not_found_cnt = 0; + for (auto nid : lived_nodes_id) { + if (!cnt.count(nid.first)) { + not_found_cnt ++; + if (!not_found) not_found = nid.second; + } + } + LOGe << "Total not_found:" << not_found_cnt; + if (not_found) + LOGe << "not found node:" << not_found; + if (_grad_backup_ptr) { + Node* not_found_grad=nullptr; + int parent_id = 0; + int not_found_grad_cnt = 0; + for (auto& gid : *_grad_backup_ptr) { + if (!lived_nodes_id.count(gid.first)) { + not_found_grad_cnt ++; + if (!not_found_grad) { + not_found_grad = gid.second.ptr; + parent_id = gid.first; + } + } + } + LOGe << "Grad not found cnt:" << not_found_grad_cnt; + if (not_found_grad) { + LOGe << "grad not found node" << not_found_grad; + LOGe << "parent id:" << parent_id; + } + } + } + } + + if (use_stat_allocator) { + log << "stat:" << use_stat_allocator; + log << "total alloc:" << FloatOutput{(double)(stat_allocator_total_alloc_byte + - stat_allocator_total_free_byte), " KMG", 1024, "B"}; + log << "total alloc call:" << FloatOutput{(double)(stat_allocator_total_alloc_call + - stat_allocator_total_free_call), " KMG", 1000, ""} + >> '(' >> stat_allocator_total_alloc_call >> '/' >> + stat_allocator_total_free_call >> ")\n"; + } + int64 all_total = 0, gpu_total = 0, cpu_total = 0; + for (auto& a : SFRLAllocator::sfrl_allocators) { + auto total = a->used_memory + a->unused_memory; + all_total += total; + a->is_cuda() ? gpu_total += total : cpu_total += total; + log << "name:" << a->name() << "is_device:" << a->is_cuda() + << "used:" << FloatOutput{(double)a->used_memory, " KMG", 1024, "B"} + >> "(" >> std::setprecision(p) >> a->used_memory*100.0 / total >> "%)" + << "unused:" << FloatOutput{(double)a->unused_memory, " KMG", 1024, "B"} + >> "(" >> std::setprecision(p) >> a->unused_memory*100.0 / total >> "%)"; + + if (a->large_blocks.blocks.size()) { + size_t largest_block = 0; + auto block = a->large_blocks.blocks.rbegin()->second; + largest_block = a->large_blocks.blocks.rbegin()->first.first; + // unused largest block + log << "ULB:" << FloatOutput{(double)largest_block, " KMG", 1024, "B"}; + log << "ULBO:" << FloatOutput{(double)block->origin_size, " KMG", 1024, "B"}; + int dump_block_info = 0; + if (a->is_cuda() && dump_block_info) { + unordered_map visited; + for (auto& kv : a->large_blocks.blocks) { + LOGir << "dump block info" << kv.second->size << kv.second->origin_size << kv.second->id; + auto s = kv.second; + while (s->prev != nullptr) s = s->prev; + if (visited.count((void*)s)) continue; + visited[s] = 1; + while (s) { + LOGir << " " << s->id << s->size << s->occupied; + s = s->next; + } + } + } + }; + + log << "total:" << FloatOutput{(double)total, " KMG", 1024, "B"} >> "\n"; + } + if (use_temp_allocator && exe.temp_allocator) { + for (auto& a : TempAllocator::temp_allocators) { + auto total = a->used_memory + a->unused_memory; + all_total += total; + a->is_cuda() ? gpu_total += total : cpu_total += total; + log << "name:" << a->name() << "is_device:" << a->is_cuda() + << "used:" << FloatOutput{(double)a->used_memory, " KMG", 1024, "B"} + >> "(" >> std::setprecision(p) >> a->used_memory*100.0 / total >> "%)" + << "unused:" << FloatOutput{(double)a->unused_memory, " KMG", 1024, "B"} + >> "(" >> std::setprecision(p) >> a->unused_memory*100.0 / total >> "%)" + << "total:" << FloatOutput{(double)total, " KMG", 1024, "B"} >> "\n"; + } + } + log << "cpu&gpu:" << FloatOutput{(double)all_total, " KMG", 1024, "B"} + << "gpu:" << FloatOutput{(double)gpu_total, " KMG", 1024, "B"} + << "cpu:" << FloatOutput{(double)cpu_total, " KMG", 1024, "B"} >> '\n'; + + size_t cpu_free = 0; +#if defined(__linux__) + cpu_free = get_avphys_pages() * sysconf(_SC_PAGESIZE); +#elif defined(__APPLE__) + { + mach_msg_type_number_t count = HOST_VM_INFO_COUNT; + vm_statistics_data_t vmstat; + if (KERN_SUCCESS == host_statistics(mach_host_self(), HOST_VM_INFO, (host_info_t)&vmstat, &count)) { + cpu_free = vmstat.free_count * sysconf(_SC_PAGESIZE); + } + } +#endif + size_t gpu_free = 0, _gpu_total = 0; + (void)gpu_free; (void)_gpu_total; + #ifdef HAS_CUDA + cudaMemGetInfo(&gpu_free, &_gpu_total); + #endif + log << "free: cpu(">>FloatOutput{(double)cpu_free, " KMG", 1024, "B"} + >> ") gpu(">>FloatOutput{(double)gpu_free, " KMG", 1024, "B"} >> ")\n"; + static int64 swap_total_last = 0; + log << "swap: total(">>FloatOutput{(double)swap_total, " KMG", 1024, "B"} + >> ") last(">>FloatOutput{(double)(swap_total-swap_total_last), " KMG", 1024, "B"} >> ")\n"; + swap_total_last = swap_total; + if (dump_var) { + vector queue; + unordered_set visited; + for (auto& vh : hold_vars) + if (!visited.count(vh->var)) { + queue.push_back(vh->var); + visited.insert(vh->var); + } + int64 cum = 0; + for (int i=0; iinputs()) + if (!visited.count(n)) { + queue.push_back(n); + visited.insert(n); + } + for (auto* n : queue[i]->outputs()) + if (!visited.count(n)) { + queue.push_back(n); + visited.insert(n); + } + if (queue[i]->is_var()) { + auto v = (Var*)queue[i]; + if (v->size>=0 && v->mem_ptr) { + cum += v->size; + log << FloatOutput{(double)v->size, " KMG", 1024, "B"} + >> "(" >> std::setprecision(p) >> v->size*100.0 / all_total >> "%)" + << FloatOutput{(double)cum, " KMG", 1024, "B"} + >> "(" >> std::setprecision(p) >> cum*100.0 / all_total >> "%)" + << v >> "\n"; + if (v->size == 100*64*112*112*4) { + for (auto op : v->outputs()) + log << "\t" << op << '\n'; + } + } + } + } + } + log >> "===========================\n"; + + if (red_color) { + bool gpu_overflow = (double)gpu_total>(double)mem_info.total_cuda_ram*0.95; + bool cpu_overflow = (double)cpu_total>(double)mem_info.total_cpu_ram*0.95; + // cpu total too small, not possible + if (mem_info.total_cpu_ram < 100000) + cpu_overflow = false; + if(gpu_overflow || cpu_overflow) { + double used = gpu_overflow ? (double)gpu_total : (double)cpu_total; + double total = gpu_overflow ? (double)mem_info.total_cuda_ram : (double)mem_info.total_cpu_ram; + log.end(); + LOGf << "\n*******************\n" + >> (gpu_overflow?"GPU":"CPU") << "memory is overflow, please reduce your batch_size or data size!\nTotal:" << FloatOutput{(double)total, " KMG", 1024, "B"} << "Used:" << FloatOutput{(double)used, " KMG", 1024, "B"}; + } else + return; + } + + log.end(); +} + +EXTERN_LIB vector sigquit_callback; + +void meminfo_callback() { + display_memory_info(); +} + +MemInfo::MemInfo() { + +#if defined(__linux__) + struct sysinfo info = {0}; + sysinfo(&info); + total_cpu_ram = info.totalram; +#elif defined(__APPLE__) + int mib[] = {CTL_HW, HW_MEMSIZE}; + size_t len=sizeof(total_cpu_ram); + sysctl(mib, 2, &total_cpu_ram, &len, NULL, 0); +#elif defined(_WIN32) + MEMORYSTATUSEX statex; + GlobalMemoryStatusEx (&statex); + total_cpu_ram = statex.ullTotalPhys; +#endif + + total_cuda_ram = 0; +#ifdef HAS_CUDA + size_t gpu_free = 0, _gpu_total = 0; + cudaMemGetInfo(&gpu_free, &_gpu_total); + total_cuda_ram = _gpu_total; +#endif + // sigquit_callback.push_back(&meminfo_callback); + + int64 all_total = 0, gpu_total = 0, cpu_total = 0; + for (auto& a : SFRLAllocator::sfrl_allocators) { + auto total = a->used_memory + a->unused_memory; + a->is_cuda() ? gpu_total += total : cpu_total += total; + } + total_cpu_used = cpu_total; + total_cuda_used = gpu_total; +} + +MemInfo mem_info; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/mem/mem_info.h b/python/jittor/src/mem/mem_info.h new file mode 100644 index 00000000..f973a842 --- /dev/null +++ b/python/jittor/src/mem/mem_info.h @@ -0,0 +1,36 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +namespace jittor { + +// @pyjt(display_memory_info) +void display_memory_info(const char* fileline="", bool dump_var=false, bool red_color=false); + +// @pyjt(MemInfo) +struct MemInfo { + // @pyjt(total_cpu_ram) + int64 total_cpu_ram; + // @pyjt(total_cuda_ram) + int64 total_cuda_ram; + // @pyjt(total_cpu_used) + int64 total_cpu_used; + // @pyjt(total_cuda_used) + int64 total_cuda_used; + + inline MemInfo(const MemInfo&) = default; + + MemInfo(); +}; + +EXTERN_LIB MemInfo mem_info; + +// @pyjt(get_mem_info) +inline MemInfo get_mem_info() { return MemInfo(); } + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/mem/swap.cc b/python/jittor/src/mem/swap.cc new file mode 100644 index 00000000..4293e235 --- /dev/null +++ b/python/jittor/src/mem/swap.cc @@ -0,0 +1,233 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#ifdef HAS_CUDA +#include +#endif +#include +#include +#ifndef _MSC_VER +#include +#endif +#include "var.h" +#include "mem/swap.h" +#include "mem/mem_info.h" + +namespace jittor { + +int64 swap_timestamp; +int64 swap_total; +constexpr int64 SWAP_BUF_SIZE = 1<<23; // 8M +extern string cache_path; +static int _pid = getpid(); + +DEFINE_FLAG(int64, cpu_mem_limit, -1, "cpu_mem_limit"); +DEFINE_FLAG(int64, device_mem_limit, -1, "device_mem_limit"); + +struct Swap { + map, Var*> lived; +}; + +unordered_map swaps; + +void swap_to_disk(Var* x, Swap& swap) { + swap_total += x->size; + ASSERT(!x->flags.get(NodeFlags::_is_swapped)); + string path = cache_path + "/tmp/" + S(_pid) + "-" + S(x->id) + ".bin"; + #ifdef HAS_CUDA + if (x->allocator->is_cuda()) { + static char* buffer = new char[SWAP_BUF_SIZE]; + auto* memptr = (char*)x->mem_ptr; + auto* fd = fopen(path.c_str(), "wb"); + CHECK(fd) << "swap file open failed:" << path << x; + for (int64 i=0; isize; i+=SWAP_BUF_SIZE) { + int64 cp_size = std::min(x->size-i, SWAP_BUF_SIZE); + cudaMemcpy(buffer, memptr+i, cp_size, cudaMemcpyDeviceToHost); + auto res = fwrite(buffer, cp_size, 1, fd); + if (res==1) { + fclose(fd); + LOGf << "swap file write failed" << path << x; + } + } + fclose(fd); + } else + #endif + { + auto* fd = fopen(path.c_str(), "wb"); + auto res = fwrite(x->mem_ptr, x->size, 1, fd); + CHECK(res==1) << "failed to write swap file" << path << res << x->size << x; + fclose(fd); + } + auto iter = swap.lived.find({x->size, x->id}); + ASSERT(iter != swap.lived.end()); + swap.lived.erase(iter); + x->allocator->free(x->mem_ptr, x->size, x->allocation); + x->mem_ptr = nullptr; + x->allocator = nullptr; + x->allocation = 0; + x->flags.set(NodeFlags::_is_swapped); +} + +bool alloc_with_swap(Var* x, Allocator* allocator, bool force) { + + auto& swap = swaps[allocator]; + if (x->allocator) { + // shared memory, no need alloc + if (x->alloc(allocator)) { + swap.lived[{x->size, x->id}] = x; + return true; + } + } + bool is_cpu = !allocator->is_cuda(); + int64 limit = is_cpu ? cpu_mem_limit : device_mem_limit; + if (limit < 0) limit = 1ll<<60; + if (allocator->used_memory + allocator->unused_memory + x->size > limit) + allocator->gc(); + if (force && allocator->used_memory + allocator->unused_memory + x->size > limit) { + auto iter = swap.lived.upper_bound({x->size, -1}); + auto unused_target = allocator->unused_memory + x->size; + while (iter != swap.lived.end()) { + auto* var = iter->second; + iter++; + if (var->tflag == swap_timestamp) + continue; + ASSERT(var->mem_ptr) << var->exist() << (void*)var << iter->first << (display_memory_info(), 1); + if (!is_cpu) { + // try move to cpu + if (!move_with_swap(var, cpu_allocator, false)) + swap_to_disk(var, swap); + } else + swap_to_disk(var, swap); + if (allocator->used_memory + allocator->unused_memory + x->size <= limit || allocator->unused_memory >= unused_target) break; + } + // if still no space, swap other smaller var + if (!(allocator->used_memory + allocator->unused_memory + x->size <= limit || allocator->unused_memory >= unused_target)) { + auto iter = swap.lived.end(); + if (swap.lived.size()) iter = std::prev(iter); + while (iter != swap.lived.end()) { + auto var = iter->second; + iter = iter==swap.lived.begin() ? swap.lived.end() : std::prev(iter); + if (var->tflag == swap_timestamp) + continue; + ASSERT(var->mem_ptr) << x << var; + if (!is_cpu) { + // try move to cpu + if (!move_with_swap(var, cpu_allocator, false)) + swap_to_disk(var, swap); + } else + swap_to_disk(var, swap); + allocator->gc(); + if (allocator->used_memory + allocator->unused_memory + x->size <= limit || allocator->unused_memory >= unused_target) break; + } + if (!(allocator->used_memory + allocator->unused_memory + x->size <= limit || allocator->unused_memory >= unused_target)) { + display_memory_info(); + LOGw << "unable to alloc var" << x; + } + } + } + if (x->alloc(allocator)) { + swap.lived[{x->size, x->id}] = x; + return true; + } + return false; +} + +void free_with_swap(Var* x) { + if (x->flags.get(NodeFlags::_is_swapped)) { + string path = cache_path + "/tmp/" + S(_pid) + "-" + S(x->id) + ".bin"; + if (remove(path.c_str()) != 0) + LOGe << "failed to remove swap file" << path << x->shape << x->dtype(); + } else { + if (!x->mem_ptr) return; + auto& swap = swaps[x->allocator]; + auto iter = swap.lived.find({x->size, x->id}); + if (iter != swap.lived.end()) + swap.lived.erase(iter); + x->allocator->free(x->mem_ptr, x->size, x->allocation); + x->mem_ptr = nullptr; + x->allocator = nullptr; + x->allocation = 0; + } +} + +bool move_with_swap(Var* x, Allocator* allocator, bool force) { + if (allocator == x->allocator) return true; + swap_total += x->size; + Allocation allocation(x->mem_ptr, x->allocation, x->size, x->allocator); + x->mem_ptr = nullptr; + x->allocator = nullptr; + x->allocation = 0; + if (!alloc_with_swap(x, allocator, force)) { + x->mem_ptr = allocation.ptr; + x->allocator = allocation.allocator; + x->allocation = allocation.allocation; + allocation.ptr = nullptr; + allocation.allocation = 0; + return false; + } + if (x->flags.get(NodeFlags::_is_swapped)) { + string path = cache_path + "/tmp/" + S(_pid) + "-" + S(x->id) + ".bin"; + #ifdef HAS_CUDA + if (x->allocator->is_cuda()) { + static char* buffer = new char[SWAP_BUF_SIZE]; + auto* memptr = (char*)x->mem_ptr; + auto* fd = fopen(path.c_str(), "rb"); + CHECK(fd) << "swap file open failed:" << path << x; + for (int64 i=0; isize; i+=SWAP_BUF_SIZE) { + int64 cp_size = std::min(x->size-i, SWAP_BUF_SIZE); + auto res = fread(buffer, cp_size, 1, fd); + cudaMemcpy(memptr+i, buffer, cp_size, cudaMemcpyHostToDevice); + if (res != 1) { + fclose(fd); + LOGf << "swap file read failed" << path << x; + } + } + fclose(fd); + } else + #endif + { + auto* fd = fopen(path.c_str(), "rb"); + auto res = fread(x->mem_ptr, x->size, 1, fd); + CHECK(res==1); + fclose(fd); + } + + if (remove(path.c_str()) != 0) + LOGe << "failed to remove swap file" << path << x->shape << x->dtype(); + x->flags.set(NodeFlags::_is_swapped, 0); + } else { + #ifdef HAS_CUDA + if (x->allocator->is_cuda()) { + if (allocation.allocator->is_cuda()) + cudaMemcpy(x->mem_ptr, allocation.ptr, x->size, cudaMemcpyDeviceToDevice); + else + cudaMemcpy(x->mem_ptr, allocation.ptr, x->size, cudaMemcpyHostToDevice); + } else + if (allocation.allocator->is_cuda()) { + cudaMemcpy(x->mem_ptr, allocation.ptr, x->size, cudaMemcpyDeviceToHost); + } else + #endif + { + std::memcpy(x->mem_ptr, allocation.ptr, x->size); + } + } + if (allocation.ptr) { + auto& swap = swaps[allocation.allocator]; + auto iter = swap.lived.find({x->size, x->id}); + if (iter != swap.lived.end()) + swap.lived.erase(iter); + } + return true; +} + +void registe_swap(Var* x) { + auto& swap = swaps[x->allocator]; + swap.lived[{x->size, x->id}] = x; +} + +} diff --git a/python/jittor/src/mem/swap.h b/python/jittor/src/mem/swap.h new file mode 100644 index 00000000..ef8987c8 --- /dev/null +++ b/python/jittor/src/mem/swap.h @@ -0,0 +1,100 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** + +/* +var heap + map allocator + lived: map size, var + swaped: set var + + if gpu is full, + if cpu is ok: + swap to cpu + else + swap to disk + + operation + take over all alloc, + mark time stamp + get_allocation(allocator, var) + + alloc(allocator, size) -> Allocation + add_ts, mark_ts + signin, signout + move_to(own alloc) + +global + mem_save_mode + cpu_mem_limit_n + device_mem_limit_n + +share_with handle: + free var, until allocator reduce size + +TODO: + change exe.allocator->alloc to exe.temp_allocator->alloc + handle cutt jt_alloc + handle cupy jittor_cuda_malloc + search share_with + search migrate + check Allocation move + migrate_to_cpu + migrate_to_gpu + array op + fetch op + !!disable dual allocator, reuse array + handle foreign allocator, only handle cpu allocator and gpu allocator + +code change: + free var + alloc var + executor mark timestamp + migrate_to_cpu + migrate_to_gpu + array op: finish imm + fetch op + if is cached, access? + item, data, numpy, all calling migrate_to_cpu, handle in migrate_to_cpu + JT_SAVE_MEM env, global env for + +*/ +#pragma once +#include "common.h" +#include "mem/allocator.h" +#include "var.h" + +namespace jittor { + +#ifdef JT_SAVE_MEM +#if JT_SAVE_MEM != 0 +constexpr int save_mem = 1; +#else +constexpr int save_mem = 0; +#endif +#else +constexpr int save_mem = 0; +#endif +extern int64 swap_timestamp; +extern int64 swap_total; + +DECLARE_FLAG(int64, cpu_mem_limit); +DECLARE_FLAG(int64, device_mem_limit); + +bool alloc_with_swap(Var* x, Allocator* allocator, bool force); +void free_with_swap(Var* x); +bool move_with_swap(Var* x, Allocator* allocator, bool force); +void registe_swap(Var* x); + +inline void check_and_swap_out(Var* x, Allocator* allocator) { + if (x->flags.get(NodeFlags::_is_swapped)) + move_with_swap(x, allocator, true); +} + + +} diff --git a/python/jittor/src/memory_profiler.cc b/python/jittor/src/memory_profiler.cc new file mode 100644 index 00000000..062050ba --- /dev/null +++ b/python/jittor/src/memory_profiler.cc @@ -0,0 +1,180 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "memory_profiler.h" +#include "graph.h" +#include "var_holder.h" +#include "var.h" +#include "mem/allocator/sfrl_allocator.h" +#include +#include +#include +#include "pybind/py_var_tracer.h" + +namespace jittor { + +//TODO reuse from mem_info.cc +struct FloatOutput_ { + double value; + string scale; + int base; + string suffix; + int p=4; +}; + +inline std::ostream& operator<<(std::ostream& os, const FloatOutput_& o) { + int w = 8; + os << std::setw(w-2-o.suffix.size()); + os << std::setprecision(o.p); + uint i=0; + double k = o.value; + for (; i+1 MemoryProfiler::get_memory_info() { + ASSERT(profile_memory_enable); + size_t used = 0; + size_t unused = 0; + //TODO add mssfrl allocator + for (auto& a : SFRLAllocator::sfrl_allocators) { + used += a->used_memory; + unused += a->unused_memory; + } + return std::make_pair(used, unused); +} + +void MemoryProfiler::check() { + ASSERT(profile_memory_enable); + std::pair mem_info = get_memory_info(); + if (mem_info.first > max_used_memory_size) { + max_used_memory_size = mem_info.first; + + allocations.clear(); + size_t memory_size = 0; + std::vector>, size_t>> live_vars; + vector queue, queue2; + + auto t = ++tflag_count; + for (auto& vh : hold_vars) + if (vh->var->tflag != t) { + vh->var->tflag = t; + queue.push_back(vh->var); + } + bfs_both(queue, [](Node*){return true;}); + vector backup_custom_data; + backup_custom_data.resize(queue.size()); + for (int i=0; icustom_data; + toplogical_sort_forward(queue, queue2, [](Node*){}); + for (int i=0; icustom_data = backup_custom_data[i]; + queue.swap(queue2); + int64 cpu_sum = 0; + for (Node* node : queue) { + if (node->is_var()) { + Var* var = (Var*)node; + if (var->mem_ptr != nullptr) { + if (profile_memory_enable == 2 && !var->allocator->is_cuda()) { + cpu_sum += var->size; + continue; + } + vector stacks = get_node_trace(var); + if (stacks.size() == 0) { + stacks.push_back(Stack()); + } + auto alloc = std::make_pair((void*)var->allocator, (void*)var->allocation); + if (!allocations.count(alloc)) { + std::stringstream stream; + stream << var; + live_vars.push_back(std::make_pair(std::make_pair(stream.str(), stacks), var->size)); + allocations[alloc] = 1; + memory_size += var->size; + } + } + } + } + max_live_vars = live_vars; + max_memory_size = memory_size; + } +} + +bool MemoryProfiler::cmp(const std::pair>, size_t>& a, const std::pair>, size_t>& b) { + return a.second > b.second; +} + +void MemoryProfiler::display_max_memory_info() { + ASSERT(profile_memory_enable); + Log log("", 'i', 0); + std::sort(max_live_vars.begin(), max_live_vars.end(), cmp); + log << "\n=====display_max_memory_info=====\n"; + log << "max used memory" << FloatOutput_{(double)max_used_memory_size, " KMG", 1024, "B"} << "\n"; + log << "max var memory" << FloatOutput_{(double)max_memory_size, " KMG", 1024, "B"} << "\n\n"; + log << "[Size]" << "[Percent]" << "[Var Info]" << "\n"; + for (int i = 0; i < max_live_vars.size(); ++i) { + log << FloatOutput_{(double)max_live_vars[i].second, " KMG", 1024, "B"} + << double(max_live_vars[i].second) / max_memory_size * 100 << "%" + << max_live_vars[i].first.first + << max_live_vars[i].first.second[0].file_path + ":" + std::to_string(max_live_vars[i].first.second[0].lineno) + << "\n\n"; + } + log << "=========================\n"; + log.end(); +} + +void display_max_memory_info() { + ASSERT(profile_memory_enable); + memory_profiler.display_max_memory_info(); +} + +string MemoryProfiler::get_max_memory_info() { + ASSERT(profile_memory_enable); + std::stringstream out; + string div1 = "[!@#div1!@#]"; + string div2 = "[!@#div2!@#]"; + string div3 = "[!@#div3!@#]"; + + std::sort(max_live_vars.begin(), max_live_vars.end(), cmp); + out << max_memory_size; + for (int i = 0; i < max_live_vars.size(); ++i) { + out << div1; + out << max_live_vars[i].first.first << div2; + out << max_live_vars[i].second << div2; + for (int j = 0; j < max_live_vars[i].first.second.size(); ++j) { + out << max_live_vars[i].first.second[j].file_path + ":" + std::to_string(max_live_vars[i].first.second[j].lineno) << div3 + << max_live_vars[i].first.second[j].module_name << div3 + << max_live_vars[i].first.second[j].module_type << div2; + } + } + return out.str(); +} + +string get_max_memory_info() { + ASSERT(profile_memory_enable); + return memory_profiler.get_max_memory_info(); +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/memory_profiler.h b/python/jittor/src/memory_profiler.h new file mode 100644 index 00000000..7bc555e1 --- /dev/null +++ b/python/jittor/src/memory_profiler.h @@ -0,0 +1,46 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include "mem/allocator.h" +#include +#include +#include +#include "var.h" +#include "pybind/py_var_tracer.h" +namespace jittor { + +// @pyjt(display_max_memory_info) +void display_max_memory_info(); +// @pyjt(get_max_memory_info) +string get_max_memory_info(); + +struct MemoryProfiler { + std::map, size_t> allocations; + // Max Infos + vector>, size_t>> max_live_vars; + size_t max_used_memory_size; + size_t max_memory_size; + + + MemoryProfiler(); + static bool cmp(const std::pair>, size_t>& a, const std::pair>, size_t>& b); + void clear(); + void check(); + std::pair get_memory_info(); + void display_max_memory_info(); + string get_max_memory_info(); +}; + +EXTERN_LIB MemoryProfiler memory_profiler; + +DECLARE_FLAG(int, profile_memory_enable); + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/misc/cpu_atomic.cc b/python/jittor/src/misc/cpu_atomic.cc new file mode 100644 index 00000000..e99d05d9 --- /dev/null +++ b/python/jittor/src/misc/cpu_atomic.cc @@ -0,0 +1,13 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "misc/cpu_atomic.h" + +namespace jittor { + +std::atomic_flag lock = ATOMIC_FLAG_INIT;; + +} // jittor diff --git a/python/jittor/src/misc/cpu_atomic.h b/python/jittor/src/misc/cpu_atomic.h new file mode 100644 index 00000000..26ae212b --- /dev/null +++ b/python/jittor/src/misc/cpu_atomic.h @@ -0,0 +1,88 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include "common.h" + +namespace jittor { + +EXTERN_LIB std::atomic_flag lock; + +struct spin_lock_guard { + inline spin_lock_guard() { + while (lock.test_and_set(std::memory_order_acquire)); + } + inline ~spin_lock_guard() { + lock.clear(std::memory_order_release); + } +}; + +template +T cpu_atomic_add(T* a, T b) { + spin_lock_guard _; + auto old = *a; + a[0] += b; + return old; +} + +template +T cpu_atomic_mul(T* a, T b) { + spin_lock_guard _; + auto old = *a; + a[0] *= b; + return old; +} + +template +T cpu_atomic_sub(T* a, T b) { + spin_lock_guard _; + auto old = *a; + a[0] -= b; + return old; +} + +template +T cpu_atomic_min(T* a, T b) { + spin_lock_guard _; + auto old = *a; + a[0] = std::min(old, b); + return old; +} + +template +T cpu_atomic_max(T* a, T b) { + spin_lock_guard _; + auto old = *a; + a[0] = std::max(old, b); + return old; +} + +template +T cpu_atomic_and(T* a, T b) { + spin_lock_guard _; + auto old = *a; + a[0] = old & b; + return old; +} + +template +T cpu_atomic_or(T* a, T b) { + spin_lock_guard _; + auto old = *a; + a[0] = old | b; + return old; +} + +template +T cpu_atomic_xor(T* a, T b) { + spin_lock_guard _; + auto old = *a; + a[0] = old ^ b; + return old; +} + +} // jittor diff --git a/python/jittor/src/misc/cpu_math.cc b/python/jittor/src/misc/cpu_math.cc new file mode 100644 index 00000000..4dc68f91 --- /dev/null +++ b/python/jittor/src/misc/cpu_math.cc @@ -0,0 +1,58 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#define _USE_MATH_DEFINES +#include +#include +#include "misc/cpu_math.h" + +namespace jittor { + +#define CENTRAL_RANGE 0.7 + +template +static inline typename std::enable_if::value, T>::type +calc_erfinv(T y) { +/* Function to calculate inverse error function. Rational approximation +is used to generate an initial approximation, which is then improved to +full accuracy by two steps of Newton's method. Code is a direct +translation of the erfinv m file in matlab version 2.0. +Author: Gary L. Pavlis, Indiana University +Date: February 1996 +*/ + T x, z, num, dem; /*working variables */ + /* coefficients in rational expansion */ + T a[4]={ 0.886226899, -1.645349621, 0.914624893, -0.140543331}; + T b[4]={-2.118377725, 1.442710462, -0.329097515, 0.012229801}; + T c[4]={-1.970840454, -1.624906493, 3.429567803, 1.641345311}; + T d[2]={ 3.543889200, 1.637067800}; + T y_abs = std::abs(y); + if(y_abs > 1.0) return std::numeric_limits::quiet_NaN(); + if(y_abs == 1.0) return std::copysign(std::numeric_limits::infinity(), y); + if(y_abs <= static_cast(CENTRAL_RANGE)) { + z = y * y; + num = (((a[3]*z + a[2])*z + a[1])*z + a[0]); + dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0]) * z + static_cast(1.0)); + x = y * num / dem; + } + else{ + z = std::sqrt(-std::log((static_cast(1.0)-y_abs)/static_cast(2.0))); + num = ((c[3]*z + c[2])*z + c[1]) * z + c[0]; + dem = (d[1]*z + d[0])*z + static_cast(1.0); + x = std::copysign(num, y) / dem; + } + /* Two steps of Newton-Raphson correction */ + x = x - (std::erf(x) - y) / ((static_cast(2.0)/static_cast(std::sqrt(3.14159265358979323846)))*std::exp(-x*x)); + x = x - (std::erf(x) - y) / ((static_cast(2.0)/static_cast(std::sqrt(3.14159265358979323846)))*std::exp(-x*x)); + + return x; +} + +float _erfinv(float y) { return calc_erfinv(y); }; +double _erfinv(double y) { return calc_erfinv(y); }; + +} + diff --git a/python/jittor/src/misc/cpu_math.h b/python/jittor/src/misc/cpu_math.h new file mode 100644 index 00000000..2a0f71cd --- /dev/null +++ b/python/jittor/src/misc/cpu_math.h @@ -0,0 +1,16 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +namespace jittor { + +float _erfinv(float y); +double _erfinv(double y); + +} + diff --git a/python/jittor/src/misc/cstr.h b/python/jittor/src/misc/cstr.h new file mode 100644 index 00000000..f6fd886f --- /dev/null +++ b/python/jittor/src/misc/cstr.h @@ -0,0 +1,56 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include "common.h" + +namespace jittor { + +struct cstr { + unique_ptr ptr; + + inline const char* c_str() const { return ptr ? ptr.get() : ""; } + inline cstr& operator=(const char* s) { + auto len = std::strlen(s); + ptr.reset(new char[len+1]); + std::memcpy(ptr.get(), s, len+1); + return *this; + } + inline cstr& operator=(const string& s) { + auto len = s.size(); + ptr.reset(new char[len+1]); + std::memcpy(ptr.get(), s.c_str(), len+1); + return *this; + } + inline cstr& operator=(cstr&& s) { + ptr = move(s.ptr); + return *this; + } + inline cstr& operator=(const cstr& s) { + *this = s.c_str(); + return *this; + } + inline cstr(const cstr& s) { + *this = s.c_str(); + } + inline cstr() {} + inline size_t size() const { return ptr ? std::strlen(ptr.get()) : 0; } +}; + +inline std::ostream& operator<<(std::ostream& os, const cstr& p) { + return os << p.c_str(); +} + + +inline std::istream& operator>>(std::istream& is, cstr& p) { + string s; + is >> s; + p = s; + return is; +} + +} // jittor diff --git a/python/jittor/src/misc/cuda_atomic.h b/python/jittor/src/misc/cuda_atomic.h new file mode 100644 index 00000000..15b39412 --- /dev/null +++ b/python/jittor/src/misc/cuda_atomic.h @@ -0,0 +1,286 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#ifndef IS_ROCM +#include +#endif +#include "common.h" + +namespace jittor { + +__device__ inline static int floatToOrderedInt(float floatVal) { + int intVal = __float_as_int( floatVal ); + return (intVal >= 0 ) ? intVal : intVal ^ 0x7FFFFFFF; +} +__device__ inline static float orderedIntToFloat(int intVal) { + return __int_as_float((intVal >= 0) ? intVal : intVal ^ 0x7FFFFFFF); +} + +__global__ inline static void fix_float_kernel(float* x, int num) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int tnum = gridDim.x * blockDim.x; + for (int i=tid; i= 0 ) ? intVal : intVal ^ 0x7FFFFFFFFFFFFFFF; +} +__device__ inline static double orderedIntToFloat(long long intVal) { + return __longlong_as_double((intVal >= 0) ? intVal : intVal ^ 0x7FFFFFFFFFFFFFFF); +} + +__global__ inline static void fix_float_kernel(double* x, int num) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int tnum = gridDim.x * blockDim.x; + for (int i=tid; i +inline static void fix_float(T* x, int num) { + fix_float_kernel<<>>(x, num); +} + +template __device__ +T cuda_atomic_max(T* a, T b) { + return atomicMax(a, b); +} + +template<> __device__ +inline float cuda_atomic_max(float* a, float b) { + return orderedIntToFloat(atomicMax((int *)a, floatToOrderedInt(b))); +} + +#ifndef NO_ATOMIC64 +template<> __device__ +inline double cuda_atomic_max(double* a, double b) { + return orderedIntToFloat(atomicMax((long long *)a, floatToOrderedInt(b))); +} +#endif + +template __device__ +T cuda_atomic_min(T* a, T b) { + return atomicMin(a, b); +} + +template<> __device__ +inline float cuda_atomic_min(float* a, float b) { + return orderedIntToFloat(atomicMin((int *)a, floatToOrderedInt(b))); +} + +#ifndef NO_ATOMIC64 +template<> __device__ +inline double cuda_atomic_min(double* a, double b) { + return orderedIntToFloat(atomicMin((long long *)a, floatToOrderedInt(b))); +} +#endif + +template struct int_mapper { + typedef T src; + typedef T target; + inline static __device__ target to_int(src a) { return a; } + inline static __device__ target* to_intp(src* a) { return a; } + inline static __device__ src from_int(target a) { return a; } +}; + +template <> struct int_mapper { + typedef float src; + typedef int target; + inline static __device__ target to_int(src a) { return __float_as_int(a); } + inline static __device__ target* to_intp(src* a) { return (target*)a; } + inline static __device__ src from_int(target a) { return __int_as_float(a); } +}; + +template <> struct int_mapper<__half> { + typedef __half src; + typedef unsigned short target; + inline static __device__ target to_int(src a) { return __half_as_ushort(a); } + inline static __device__ target* to_intp(src* a) { return (target*)a; } + inline static __device__ src from_int(target a) { return __ushort_as_half(a); } +}; +#if CUDA_ARCH >= 800 +template <> struct int_mapper<__nv_bfloat16> { + typedef __nv_bfloat16 src; + typedef unsigned short target; + inline static __device__ target to_int(src a) { return __bfloat16_as_ushort(a); } + inline static __device__ target* to_intp(src* a) { return (target*)a; } + inline static __device__ src from_int(target a) { return __ushort_as_bfloat16(a); } +}; +#endif + +template <> struct int_mapper { + typedef double src; + typedef long long target; + inline static __device__ target to_int(src a) { return __double_as_longlong(a); } + inline static __device__ target* to_intp(src* a) { return (target*)a; } + inline static __device__ src from_int(target a) { return __longlong_as_double(a); } +}; + +template __device__ +T cuda_atomic_mul(T* a, T b) { + auto old_f = *a; + auto old = int_mapper::to_int(old_f); + auto a_i = int_mapper::to_intp(a); + while (1) { + auto assume = old; + old = atomicCAS(a_i, assume, int_mapper::to_int(old_f*b)); + old_f = int_mapper::from_int(old); + if (assume==old) break; + } + return old_f; +} + +#if CUDA_ARCH >= 800 +template<> __device__ +__half cuda_atomic_max(__half* a, __half b) { + auto old_f = *a; + auto old = int_mapper<__half>::to_int(old_f); + auto a_i = int_mapper<__half>::to_intp(a); + while (1) { + auto assume = old; + if (old_f>=b) break; + old = atomicCAS(a_i, assume, int_mapper<__half>::to_int(b)); + old_f = int_mapper<__half>::from_int(old); + if (assume==old) break; + } + return old_f; +} + +template<> __device__ +__half cuda_atomic_min(__half* a, __half b) { + auto old_f = *a; + auto old = int_mapper<__half>::to_int(old_f); + auto a_i = int_mapper<__half>::to_intp(a); + while (1) { + auto assume = old; + if (old_f<=b) break; + old = atomicCAS(a_i, assume, int_mapper<__half>::to_int(b)); + old_f = int_mapper<__half>::from_int(old); + if (assume==old) break; + } + return old_f; +} +#endif +#if CUDA_ARCH >= 800 +template<> __device__ +__nv_bfloat16 cuda_atomic_max(__nv_bfloat16* a, __nv_bfloat16 b) { + auto old_f = *a; + auto old = int_mapper<__nv_bfloat16>::to_int(old_f); + auto a_i = int_mapper<__nv_bfloat16>::to_intp(a); + while (1) { + auto assume = old; + if (old_f>=b) break; + old = atomicCAS(a_i, assume, int_mapper<__nv_bfloat16>::to_int(b)); + old_f = int_mapper<__nv_bfloat16>::from_int(old); + if (assume==old) break; + } + return old_f; +} + +template<> __device__ +__nv_bfloat16 cuda_atomic_min(__nv_bfloat16* a, __nv_bfloat16 b) { + auto old_f = *a; + auto old = int_mapper<__nv_bfloat16>::to_int(old_f); + auto a_i = int_mapper<__nv_bfloat16>::to_intp(a); + while (1) { + auto assume = old; + if (old_f<=b) break; + old = atomicCAS(a_i, assume, int_mapper<__nv_bfloat16>::to_int(b)); + old_f = int_mapper<__nv_bfloat16>::from_int(old); + if (assume==old) break; + } + return old_f; +} +#endif + +template +__device__ inline T shared_reduce_add(T a, T b) { + return a + b; +} + +template +__device__ inline T shared_reduce_mul(T a, T b) { + return a * b; +} + +template +__device__ inline T shared_reduce_max(T a, T b) { + return a > b ? a : b; +} + +template +__device__ inline T shared_reduce_min(T a, T b) { + return a < b ? a : b; +} + +template +__device__ inline T shared_reduce_and(T a, T b) { + return a & b; +} + +template +__device__ inline T shared_reduce_or(T a, T b) { + return a | b; +} + +template +__device__ inline T shared_reduce_xor(T a, T b) { + return a ^ b; +} + + +template +__device__ inline void warpReduce(volatile T* sdata, int tid) { + if (blockDim.x >= 64) + sdata[tid] = op(sdata[tid], sdata[tid + 32]); + sdata[tid] = op(sdata[tid], sdata[tid + 16]); + sdata[tid] = op(sdata[tid], sdata[tid + 8]); + sdata[tid] = op(sdata[tid], sdata[tid + 4]); + sdata[tid] = op(sdata[tid], sdata[tid + 2]); + sdata[tid] = op(sdata[tid], sdata[tid + 1]); +} + +template +__device__ inline static T shared_reduce(T u) { + __shared__ T sdata[1024]; + + int tid = threadIdx.x; + + sdata[tid] = u; + __syncthreads(); + + if (blockDim.x >= 1024 && tid < 512) { + sdata[tid] = u = op(u, sdata[tid + 512]); + } + __syncthreads(); + + if (blockDim.x >= 512 && tid < 256) { + sdata[tid] = u = op(u, sdata[tid + 256]); + } + __syncthreads(); + + if (blockDim.x >= 256 && tid < 128) { + sdata[tid] = u = op(u, sdata[tid + 128]); + } + __syncthreads(); + + if (blockDim.x >= 128 && tid < 64) { + sdata[tid] = u = op(u, sdata[tid + 64]); + } + __syncthreads(); + + if (tid < 32) + warpReduce(sdata, tid); + + return sdata[0]; +} + +} // jittor diff --git a/python/jittor/src/misc/cuda_flags.cc b/python/jittor/src/misc/cuda_flags.cc new file mode 100644 index 00000000..999385fc --- /dev/null +++ b/python/jittor/src/misc/cuda_flags.cc @@ -0,0 +1,103 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** + +#include "common.h" +#ifdef HAS_CUDA +#include +#ifdef __linux__ +#include +#include +#endif +#endif + +namespace jittor { + +DEFINE_FLAG_WITH_SETTER(int, use_cuda, 0, + "Use cuda or not. 1 for trying to use cuda, 2 for forcing to use cuda."); +DEFINE_FLAG_WITH_SETTER(int, device_id, -1, + "number of the device to used"); + +EXTERN_LIB void sync_all(bool device_sync); + +#ifdef HAS_CUDA +int get_device_count() { + static int count=-1; + if (count==-1) + cudaGetDeviceCount(&count); + return count; +} +#endif + + +void setter_use_cuda(int value) { + if (use_cuda == value) return; +#ifdef HAS_CUDA + if (value) { + int count=0; + cudaGetDeviceCount(&count); + if (count == 0) { + if (getenv("CUDA_VISIBLE_DEVICES")) { + LOGf << "No device found, please unset your " + "enviroment variable 'CUDA_VISIBLE_DEVICES'"; + } else + LOGf << "No device found"; + } + LOGi << "CUDA enabled."; + } else { + LOGv << "CUDA disabled."; + } +#else + CHECK(value==0) << "No CUDA found."; +#endif + if (use_cuda != value) + sync_all(0); + // jtorch will call this directly + use_cuda = value; +} + +void setter_device_id(int value) { +#if defined(HAS_CUDA) && defined(__linux__) + // case1: set env device_id, not restart + // case2: set in python, restart + // case3: restart, device id and CUDA env set both + if (value<0) + return; + int count=0; + cudaGetDeviceCount(&count); + auto s = getenv("CUDA_VISIBLE_DEVICES"); + auto s2 = getenv("device_id"); + auto sv = std::to_string(value); + if (s2 && s2 == sv && (!s || count!=1)) { + // only handle case1 and case3(not cuda) + LOGi << "change to device #" >> value; + cudaSetDevice(value); + return; + } + if (s && s == sv) + return; + setenv("CUDA_VISIBLE_DEVICES", sv.c_str(), 1); + setenv("device_id", sv.c_str(), 1); + std::ifstream ifs("/proc/self/cmdline"); + if (!(ifs && ifs.good())) return; + string cmd((std::istreambuf_iterator(ifs)), + (std::istreambuf_iterator())); + vector ss; + auto cstr = (char*)cmd.c_str(); + ss.push_back(cstr); + for (int i=0; i> value; + execvp(ss[0], &ss[0]); + ss.pop_back(); + LOGe << "restart failed" << ss; +#endif +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/misc/cuda_flags.h b/python/jittor/src/misc/cuda_flags.h new file mode 100644 index 00000000..4e06897d --- /dev/null +++ b/python/jittor/src/misc/cuda_flags.h @@ -0,0 +1,42 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + + +#ifdef HAS_CUDA +#include + +namespace jittor { + +DECLARE_FLAG(int, use_cuda); + +// @pyjt(get_device_count) +int get_device_count(); + +} // jittor + +#if defined(CUDART_VERSION) && CUDART_VERSION < 10000 + #define _cudaLaunchHostFunc(a,b,c) \ + cudaStreamAddCallback(a,b,c,0) + #define CUDA_HOST_FUNC_ARGS cudaStream_t stream, cudaError_t status, void* +#else + #define _cudaLaunchHostFunc(a,b,c) \ + cudaLaunchHostFunc(a,b,c) + #define CUDA_HOST_FUNC_ARGS void* +#endif + +#else + +namespace jittor { + +constexpr int use_cuda = 0; + +inline int get_device_count() { return 0; } + +} // jittor +#endif diff --git a/python/jittor/src/misc/cuda_limits.h b/python/jittor/src/misc/cuda_limits.h new file mode 100644 index 00000000..1e6c5849 --- /dev/null +++ b/python/jittor/src/misc/cuda_limits.h @@ -0,0 +1,47 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once + +#ifdef IS_CUDA +#include +#include +#else +#include +#define NPP_MIN_32U ( 0 ) +#define NPP_MAX_32U ( 4294967295U ) +#define NPP_MIN_32S (-2147483647 - 1 ) +#define NPP_MAX_32S ( 2147483647 ) +#define NPP_MIN_64U ( 0 ) +#define NPP_MAX_64U ( 18446744073709551615ULL ) +#define NPP_MIN_64S (-9223372036854775807LL - 1) +#define NPP_MAX_64S ( 9223372036854775807LL ) +#define CUDART_INF_F std::numeric_limits::infinity() +#define CUDART_INF std::numeric_limits::infinity() +#endif + + +template __device__ T numeric_min(); +template __device__ T numeric_max(); + +template<> __device__ __inline__ int numeric_max() { return NPP_MAX_32S; }; +template<> __device__ __inline__ int numeric_min() { return NPP_MIN_32S; }; + +template<> __device__ __inline__ unsigned int numeric_max() { return NPP_MAX_32U; }; +template<> __device__ __inline__ unsigned int numeric_min() { return NPP_MIN_32U; }; + +template<> __device__ __inline__ long long numeric_max() { return NPP_MAX_64S; }; +template<> __device__ __inline__ long long numeric_min() { return NPP_MIN_64S; }; + +template<> __device__ __inline__ unsigned long long numeric_max() { return NPP_MAX_64U; }; +template<> __device__ __inline__ unsigned long long numeric_min() { return NPP_MIN_64U; }; + + +template<> __device__ __inline__ float numeric_max() { return CUDART_INF_F; }; +template<> __device__ __inline__ float numeric_min() { return -CUDART_INF_F; }; + +template<> __device__ __inline__ double numeric_max() { return CUDART_INF; }; +template<> __device__ __inline__ double numeric_min() { return -CUDART_INF; }; \ No newline at end of file diff --git a/python/jittor/src/misc/deleter.h b/python/jittor/src/misc/deleter.h new file mode 100644 index 00000000..14e84b97 --- /dev/null +++ b/python/jittor/src/misc/deleter.h @@ -0,0 +1,20 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include + +namespace jittor { + +struct Deleter { + std::function del; + inline Deleter(std::function&& func) : del(move(func)) {} + inline Deleter() {} + inline ~Deleter() { if (del) del(); } +}; + +} // jittor diff --git a/python/jittor/src/misc/fast_shared_ptr.h b/python/jittor/src/misc/fast_shared_ptr.h new file mode 100644 index 00000000..b075ba9d --- /dev/null +++ b/python/jittor/src/misc/fast_shared_ptr.h @@ -0,0 +1,95 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +namespace jittor { + +template +struct fast_shared_ptr { + typedef T value_type; + pair* ptr; + inline fast_shared_ptr() { ptr = nullptr; } + inline fast_shared_ptr(std::nullptr_t) { ptr = nullptr; } + + inline fast_shared_ptr(T&& a) { + if (a.size()) { + ptr = new pair(1, move(a)); + } else { + ptr = nullptr; + } + } + + inline fast_shared_ptr(const fast_shared_ptr& other) { + ptr = other.ptr; + if (ptr) ptr->first++; + } + + inline ~fast_shared_ptr() { + if (ptr) { + ptr->first--; + if (!ptr->first) + delete ptr; + } + } + + inline void clear() { + this->~fast_shared_ptr(); + ptr = nullptr; + } + + inline fast_shared_ptr& operator=(std::nullptr_t) { + clear(); + return *this; + } + + inline fast_shared_ptr& operator=(T&& a) { + this->~fast_shared_ptr(); + new(this) fast_shared_ptr(move(a)); + return *this; + } + + inline fast_shared_ptr& operator=(const fast_shared_ptr& other) { + this->~fast_shared_ptr(); + new(this) fast_shared_ptr(other); + return *this; + } + + inline operator bool() const { return ptr; } + inline operator T() const { return ptr ? ptr->second : T(); } + inline T& data() const { return ptr->second; } + inline uint64 ref_cnt() const { return ptr ? ptr->first : 0; } +}; + +template +inline std::ostream& operator<<(std::ostream& os, const fast_shared_ptr& p) { + if (p) + return os << p.ptr->second; + return os << "null"; +} + + +template +inline std::istream& operator>>(std::istream& is, fast_shared_ptr& p) { + T a; + is >> a; + p = move(a); + return is; +} + + +template +struct Maybe { + typedef T value_type; + T* ptr; + inline Maybe() { ptr = nullptr; } + inline Maybe(std::nullptr_t) { ptr = nullptr; } + inline Maybe(T* ptr) : ptr(ptr) {} + inline operator bool() const { return ptr; } +}; + +} // jittor diff --git a/python/jittor/src/misc/hash.h b/python/jittor/src/misc/hash.h new file mode 100644 index 00000000..06b27aad --- /dev/null +++ b/python/jittor/src/misc/hash.h @@ -0,0 +1,40 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +namespace jittor { + +uint constexpr const_hash(const char *input) { + return *input ? + static_cast(*input) + 55 * const_hash(input + 1) : + 0; +} + +/* simple hash function */ +// @pyjt(hash) +inline uint hash(const char* input) { + uint v=0, mul=1; + while (*input) { + v += mul * (uint)*input; + mul *= 55; + input++; + } + return v; +} + + +inline uint64 hash64(const string& input) { + uint64 v=0, mul=1; + for (char c : input) { + v += mul * (uint64)c; + mul *= 257; + } + return v; +} + +} // jittor diff --git a/python/jittor/src/misc/intrin.h b/python/jittor/src/misc/intrin.h new file mode 100644 index 00000000..568ca750 --- /dev/null +++ b/python/jittor/src/misc/intrin.h @@ -0,0 +1,30 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +namespace jittor { + +static inline int lzcnt(int64 v) { + #ifdef __clang__ + #if __has_feature(__builtin_ia32_lzcnt_u64) + return __builtin_ia32_lzcnt_u64(v); + #else + return v ? __builtin_clzll(v) : 64; + #endif + #else + #ifdef _MSC_VER + unsigned long index; + _BitScanReverse64(&index, v); + return v ? 63-index : 64; + #else + return __builtin_clzll(v); + #endif + #endif +} + +} \ No newline at end of file diff --git a/python/jittor/src/misc/miniz.cc b/python/jittor/src/misc/miniz.cc new file mode 100755 index 00000000..8c98d1b8 --- /dev/null +++ b/python/jittor/src/misc/miniz.cc @@ -0,0 +1,7801 @@ +/************************************************************************** + * + * Copyright 2013-2014 RAD Game Tools and Valve Software + * Copyright 2010-2014 Rich Geldreich and Tenacious Software LLC + * All Rights Reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + **************************************************************************/ +#include +#include "misc/miniz.h" + +namespace jittor { + +typedef unsigned char mz_validate_uint16[sizeof(mz_uint16) == 2 ? 1 : -1]; +typedef unsigned char mz_validate_uint32[sizeof(mz_uint32) == 4 ? 1 : -1]; +typedef unsigned char mz_validate_uint64[sizeof(mz_uint64) == 8 ? 1 : -1]; + +/* ------------------- zlib-style API's */ + +mz_ulong mz_adler32(mz_ulong adler, const unsigned char *ptr, size_t buf_len) +{ + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-init-variables) + mz_uint32 i, s1 = (mz_uint32)(adler & 0xffff), s2 = (mz_uint32)(adler >> 16); + size_t block_len = buf_len % 5552; + if (!ptr) + return MZ_ADLER32_INIT; + while (buf_len) + { + for (i = 0; i + 7 < block_len; i += 8, ptr += 8) + { + s1 += ptr[0], s2 += s1; + s1 += ptr[1], s2 += s1; + s1 += ptr[2], s2 += s1; + s1 += ptr[3], s2 += s1; + s1 += ptr[4], s2 += s1; + s1 += ptr[5], s2 += s1; + s1 += ptr[6], s2 += s1; + s1 += ptr[7], s2 += s1; + } + for (; i < block_len; ++i) + s1 += *ptr++, s2 += s1; + s1 %= 65521U, s2 %= 65521U; + buf_len -= block_len; + block_len = 5552; + } + return (s2 << 16) + s1; +} + +/* Karl Malbrain's compact CRC-32. See "A compact CCITT crc16 and crc32 C implementation that balances processor cache usage against speed": http://www.geocities.com/malbrain/ */ +#if 0 + mz_ulong mz_crc32(mz_ulong crc, const mz_uint8 *ptr, size_t buf_len) + { + static const mz_uint32 s_crc32[16] = { 0, 0x1db71064, 0x3b6e20c8, 0x26d930ac, 0x76dc4190, 0x6b6b51f4, 0x4db26158, 0x5005713c, + 0xedb88320, 0xf00f9344, 0xd6d6a3e8, 0xcb61b38c, 0x9b64c2b0, 0x86d3d2d4, 0xa00ae278, 0xbdbdf21c }; + mz_uint32 crcu32 = (mz_uint32)crc; + if (!ptr) + return MZ_CRC32_INIT; + crcu32 = ~crcu32; + while (buf_len--) + { + mz_uint8 b = *ptr++; + crcu32 = (crcu32 >> 4) ^ s_crc32[(crcu32 & 0xF) ^ (b & 0xF)]; + crcu32 = (crcu32 >> 4) ^ s_crc32[(crcu32 & 0xF) ^ (b >> 4)]; + } + return ~crcu32; + } +#elif defined(USE_EXTERNAL_MZCRC) +/* If USE_EXTERNAL_CRC is defined, an external module will export the + * mz_crc32() symbol for us to use, e.g. an SSE-accelerated version. + * Depending on the impl, it may be necessary to ~ the input/output crc values. + */ +mz_ulong mz_crc32(mz_ulong crc, const mz_uint8 *ptr, size_t buf_len); +#else +/* Faster, but larger CPU cache footprint. + */ +mz_ulong mz_crc32(mz_ulong crc, const mz_uint8 *ptr, size_t buf_len) +{ + static const mz_uint32 s_crc_table[256] = + { + 0x00000000, 0x77073096, 0xEE0E612C, 0x990951BA, 0x076DC419, 0x706AF48F, 0xE963A535, + 0x9E6495A3, 0x0EDB8832, 0x79DCB8A4, 0xE0D5E91E, 0x97D2D988, 0x09B64C2B, 0x7EB17CBD, + 0xE7B82D07, 0x90BF1D91, 0x1DB71064, 0x6AB020F2, 0xF3B97148, 0x84BE41DE, 0x1ADAD47D, + 0x6DDDE4EB, 0xF4D4B551, 0x83D385C7, 0x136C9856, 0x646BA8C0, 0xFD62F97A, 0x8A65C9EC, + 0x14015C4F, 0x63066CD9, 0xFA0F3D63, 0x8D080DF5, 0x3B6E20C8, 0x4C69105E, 0xD56041E4, + 0xA2677172, 0x3C03E4D1, 0x4B04D447, 0xD20D85FD, 0xA50AB56B, 0x35B5A8FA, 0x42B2986C, + 0xDBBBC9D6, 0xACBCF940, 0x32D86CE3, 0x45DF5C75, 0xDCD60DCF, 0xABD13D59, 0x26D930AC, + 0x51DE003A, 0xC8D75180, 0xBFD06116, 0x21B4F4B5, 0x56B3C423, 0xCFBA9599, 0xB8BDA50F, + 0x2802B89E, 0x5F058808, 0xC60CD9B2, 0xB10BE924, 0x2F6F7C87, 0x58684C11, 0xC1611DAB, + 0xB6662D3D, 0x76DC4190, 0x01DB7106, 0x98D220BC, 0xEFD5102A, 0x71B18589, 0x06B6B51F, + 0x9FBFE4A5, 0xE8B8D433, 0x7807C9A2, 0x0F00F934, 0x9609A88E, 0xE10E9818, 0x7F6A0DBB, + 0x086D3D2D, 0x91646C97, 0xE6635C01, 0x6B6B51F4, 0x1C6C6162, 0x856530D8, 0xF262004E, + 0x6C0695ED, 0x1B01A57B, 0x8208F4C1, 0xF50FC457, 0x65B0D9C6, 0x12B7E950, 0x8BBEB8EA, + 0xFCB9887C, 0x62DD1DDF, 0x15DA2D49, 0x8CD37CF3, 0xFBD44C65, 0x4DB26158, 0x3AB551CE, + 0xA3BC0074, 0xD4BB30E2, 0x4ADFA541, 0x3DD895D7, 0xA4D1C46D, 0xD3D6F4FB, 0x4369E96A, + 0x346ED9FC, 0xAD678846, 0xDA60B8D0, 0x44042D73, 0x33031DE5, 0xAA0A4C5F, 0xDD0D7CC9, + 0x5005713C, 0x270241AA, 0xBE0B1010, 0xC90C2086, 0x5768B525, 0x206F85B3, 0xB966D409, + 0xCE61E49F, 0x5EDEF90E, 0x29D9C998, 0xB0D09822, 0xC7D7A8B4, 0x59B33D17, 0x2EB40D81, + 0xB7BD5C3B, 0xC0BA6CAD, 0xEDB88320, 0x9ABFB3B6, 0x03B6E20C, 0x74B1D29A, 0xEAD54739, + 0x9DD277AF, 0x04DB2615, 0x73DC1683, 0xE3630B12, 0x94643B84, 0x0D6D6A3E, 0x7A6A5AA8, + 0xE40ECF0B, 0x9309FF9D, 0x0A00AE27, 0x7D079EB1, 0xF00F9344, 0x8708A3D2, 0x1E01F268, + 0x6906C2FE, 0xF762575D, 0x806567CB, 0x196C3671, 0x6E6B06E7, 0xFED41B76, 0x89D32BE0, + 0x10DA7A5A, 0x67DD4ACC, 0xF9B9DF6F, 0x8EBEEFF9, 0x17B7BE43, 0x60B08ED5, 0xD6D6A3E8, + 0xA1D1937E, 0x38D8C2C4, 0x4FDFF252, 0xD1BB67F1, 0xA6BC5767, 0x3FB506DD, 0x48B2364B, + 0xD80D2BDA, 0xAF0A1B4C, 0x36034AF6, 0x41047A60, 0xDF60EFC3, 0xA867DF55, 0x316E8EEF, + 0x4669BE79, 0xCB61B38C, 0xBC66831A, 0x256FD2A0, 0x5268E236, 0xCC0C7795, 0xBB0B4703, + 0x220216B9, 0x5505262F, 0xC5BA3BBE, 0xB2BD0B28, 0x2BB45A92, 0x5CB36A04, 0xC2D7FFA7, + 0xB5D0CF31, 0x2CD99E8B, 0x5BDEAE1D, 0x9B64C2B0, 0xEC63F226, 0x756AA39C, 0x026D930A, + 0x9C0906A9, 0xEB0E363F, 0x72076785, 0x05005713, 0x95BF4A82, 0xE2B87A14, 0x7BB12BAE, + 0x0CB61B38, 0x92D28E9B, 0xE5D5BE0D, 0x7CDCEFB7, 0x0BDBDF21, 0x86D3D2D4, 0xF1D4E242, + 0x68DDB3F8, 0x1FDA836E, 0x81BE16CD, 0xF6B9265B, 0x6FB077E1, 0x18B74777, 0x88085AE6, + 0xFF0F6A70, 0x66063BCA, 0x11010B5C, 0x8F659EFF, 0xF862AE69, 0x616BFFD3, 0x166CCF45, + 0xA00AE278, 0xD70DD2EE, 0x4E048354, 0x3903B3C2, 0xA7672661, 0xD06016F7, 0x4969474D, + 0x3E6E77DB, 0xAED16A4A, 0xD9D65ADC, 0x40DF0B66, 0x37D83BF0, 0xA9BCAE53, 0xDEBB9EC5, + 0x47B2CF7F, 0x30B5FFE9, 0xBDBDF21C, 0xCABAC28A, 0x53B39330, 0x24B4A3A6, 0xBAD03605, + 0xCDD70693, 0x54DE5729, 0x23D967BF, 0xB3667A2E, 0xC4614AB8, 0x5D681B02, 0x2A6F2B94, + 0xB40BBE37, 0xC30C8EA1, 0x5A05DF1B, 0x2D02EF8D + }; + + mz_uint32 crc32 = (mz_uint32)crc ^ 0xFFFFFFFF; + const mz_uint8 *pByte_buf = (const mz_uint8 *)ptr; + + while (buf_len >= 4) + { + crc32 = (crc32 >> 8) ^ s_crc_table[(crc32 ^ pByte_buf[0]) & 0xFF]; + crc32 = (crc32 >> 8) ^ s_crc_table[(crc32 ^ pByte_buf[1]) & 0xFF]; + crc32 = (crc32 >> 8) ^ s_crc_table[(crc32 ^ pByte_buf[2]) & 0xFF]; + crc32 = (crc32 >> 8) ^ s_crc_table[(crc32 ^ pByte_buf[3]) & 0xFF]; + pByte_buf += 4; + buf_len -= 4; + } + + while (buf_len) + { + crc32 = (crc32 >> 8) ^ s_crc_table[(crc32 ^ pByte_buf[0]) & 0xFF]; + ++pByte_buf; + --buf_len; + } + + return ~crc32; +} +#endif + +void mz_free(void *p) +{ + MZ_FREE(p); +} + +void *miniz_def_alloc_func(void *opaque, size_t items, size_t size) +{ + (void)opaque, (void)items, (void)size; + return MZ_MALLOC(items * size); +} +void miniz_def_free_func(void *opaque, void *address) +{ + (void)opaque, (void)address; + MZ_FREE(address); +} +void *miniz_def_realloc_func(void *opaque, void *address, size_t items, size_t size) +{ + (void)opaque, (void)address, (void)items, (void)size; + return MZ_REALLOC(address, items * size); +} + +const char *mz_version(void) +{ + return MZ_VERSION; +} + +#ifndef MINIZ_NO_ZLIB_APIS + +int mz_deflateInit(mz_streamp pStream, int level) +{ + return mz_deflateInit2(pStream, level, MZ_DEFLATED, MZ_DEFAULT_WINDOW_BITS, 9, MZ_DEFAULT_STRATEGY); +} + +int mz_deflateInit2(mz_streamp pStream, int level, int method, int window_bits, int mem_level, int strategy) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + tdefl_compressor *pComp; + mz_uint comp_flags = TDEFL_COMPUTE_ADLER32 | tdefl_create_comp_flags_from_zip_params(level, window_bits, strategy); + + if (!pStream) + return MZ_STREAM_ERROR; + if ((method != MZ_DEFLATED) || ((mem_level < 1) || (mem_level > 9)) || ((window_bits != MZ_DEFAULT_WINDOW_BITS) && (-window_bits != MZ_DEFAULT_WINDOW_BITS))) + return MZ_PARAM_ERROR; + + pStream->data_type = 0; + pStream->adler = MZ_ADLER32_INIT; + pStream->msg = NULL; + pStream->reserved = 0; + pStream->total_in = 0; + pStream->total_out = 0; + if (!pStream->zalloc) + pStream->zalloc = miniz_def_alloc_func; + if (!pStream->zfree) + pStream->zfree = miniz_def_free_func; + + pComp = (tdefl_compressor *)pStream->zalloc(pStream->opaque, 1, sizeof(tdefl_compressor)); + if (!pComp) + return MZ_MEM_ERROR; + + pStream->state = (struct mz_internal_state *)pComp; + + if (tdefl_init(pComp, NULL, NULL, comp_flags) != TDEFL_STATUS_OKAY) + { + mz_deflateEnd(pStream); + return MZ_PARAM_ERROR; + } + + return MZ_OK; +} + +int mz_deflateReset(mz_streamp pStream) +{ + if ((!pStream) || (!pStream->state) || (!pStream->zalloc) || (!pStream->zfree)) + return MZ_STREAM_ERROR; + pStream->total_in = pStream->total_out = 0; + tdefl_init((tdefl_compressor *)pStream->state, NULL, NULL, ((tdefl_compressor *)pStream->state)->m_flags); + return MZ_OK; +} + +int mz_deflate(mz_streamp pStream, int flush) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + size_t in_bytes, out_bytes; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_ulong orig_total_in, orig_total_out; + int mz_status = MZ_OK; + + if ((!pStream) || (!pStream->state) || (flush < 0) || (flush > MZ_FINISH) || (!pStream->next_out)) + return MZ_STREAM_ERROR; + if (!pStream->avail_out) + return MZ_BUF_ERROR; + + if (flush == MZ_PARTIAL_FLUSH) + flush = MZ_SYNC_FLUSH; + + if (((tdefl_compressor *)pStream->state)->m_prev_return_status == TDEFL_STATUS_DONE) + return (flush == MZ_FINISH) ? MZ_STREAM_END : MZ_BUF_ERROR; + + orig_total_in = pStream->total_in; + orig_total_out = pStream->total_out; + for (;;) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + tdefl_status defl_status; + in_bytes = pStream->avail_in; + out_bytes = pStream->avail_out; + + defl_status = tdefl_compress((tdefl_compressor *)pStream->state, pStream->next_in, &in_bytes, pStream->next_out, &out_bytes, (tdefl_flush)flush); + pStream->next_in += (mz_uint)in_bytes; + pStream->avail_in -= (mz_uint)in_bytes; + pStream->total_in += (mz_uint)in_bytes; + pStream->adler = tdefl_get_adler32((tdefl_compressor *)pStream->state); + + pStream->next_out += (mz_uint)out_bytes; + pStream->avail_out -= (mz_uint)out_bytes; + pStream->total_out += (mz_uint)out_bytes; + + if (defl_status < 0) + { + mz_status = MZ_STREAM_ERROR; + break; + } + else if (defl_status == TDEFL_STATUS_DONE) + { + mz_status = MZ_STREAM_END; + break; + } + else if (!pStream->avail_out) + break; + else if ((!pStream->avail_in) && (flush != MZ_FINISH)) + { + if ((flush) || (pStream->total_in != orig_total_in) || (pStream->total_out != orig_total_out)) + break; + return MZ_BUF_ERROR; /* Can't make forward progress without some input. + */ + } + } + return mz_status; +} + +int mz_deflateEnd(mz_streamp pStream) +{ + if (!pStream) + return MZ_STREAM_ERROR; + if (pStream->state) + { + pStream->zfree(pStream->opaque, pStream->state); + pStream->state = NULL; + } + return MZ_OK; +} + +mz_ulong mz_deflateBound(mz_streamp pStream, mz_ulong source_len) +{ + (void)pStream; + /* This is really over conservative. (And lame, but it's actually pretty tricky to compute a true upper bound given the way tdefl's blocking works.) */ + return MZ_MAX(128 + (source_len * 110) / 100, 128 + source_len + ((source_len / (31 * 1024)) + 1) * 5); +} + +int mz_compress2(unsigned char *pDest, mz_ulong *pDest_len, const unsigned char *pSource, mz_ulong source_len, int level) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int status; + mz_stream stream; + memset(&stream, 0, sizeof(stream)); + + /* In case mz_ulong is 64-bits (argh I hate longs). */ + if ((source_len | *pDest_len) > 0xFFFFFFFFU) + return MZ_PARAM_ERROR; + + stream.next_in = pSource; + stream.avail_in = (mz_uint32)source_len; + stream.next_out = pDest; + stream.avail_out = (mz_uint32)*pDest_len; + + status = mz_deflateInit(&stream, level); + if (status != MZ_OK) + return status; + + status = mz_deflate(&stream, MZ_FINISH); + if (status != MZ_STREAM_END) + { + mz_deflateEnd(&stream); + return (status == MZ_OK) ? MZ_BUF_ERROR : status; + } + + *pDest_len = stream.total_out; + return mz_deflateEnd(&stream); +} + +int mz_compress(unsigned char *pDest, mz_ulong *pDest_len, const unsigned char *pSource, mz_ulong source_len) +{ + return mz_compress2(pDest, pDest_len, pSource, source_len, MZ_DEFAULT_COMPRESSION); +} + +mz_ulong mz_compressBound(mz_ulong source_len) +{ + return mz_deflateBound(NULL, source_len); +} + +typedef struct +{ + tinfl_decompressor m_decomp; + mz_uint m_dict_ofs, m_dict_avail, m_first_call, m_has_flushed; + int m_window_bits; + mz_uint8 m_dict[TINFL_LZ_DICT_SIZE]; + tinfl_status m_last_status; +} inflate_state; + +int mz_inflateInit2(mz_streamp pStream, int window_bits) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + inflate_state *pDecomp; + if (!pStream) + return MZ_STREAM_ERROR; + if ((window_bits != MZ_DEFAULT_WINDOW_BITS) && (-window_bits != MZ_DEFAULT_WINDOW_BITS)) + return MZ_PARAM_ERROR; + + pStream->data_type = 0; + pStream->adler = 0; + pStream->msg = NULL; + pStream->total_in = 0; + pStream->total_out = 0; + pStream->reserved = 0; + if (!pStream->zalloc) + pStream->zalloc = miniz_def_alloc_func; + if (!pStream->zfree) + pStream->zfree = miniz_def_free_func; + + pDecomp = (inflate_state *)pStream->zalloc(pStream->opaque, 1, sizeof(inflate_state)); + if (!pDecomp) + return MZ_MEM_ERROR; + + pStream->state = (struct mz_internal_state *)pDecomp; + + tinfl_init(&pDecomp->m_decomp); + pDecomp->m_dict_ofs = 0; + pDecomp->m_dict_avail = 0; + pDecomp->m_last_status = TINFL_STATUS_NEEDS_MORE_INPUT; + pDecomp->m_first_call = 1; + pDecomp->m_has_flushed = 0; + pDecomp->m_window_bits = window_bits; + + return MZ_OK; +} + +int mz_inflateInit(mz_streamp pStream) +{ + return mz_inflateInit2(pStream, MZ_DEFAULT_WINDOW_BITS); +} + +int mz_inflateReset(mz_streamp pStream) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + inflate_state *pDecomp; + if (!pStream) + return MZ_STREAM_ERROR; + + pStream->data_type = 0; + pStream->adler = 0; + pStream->msg = NULL; + pStream->total_in = 0; + pStream->total_out = 0; + pStream->reserved = 0; + + pDecomp = (inflate_state *)pStream->state; + + tinfl_init(&pDecomp->m_decomp); + pDecomp->m_dict_ofs = 0; + pDecomp->m_dict_avail = 0; + pDecomp->m_last_status = TINFL_STATUS_NEEDS_MORE_INPUT; + pDecomp->m_first_call = 1; + pDecomp->m_has_flushed = 0; + /* pDecomp->m_window_bits = window_bits */; + + return MZ_OK; +} + +int mz_inflate(mz_streamp pStream, int flush) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + inflate_state *pState; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint n, first_call, decomp_flags = TINFL_FLAG_COMPUTE_ADLER32; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + size_t in_bytes, out_bytes, orig_avail_in; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + tinfl_status status; + + if ((!pStream) || (!pStream->state)) + return MZ_STREAM_ERROR; + if (flush == MZ_PARTIAL_FLUSH) + flush = MZ_SYNC_FLUSH; + if ((flush) && (flush != MZ_SYNC_FLUSH) && (flush != MZ_FINISH)) + return MZ_STREAM_ERROR; + + pState = (inflate_state *)pStream->state; + if (pState->m_window_bits > 0) + decomp_flags |= TINFL_FLAG_PARSE_ZLIB_HEADER; + orig_avail_in = pStream->avail_in; + + first_call = pState->m_first_call; + pState->m_first_call = 0; + if (pState->m_last_status < 0) + return MZ_DATA_ERROR; + + if (pState->m_has_flushed && (flush != MZ_FINISH)) + return MZ_STREAM_ERROR; + pState->m_has_flushed |= (flush == MZ_FINISH); + + if ((flush == MZ_FINISH) && (first_call)) + { + /* MZ_FINISH on the first call implies that the input and output buffers are large enough to hold the entire compressed/decompressed file. */ + decomp_flags |= TINFL_FLAG_USING_NON_WRAPPING_OUTPUT_BUF; + in_bytes = pStream->avail_in; + out_bytes = pStream->avail_out; + status = tinfl_decompress(&pState->m_decomp, pStream->next_in, &in_bytes, pStream->next_out, pStream->next_out, &out_bytes, decomp_flags); + pState->m_last_status = status; + pStream->next_in += (mz_uint)in_bytes; + pStream->avail_in -= (mz_uint)in_bytes; + pStream->total_in += (mz_uint)in_bytes; + pStream->adler = tinfl_get_adler32(&pState->m_decomp); + pStream->next_out += (mz_uint)out_bytes; + pStream->avail_out -= (mz_uint)out_bytes; + pStream->total_out += (mz_uint)out_bytes; + + if (status < 0) + return MZ_DATA_ERROR; + else if (status != TINFL_STATUS_DONE) + { + pState->m_last_status = TINFL_STATUS_FAILED; + return MZ_BUF_ERROR; + } + return MZ_STREAM_END; + } + /* flush != MZ_FINISH then we must assume there's more input. */ + if (flush != MZ_FINISH) + decomp_flags |= TINFL_FLAG_HAS_MORE_INPUT; + + if (pState->m_dict_avail) + { + n = MZ_MIN(pState->m_dict_avail, pStream->avail_out); + memcpy(pStream->next_out, pState->m_dict + pState->m_dict_ofs, n); + pStream->next_out += n; + pStream->avail_out -= n; + pStream->total_out += n; + pState->m_dict_avail -= n; + pState->m_dict_ofs = (pState->m_dict_ofs + n) & (TINFL_LZ_DICT_SIZE - 1); + return ((pState->m_last_status == TINFL_STATUS_DONE) && (!pState->m_dict_avail)) ? MZ_STREAM_END : MZ_OK; + } + + for (;;) + { + in_bytes = pStream->avail_in; + out_bytes = TINFL_LZ_DICT_SIZE - pState->m_dict_ofs; + + status = tinfl_decompress(&pState->m_decomp, pStream->next_in, &in_bytes, pState->m_dict, pState->m_dict + pState->m_dict_ofs, &out_bytes, decomp_flags); + pState->m_last_status = status; + + pStream->next_in += (mz_uint)in_bytes; + pStream->avail_in -= (mz_uint)in_bytes; + pStream->total_in += (mz_uint)in_bytes; + pStream->adler = tinfl_get_adler32(&pState->m_decomp); + + pState->m_dict_avail = (mz_uint)out_bytes; + + n = MZ_MIN(pState->m_dict_avail, pStream->avail_out); + memcpy(pStream->next_out, pState->m_dict + pState->m_dict_ofs, n); + pStream->next_out += n; + pStream->avail_out -= n; + pStream->total_out += n; + pState->m_dict_avail -= n; + pState->m_dict_ofs = (pState->m_dict_ofs + n) & (TINFL_LZ_DICT_SIZE - 1); + + if (status < 0) + return MZ_DATA_ERROR; /* Stream is corrupted (there could be some uncompressed data left in the output dictionary - oh well). */ + else if ((status == TINFL_STATUS_NEEDS_MORE_INPUT) && (!orig_avail_in)) + return MZ_BUF_ERROR; /* Signal caller that we can't make forward progress without supplying more input or by setting flush to MZ_FINISH. */ + else if (flush == MZ_FINISH) + { + /* The output buffer MUST be large to hold the remaining uncompressed data when flush==MZ_FINISH. */ + if (status == TINFL_STATUS_DONE) + return pState->m_dict_avail ? MZ_BUF_ERROR : MZ_STREAM_END; + /* status here must be TINFL_STATUS_HAS_MORE_OUTPUT, which means there's at least 1 more byte on the way. If there's no more room left in the output buffer then something is wrong. */ + else if (!pStream->avail_out) + return MZ_BUF_ERROR; + } + else if ((status == TINFL_STATUS_DONE) || (!pStream->avail_in) || (!pStream->avail_out) || (pState->m_dict_avail)) + break; + } + + return ((status == TINFL_STATUS_DONE) && (!pState->m_dict_avail)) ? MZ_STREAM_END : MZ_OK; +} + +int mz_inflateEnd(mz_streamp pStream) +{ + if (!pStream) + return MZ_STREAM_ERROR; + if (pStream->state) + { + pStream->zfree(pStream->opaque, pStream->state); + pStream->state = NULL; + } + return MZ_OK; +} + +int mz_uncompress(unsigned char *pDest, mz_ulong *pDest_len, const unsigned char *pSource, mz_ulong source_len) +{ + mz_stream stream; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int status; + memset(&stream, 0, sizeof(stream)); + + /* In case mz_ulong is 64-bits (argh I hate longs). */ + if ((source_len | *pDest_len) > 0xFFFFFFFFU) + return MZ_PARAM_ERROR; + + stream.next_in = pSource; + stream.avail_in = (mz_uint32)source_len; + stream.next_out = pDest; + stream.avail_out = (mz_uint32)*pDest_len; + + status = mz_inflateInit(&stream); + if (status != MZ_OK) + return status; + + status = mz_inflate(&stream, MZ_FINISH); + if (status != MZ_STREAM_END) + { + mz_inflateEnd(&stream); + return ((status == MZ_BUF_ERROR) && (!stream.avail_in)) ? MZ_DATA_ERROR : status; + } + *pDest_len = stream.total_out; + + return mz_inflateEnd(&stream); +} + +const char *mz_error(int err) +{ + static struct + { + int m_err; + const char *m_pDesc; + } s_error_descs[] = + { + { MZ_OK, "" }, { MZ_STREAM_END, "stream end" }, { MZ_NEED_DICT, "need dictionary" }, { MZ_ERRNO, "file error" }, { MZ_STREAM_ERROR, "stream error" }, { MZ_DATA_ERROR, "data error" }, { MZ_MEM_ERROR, "out of memory" }, { MZ_BUF_ERROR, "buf error" }, { MZ_VERSION_ERROR, "version error" }, { MZ_PARAM_ERROR, "parameter error" } + }; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint i; + for (i = 0; i < sizeof(s_error_descs) / sizeof(s_error_descs[0]); ++i) + if (s_error_descs[i].m_err == err) + return s_error_descs[i].m_pDesc; + return NULL; +} + +#endif /*MINIZ_NO_ZLIB_APIS */ + + +/* + This is free and unencumbered software released into the public domain. + + Anyone is free to copy, modify, publish, use, compile, sell, or + distribute this software, either in source code form or as a compiled + binary, for any purpose, commercial or non-commercial, and by any + means. + + In jurisdictions that recognize copyright laws, the author or authors + of this software dedicate any and all copyright interest in the + software to the public domain. We make this dedication for the benefit + of the public at large and to the detriment of our heirs and + successors. We intend this dedication to be an overt act of + relinquishment in perpetuity of all present and future rights to this + software under copyright law. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR + OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, + ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + OTHER DEALINGS IN THE SOFTWARE. + + For more information, please refer to +*/ +/************************************************************************** + * + * Copyright 2013-2014 RAD Game Tools and Valve Software + * Copyright 2010-2014 Rich Geldreich and Tenacious Software LLC + * All Rights Reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + **************************************************************************/ + + + +/* ------------------- Low-level Compression (independent from all decompression API's) */ + +/* Purposely making these tables static for faster init and thread safety. */ +static const mz_uint16 s_tdefl_len_sym[256] = + { + 257, 258, 259, 260, 261, 262, 263, 264, 265, 265, 266, 266, 267, 267, 268, 268, 269, 269, 269, 269, 270, 270, 270, 270, 271, 271, 271, 271, 272, 272, 272, 272, + 273, 273, 273, 273, 273, 273, 273, 273, 274, 274, 274, 274, 274, 274, 274, 274, 275, 275, 275, 275, 275, 275, 275, 275, 276, 276, 276, 276, 276, 276, 276, 276, + 277, 277, 277, 277, 277, 277, 277, 277, 277, 277, 277, 277, 277, 277, 277, 277, 278, 278, 278, 278, 278, 278, 278, 278, 278, 278, 278, 278, 278, 278, 278, 278, + 279, 279, 279, 279, 279, 279, 279, 279, 279, 279, 279, 279, 279, 279, 279, 279, 280, 280, 280, 280, 280, 280, 280, 280, 280, 280, 280, 280, 280, 280, 280, 280, + 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, 281, + 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, 282, + 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, 283, + 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 284, 285 + }; + +static const mz_uint8 s_tdefl_len_extra[256] = + { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 0 + }; + +static const mz_uint8 s_tdefl_small_dist_sym[512] = + { + 0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, + 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, + 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, + 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, + 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, + 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, + 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, + 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, + 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17, 17 + }; + +static const mz_uint8 s_tdefl_small_dist_extra[512] = + { + 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7 + }; + +static const mz_uint8 s_tdefl_large_dist_sym[128] = + { + 0, 0, 18, 19, 20, 20, 21, 21, 22, 22, 22, 22, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 25, 25, 25, 25, 25, 25, 25, 25, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, + 26, 26, 26, 26, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, + 28, 28, 28, 28, 28, 28, 28, 28, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29 + }; + +static const mz_uint8 s_tdefl_large_dist_extra[128] = + { + 0, 0, 8, 8, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, + 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, + 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13 + }; + +/* Radix sorts tdefl_sym_freq[] array by 16-bit key m_key. Returns ptr to sorted values. */ +typedef struct +{ + mz_uint16 m_key, m_sym_index; +} tdefl_sym_freq; +static tdefl_sym_freq *tdefl_radix_sort_syms(mz_uint num_syms, tdefl_sym_freq *pSyms0, tdefl_sym_freq *pSyms1) +{ + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-init-variables) + mz_uint32 total_passes = 2, pass_shift, pass, i, hist[256 * 2]; + tdefl_sym_freq *pCur_syms = pSyms0, *pNew_syms = pSyms1; + MZ_CLEAR_OBJ(hist); + for (i = 0; i < num_syms; i++) + { + mz_uint freq = pSyms0[i].m_key; + hist[freq & 0xFF]++; + hist[256 + ((freq >> 8) & 0xFF)]++; + } + while ((total_passes > 1) && (num_syms == hist[(total_passes - 1) * 256])) + total_passes--; + for (pass_shift = 0, pass = 0; pass < total_passes; pass++, pass_shift += 8) + { + const mz_uint32 *pHist = &hist[pass << 8]; + mz_uint offsets[256], cur_ofs = 0; + for (i = 0; i < 256; i++) + { + offsets[i] = cur_ofs; + cur_ofs += pHist[i]; + } + for (i = 0; i < num_syms; i++) + pNew_syms[offsets[(pCur_syms[i].m_key >> pass_shift) & 0xFF]++] = pCur_syms[i]; + { + tdefl_sym_freq *t = pCur_syms; + pCur_syms = pNew_syms; + pNew_syms = t; + } + } + return pCur_syms; +} + +/* tdefl_calculate_minimum_redundancy() originally written by: Alistair Moffat, alistair@cs.mu.oz.au, Jyrki Katajainen, jyrki@diku.dk, November 1996. */ +static void tdefl_calculate_minimum_redundancy(tdefl_sym_freq *A, int n) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int root, leaf, next, avbl, used, dpth; + if (n == 0) + return; + else if (n == 1) + { + A[0].m_key = 1; + return; + } + A[0].m_key += A[1].m_key; + root = 0; + leaf = 2; + for (next = 1; next < n - 1; next++) + { + if (leaf >= n || A[root].m_key < A[leaf].m_key) + { + A[next].m_key = A[root].m_key; + A[root++].m_key = (mz_uint16)next; + } + else + A[next].m_key = A[leaf++].m_key; + if (leaf >= n || (root < next && A[root].m_key < A[leaf].m_key)) + { + A[next].m_key = (mz_uint16)(A[next].m_key + A[root].m_key); + A[root++].m_key = (mz_uint16)next; + } + else + A[next].m_key = (mz_uint16)(A[next].m_key + A[leaf++].m_key); + } + A[n - 2].m_key = 0; + for (next = n - 3; next >= 0; next--) + A[next].m_key = A[A[next].m_key].m_key + 1; + avbl = 1; + used = dpth = 0; + root = n - 2; + next = n - 1; + while (avbl > 0) + { + while (root >= 0 && (int)A[root].m_key == dpth) + { + used++; + root--; + } + while (avbl > used) + { + A[next--].m_key = (mz_uint16)(dpth); + avbl--; + } + avbl = 2 * used; + dpth++; + used = 0; + } +} + +/* Limits canonical Huffman code table's max code size. */ +enum +{ + TDEFL_MAX_SUPPORTED_HUFF_CODESIZE = 32 +}; +static void tdefl_huffman_enforce_max_code_size(int *pNum_codes, int code_list_len, int max_code_size) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int i; + mz_uint32 total = 0; + if (code_list_len <= 1) + return; + for (i = max_code_size + 1; i <= TDEFL_MAX_SUPPORTED_HUFF_CODESIZE; i++) + pNum_codes[max_code_size] += pNum_codes[i]; + for (i = max_code_size; i > 0; i--) + total += (((mz_uint32)pNum_codes[i]) << (max_code_size - i)); + while (total != (1UL << max_code_size)) + { + pNum_codes[max_code_size]--; + for (i = max_code_size - 1; i > 0; i--) + if (pNum_codes[i]) + { + pNum_codes[i]--; + pNum_codes[i + 1] += 2; + break; + } + total--; + } +} + +static void tdefl_optimize_huffman_table(tdefl_compressor *d, int table_num, int table_len, int code_size_limit, int static_table) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int i, j, l, num_codes[1 + TDEFL_MAX_SUPPORTED_HUFF_CODESIZE]; + mz_uint next_code[TDEFL_MAX_SUPPORTED_HUFF_CODESIZE + 1]; + MZ_CLEAR_OBJ(num_codes); + if (static_table) + { + for (i = 0; i < table_len; i++) + num_codes[d->m_huff_code_sizes[table_num][i]]++; + } + else + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + tdefl_sym_freq syms0[TDEFL_MAX_HUFF_SYMBOLS], syms1[TDEFL_MAX_HUFF_SYMBOLS], *pSyms; + int num_used_syms = 0; + const mz_uint16 *pSym_count = &d->m_huff_count[table_num][0]; + for (i = 0; i < table_len; i++) + if (pSym_count[i]) + { + syms0[num_used_syms].m_key = (mz_uint16)pSym_count[i]; + syms0[num_used_syms++].m_sym_index = (mz_uint16)i; + } + + pSyms = tdefl_radix_sort_syms(num_used_syms, syms0, syms1); + tdefl_calculate_minimum_redundancy(pSyms, num_used_syms); + + for (i = 0; i < num_used_syms; i++) + num_codes[pSyms[i].m_key]++; + + tdefl_huffman_enforce_max_code_size(num_codes, num_used_syms, code_size_limit); + + MZ_CLEAR_OBJ(d->m_huff_code_sizes[table_num]); + MZ_CLEAR_OBJ(d->m_huff_codes[table_num]); + for (i = 1, j = num_used_syms; i <= code_size_limit; i++) + for (l = num_codes[i]; l > 0; l--) + d->m_huff_code_sizes[table_num][pSyms[--j].m_sym_index] = (mz_uint8)(i); + } + + next_code[1] = 0; + for (j = 0, i = 2; i <= code_size_limit; i++) + next_code[i] = j = ((j + num_codes[i - 1]) << 1); + + for (i = 0; i < table_len; i++) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint rev_code = 0, code, code_size; + if ((code_size = d->m_huff_code_sizes[table_num][i]) == 0) + continue; + code = next_code[code_size]++; + for (l = code_size; l > 0; l--, code >>= 1) + rev_code = (rev_code << 1) | (code & 1); + d->m_huff_codes[table_num][i] = (mz_uint16)rev_code; + } +} + +#define TDEFL_PUT_BITS(b, l) \ + do \ + { \ + mz_uint bits = b; \ + mz_uint len = l; \ + MZ_ASSERT(bits <= ((1U << len) - 1U)); \ + d->m_bit_buffer |= (bits << d->m_bits_in); \ + d->m_bits_in += len; \ + while (d->m_bits_in >= 8) \ + { \ + if (d->m_pOutput_buf < d->m_pOutput_buf_end) \ + *d->m_pOutput_buf++ = (mz_uint8)(d->m_bit_buffer); \ + d->m_bit_buffer >>= 8; \ + d->m_bits_in -= 8; \ + } \ + } \ + MZ_MACRO_END + +#define TDEFL_RLE_PREV_CODE_SIZE() \ + { \ + if (rle_repeat_count) \ + { \ + if (rle_repeat_count < 3) \ + { \ + d->m_huff_count[2][prev_code_size] = (mz_uint16)(d->m_huff_count[2][prev_code_size] + rle_repeat_count); \ + while (rle_repeat_count--) \ + packed_code_sizes[num_packed_code_sizes++] = prev_code_size; \ + } \ + else \ + { \ + d->m_huff_count[2][16] = (mz_uint16)(d->m_huff_count[2][16] + 1); \ + packed_code_sizes[num_packed_code_sizes++] = 16; \ + packed_code_sizes[num_packed_code_sizes++] = (mz_uint8)(rle_repeat_count - 3); \ + } \ + rle_repeat_count = 0; \ + } \ + } + +#define TDEFL_RLE_ZERO_CODE_SIZE() \ + { \ + if (rle_z_count) \ + { \ + if (rle_z_count < 3) \ + { \ + d->m_huff_count[2][0] = (mz_uint16)(d->m_huff_count[2][0] + rle_z_count); \ + while (rle_z_count--) \ + packed_code_sizes[num_packed_code_sizes++] = 0; \ + } \ + else if (rle_z_count <= 10) \ + { \ + d->m_huff_count[2][17] = (mz_uint16)(d->m_huff_count[2][17] + 1); \ + packed_code_sizes[num_packed_code_sizes++] = 17; \ + packed_code_sizes[num_packed_code_sizes++] = (mz_uint8)(rle_z_count - 3); \ + } \ + else \ + { \ + d->m_huff_count[2][18] = (mz_uint16)(d->m_huff_count[2][18] + 1); \ + packed_code_sizes[num_packed_code_sizes++] = 18; \ + packed_code_sizes[num_packed_code_sizes++] = (mz_uint8)(rle_z_count - 11); \ + } \ + rle_z_count = 0; \ + } \ + } + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-magic-numbers) +static mz_uint8 s_tdefl_packed_code_size_syms_swizzle[] = { 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15 }; + +static void tdefl_start_dynamic_block(tdefl_compressor *d) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int num_lit_codes, num_dist_codes, num_bit_lengths; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint i, total_code_sizes_to_pack, num_packed_code_sizes, rle_z_count, rle_repeat_count, packed_code_sizes_index; + mz_uint8 code_sizes_to_pack[TDEFL_MAX_HUFF_SYMBOLS_0 + TDEFL_MAX_HUFF_SYMBOLS_1], packed_code_sizes[TDEFL_MAX_HUFF_SYMBOLS_0 + TDEFL_MAX_HUFF_SYMBOLS_1], prev_code_size = 0xFF; + + d->m_huff_count[0][256] = 1; + + tdefl_optimize_huffman_table(d, 0, TDEFL_MAX_HUFF_SYMBOLS_0, 15, MZ_FALSE); + tdefl_optimize_huffman_table(d, 1, TDEFL_MAX_HUFF_SYMBOLS_1, 15, MZ_FALSE); + + for (num_lit_codes = 286; num_lit_codes > 257; num_lit_codes--) + if (d->m_huff_code_sizes[0][num_lit_codes - 1]) + break; + for (num_dist_codes = 30; num_dist_codes > 1; num_dist_codes--) + if (d->m_huff_code_sizes[1][num_dist_codes - 1]) + break; + + memcpy(code_sizes_to_pack, &d->m_huff_code_sizes[0][0], num_lit_codes); + memcpy(code_sizes_to_pack + num_lit_codes, &d->m_huff_code_sizes[1][0], num_dist_codes); + total_code_sizes_to_pack = num_lit_codes + num_dist_codes; + num_packed_code_sizes = 0; + rle_z_count = 0; + rle_repeat_count = 0; + + memset(&d->m_huff_count[2][0], 0, sizeof(d->m_huff_count[2][0]) * TDEFL_MAX_HUFF_SYMBOLS_2); + for (i = 0; i < total_code_sizes_to_pack; i++) + { + mz_uint8 code_size = code_sizes_to_pack[i]; + if (!code_size) + { + TDEFL_RLE_PREV_CODE_SIZE(); + if (++rle_z_count == 138) + { + TDEFL_RLE_ZERO_CODE_SIZE(); + } + } + else + { + TDEFL_RLE_ZERO_CODE_SIZE(); + if (code_size != prev_code_size) + { + TDEFL_RLE_PREV_CODE_SIZE(); + d->m_huff_count[2][code_size] = (mz_uint16)(d->m_huff_count[2][code_size] + 1); + packed_code_sizes[num_packed_code_sizes++] = code_size; + } + else if (++rle_repeat_count == 6) + { + TDEFL_RLE_PREV_CODE_SIZE(); + } + } + prev_code_size = code_size; + } + if (rle_repeat_count) + { + TDEFL_RLE_PREV_CODE_SIZE(); + } + else + { + TDEFL_RLE_ZERO_CODE_SIZE(); + } + + tdefl_optimize_huffman_table(d, 2, TDEFL_MAX_HUFF_SYMBOLS_2, 7, MZ_FALSE); + + TDEFL_PUT_BITS(2, 2); + + TDEFL_PUT_BITS(num_lit_codes - 257, 5); + TDEFL_PUT_BITS(num_dist_codes - 1, 5); + + for (num_bit_lengths = 18; num_bit_lengths >= 0; num_bit_lengths--) + if (d->m_huff_code_sizes[2][s_tdefl_packed_code_size_syms_swizzle[num_bit_lengths]]) + break; + num_bit_lengths = MZ_MAX(4, (num_bit_lengths + 1)); + TDEFL_PUT_BITS(num_bit_lengths - 4, 4); + for (i = 0; (int)i < num_bit_lengths; i++) + TDEFL_PUT_BITS(d->m_huff_code_sizes[2][s_tdefl_packed_code_size_syms_swizzle[i]], 3); + + for (packed_code_sizes_index = 0; packed_code_sizes_index < num_packed_code_sizes;) + { + mz_uint code = packed_code_sizes[packed_code_sizes_index++]; + MZ_ASSERT(code < TDEFL_MAX_HUFF_SYMBOLS_2); + TDEFL_PUT_BITS(d->m_huff_codes[2][code], d->m_huff_code_sizes[2][code]); + if (code >= 16) + // NOLINTNEXTLINE(bugprone-signed-char-misuse) + TDEFL_PUT_BITS(packed_code_sizes[packed_code_sizes_index++], "\02\03\07"[code - 16]); + } +} + +static void tdefl_start_static_block(tdefl_compressor *d) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint i; + mz_uint8 *p = &d->m_huff_code_sizes[0][0]; + + for (i = 0; i <= 143; ++i) + *p++ = 8; + for (; i <= 255; ++i) + *p++ = 9; + for (; i <= 279; ++i) + *p++ = 7; + for (; i <= 287; ++i) + *p++ = 8; + + memset(d->m_huff_code_sizes[1], 5, 32); + + tdefl_optimize_huffman_table(d, 0, 288, 15, MZ_TRUE); + tdefl_optimize_huffman_table(d, 1, 32, 15, MZ_TRUE); + + TDEFL_PUT_BITS(1, 2); +} + +static const mz_uint mz_bitmasks[17] = { 0x0000, 0x0001, 0x0003, 0x0007, 0x000F, 0x001F, 0x003F, 0x007F, 0x00FF, 0x01FF, 0x03FF, 0x07FF, 0x0FFF, 0x1FFF, 0x3FFF, 0x7FFF, 0xFFFF }; + +#if MINIZ_USE_UNALIGNED_LOADS_AND_STORES && MINIZ_LITTLE_ENDIAN && MINIZ_HAS_64BIT_REGISTERS +static mz_bool tdefl_compress_lz_codes(tdefl_compressor *d) +{ + mz_uint flags; + mz_uint8 *pLZ_codes; + mz_uint8 *pOutput_buf = d->m_pOutput_buf; + mz_uint8 *pLZ_code_buf_end = d->m_pLZ_code_buf; + mz_uint64 bit_buffer = d->m_bit_buffer; + mz_uint bits_in = d->m_bits_in; + +#define TDEFL_PUT_BITS_FAST(b, l) \ + { \ + bit_buffer |= (((mz_uint64)(b)) << bits_in); \ + bits_in += (l); \ + } + + flags = 1; + for (pLZ_codes = d->m_lz_code_buf; pLZ_codes < pLZ_code_buf_end; flags >>= 1) + { + if (flags == 1) + flags = *pLZ_codes++ | 0x100; + + if (flags & 1) + { + mz_uint s0, s1, n0, n1, sym, num_extra_bits; + mz_uint match_len = pLZ_codes[0], match_dist = *(const mz_uint16 *)(pLZ_codes + 1); + pLZ_codes += 3; + + MZ_ASSERT(d->m_huff_code_sizes[0][s_tdefl_len_sym[match_len]]); + TDEFL_PUT_BITS_FAST(d->m_huff_codes[0][s_tdefl_len_sym[match_len]], d->m_huff_code_sizes[0][s_tdefl_len_sym[match_len]]); + TDEFL_PUT_BITS_FAST(match_len & mz_bitmasks[s_tdefl_len_extra[match_len]], s_tdefl_len_extra[match_len]); + + /* This sequence coaxes MSVC into using cmov's vs. jmp's. */ + s0 = s_tdefl_small_dist_sym[match_dist & 511]; + n0 = s_tdefl_small_dist_extra[match_dist & 511]; + s1 = s_tdefl_large_dist_sym[match_dist >> 8]; + n1 = s_tdefl_large_dist_extra[match_dist >> 8]; + sym = (match_dist < 512) ? s0 : s1; + num_extra_bits = (match_dist < 512) ? n0 : n1; + + MZ_ASSERT(d->m_huff_code_sizes[1][sym]); + TDEFL_PUT_BITS_FAST(d->m_huff_codes[1][sym], d->m_huff_code_sizes[1][sym]); + TDEFL_PUT_BITS_FAST(match_dist & mz_bitmasks[num_extra_bits], num_extra_bits); + } + else + { + mz_uint lit = *pLZ_codes++; + MZ_ASSERT(d->m_huff_code_sizes[0][lit]); + TDEFL_PUT_BITS_FAST(d->m_huff_codes[0][lit], d->m_huff_code_sizes[0][lit]); + + if (((flags & 2) == 0) && (pLZ_codes < pLZ_code_buf_end)) + { + flags >>= 1; + lit = *pLZ_codes++; + MZ_ASSERT(d->m_huff_code_sizes[0][lit]); + TDEFL_PUT_BITS_FAST(d->m_huff_codes[0][lit], d->m_huff_code_sizes[0][lit]); + + if (((flags & 2) == 0) && (pLZ_codes < pLZ_code_buf_end)) + { + flags >>= 1; + lit = *pLZ_codes++; + MZ_ASSERT(d->m_huff_code_sizes[0][lit]); + TDEFL_PUT_BITS_FAST(d->m_huff_codes[0][lit], d->m_huff_code_sizes[0][lit]); + } + } + } + + if (pOutput_buf >= d->m_pOutput_buf_end) + return MZ_FALSE; + + *(mz_uint64 *)pOutput_buf = bit_buffer; + pOutput_buf += (bits_in >> 3); + bit_buffer >>= (bits_in & ~7); + bits_in &= 7; + } + +#undef TDEFL_PUT_BITS_FAST + + d->m_pOutput_buf = pOutput_buf; + d->m_bits_in = 0; + d->m_bit_buffer = 0; + + while (bits_in) + { + mz_uint32 n = MZ_MIN(bits_in, 16); + TDEFL_PUT_BITS((mz_uint)bit_buffer & mz_bitmasks[n], n); + bit_buffer >>= n; + bits_in -= n; + } + + TDEFL_PUT_BITS(d->m_huff_codes[0][256], d->m_huff_code_sizes[0][256]); + + return (d->m_pOutput_buf < d->m_pOutput_buf_end); +} +#else +static mz_bool tdefl_compress_lz_codes(tdefl_compressor *d) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint flags; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint8 *pLZ_codes; + + flags = 1; + for (pLZ_codes = d->m_lz_code_buf; pLZ_codes < d->m_pLZ_code_buf; flags >>= 1) + { + if (flags == 1) + flags = *pLZ_codes++ | 0x100; + if (flags & 1) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint sym, num_extra_bits; + mz_uint match_len = pLZ_codes[0], match_dist = (pLZ_codes[1] | (pLZ_codes[2] << 8)); + pLZ_codes += 3; + + MZ_ASSERT(d->m_huff_code_sizes[0][s_tdefl_len_sym[match_len]]); + TDEFL_PUT_BITS(d->m_huff_codes[0][s_tdefl_len_sym[match_len]], d->m_huff_code_sizes[0][s_tdefl_len_sym[match_len]]); + TDEFL_PUT_BITS(match_len & mz_bitmasks[s_tdefl_len_extra[match_len]], s_tdefl_len_extra[match_len]); + + if (match_dist < 512) + { + sym = s_tdefl_small_dist_sym[match_dist]; + num_extra_bits = s_tdefl_small_dist_extra[match_dist]; + } + else + { + sym = s_tdefl_large_dist_sym[match_dist >> 8]; + num_extra_bits = s_tdefl_large_dist_extra[match_dist >> 8]; + } + MZ_ASSERT(d->m_huff_code_sizes[1][sym]); + TDEFL_PUT_BITS(d->m_huff_codes[1][sym], d->m_huff_code_sizes[1][sym]); + TDEFL_PUT_BITS(match_dist & mz_bitmasks[num_extra_bits], num_extra_bits); + } + else + { + mz_uint lit = *pLZ_codes++; + MZ_ASSERT(d->m_huff_code_sizes[0][lit]); + TDEFL_PUT_BITS(d->m_huff_codes[0][lit], d->m_huff_code_sizes[0][lit]); + } + } + + TDEFL_PUT_BITS(d->m_huff_codes[0][256], d->m_huff_code_sizes[0][256]); + + return (d->m_pOutput_buf < d->m_pOutput_buf_end); +} +#endif /* MINIZ_USE_UNALIGNED_LOADS_AND_STORES && MINIZ_LITTLE_ENDIAN && MINIZ_HAS_64BIT_REGISTERS */ + +static mz_bool tdefl_compress_block(tdefl_compressor *d, mz_bool static_block) +{ + if (static_block) + tdefl_start_static_block(d); + else + tdefl_start_dynamic_block(d); + return tdefl_compress_lz_codes(d); +} + +static int tdefl_flush_block(tdefl_compressor *d, int flush) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint saved_bit_buf, saved_bits_in; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint8 *pSaved_output_buf; + mz_bool comp_block_succeeded = MZ_FALSE; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int n, use_raw_block = ((d->m_flags & TDEFL_FORCE_ALL_RAW_BLOCKS) != 0) && (d->m_lookahead_pos - d->m_lz_code_buf_dict_pos) <= d->m_dict_size; + mz_uint8 *pOutput_buf_start = ((d->m_pPut_buf_func == NULL) && ((*d->m_pOut_buf_size - d->m_out_buf_ofs) >= TDEFL_OUT_BUF_SIZE)) ? ((mz_uint8 *)d->m_pOut_buf + d->m_out_buf_ofs) : d->m_output_buf; + + d->m_pOutput_buf = pOutput_buf_start; + d->m_pOutput_buf_end = d->m_pOutput_buf + TDEFL_OUT_BUF_SIZE - 16; + + MZ_ASSERT(!d->m_output_flush_remaining); + d->m_output_flush_ofs = 0; + d->m_output_flush_remaining = 0; + + *d->m_pLZ_flags = (mz_uint8)(*d->m_pLZ_flags >> d->m_num_flags_left); + d->m_pLZ_code_buf -= (d->m_num_flags_left == 8); + + if ((d->m_flags & TDEFL_WRITE_ZLIB_HEADER) && (!d->m_block_index)) + { + TDEFL_PUT_BITS(0x78, 8); + TDEFL_PUT_BITS(0x01, 8); + } + + TDEFL_PUT_BITS(flush == TDEFL_FINISH, 1); + + pSaved_output_buf = d->m_pOutput_buf; + saved_bit_buf = d->m_bit_buffer; + saved_bits_in = d->m_bits_in; + + if (!use_raw_block) + comp_block_succeeded = tdefl_compress_block(d, (d->m_flags & TDEFL_FORCE_ALL_STATIC_BLOCKS) || (d->m_total_lz_bytes < 48)); + + /* If the block gets expanded, forget the current contents of the output buffer and send a raw block instead. */ + if (((use_raw_block) || ((d->m_total_lz_bytes) && ((d->m_pOutput_buf - pSaved_output_buf + 1U) >= d->m_total_lz_bytes))) && + ((d->m_lookahead_pos - d->m_lz_code_buf_dict_pos) <= d->m_dict_size)) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint i; + d->m_pOutput_buf = pSaved_output_buf; + d->m_bit_buffer = saved_bit_buf, d->m_bits_in = saved_bits_in; + TDEFL_PUT_BITS(0, 2); + if (d->m_bits_in) + { + TDEFL_PUT_BITS(0, 8 - d->m_bits_in); + } + for (i = 2; i; --i, d->m_total_lz_bytes ^= 0xFFFF) + { + TDEFL_PUT_BITS(d->m_total_lz_bytes & 0xFFFF, 16); + } + for (i = 0; i < d->m_total_lz_bytes; ++i) + { + TDEFL_PUT_BITS(d->m_dict[(d->m_lz_code_buf_dict_pos + i) & TDEFL_LZ_DICT_SIZE_MASK], 8); + } + } + /* Check for the extremely unlikely (if not impossible) case of the compressed block not fitting into the output buffer when using dynamic codes. */ + else if (!comp_block_succeeded) + { + d->m_pOutput_buf = pSaved_output_buf; + d->m_bit_buffer = saved_bit_buf, d->m_bits_in = saved_bits_in; + tdefl_compress_block(d, MZ_TRUE); + } + + if (flush) + { + if (flush == TDEFL_FINISH) + { + if (d->m_bits_in) + { + TDEFL_PUT_BITS(0, 8 - d->m_bits_in); + } + if (d->m_flags & TDEFL_WRITE_ZLIB_HEADER) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint i, a = d->m_adler32; + for (i = 0; i < 4; i++) + { + TDEFL_PUT_BITS((a >> 24) & 0xFF, 8); + a <<= 8; + } + } + } + else + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint i, z = 0; + TDEFL_PUT_BITS(0, 3); + if (d->m_bits_in) + { + TDEFL_PUT_BITS(0, 8 - d->m_bits_in); + } + for (i = 2; i; --i, z ^= 0xFFFF) + { + TDEFL_PUT_BITS(z & 0xFFFF, 16); + } + } + } + + MZ_ASSERT(d->m_pOutput_buf < d->m_pOutput_buf_end); + + memset(&d->m_huff_count[0][0], 0, sizeof(d->m_huff_count[0][0]) * TDEFL_MAX_HUFF_SYMBOLS_0); + memset(&d->m_huff_count[1][0], 0, sizeof(d->m_huff_count[1][0]) * TDEFL_MAX_HUFF_SYMBOLS_1); + + d->m_pLZ_code_buf = d->m_lz_code_buf + 1; + d->m_pLZ_flags = d->m_lz_code_buf; + d->m_num_flags_left = 8; + d->m_lz_code_buf_dict_pos += d->m_total_lz_bytes; + d->m_total_lz_bytes = 0; + d->m_block_index++; + + if ((n = (int)(d->m_pOutput_buf - pOutput_buf_start)) != 0) + { + if (d->m_pPut_buf_func) + { + *d->m_pIn_buf_size = d->m_pSrc - (const mz_uint8 *)d->m_pIn_buf; + if (!(*d->m_pPut_buf_func)(d->m_output_buf, n, d->m_pPut_buf_user)) + return (d->m_prev_return_status = TDEFL_STATUS_PUT_BUF_FAILED); + } + else if (pOutput_buf_start == d->m_output_buf) + { + int bytes_to_copy = (int)MZ_MIN((size_t)n, (size_t)(*d->m_pOut_buf_size - d->m_out_buf_ofs)); + memcpy((mz_uint8 *)d->m_pOut_buf + d->m_out_buf_ofs, d->m_output_buf, bytes_to_copy); + d->m_out_buf_ofs += bytes_to_copy; + if ((n -= bytes_to_copy) != 0) + { + d->m_output_flush_ofs = bytes_to_copy; + d->m_output_flush_remaining = n; + } + } + else + { + d->m_out_buf_ofs += n; + } + } + + return d->m_output_flush_remaining; +} + +#if MINIZ_USE_UNALIGNED_LOADS_AND_STORES +#ifdef MINIZ_UNALIGNED_USE_MEMCPY +static mz_uint16 TDEFL_READ_UNALIGNED_WORD(const mz_uint8* p) +{ + mz_uint16 ret; + memcpy(&ret, p, sizeof(mz_uint16)); + return ret; +} +static mz_uint16 TDEFL_READ_UNALIGNED_WORD2(const mz_uint16* p) +{ + mz_uint16 ret; + memcpy(&ret, p, sizeof(mz_uint16)); + return ret; +} +#else +#define TDEFL_READ_UNALIGNED_WORD(p) *(const mz_uint16 *)(p) +#define TDEFL_READ_UNALIGNED_WORD2(p) *(const mz_uint16 *)(p) +#endif +static MZ_FORCEINLINE void tdefl_find_match(tdefl_compressor *d, mz_uint lookahead_pos, mz_uint max_dist, mz_uint max_match_len, mz_uint *pMatch_dist, mz_uint *pMatch_len) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint dist, pos = lookahead_pos & TDEFL_LZ_DICT_SIZE_MASK, match_len = *pMatch_len, probe_pos = pos, next_probe_pos, probe_len; + mz_uint num_probes_left = d->m_max_probes[match_len >= 32]; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + const mz_uint16 *s = (const mz_uint16 *)(d->m_dict + pos), *p, *q; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint16 c01 = TDEFL_READ_UNALIGNED_WORD(&d->m_dict[pos + match_len - 1]), s01 = TDEFL_READ_UNALIGNED_WORD2(s); + MZ_ASSERT(max_match_len <= TDEFL_MAX_MATCH_LEN); + if (max_match_len <= match_len) + return; + for (;;) + { + for (;;) + { + if (--num_probes_left == 0) + return; +#define TDEFL_PROBE \ + next_probe_pos = d->m_next[probe_pos]; \ + if ((!next_probe_pos) || ((dist = (mz_uint16)(lookahead_pos - next_probe_pos)) > max_dist)) \ + return; \ + probe_pos = next_probe_pos & TDEFL_LZ_DICT_SIZE_MASK; \ + if (TDEFL_READ_UNALIGNED_WORD(&d->m_dict[probe_pos + match_len - 1]) == c01) \ + break; + TDEFL_PROBE; + TDEFL_PROBE; + TDEFL_PROBE; + } + if (!dist) + break; + q = (const mz_uint16 *)(d->m_dict + probe_pos); + if (TDEFL_READ_UNALIGNED_WORD2(q) != s01) + continue; + p = s; + probe_len = 32; + do + { + } while ((TDEFL_READ_UNALIGNED_WORD2(++p) == TDEFL_READ_UNALIGNED_WORD2(++q)) && (TDEFL_READ_UNALIGNED_WORD2(++p) == TDEFL_READ_UNALIGNED_WORD2(++q)) && + (TDEFL_READ_UNALIGNED_WORD2(++p) == TDEFL_READ_UNALIGNED_WORD2(++q)) && (TDEFL_READ_UNALIGNED_WORD2(++p) == TDEFL_READ_UNALIGNED_WORD2(++q)) && (--probe_len > 0)); + if (!probe_len) + { + *pMatch_dist = dist; + *pMatch_len = MZ_MIN(max_match_len, (mz_uint)TDEFL_MAX_MATCH_LEN); + break; + } + else if ((probe_len = ((mz_uint)(p - s) * 2) + (mz_uint)(*(const mz_uint8 *)p == *(const mz_uint8 *)q)) > match_len) + { + *pMatch_dist = dist; + if ((*pMatch_len = match_len = MZ_MIN(max_match_len, probe_len)) == max_match_len) + break; + c01 = TDEFL_READ_UNALIGNED_WORD(&d->m_dict[pos + match_len - 1]); + } + } +} +#else +static MZ_FORCEINLINE void tdefl_find_match(tdefl_compressor *d, mz_uint lookahead_pos, mz_uint max_dist, mz_uint max_match_len, mz_uint *pMatch_dist, mz_uint *pMatch_len) +{ + mz_uint dist, pos = lookahead_pos & TDEFL_LZ_DICT_SIZE_MASK, match_len = *pMatch_len, probe_pos = pos, next_probe_pos, probe_len; + mz_uint num_probes_left = d->m_max_probes[match_len >= 32]; + const mz_uint8 *s = d->m_dict + pos, *p, *q; + mz_uint8 c0 = d->m_dict[pos + match_len], c1 = d->m_dict[pos + match_len - 1]; + MZ_ASSERT(max_match_len <= TDEFL_MAX_MATCH_LEN); + if (max_match_len <= match_len) + return; + for (;;) + { + for (;;) + { + if (--num_probes_left == 0) + return; +#define TDEFL_PROBE \ + next_probe_pos = d->m_next[probe_pos]; \ + if ((!next_probe_pos) || ((dist = (mz_uint16)(lookahead_pos - next_probe_pos)) > max_dist)) \ + return; \ + probe_pos = next_probe_pos & TDEFL_LZ_DICT_SIZE_MASK; \ + if ((d->m_dict[probe_pos + match_len] == c0) && (d->m_dict[probe_pos + match_len - 1] == c1)) \ + break; + TDEFL_PROBE; + TDEFL_PROBE; + TDEFL_PROBE; + } + if (!dist) + break; + p = s; + q = d->m_dict + probe_pos; + for (probe_len = 0; probe_len < max_match_len; probe_len++) + if (*p++ != *q++) + break; + if (probe_len > match_len) + { + *pMatch_dist = dist; + if ((*pMatch_len = match_len = probe_len) == max_match_len) + return; + c0 = d->m_dict[pos + match_len]; + c1 = d->m_dict[pos + match_len - 1]; + } + } +} +#endif /* #if MINIZ_USE_UNALIGNED_LOADS_AND_STORES */ + +#if MINIZ_USE_UNALIGNED_LOADS_AND_STORES && MINIZ_LITTLE_ENDIAN +#ifdef MINIZ_UNALIGNED_USE_MEMCPY +static mz_uint32 TDEFL_READ_UNALIGNED_WORD32(const mz_uint8* p) +{ + mz_uint32 ret; + memcpy(&ret, p, sizeof(mz_uint32)); + return ret; +} +#else +#define TDEFL_READ_UNALIGNED_WORD32(p) *(const mz_uint32 *)(p) +#endif +static mz_bool tdefl_compress_fast(tdefl_compressor *d) +{ + /* Faster, minimally featured LZRW1-style match+parse loop with better register utilization. Intended for applications where raw throughput is valued more highly than ratio. */ + mz_uint lookahead_pos = d->m_lookahead_pos, lookahead_size = d->m_lookahead_size, dict_size = d->m_dict_size, total_lz_bytes = d->m_total_lz_bytes, num_flags_left = d->m_num_flags_left; + mz_uint8 *pLZ_code_buf = d->m_pLZ_code_buf, *pLZ_flags = d->m_pLZ_flags; + mz_uint cur_pos = lookahead_pos & TDEFL_LZ_DICT_SIZE_MASK; + + while ((d->m_src_buf_left) || ((d->m_flush) && (lookahead_size))) + { + const mz_uint TDEFL_COMP_FAST_LOOKAHEAD_SIZE = 4096; + mz_uint dst_pos = (lookahead_pos + lookahead_size) & TDEFL_LZ_DICT_SIZE_MASK; + mz_uint num_bytes_to_process = (mz_uint)MZ_MIN(d->m_src_buf_left, TDEFL_COMP_FAST_LOOKAHEAD_SIZE - lookahead_size); + d->m_src_buf_left -= num_bytes_to_process; + lookahead_size += num_bytes_to_process; + + while (num_bytes_to_process) + { + mz_uint32 n = MZ_MIN(TDEFL_LZ_DICT_SIZE - dst_pos, num_bytes_to_process); + memcpy(d->m_dict + dst_pos, d->m_pSrc, n); + if (dst_pos < (TDEFL_MAX_MATCH_LEN - 1)) + memcpy(d->m_dict + TDEFL_LZ_DICT_SIZE + dst_pos, d->m_pSrc, MZ_MIN(n, (TDEFL_MAX_MATCH_LEN - 1) - dst_pos)); + d->m_pSrc += n; + dst_pos = (dst_pos + n) & TDEFL_LZ_DICT_SIZE_MASK; + num_bytes_to_process -= n; + } + + dict_size = MZ_MIN(TDEFL_LZ_DICT_SIZE - lookahead_size, dict_size); + if ((!d->m_flush) && (lookahead_size < TDEFL_COMP_FAST_LOOKAHEAD_SIZE)) + break; + + while (lookahead_size >= 4) + { + mz_uint cur_match_dist, cur_match_len = 1; + mz_uint8 *pCur_dict = d->m_dict + cur_pos; + mz_uint first_trigram = TDEFL_READ_UNALIGNED_WORD32(pCur_dict) & 0xFFFFFF; + mz_uint hash = (first_trigram ^ (first_trigram >> (24 - (TDEFL_LZ_HASH_BITS - 8)))) & TDEFL_LEVEL1_HASH_SIZE_MASK; + mz_uint probe_pos = d->m_hash[hash]; + d->m_hash[hash] = (mz_uint16)lookahead_pos; + + if (((cur_match_dist = (mz_uint16)(lookahead_pos - probe_pos)) <= dict_size) && ((TDEFL_READ_UNALIGNED_WORD32(d->m_dict + (probe_pos &= TDEFL_LZ_DICT_SIZE_MASK)) & 0xFFFFFF) == first_trigram)) + { + const mz_uint16 *p = (const mz_uint16 *)pCur_dict; + const mz_uint16 *q = (const mz_uint16 *)(d->m_dict + probe_pos); + mz_uint32 probe_len = 32; + do + { + } while ((TDEFL_READ_UNALIGNED_WORD2(++p) == TDEFL_READ_UNALIGNED_WORD2(++q)) && (TDEFL_READ_UNALIGNED_WORD2(++p) == TDEFL_READ_UNALIGNED_WORD2(++q)) && + (TDEFL_READ_UNALIGNED_WORD2(++p) == TDEFL_READ_UNALIGNED_WORD2(++q)) && (TDEFL_READ_UNALIGNED_WORD2(++p) == TDEFL_READ_UNALIGNED_WORD2(++q)) && (--probe_len > 0)); + cur_match_len = ((mz_uint)(p - (const mz_uint16 *)pCur_dict) * 2) + (mz_uint)(*(const mz_uint8 *)p == *(const mz_uint8 *)q); + if (!probe_len) + cur_match_len = cur_match_dist ? TDEFL_MAX_MATCH_LEN : 0; + + if ((cur_match_len < TDEFL_MIN_MATCH_LEN) || ((cur_match_len == TDEFL_MIN_MATCH_LEN) && (cur_match_dist >= 8U * 1024U))) + { + cur_match_len = 1; + *pLZ_code_buf++ = (mz_uint8)first_trigram; + *pLZ_flags = (mz_uint8)(*pLZ_flags >> 1); + d->m_huff_count[0][(mz_uint8)first_trigram]++; + } + else + { + mz_uint32 s0, s1; + cur_match_len = MZ_MIN(cur_match_len, lookahead_size); + + MZ_ASSERT((cur_match_len >= TDEFL_MIN_MATCH_LEN) && (cur_match_dist >= 1) && (cur_match_dist <= TDEFL_LZ_DICT_SIZE)); + + cur_match_dist--; + + pLZ_code_buf[0] = (mz_uint8)(cur_match_len - TDEFL_MIN_MATCH_LEN); +#ifdef MINIZ_UNALIGNED_USE_MEMCPY + memcpy(&pLZ_code_buf[1], &cur_match_dist, sizeof(cur_match_dist)); +#else + *(mz_uint16 *)(&pLZ_code_buf[1]) = (mz_uint16)cur_match_dist; +#endif + pLZ_code_buf += 3; + *pLZ_flags = (mz_uint8)((*pLZ_flags >> 1) | 0x80); + + s0 = s_tdefl_small_dist_sym[cur_match_dist & 511]; + s1 = s_tdefl_large_dist_sym[cur_match_dist >> 8]; + d->m_huff_count[1][(cur_match_dist < 512) ? s0 : s1]++; + + d->m_huff_count[0][s_tdefl_len_sym[cur_match_len - TDEFL_MIN_MATCH_LEN]]++; + } + } + else + { + *pLZ_code_buf++ = (mz_uint8)first_trigram; + *pLZ_flags = (mz_uint8)(*pLZ_flags >> 1); + d->m_huff_count[0][(mz_uint8)first_trigram]++; + } + + if (--num_flags_left == 0) + { + num_flags_left = 8; + pLZ_flags = pLZ_code_buf++; + } + + total_lz_bytes += cur_match_len; + lookahead_pos += cur_match_len; + dict_size = MZ_MIN(dict_size + cur_match_len, (mz_uint)TDEFL_LZ_DICT_SIZE); + cur_pos = (cur_pos + cur_match_len) & TDEFL_LZ_DICT_SIZE_MASK; + MZ_ASSERT(lookahead_size >= cur_match_len); + lookahead_size -= cur_match_len; + + if (pLZ_code_buf > &d->m_lz_code_buf[TDEFL_LZ_CODE_BUF_SIZE - 8]) + { + int n; + d->m_lookahead_pos = lookahead_pos; + d->m_lookahead_size = lookahead_size; + d->m_dict_size = dict_size; + d->m_total_lz_bytes = total_lz_bytes; + d->m_pLZ_code_buf = pLZ_code_buf; + d->m_pLZ_flags = pLZ_flags; + d->m_num_flags_left = num_flags_left; + if ((n = tdefl_flush_block(d, 0)) != 0) + return (n < 0) ? MZ_FALSE : MZ_TRUE; + total_lz_bytes = d->m_total_lz_bytes; + pLZ_code_buf = d->m_pLZ_code_buf; + pLZ_flags = d->m_pLZ_flags; + num_flags_left = d->m_num_flags_left; + } + } + + while (lookahead_size) + { + mz_uint8 lit = d->m_dict[cur_pos]; + + total_lz_bytes++; + *pLZ_code_buf++ = lit; + *pLZ_flags = (mz_uint8)(*pLZ_flags >> 1); + if (--num_flags_left == 0) + { + num_flags_left = 8; + pLZ_flags = pLZ_code_buf++; + } + + d->m_huff_count[0][lit]++; + + lookahead_pos++; + dict_size = MZ_MIN(dict_size + 1, (mz_uint)TDEFL_LZ_DICT_SIZE); + cur_pos = (cur_pos + 1) & TDEFL_LZ_DICT_SIZE_MASK; + lookahead_size--; + + if (pLZ_code_buf > &d->m_lz_code_buf[TDEFL_LZ_CODE_BUF_SIZE - 8]) + { + int n; + d->m_lookahead_pos = lookahead_pos; + d->m_lookahead_size = lookahead_size; + d->m_dict_size = dict_size; + d->m_total_lz_bytes = total_lz_bytes; + d->m_pLZ_code_buf = pLZ_code_buf; + d->m_pLZ_flags = pLZ_flags; + d->m_num_flags_left = num_flags_left; + if ((n = tdefl_flush_block(d, 0)) != 0) + return (n < 0) ? MZ_FALSE : MZ_TRUE; + total_lz_bytes = d->m_total_lz_bytes; + pLZ_code_buf = d->m_pLZ_code_buf; + pLZ_flags = d->m_pLZ_flags; + num_flags_left = d->m_num_flags_left; + } + } + } + + d->m_lookahead_pos = lookahead_pos; + d->m_lookahead_size = lookahead_size; + d->m_dict_size = dict_size; + d->m_total_lz_bytes = total_lz_bytes; + d->m_pLZ_code_buf = pLZ_code_buf; + d->m_pLZ_flags = pLZ_flags; + d->m_num_flags_left = num_flags_left; + return MZ_TRUE; +} +#endif /* MINIZ_USE_UNALIGNED_LOADS_AND_STORES && MINIZ_LITTLE_ENDIAN */ + +static MZ_FORCEINLINE void tdefl_record_literal(tdefl_compressor *d, mz_uint8 lit) +{ + d->m_total_lz_bytes++; + *d->m_pLZ_code_buf++ = lit; + // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult) + *d->m_pLZ_flags = (mz_uint8)(*d->m_pLZ_flags >> 1); + if (--d->m_num_flags_left == 0) + { + d->m_num_flags_left = 8; + d->m_pLZ_flags = d->m_pLZ_code_buf++; + } + d->m_huff_count[0][lit]++; +} + +static MZ_FORCEINLINE void tdefl_record_match(tdefl_compressor *d, mz_uint match_len, mz_uint match_dist) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 s0, s1; + + MZ_ASSERT((match_len >= TDEFL_MIN_MATCH_LEN) && (match_dist >= 1) && (match_dist <= TDEFL_LZ_DICT_SIZE)); + + d->m_total_lz_bytes += match_len; + + d->m_pLZ_code_buf[0] = (mz_uint8)(match_len - TDEFL_MIN_MATCH_LEN); + + match_dist -= 1; + d->m_pLZ_code_buf[1] = (mz_uint8)(match_dist & 0xFF); + d->m_pLZ_code_buf[2] = (mz_uint8)(match_dist >> 8); + d->m_pLZ_code_buf += 3; + + *d->m_pLZ_flags = (mz_uint8)((*d->m_pLZ_flags >> 1) | 0x80); + if (--d->m_num_flags_left == 0) + { + d->m_num_flags_left = 8; + d->m_pLZ_flags = d->m_pLZ_code_buf++; + } + + s0 = s_tdefl_small_dist_sym[match_dist & 511]; + s1 = s_tdefl_large_dist_sym[(match_dist >> 8) & 127]; + d->m_huff_count[1][(match_dist < 512) ? s0 : s1]++; + + if (match_len >= TDEFL_MIN_MATCH_LEN) + d->m_huff_count[0][s_tdefl_len_sym[match_len - TDEFL_MIN_MATCH_LEN]]++; +} + +static mz_bool tdefl_compress_normal(tdefl_compressor *d) +{ + const mz_uint8 *pSrc = d->m_pSrc; + size_t src_buf_left = d->m_src_buf_left; + tdefl_flush flush = d->m_flush; + + while ((src_buf_left) || ((flush) && (d->m_lookahead_size))) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint len_to_move, cur_match_dist, cur_match_len, cur_pos; + /* Update dictionary and hash chains. Keeps the lookahead size equal to TDEFL_MAX_MATCH_LEN. */ + if ((d->m_lookahead_size + d->m_dict_size) >= (TDEFL_MIN_MATCH_LEN - 1)) + { + mz_uint dst_pos = (d->m_lookahead_pos + d->m_lookahead_size) & TDEFL_LZ_DICT_SIZE_MASK, ins_pos = d->m_lookahead_pos + d->m_lookahead_size - 2; + mz_uint hash = (d->m_dict[ins_pos & TDEFL_LZ_DICT_SIZE_MASK] << TDEFL_LZ_HASH_SHIFT) ^ d->m_dict[(ins_pos + 1) & TDEFL_LZ_DICT_SIZE_MASK]; + mz_uint num_bytes_to_process = (mz_uint)MZ_MIN(src_buf_left, TDEFL_MAX_MATCH_LEN - d->m_lookahead_size); + const mz_uint8 *pSrc_end = pSrc + num_bytes_to_process; + src_buf_left -= num_bytes_to_process; + d->m_lookahead_size += num_bytes_to_process; + while (pSrc != pSrc_end) + { + mz_uint8 c = *pSrc++; + d->m_dict[dst_pos] = c; + if (dst_pos < (TDEFL_MAX_MATCH_LEN - 1)) + d->m_dict[TDEFL_LZ_DICT_SIZE + dst_pos] = c; + hash = ((hash << TDEFL_LZ_HASH_SHIFT) ^ c) & (TDEFL_LZ_HASH_SIZE - 1); + d->m_next[ins_pos & TDEFL_LZ_DICT_SIZE_MASK] = d->m_hash[hash]; + d->m_hash[hash] = (mz_uint16)(ins_pos); + dst_pos = (dst_pos + 1) & TDEFL_LZ_DICT_SIZE_MASK; + ins_pos++; + } + } + else + { + while ((src_buf_left) && (d->m_lookahead_size < TDEFL_MAX_MATCH_LEN)) + { + mz_uint8 c = *pSrc++; + mz_uint dst_pos = (d->m_lookahead_pos + d->m_lookahead_size) & TDEFL_LZ_DICT_SIZE_MASK; + src_buf_left--; + d->m_dict[dst_pos] = c; + if (dst_pos < (TDEFL_MAX_MATCH_LEN - 1)) + d->m_dict[TDEFL_LZ_DICT_SIZE + dst_pos] = c; + if ((++d->m_lookahead_size + d->m_dict_size) >= TDEFL_MIN_MATCH_LEN) + { + mz_uint ins_pos = d->m_lookahead_pos + (d->m_lookahead_size - 1) - 2; + mz_uint hash = ((d->m_dict[ins_pos & TDEFL_LZ_DICT_SIZE_MASK] << (TDEFL_LZ_HASH_SHIFT * 2)) ^ (d->m_dict[(ins_pos + 1) & TDEFL_LZ_DICT_SIZE_MASK] << TDEFL_LZ_HASH_SHIFT) ^ c) & (TDEFL_LZ_HASH_SIZE - 1); + d->m_next[ins_pos & TDEFL_LZ_DICT_SIZE_MASK] = d->m_hash[hash]; + d->m_hash[hash] = (mz_uint16)(ins_pos); + } + } + } + d->m_dict_size = MZ_MIN(TDEFL_LZ_DICT_SIZE - d->m_lookahead_size, d->m_dict_size); + if ((!flush) && (d->m_lookahead_size < TDEFL_MAX_MATCH_LEN)) + break; + + /* Simple lazy/greedy parsing state machine. */ + len_to_move = 1; + cur_match_dist = 0; + cur_match_len = d->m_saved_match_len ? d->m_saved_match_len : (TDEFL_MIN_MATCH_LEN - 1); + cur_pos = d->m_lookahead_pos & TDEFL_LZ_DICT_SIZE_MASK; + if (d->m_flags & (TDEFL_RLE_MATCHES | TDEFL_FORCE_ALL_RAW_BLOCKS)) + { + if ((d->m_dict_size) && (!(d->m_flags & TDEFL_FORCE_ALL_RAW_BLOCKS))) + { + mz_uint8 c = d->m_dict[(cur_pos - 1) & TDEFL_LZ_DICT_SIZE_MASK]; + cur_match_len = 0; + while (cur_match_len < d->m_lookahead_size) + { + if (d->m_dict[cur_pos + cur_match_len] != c) + break; + cur_match_len++; + } + if (cur_match_len < TDEFL_MIN_MATCH_LEN) + cur_match_len = 0; + else + cur_match_dist = 1; + } + } + else + { + tdefl_find_match(d, d->m_lookahead_pos, d->m_dict_size, d->m_lookahead_size, &cur_match_dist, &cur_match_len); + } + if (((cur_match_len == TDEFL_MIN_MATCH_LEN) && (cur_match_dist >= 8U * 1024U)) || (cur_pos == cur_match_dist) || ((d->m_flags & TDEFL_FILTER_MATCHES) && (cur_match_len <= 5))) + { + cur_match_dist = cur_match_len = 0; + } + if (d->m_saved_match_len) + { + if (cur_match_len > d->m_saved_match_len) + { + tdefl_record_literal(d, (mz_uint8)d->m_saved_lit); + if (cur_match_len >= 128) + { + tdefl_record_match(d, cur_match_len, cur_match_dist); + d->m_saved_match_len = 0; + len_to_move = cur_match_len; + } + else + { + d->m_saved_lit = d->m_dict[cur_pos]; + d->m_saved_match_dist = cur_match_dist; + d->m_saved_match_len = cur_match_len; + } + } + else + { + tdefl_record_match(d, d->m_saved_match_len, d->m_saved_match_dist); + len_to_move = d->m_saved_match_len - 1; + d->m_saved_match_len = 0; + } + } + else if (!cur_match_dist) + tdefl_record_literal(d, d->m_dict[MZ_MIN(cur_pos, sizeof(d->m_dict) - 1)]); + else if ((d->m_greedy_parsing) || (d->m_flags & TDEFL_RLE_MATCHES) || (cur_match_len >= 128)) + { + tdefl_record_match(d, cur_match_len, cur_match_dist); + len_to_move = cur_match_len; + } + else + { + d->m_saved_lit = d->m_dict[MZ_MIN(cur_pos, sizeof(d->m_dict) - 1)]; + d->m_saved_match_dist = cur_match_dist; + d->m_saved_match_len = cur_match_len; + } + /* Move the lookahead forward by len_to_move bytes. */ + d->m_lookahead_pos += len_to_move; + MZ_ASSERT(d->m_lookahead_size >= len_to_move); + d->m_lookahead_size -= len_to_move; + d->m_dict_size = MZ_MIN(d->m_dict_size + len_to_move, (mz_uint)TDEFL_LZ_DICT_SIZE); + /* Check if it's time to flush the current LZ codes to the internal output buffer. */ + if ((d->m_pLZ_code_buf > &d->m_lz_code_buf[TDEFL_LZ_CODE_BUF_SIZE - 8]) || + ((d->m_total_lz_bytes > 31 * 1024) && (((((mz_uint)(d->m_pLZ_code_buf - d->m_lz_code_buf) * 115) >> 7) >= d->m_total_lz_bytes) || (d->m_flags & TDEFL_FORCE_ALL_RAW_BLOCKS)))) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int n; + d->m_pSrc = pSrc; + d->m_src_buf_left = src_buf_left; + if ((n = tdefl_flush_block(d, 0)) != 0) + return (n < 0) ? MZ_FALSE : MZ_TRUE; + } + } + + d->m_pSrc = pSrc; + d->m_src_buf_left = src_buf_left; + return MZ_TRUE; +} + +static tdefl_status tdefl_flush_output_buffer(tdefl_compressor *d) +{ + if (d->m_pIn_buf_size) + { + *d->m_pIn_buf_size = d->m_pSrc - (const mz_uint8 *)d->m_pIn_buf; + } + + if (d->m_pOut_buf_size) + { + size_t n = MZ_MIN(*d->m_pOut_buf_size - d->m_out_buf_ofs, d->m_output_flush_remaining); + memcpy((mz_uint8 *)d->m_pOut_buf + d->m_out_buf_ofs, d->m_output_buf + d->m_output_flush_ofs, n); + d->m_output_flush_ofs += (mz_uint)n; + d->m_output_flush_remaining -= (mz_uint)n; + d->m_out_buf_ofs += n; + + *d->m_pOut_buf_size = d->m_out_buf_ofs; + } + + return (d->m_finished && !d->m_output_flush_remaining) ? TDEFL_STATUS_DONE : TDEFL_STATUS_OKAY; +} + +tdefl_status tdefl_compress(tdefl_compressor *d, const void *pIn_buf, size_t *pIn_buf_size, void *pOut_buf, size_t *pOut_buf_size, tdefl_flush flush) +{ + if (!d) + { + if (pIn_buf_size) + *pIn_buf_size = 0; + if (pOut_buf_size) + *pOut_buf_size = 0; + return TDEFL_STATUS_BAD_PARAM; + } + + d->m_pIn_buf = pIn_buf; + d->m_pIn_buf_size = pIn_buf_size; + d->m_pOut_buf = pOut_buf; + d->m_pOut_buf_size = pOut_buf_size; + d->m_pSrc = (const mz_uint8 *)(pIn_buf); + d->m_src_buf_left = pIn_buf_size ? *pIn_buf_size : 0; + d->m_out_buf_ofs = 0; + d->m_flush = flush; + + if (((d->m_pPut_buf_func != NULL) == ((pOut_buf != NULL) || (pOut_buf_size != NULL))) || (d->m_prev_return_status != TDEFL_STATUS_OKAY) || + (d->m_wants_to_finish && (flush != TDEFL_FINISH)) || (pIn_buf_size && *pIn_buf_size && !pIn_buf) || (pOut_buf_size && *pOut_buf_size && !pOut_buf)) + { + if (pIn_buf_size) + *pIn_buf_size = 0; + if (pOut_buf_size) + *pOut_buf_size = 0; + return (d->m_prev_return_status = TDEFL_STATUS_BAD_PARAM); + } + d->m_wants_to_finish |= (flush == TDEFL_FINISH); + + if ((d->m_output_flush_remaining) || (d->m_finished)) + return (d->m_prev_return_status = tdefl_flush_output_buffer(d)); + +#if MINIZ_USE_UNALIGNED_LOADS_AND_STORES && MINIZ_LITTLE_ENDIAN + if (((d->m_flags & TDEFL_MAX_PROBES_MASK) == 1) && + ((d->m_flags & TDEFL_GREEDY_PARSING_FLAG) != 0) && + ((d->m_flags & (TDEFL_FILTER_MATCHES | TDEFL_FORCE_ALL_RAW_BLOCKS | TDEFL_RLE_MATCHES)) == 0)) + { + if (!tdefl_compress_fast(d)) + return d->m_prev_return_status; + } + else +#endif /* #if MINIZ_USE_UNALIGNED_LOADS_AND_STORES && MINIZ_LITTLE_ENDIAN */ + { + if (!tdefl_compress_normal(d)) + return d->m_prev_return_status; + } + + if ((d->m_flags & (TDEFL_WRITE_ZLIB_HEADER | TDEFL_COMPUTE_ADLER32)) && (pIn_buf)) + d->m_adler32 = (mz_uint32)mz_adler32(d->m_adler32, (const mz_uint8 *)pIn_buf, d->m_pSrc - (const mz_uint8 *)pIn_buf); + + if ((flush) && (!d->m_lookahead_size) && (!d->m_src_buf_left) && (!d->m_output_flush_remaining)) + { + if (tdefl_flush_block(d, flush) < 0) + return d->m_prev_return_status; + d->m_finished = (flush == TDEFL_FINISH); + if (flush == TDEFL_FULL_FLUSH) + { + MZ_CLEAR_OBJ(d->m_hash); + MZ_CLEAR_OBJ(d->m_next); + d->m_dict_size = 0; + } + } + + return (d->m_prev_return_status = tdefl_flush_output_buffer(d)); +} + +tdefl_status tdefl_compress_buffer(tdefl_compressor *d, const void *pIn_buf, size_t in_buf_size, tdefl_flush flush) +{ + MZ_ASSERT(d->m_pPut_buf_func); + return tdefl_compress(d, pIn_buf, &in_buf_size, NULL, NULL, flush); +} + +tdefl_status tdefl_init(tdefl_compressor *d, tdefl_put_buf_func_ptr pPut_buf_func, void *pPut_buf_user, int flags) +{ + d->m_pPut_buf_func = pPut_buf_func; + d->m_pPut_buf_user = pPut_buf_user; + d->m_flags = (mz_uint)(flags); + d->m_max_probes[0] = 1 + ((flags & 0xFFF) + 2) / 3; + d->m_greedy_parsing = (flags & TDEFL_GREEDY_PARSING_FLAG) != 0; + d->m_max_probes[1] = 1 + (((flags & 0xFFF) >> 2) + 2) / 3; + if (!(flags & TDEFL_NONDETERMINISTIC_PARSING_FLAG)) + MZ_CLEAR_OBJ(d->m_hash); + d->m_lookahead_pos = d->m_lookahead_size = d->m_dict_size = d->m_total_lz_bytes = d->m_lz_code_buf_dict_pos = d->m_bits_in = 0; + d->m_output_flush_ofs = d->m_output_flush_remaining = d->m_finished = d->m_block_index = d->m_bit_buffer = d->m_wants_to_finish = 0; + d->m_pLZ_code_buf = d->m_lz_code_buf + 1; + d->m_pLZ_flags = d->m_lz_code_buf; + d->m_num_flags_left = 8; + d->m_pOutput_buf = d->m_output_buf; + d->m_pOutput_buf_end = d->m_output_buf; + d->m_prev_return_status = TDEFL_STATUS_OKAY; + d->m_saved_match_dist = d->m_saved_match_len = d->m_saved_lit = 0; + d->m_adler32 = 1; + d->m_pIn_buf = NULL; + d->m_pOut_buf = NULL; + d->m_pIn_buf_size = NULL; + d->m_pOut_buf_size = NULL; + d->m_flush = TDEFL_NO_FLUSH; + d->m_pSrc = NULL; + d->m_src_buf_left = 0; + d->m_out_buf_ofs = 0; + if (!(flags & TDEFL_NONDETERMINISTIC_PARSING_FLAG)) + MZ_CLEAR_OBJ(d->m_dict); + memset(&d->m_huff_count[0][0], 0, sizeof(d->m_huff_count[0][0]) * TDEFL_MAX_HUFF_SYMBOLS_0); + memset(&d->m_huff_count[1][0], 0, sizeof(d->m_huff_count[1][0]) * TDEFL_MAX_HUFF_SYMBOLS_1); + return TDEFL_STATUS_OKAY; +} + +tdefl_status tdefl_get_prev_return_status(tdefl_compressor *d) +{ + return d->m_prev_return_status; +} + +mz_uint32 tdefl_get_adler32(tdefl_compressor *d) +{ + return d->m_adler32; +} + +mz_bool tdefl_compress_mem_to_output(const void *pBuf, size_t buf_len, tdefl_put_buf_func_ptr pPut_buf_func, void *pPut_buf_user, int flags) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + tdefl_compressor *pComp; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_bool succeeded; + if (((buf_len) && (!pBuf)) || (!pPut_buf_func)) + return MZ_FALSE; + pComp = (tdefl_compressor *)MZ_MALLOC(sizeof(tdefl_compressor)); + if (!pComp) + return MZ_FALSE; + succeeded = (tdefl_init(pComp, pPut_buf_func, pPut_buf_user, flags) == TDEFL_STATUS_OKAY); + succeeded = succeeded && (tdefl_compress_buffer(pComp, pBuf, buf_len, TDEFL_FINISH) == TDEFL_STATUS_DONE); + MZ_FREE(pComp); + return succeeded; +} + +typedef struct +{ + size_t m_size, m_capacity; + mz_uint8 *m_pBuf; + mz_bool m_expandable; +} tdefl_output_buffer; + +static mz_bool tdefl_output_buffer_putter(const void *pBuf, int len, void *pUser) +{ + tdefl_output_buffer *p = (tdefl_output_buffer *)pUser; + size_t new_size = p->m_size + len; + if (new_size > p->m_capacity) + { + size_t new_capacity = p->m_capacity; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint8 *pNew_buf; + if (!p->m_expandable) + return MZ_FALSE; + do + { + new_capacity = MZ_MAX(128U, new_capacity << 1U); + } while (new_size > new_capacity); + pNew_buf = (mz_uint8 *)MZ_REALLOC(p->m_pBuf, new_capacity); + if (!pNew_buf) + return MZ_FALSE; + p->m_pBuf = pNew_buf; + p->m_capacity = new_capacity; + } + memcpy((mz_uint8 *)p->m_pBuf + p->m_size, pBuf, len); + p->m_size = new_size; + return MZ_TRUE; +} + +void *tdefl_compress_mem_to_heap(const void *pSrc_buf, size_t src_buf_len, size_t *pOut_len, int flags) +{ + tdefl_output_buffer out_buf; + MZ_CLEAR_OBJ(out_buf); + if (!pOut_len) + return MZ_FALSE; + else + *pOut_len = 0; + out_buf.m_expandable = MZ_TRUE; + if (!tdefl_compress_mem_to_output(pSrc_buf, src_buf_len, tdefl_output_buffer_putter, &out_buf, flags)) + return NULL; + *pOut_len = out_buf.m_size; + return out_buf.m_pBuf; +} + +size_t tdefl_compress_mem_to_mem(void *pOut_buf, size_t out_buf_len, const void *pSrc_buf, size_t src_buf_len, int flags) +{ + tdefl_output_buffer out_buf; + MZ_CLEAR_OBJ(out_buf); + if (!pOut_buf) + return 0; + out_buf.m_pBuf = (mz_uint8 *)pOut_buf; + out_buf.m_capacity = out_buf_len; + if (!tdefl_compress_mem_to_output(pSrc_buf, src_buf_len, tdefl_output_buffer_putter, &out_buf, flags)) + return 0; + return out_buf.m_size; +} + +static const mz_uint s_tdefl_num_probes[11] = { 0, 1, 6, 32, 16, 32, 128, 256, 512, 768, 1500 }; + +/* level may actually range from [0,10] (10 is a "hidden" max level, where we want a bit more compression and it's fine if throughput to fall off a cliff on some files). */ +mz_uint tdefl_create_comp_flags_from_zip_params(int level, int window_bits, int strategy) +{ + mz_uint comp_flags = s_tdefl_num_probes[(level >= 0) ? MZ_MIN(10, level) : MZ_DEFAULT_LEVEL] | ((level <= 3) ? TDEFL_GREEDY_PARSING_FLAG : 0); + if (window_bits > 0) + comp_flags |= TDEFL_WRITE_ZLIB_HEADER; + + if (!level) + comp_flags |= TDEFL_FORCE_ALL_RAW_BLOCKS; + else if (strategy == MZ_FILTERED) + comp_flags |= TDEFL_FILTER_MATCHES; + else if (strategy == MZ_HUFFMAN_ONLY) + comp_flags &= ~TDEFL_MAX_PROBES_MASK; + else if (strategy == MZ_FIXED) + comp_flags |= TDEFL_FORCE_ALL_STATIC_BLOCKS; + else if (strategy == MZ_RLE) + comp_flags |= TDEFL_RLE_MATCHES; + + return comp_flags; +} + +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4204) /* nonstandard extension used : non-constant aggregate initializer (also supported by GNU C and C99, so no big deal) */ +#endif + +/* Simple PNG writer function by Alex Evans, 2011. Released into the public domain: https://gist.github.com/908299, more context at + http://altdevblogaday.org/2011/04/06/a-smaller-jpg-encoder/. + This is actually a modification of Alex's original code so PNG files generated by this function pass pngcheck. */ +void *tdefl_write_image_to_png_file_in_memory_ex(const void *pImage, int w, int h, int num_chans, size_t *pLen_out, mz_uint level, mz_bool flip) +{ + /* Using a local copy of this array here in case MINIZ_NO_ZLIB_APIS was defined. */ + static const mz_uint s_tdefl_png_num_probes[11] = { 0, 1, 6, 32, 16, 32, 128, 256, 512, 768, 1500 }; + tdefl_compressor *pComp = (tdefl_compressor *)MZ_MALLOC(sizeof(tdefl_compressor)); + tdefl_output_buffer out_buf; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int i, bpl = w * num_chans, y, z; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 c; + *pLen_out = 0; + if (!pComp) + return NULL; + MZ_CLEAR_OBJ(out_buf); + out_buf.m_expandable = MZ_TRUE; + out_buf.m_capacity = 57 + MZ_MAX(64, (1 + bpl) * h); + if (NULL == (out_buf.m_pBuf = (mz_uint8 *)MZ_MALLOC(out_buf.m_capacity))) + { + MZ_FREE(pComp); + return NULL; + } + /* write dummy header */ + for (z = 41; z; --z) + tdefl_output_buffer_putter(&z, 1, &out_buf); + /* compress image data */ + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + tdefl_init(pComp, tdefl_output_buffer_putter, &out_buf, s_tdefl_png_num_probes[MZ_MIN(10, level)] | TDEFL_WRITE_ZLIB_HEADER); + for (y = 0; y < h; ++y) + { + tdefl_compress_buffer(pComp, &z, 1, TDEFL_NO_FLUSH); + tdefl_compress_buffer(pComp, (mz_uint8 *)pImage + (flip ? (h - 1 - y) : y) * bpl, bpl, TDEFL_NO_FLUSH); + } + if (tdefl_compress_buffer(pComp, NULL, 0, TDEFL_FINISH) != TDEFL_STATUS_DONE) + { + MZ_FREE(pComp); + MZ_FREE(out_buf.m_pBuf); + return NULL; + } + /* write real header */ + *pLen_out = out_buf.m_size - 41; + { + static const mz_uint8 chans[] = { 0x00, 0x00, 0x04, 0x02, 0x06 }; + mz_uint8 pnghdr[41] = { 0x89, 0x50, 0x4e, 0x47, 0x0d, + 0x0a, 0x1a, 0x0a, 0x00, 0x00, + 0x00, 0x0d, 0x49, 0x48, 0x44, + 0x52, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x08, + 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x49, 0x44, 0x41, + 0x54 }; + pnghdr[18] = (mz_uint8)(w >> 8); + pnghdr[19] = (mz_uint8)w; + pnghdr[22] = (mz_uint8)(h >> 8); + pnghdr[23] = (mz_uint8)h; + pnghdr[25] = chans[num_chans]; + pnghdr[33] = (mz_uint8)(*pLen_out >> 24); + pnghdr[34] = (mz_uint8)(*pLen_out >> 16); + pnghdr[35] = (mz_uint8)(*pLen_out >> 8); + pnghdr[36] = (mz_uint8)*pLen_out; + c = (mz_uint32)mz_crc32(MZ_CRC32_INIT, pnghdr + 12, 17); + for (i = 0; i < 4; ++i, c <<= 8) + ((mz_uint8 *)(pnghdr + 29))[i] = (mz_uint8)(c >> 24); + memcpy(out_buf.m_pBuf, pnghdr, 41); + } + /* write footer (IDAT CRC-32, followed by IEND chunk) */ + if (!tdefl_output_buffer_putter("\0\0\0\0\0\0\0\0\x49\x45\x4e\x44\xae\x42\x60\x82", 16, &out_buf)) + { + *pLen_out = 0; + MZ_FREE(pComp); + MZ_FREE(out_buf.m_pBuf); + return NULL; + } + c = (mz_uint32)mz_crc32(MZ_CRC32_INIT, out_buf.m_pBuf + 41 - 4, *pLen_out + 4); + for (i = 0; i < 4; ++i, c <<= 8) + (out_buf.m_pBuf + out_buf.m_size - 16)[i] = (mz_uint8)(c >> 24); + /* compute final size of file, grab compressed data buffer and return */ + *pLen_out += 57; + MZ_FREE(pComp); + return out_buf.m_pBuf; +} +void *tdefl_write_image_to_png_file_in_memory(const void *pImage, int w, int h, int num_chans, size_t *pLen_out) +{ + /* Level 6 corresponds to TDEFL_DEFAULT_MAX_PROBES or MZ_DEFAULT_LEVEL (but we can't depend on MZ_DEFAULT_LEVEL being available in case the zlib API's where #defined out) */ + return tdefl_write_image_to_png_file_in_memory_ex(pImage, w, h, num_chans, pLen_out, 6, MZ_FALSE); +} + +#ifndef MINIZ_NO_MALLOC +/* Allocate the tdefl_compressor and tinfl_decompressor structures in C so that */ +/* non-C language bindings to tdefL_ and tinfl_ API don't need to worry about */ +/* structure size and allocation mechanism. */ +tdefl_compressor *tdefl_compressor_alloc() +{ + return (tdefl_compressor *)MZ_MALLOC(sizeof(tdefl_compressor)); +} + +void tdefl_compressor_free(tdefl_compressor *pComp) +{ + MZ_FREE(pComp); +} +#endif + +#ifdef _MSC_VER +#pragma warning(pop) +#endif + +/************************************************************************** + * + * Copyright 2013-2014 RAD Game Tools and Valve Software + * Copyright 2010-2014 Rich Geldreich and Tenacious Software LLC + * All Rights Reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + **************************************************************************/ + + + +/* ------------------- Low-level Decompression (completely independent from all compression API's) */ + +#define TINFL_MEMCPY(d, s, l) memcpy(d, s, l) +#define TINFL_MEMSET(p, c, l) memset(p, c, l) + +#define TINFL_CR_BEGIN \ + switch (r->m_state) \ + { \ + case 0: +#define TINFL_CR_RETURN(state_index, result) \ + do \ + { \ + status = result; \ + r->m_state = state_index; \ + goto common_exit; \ + case state_index:; \ + } \ + MZ_MACRO_END +#define TINFL_CR_RETURN_FOREVER(state_index, result) \ + do \ + { \ + for (;;) \ + { \ + TINFL_CR_RETURN(state_index, result); \ + } \ + } \ + MZ_MACRO_END +#define TINFL_CR_FINISH } + +#define TINFL_GET_BYTE(state_index, c) \ + do \ + { \ + while (pIn_buf_cur >= pIn_buf_end) \ + { \ + TINFL_CR_RETURN(state_index, (decomp_flags & TINFL_FLAG_HAS_MORE_INPUT) ? TINFL_STATUS_NEEDS_MORE_INPUT : TINFL_STATUS_FAILED_CANNOT_MAKE_PROGRESS); \ + } \ + c = *pIn_buf_cur++; \ + } \ + MZ_MACRO_END + +#define TINFL_NEED_BITS(state_index, n) \ + do \ + { \ + mz_uint c; \ + TINFL_GET_BYTE(state_index, c); \ + bit_buf |= (((tinfl_bit_buf_t)c) << num_bits); \ + num_bits += 8; \ + } while (num_bits < (mz_uint)(n)) +#define TINFL_SKIP_BITS(state_index, n) \ + do \ + { \ + if (num_bits < (mz_uint)(n)) \ + { \ + TINFL_NEED_BITS(state_index, n); \ + } \ + bit_buf >>= (n); \ + num_bits -= (n); \ + } \ + MZ_MACRO_END +#define TINFL_GET_BITS(state_index, b, n) \ + do \ + { \ + if (num_bits < (mz_uint)(n)) \ + { \ + TINFL_NEED_BITS(state_index, n); \ + } \ + b = bit_buf & ((1 << (n)) - 1); \ + bit_buf >>= (n); \ + num_bits -= (n); \ + } \ + MZ_MACRO_END + +/* TINFL_HUFF_BITBUF_FILL() is only used rarely, when the number of bytes remaining in the input buffer falls below 2. */ +/* It reads just enough bytes from the input stream that are needed to decode the next Huffman code (and absolutely no more). It works by trying to fully decode a */ +/* Huffman code by using whatever bits are currently present in the bit buffer. If this fails, it reads another byte, and tries again until it succeeds or until the */ +/* bit buffer contains >=15 bits (deflate's max. Huffman code size). */ +#define TINFL_HUFF_BITBUF_FILL(state_index, pHuff) \ + do \ + { \ + temp = (pHuff)->m_look_up[bit_buf & (TINFL_FAST_LOOKUP_SIZE - 1)]; \ + if (temp >= 0) \ + { \ + code_len = temp >> 9; \ + if ((code_len) && (num_bits >= code_len)) \ + break; \ + } \ + else if (num_bits > TINFL_FAST_LOOKUP_BITS) \ + { \ + code_len = TINFL_FAST_LOOKUP_BITS; \ + do \ + { \ + temp = (pHuff)->m_tree[~temp + ((bit_buf >> code_len++) & 1)]; \ + } while ((temp < 0) && (num_bits >= (code_len + 1))); \ + if (temp >= 0) \ + break; \ + } \ + TINFL_GET_BYTE(state_index, c); \ + bit_buf |= (((tinfl_bit_buf_t)c) << num_bits); \ + num_bits += 8; \ + } while (num_bits < 15); + +/* TINFL_HUFF_DECODE() decodes the next Huffman coded symbol. It's more complex than you would initially expect because the zlib API expects the decompressor to never read */ +/* beyond the final byte of the deflate stream. (In other words, when this macro wants to read another byte from the input, it REALLY needs another byte in order to fully */ +/* decode the next Huffman code.) Handling this properly is particularly important on raw deflate (non-zlib) streams, which aren't followed by a byte aligned adler-32. */ +/* The slow path is only executed at the very end of the input buffer. */ +/* v1.16: The original macro handled the case at the very end of the passed-in input buffer, but we also need to handle the case where the user passes in 1+zillion bytes */ +/* following the deflate data and our non-conservative read-ahead path won't kick in here on this code. This is much trickier. */ +#define TINFL_HUFF_DECODE(state_index, sym, pHuff) \ + do \ + { \ + int temp; \ + mz_uint code_len, c; \ + if (num_bits < 15) \ + { \ + if ((pIn_buf_end - pIn_buf_cur) < 2) \ + { \ + TINFL_HUFF_BITBUF_FILL(state_index, pHuff); \ + } \ + else \ + { \ + bit_buf |= (((tinfl_bit_buf_t)pIn_buf_cur[0]) << num_bits) | (((tinfl_bit_buf_t)pIn_buf_cur[1]) << (num_bits + 8)); \ + pIn_buf_cur += 2; \ + num_bits += 16; \ + } \ + } \ + if ((temp = (pHuff)->m_look_up[bit_buf & (TINFL_FAST_LOOKUP_SIZE - 1)]) >= 0) \ + code_len = temp >> 9, temp &= 511; \ + else \ + { \ + code_len = TINFL_FAST_LOOKUP_BITS; \ + do \ + { \ + temp = (pHuff)->m_tree[~temp + ((bit_buf >> code_len++) & 1)]; \ + } while (temp < 0); \ + } \ + sym = temp; \ + bit_buf >>= code_len; \ + num_bits -= code_len; \ + } \ + MZ_MACRO_END + +tinfl_status tinfl_decompress(tinfl_decompressor *r, const mz_uint8 *pIn_buf_next, size_t *pIn_buf_size, mz_uint8 *pOut_buf_start, mz_uint8 *pOut_buf_next, size_t *pOut_buf_size, const mz_uint32 decomp_flags) +{ + static const int s_length_base[31] = { 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258, 0, 0 }; + static const int s_length_extra[31] = { 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0, 0, 0 }; + static const int s_dist_base[32] = { 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577, 0, 0 }; + static const int s_dist_extra[32] = { 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13 }; + static const mz_uint8 s_length_dezigzag[19] = { 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15 }; + static const int s_min_table_sizes[3] = { 257, 1, 4 }; + + tinfl_status status = TINFL_STATUS_FAILED; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 num_bits, dist, counter, num_extra; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + tinfl_bit_buf_t bit_buf; + const mz_uint8 *pIn_buf_cur = pIn_buf_next, *const pIn_buf_end = pIn_buf_next + *pIn_buf_size; + mz_uint8 *pOut_buf_cur = pOut_buf_next, *const pOut_buf_end = pOut_buf_next + *pOut_buf_size; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + size_t out_buf_size_mask = (decomp_flags & TINFL_FLAG_USING_NON_WRAPPING_OUTPUT_BUF) ? (size_t)-1 : ((pOut_buf_next - pOut_buf_start) + *pOut_buf_size) - 1, dist_from_out_buf_start; + + /* Ensure the output buffer's size is a power of 2, unless the output buffer is large enough to hold the entire output file (in which case it doesn't matter). */ + if (((out_buf_size_mask + 1) & out_buf_size_mask) || (pOut_buf_next < pOut_buf_start)) + { + *pIn_buf_size = *pOut_buf_size = 0; + return TINFL_STATUS_BAD_PARAM; + } + + num_bits = r->m_num_bits; + bit_buf = r->m_bit_buf; + dist = r->m_dist; + counter = r->m_counter; + num_extra = r->m_num_extra; + dist_from_out_buf_start = r->m_dist_from_out_buf_start; + TINFL_CR_BEGIN + + bit_buf = num_bits = dist = counter = num_extra = r->m_zhdr0 = r->m_zhdr1 = 0; + r->m_z_adler32 = r->m_check_adler32 = 1; + if (decomp_flags & TINFL_FLAG_PARSE_ZLIB_HEADER) + { + TINFL_GET_BYTE(1, r->m_zhdr0); + TINFL_GET_BYTE(2, r->m_zhdr1); + counter = (((r->m_zhdr0 * 256 + r->m_zhdr1) % 31 != 0) || (r->m_zhdr1 & 32) || ((r->m_zhdr0 & 15) != 8)); + if (!(decomp_flags & TINFL_FLAG_USING_NON_WRAPPING_OUTPUT_BUF)) + // NOLINTNEXTLINE(bugprone-misplaced-widening-cast,cppcoreguidelines-avoid-magic-numbers) + counter |= (((1U << (8U + (r->m_zhdr0 >> 4))) > 32768U) || ((out_buf_size_mask + 1) < (size_t)(1U << (8U + (r->m_zhdr0 >> 4))))); + if (counter) + { + TINFL_CR_RETURN_FOREVER(36, TINFL_STATUS_FAILED); + } + } + + do + { + TINFL_GET_BITS(3, r->m_final, 3); + r->m_type = r->m_final >> 1; + if (r->m_type == 0) + { + TINFL_SKIP_BITS(5, num_bits & 7); + for (counter = 0; counter < 4; ++counter) + { + if (num_bits) + TINFL_GET_BITS(6, r->m_raw_header[counter], 8); + else + TINFL_GET_BYTE(7, r->m_raw_header[counter]); + } + if ((counter = (r->m_raw_header[0] | (r->m_raw_header[1] << 8))) != (mz_uint)(0xFFFF ^ (r->m_raw_header[2] | (r->m_raw_header[3] << 8)))) + { + TINFL_CR_RETURN_FOREVER(39, TINFL_STATUS_FAILED); + } + while ((counter) && (num_bits)) + { + TINFL_GET_BITS(51, dist, 8); + while (pOut_buf_cur >= pOut_buf_end) + { + TINFL_CR_RETURN(52, TINFL_STATUS_HAS_MORE_OUTPUT); + } + *pOut_buf_cur++ = (mz_uint8)dist; + counter--; + } + while (counter) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + size_t n; + while (pOut_buf_cur >= pOut_buf_end) + { + TINFL_CR_RETURN(9, TINFL_STATUS_HAS_MORE_OUTPUT); + } + while (pIn_buf_cur >= pIn_buf_end) + { + TINFL_CR_RETURN(38, (decomp_flags & TINFL_FLAG_HAS_MORE_INPUT) ? TINFL_STATUS_NEEDS_MORE_INPUT : TINFL_STATUS_FAILED_CANNOT_MAKE_PROGRESS); + } + n = MZ_MIN(MZ_MIN((size_t)(pOut_buf_end - pOut_buf_cur), (size_t)(pIn_buf_end - pIn_buf_cur)), counter); + TINFL_MEMCPY(pOut_buf_cur, pIn_buf_cur, n); + pIn_buf_cur += n; + pOut_buf_cur += n; + counter -= (mz_uint)n; + } + } + else if (r->m_type == 3) + { + TINFL_CR_RETURN_FOREVER(10, TINFL_STATUS_FAILED); + } + else + { + if (r->m_type == 1) + { + mz_uint8 *p = r->m_tables[0].m_code_size; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint i; + r->m_table_sizes[0] = 288; + r->m_table_sizes[1] = 32; + TINFL_MEMSET(r->m_tables[1].m_code_size, 5, 32); + for (i = 0; i <= 143; ++i) + *p++ = 8; + for (; i <= 255; ++i) + *p++ = 9; + for (; i <= 279; ++i) + *p++ = 7; + for (; i <= 287; ++i) + *p++ = 8; + } + else + { + for (counter = 0; counter < 3; counter++) + { + TINFL_GET_BITS(11, r->m_table_sizes[counter], "\05\05\04"[counter]); + r->m_table_sizes[counter] += s_min_table_sizes[counter]; + } + MZ_CLEAR_OBJ(r->m_tables[2].m_code_size); + for (counter = 0; counter < r->m_table_sizes[2]; counter++) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint s; + TINFL_GET_BITS(14, s, 3); + r->m_tables[2].m_code_size[s_length_dezigzag[counter]] = (mz_uint8)s; + } + r->m_table_sizes[2] = 19; + } + for (; (int)r->m_type >= 0; r->m_type--) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int tree_next, tree_cur; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + tinfl_huff_table *pTable; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-init-variables) + mz_uint i, j, used_syms, total, sym_index, next_code[17], total_syms[16]; + pTable = &r->m_tables[r->m_type]; + MZ_CLEAR_OBJ(total_syms); + MZ_CLEAR_OBJ(pTable->m_look_up); + MZ_CLEAR_OBJ(pTable->m_tree); + for (i = 0; i < r->m_table_sizes[r->m_type]; ++i) + total_syms[pTable->m_code_size[i]]++; + used_syms = 0, total = 0; + next_code[0] = next_code[1] = 0; + for (i = 1; i <= 15; ++i) + { + used_syms += total_syms[i]; + next_code[i + 1] = (total = ((total + total_syms[i]) << 1)); + } + if ((65536 != total) && (used_syms > 1)) + { + TINFL_CR_RETURN_FOREVER(35, TINFL_STATUS_FAILED); + } + for (tree_next = -1, sym_index = 0; sym_index < r->m_table_sizes[r->m_type]; ++sym_index) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint rev_code = 0, l, cur_code, code_size = pTable->m_code_size[sym_index]; + if (!code_size) + continue; + cur_code = next_code[code_size]++; + for (l = code_size; l > 0; l--, cur_code >>= 1) + rev_code = (rev_code << 1) | (cur_code & 1); + if (code_size <= TINFL_FAST_LOOKUP_BITS) + { + mz_int16 k = (mz_int16)((code_size << 9) | sym_index); + while (rev_code < TINFL_FAST_LOOKUP_SIZE) + { + pTable->m_look_up[rev_code] = k; + rev_code += (1 << code_size); + } + continue; + } + if (0 == (tree_cur = pTable->m_look_up[rev_code & (TINFL_FAST_LOOKUP_SIZE - 1)])) + { + pTable->m_look_up[rev_code & (TINFL_FAST_LOOKUP_SIZE - 1)] = (mz_int16)tree_next; + tree_cur = tree_next; + tree_next -= 2; + } + rev_code >>= (TINFL_FAST_LOOKUP_BITS - 1); + for (j = code_size; j > (TINFL_FAST_LOOKUP_BITS + 1); j--) + { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + tree_cur -= ((rev_code >>= 1) & 1); + if (!pTable->m_tree[-tree_cur - 1]) + { + pTable->m_tree[-tree_cur - 1] = (mz_int16)tree_next; + tree_cur = tree_next; + tree_next -= 2; + } + else + tree_cur = pTable->m_tree[-tree_cur - 1]; + } + // NOLINTNEXTLINE(bugprone-narrowing-conversions,clang-analyzer-deadcode.DeadStores,cppcoreguidelines-narrowing-conversions) + tree_cur -= ((rev_code >>= 1) & 1); + pTable->m_tree[-tree_cur - 1] = (mz_int16)sym_index; + } + if (r->m_type == 2) + { + for (counter = 0; counter < (r->m_table_sizes[0] + r->m_table_sizes[1]);) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint s; + TINFL_HUFF_DECODE(16, dist, &r->m_tables[2]); + if (dist < 16) + { + r->m_len_codes[counter++] = (mz_uint8)dist; + continue; + } + if ((dist == 16) && (!counter)) + { + TINFL_CR_RETURN_FOREVER(17, TINFL_STATUS_FAILED); + } + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,bugprone-signed-char-misuse) + num_extra = "\02\03\07"[dist - 16]; + TINFL_GET_BITS(18, s, num_extra); + s += "\03\03\013"[dist - 16]; + TINFL_MEMSET(r->m_len_codes + counter, (dist == 16) ? r->m_len_codes[counter - 1] : 0, s); + counter += s; + } + if ((r->m_table_sizes[0] + r->m_table_sizes[1]) != counter) + { + TINFL_CR_RETURN_FOREVER(21, TINFL_STATUS_FAILED); + } + TINFL_MEMCPY(r->m_tables[0].m_code_size, r->m_len_codes, r->m_table_sizes[0]); + TINFL_MEMCPY(r->m_tables[1].m_code_size, r->m_len_codes + r->m_table_sizes[0], r->m_table_sizes[1]); + } + } + for (;;) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint8 *pSrc; + for (;;) + { + if (((pIn_buf_end - pIn_buf_cur) < 4) || ((pOut_buf_end - pOut_buf_cur) < 2)) + { + TINFL_HUFF_DECODE(23, counter, &r->m_tables[0]); + if (counter >= 256) + break; + while (pOut_buf_cur >= pOut_buf_end) + { + TINFL_CR_RETURN(24, TINFL_STATUS_HAS_MORE_OUTPUT); + } + *pOut_buf_cur++ = (mz_uint8)counter; + } + else + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int sym2; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint code_len; +#if TINFL_USE_64BIT_BITBUF + if (num_bits < 30) + { + bit_buf |= (((tinfl_bit_buf_t)MZ_READ_LE32(pIn_buf_cur)) << num_bits); + pIn_buf_cur += 4; + num_bits += 32; + } +#else + if (num_bits < 15) + { + bit_buf |= (((tinfl_bit_buf_t)MZ_READ_LE16(pIn_buf_cur)) << num_bits); + pIn_buf_cur += 2; + num_bits += 16; + } +#endif + if ((sym2 = r->m_tables[0].m_look_up[bit_buf & (TINFL_FAST_LOOKUP_SIZE - 1)]) >= 0) + code_len = sym2 >> 9; + else + { + code_len = TINFL_FAST_LOOKUP_BITS; + do + { + sym2 = r->m_tables[0].m_tree[~sym2 + ((bit_buf >> code_len++) & 1)]; + } while (sym2 < 0); + } + counter = sym2; + bit_buf >>= code_len; + num_bits -= code_len; + if (counter & 256) + break; + +#if !TINFL_USE_64BIT_BITBUF + if (num_bits < 15) + { + bit_buf |= (((tinfl_bit_buf_t)MZ_READ_LE16(pIn_buf_cur)) << num_bits); + pIn_buf_cur += 2; + num_bits += 16; + } +#endif + if ((sym2 = r->m_tables[0].m_look_up[bit_buf & (TINFL_FAST_LOOKUP_SIZE - 1)]) >= 0) + code_len = sym2 >> 9; + else + { + code_len = TINFL_FAST_LOOKUP_BITS; + do + { + sym2 = r->m_tables[0].m_tree[~sym2 + ((bit_buf >> code_len++) & 1)]; + } while (sym2 < 0); + } + bit_buf >>= code_len; + num_bits -= code_len; + + pOut_buf_cur[0] = (mz_uint8)counter; + if (sym2 & 256) + { + pOut_buf_cur++; + counter = sym2; + break; + } + pOut_buf_cur[1] = (mz_uint8)sym2; + pOut_buf_cur += 2; + } + } + if ((counter &= 511) == 256) + break; + + num_extra = s_length_extra[counter - 257]; + counter = s_length_base[counter - 257]; + if (num_extra) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint extra_bits; + TINFL_GET_BITS(25, extra_bits, num_extra); + counter += extra_bits; + } + + TINFL_HUFF_DECODE(26, dist, &r->m_tables[1]); + num_extra = s_dist_extra[dist]; + dist = s_dist_base[dist]; + if (num_extra) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint extra_bits; + TINFL_GET_BITS(27, extra_bits, num_extra); + dist += extra_bits; + } + + dist_from_out_buf_start = pOut_buf_cur - pOut_buf_start; + if ((dist > dist_from_out_buf_start) && (decomp_flags & TINFL_FLAG_USING_NON_WRAPPING_OUTPUT_BUF)) + { + TINFL_CR_RETURN_FOREVER(37, TINFL_STATUS_FAILED); + } + + pSrc = pOut_buf_start + ((dist_from_out_buf_start - dist) & out_buf_size_mask); + + if ((MZ_MAX(pOut_buf_cur, pSrc) + counter) > pOut_buf_end) + { + while (counter--) + { + while (pOut_buf_cur >= pOut_buf_end) + { + TINFL_CR_RETURN(53, TINFL_STATUS_HAS_MORE_OUTPUT); + } + *pOut_buf_cur++ = pOut_buf_start[(dist_from_out_buf_start++ - dist) & out_buf_size_mask]; + } + continue; + } +#if MINIZ_USE_UNALIGNED_LOADS_AND_STORES + else if ((counter >= 9) && (counter <= dist)) + { + const mz_uint8 *pSrc_end = pSrc + (counter & ~7); + do + { +#ifdef MINIZ_UNALIGNED_USE_MEMCPY + memcpy(pOut_buf_cur, pSrc, sizeof(mz_uint32)*2); +#else + ((mz_uint32 *)pOut_buf_cur)[0] = ((const mz_uint32 *)pSrc)[0]; + ((mz_uint32 *)pOut_buf_cur)[1] = ((const mz_uint32 *)pSrc)[1]; +#endif + pOut_buf_cur += 8; + } while ((pSrc += 8) < pSrc_end); + if ((counter &= 7) < 3) + { + if (counter) + { + pOut_buf_cur[0] = pSrc[0]; + if (counter > 1) + pOut_buf_cur[1] = pSrc[1]; + pOut_buf_cur += counter; + } + continue; + } + } +#endif + while(counter>2) + { + pOut_buf_cur[0] = pSrc[0]; + pOut_buf_cur[1] = pSrc[1]; + pOut_buf_cur[2] = pSrc[2]; + pOut_buf_cur += 3; + pSrc += 3; + counter -= 3; + } + if (counter > 0) + { + pOut_buf_cur[0] = pSrc[0]; + if (counter > 1) + pOut_buf_cur[1] = pSrc[1]; + pOut_buf_cur += counter; + } + } + } + } while (!(r->m_final & 1)); + + /* Ensure byte alignment and put back any bytes from the bitbuf if we've looked ahead too far on gzip, or other Deflate streams followed by arbitrary data. */ + /* I'm being super conservative here. A number of simplifications can be made to the byte alignment part, and the Adler32 check shouldn't ever need to worry about reading from the bitbuf now. */ + TINFL_SKIP_BITS(32, num_bits & 7); + while ((pIn_buf_cur > pIn_buf_next) && (num_bits >= 8)) + { + --pIn_buf_cur; + num_bits -= 8; + } + bit_buf &= (tinfl_bit_buf_t)((((mz_uint64)1) << num_bits) - (mz_uint64)1); + MZ_ASSERT(!num_bits); /* if this assert fires then we've read beyond the end of non-deflate/zlib streams with following data (such as gzip streams). */ + + if (decomp_flags & TINFL_FLAG_PARSE_ZLIB_HEADER) + { + for (counter = 0; counter < 4; ++counter) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint s; + if (num_bits) + TINFL_GET_BITS(41, s, 8); + else + TINFL_GET_BYTE(42, s); + r->m_z_adler32 = (r->m_z_adler32 << 8) | s; + } + } + TINFL_CR_RETURN_FOREVER(34, TINFL_STATUS_DONE); + + TINFL_CR_FINISH + +common_exit: + /* As long as we aren't telling the caller that we NEED more input to make forward progress: */ + /* Put back any bytes from the bitbuf in case we've looked ahead too far on gzip, or other Deflate streams followed by arbitrary data. */ + /* We need to be very careful here to NOT push back any bytes we definitely know we need to make forward progress, though, or we'll lock the caller up into an inf loop. */ + if ((status != TINFL_STATUS_NEEDS_MORE_INPUT) && (status != TINFL_STATUS_FAILED_CANNOT_MAKE_PROGRESS)) + { + while ((pIn_buf_cur > pIn_buf_next) && (num_bits >= 8)) + { + --pIn_buf_cur; + num_bits -= 8; + } + } + r->m_num_bits = num_bits; + r->m_bit_buf = bit_buf & (tinfl_bit_buf_t)((((mz_uint64)1) << num_bits) - (mz_uint64)1); + r->m_dist = dist; + r->m_counter = counter; + r->m_num_extra = num_extra; + r->m_dist_from_out_buf_start = dist_from_out_buf_start; + *pIn_buf_size = pIn_buf_cur - pIn_buf_next; + *pOut_buf_size = pOut_buf_cur - pOut_buf_next; + if ((decomp_flags & (TINFL_FLAG_PARSE_ZLIB_HEADER | TINFL_FLAG_COMPUTE_ADLER32)) && (status >= 0)) + { + const mz_uint8 *ptr = pOut_buf_next; + size_t buf_len = *pOut_buf_size; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-init-variables) + mz_uint32 i, s1 = r->m_check_adler32 & 0xffff, s2 = r->m_check_adler32 >> 16; + size_t block_len = buf_len % 5552; + while (buf_len) + { + for (i = 0; i + 7 < block_len; i += 8, ptr += 8) + { + s1 += ptr[0], s2 += s1; + s1 += ptr[1], s2 += s1; + s1 += ptr[2], s2 += s1; + s1 += ptr[3], s2 += s1; + s1 += ptr[4], s2 += s1; + s1 += ptr[5], s2 += s1; + s1 += ptr[6], s2 += s1; + s1 += ptr[7], s2 += s1; + } + for (; i < block_len; ++i) + s1 += *ptr++, s2 += s1; + s1 %= 65521U, s2 %= 65521U; + buf_len -= block_len; + block_len = 5552; + } + r->m_check_adler32 = (s2 << 16) + s1; + if ((status == TINFL_STATUS_DONE) && (decomp_flags & TINFL_FLAG_PARSE_ZLIB_HEADER) && (r->m_check_adler32 != r->m_z_adler32)) + status = TINFL_STATUS_ADLER32_MISMATCH; + } + return status; +} + +/* Higher level helper functions. */ +void *tinfl_decompress_mem_to_heap(const void *pSrc_buf, size_t src_buf_len, size_t *pOut_len, int flags) +{ + tinfl_decompressor decomp; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + void *pBuf = NULL, *pNew_buf; + size_t src_buf_ofs = 0, out_buf_capacity = 0; + *pOut_len = 0; + tinfl_init(&decomp); + for (;;) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + size_t src_buf_size = src_buf_len - src_buf_ofs, dst_buf_size = out_buf_capacity - *pOut_len, new_out_buf_capacity; + tinfl_status status = tinfl_decompress(&decomp, (const mz_uint8 *)pSrc_buf + src_buf_ofs, &src_buf_size, (mz_uint8 *)pBuf, pBuf ? (mz_uint8 *)pBuf + *pOut_len : NULL, &dst_buf_size, + (flags & ~TINFL_FLAG_HAS_MORE_INPUT) | TINFL_FLAG_USING_NON_WRAPPING_OUTPUT_BUF); + if ((status < 0) || (status == TINFL_STATUS_NEEDS_MORE_INPUT)) + { + MZ_FREE(pBuf); + *pOut_len = 0; + return NULL; + } + src_buf_ofs += src_buf_size; + *pOut_len += dst_buf_size; + if (status == TINFL_STATUS_DONE) + break; + new_out_buf_capacity = out_buf_capacity * 2; + if (new_out_buf_capacity < 128) + new_out_buf_capacity = 128; + pNew_buf = MZ_REALLOC(pBuf, new_out_buf_capacity); + if (!pNew_buf) + { + MZ_FREE(pBuf); + *pOut_len = 0; + return NULL; + } + pBuf = pNew_buf; + out_buf_capacity = new_out_buf_capacity; + } + return pBuf; +} + +size_t tinfl_decompress_mem_to_mem(void *pOut_buf, size_t out_buf_len, const void *pSrc_buf, size_t src_buf_len, int flags) +{ + tinfl_decompressor decomp; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + tinfl_status status; + tinfl_init(&decomp); + status = tinfl_decompress(&decomp, (const mz_uint8 *)pSrc_buf, &src_buf_len, (mz_uint8 *)pOut_buf, (mz_uint8 *)pOut_buf, &out_buf_len, (flags & ~TINFL_FLAG_HAS_MORE_INPUT) | TINFL_FLAG_USING_NON_WRAPPING_OUTPUT_BUF); + return (status != TINFL_STATUS_DONE) ? TINFL_DECOMPRESS_MEM_TO_MEM_FAILED : out_buf_len; +} + +int tinfl_decompress_mem_to_callback(const void *pIn_buf, size_t *pIn_buf_size, tinfl_put_buf_func_ptr pPut_buf_func, void *pPut_buf_user, int flags) +{ + int result = 0; + tinfl_decompressor decomp; + mz_uint8 *pDict = (mz_uint8 *)MZ_MALLOC(TINFL_LZ_DICT_SIZE); + size_t in_buf_ofs = 0, dict_ofs = 0; + if (!pDict) + return TINFL_STATUS_FAILED; + tinfl_init(&decomp); + for (;;) + { + size_t in_buf_size = *pIn_buf_size - in_buf_ofs, dst_buf_size = TINFL_LZ_DICT_SIZE - dict_ofs; + tinfl_status status = tinfl_decompress(&decomp, (const mz_uint8 *)pIn_buf + in_buf_ofs, &in_buf_size, pDict, pDict + dict_ofs, &dst_buf_size, + (flags & ~(TINFL_FLAG_HAS_MORE_INPUT | TINFL_FLAG_USING_NON_WRAPPING_OUTPUT_BUF))); + in_buf_ofs += in_buf_size; + if ((dst_buf_size) && (!(*pPut_buf_func)(pDict + dict_ofs, (int)dst_buf_size, pPut_buf_user))) + break; + if (status != TINFL_STATUS_HAS_MORE_OUTPUT) + { + result = (status == TINFL_STATUS_DONE); + break; + } + dict_ofs = (dict_ofs + dst_buf_size) & (TINFL_LZ_DICT_SIZE - 1); + } + MZ_FREE(pDict); + *pIn_buf_size = in_buf_ofs; + return result; +} + +#ifndef MINIZ_NO_MALLOC +tinfl_decompressor *tinfl_decompressor_alloc() +{ + tinfl_decompressor *pDecomp = (tinfl_decompressor *)MZ_MALLOC(sizeof(tinfl_decompressor)); + if (pDecomp) + tinfl_init(pDecomp); + return pDecomp; +} + +void tinfl_decompressor_free(tinfl_decompressor *pDecomp) +{ + MZ_FREE(pDecomp); +} +#endif + +/************************************************************************** + * + * Copyright 2013-2014 RAD Game Tools and Valve Software + * Copyright 2010-2014 Rich Geldreich and Tenacious Software LLC + * Copyright 2016 Martin Raiber + * All Rights Reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + **************************************************************************/ + + +#ifndef MINIZ_NO_ARCHIVE_APIS + +/* ------------------- .ZIP archive reading */ + +#ifdef MINIZ_NO_STDIO +#define MZ_FILE void * +#else + +#if defined(_MSC_VER) || defined(__MINGW64__) +static FILE *mz_fopen(const char *pFilename, const char *pMode) +{ + FILE *pFile = NULL; + fopen_s(&pFile, pFilename, pMode); + return pFile; +} +static FILE *mz_freopen(const char *pPath, const char *pMode, FILE *pStream) +{ + FILE *pFile = NULL; + if (freopen_s(&pFile, pPath, pMode, pStream)) + return NULL; + return pFile; +} +#ifndef MINIZ_NO_TIME +#include +#endif +#define MZ_FOPEN mz_fopen +#define MZ_FCLOSE fclose +#define MZ_FREAD fread +#define MZ_FWRITE fwrite +#define MZ_FTELL64 _ftelli64 +#define MZ_FSEEK64 _fseeki64 +#define MZ_FILE_STAT_STRUCT _stat64 +#define MZ_FILE_STAT _stat64 +#define MZ_FFLUSH fflush +#define MZ_FREOPEN mz_freopen +#define MZ_DELETE_FILE remove +#elif defined(__MINGW32__) +#ifndef MINIZ_NO_TIME +#include +#endif +#define MZ_FOPEN(f, m) fopen(f, m) +#define MZ_FCLOSE fclose +#define MZ_FREAD fread +#define MZ_FWRITE fwrite +#define MZ_FTELL64 ftello64 +#define MZ_FSEEK64 fseeko64 +#define MZ_FILE_STAT_STRUCT _stat +#define MZ_FILE_STAT _stat +#define MZ_FFLUSH fflush +#define MZ_FREOPEN(f, m, s) freopen(f, m, s) +#define MZ_DELETE_FILE remove +#elif defined(__TINYC__) +#ifndef MINIZ_NO_TIME +#include +#endif +#define MZ_FOPEN(f, m) fopen(f, m) +#define MZ_FCLOSE fclose +#define MZ_FREAD fread +#define MZ_FWRITE fwrite +#define MZ_FTELL64 ftell +#define MZ_FSEEK64 fseek +#define MZ_FILE_STAT_STRUCT stat +#define MZ_FILE_STAT stat +#define MZ_FFLUSH fflush +#define MZ_FREOPEN(f, m, s) freopen(f, m, s) +#define MZ_DELETE_FILE remove +#elif defined(__GNUC__) && defined(_LARGEFILE64_SOURCE) +#ifndef MINIZ_NO_TIME +#include +#endif +#define MZ_FOPEN(f, m) fopen64(f, m) +#define MZ_FCLOSE fclose +#define MZ_FREAD fread +#define MZ_FWRITE fwrite +#define MZ_FTELL64 ftello64 +#define MZ_FSEEK64 fseeko64 +#define MZ_FILE_STAT_STRUCT stat64 +#define MZ_FILE_STAT stat64 +#define MZ_FFLUSH fflush +#define MZ_FREOPEN(p, m, s) freopen64(p, m, s) +#define MZ_DELETE_FILE remove +#elif defined(__APPLE__) +#ifndef MINIZ_NO_TIME +#include +#endif +#define MZ_FOPEN(f, m) fopen(f, m) +#define MZ_FCLOSE fclose +#define MZ_FREAD fread +#define MZ_FWRITE fwrite +#define MZ_FTELL64 ftello +#define MZ_FSEEK64 fseeko +#define MZ_FILE_STAT_STRUCT stat +#define MZ_FILE_STAT stat +#define MZ_FFLUSH fflush +#define MZ_FREOPEN(p, m, s) freopen(p, m, s) +#define MZ_DELETE_FILE remove + +#else +#pragma message("Using fopen, ftello, fseeko, stat() etc. path for file I/O - this path may not support large files.") +#ifndef MINIZ_NO_TIME +#include +#endif +#define MZ_FOPEN(f, m) fopen(f, m) +#define MZ_FCLOSE fclose +#define MZ_FREAD fread +#define MZ_FWRITE fwrite +#ifdef __STRICT_ANSI__ +#define MZ_FTELL64 ftell +#define MZ_FSEEK64 fseek +#else +#define MZ_FTELL64 ftello +#define MZ_FSEEK64 fseeko +#endif +#define MZ_FILE_STAT_STRUCT stat +#define MZ_FILE_STAT stat +#define MZ_FFLUSH fflush +#define MZ_FREOPEN(f, m, s) freopen(f, m, s) +#define MZ_DELETE_FILE remove +#endif /* #ifdef _MSC_VER */ +#endif /* #ifdef MINIZ_NO_STDIO */ + +#define MZ_TOLOWER(c) ((((c) >= 'A') && ((c) <= 'Z')) ? ((c) - 'A' + 'a') : (c)) + +/* Various ZIP archive enums. To completely avoid cross platform compiler alignment and platform endian issues, miniz.c doesn't use structs for any of this stuff. */ +enum +{ + /* ZIP archive identifiers and record sizes */ + MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIG = 0x06054b50, + MZ_ZIP_CENTRAL_DIR_HEADER_SIG = 0x02014b50, + MZ_ZIP_LOCAL_DIR_HEADER_SIG = 0x04034b50, + MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30, + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE = 46, + MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIZE = 22, + + /* ZIP64 archive identifier and record sizes */ + MZ_ZIP64_END_OF_CENTRAL_DIR_HEADER_SIG = 0x06064b50, + MZ_ZIP64_END_OF_CENTRAL_DIR_LOCATOR_SIG = 0x07064b50, + MZ_ZIP64_END_OF_CENTRAL_DIR_HEADER_SIZE = 56, + MZ_ZIP64_END_OF_CENTRAL_DIR_LOCATOR_SIZE = 20, + MZ_ZIP64_EXTENDED_INFORMATION_FIELD_HEADER_ID = 0x0001, + MZ_ZIP_DATA_DESCRIPTOR_ID = 0x08074b50, + MZ_ZIP_DATA_DESCRIPTER_SIZE64 = 24, + MZ_ZIP_DATA_DESCRIPTER_SIZE32 = 16, + + /* Central directory header record offsets */ + MZ_ZIP_CDH_SIG_OFS = 0, + MZ_ZIP_CDH_VERSION_MADE_BY_OFS = 4, + MZ_ZIP_CDH_VERSION_NEEDED_OFS = 6, + MZ_ZIP_CDH_BIT_FLAG_OFS = 8, + MZ_ZIP_CDH_METHOD_OFS = 10, + MZ_ZIP_CDH_FILE_TIME_OFS = 12, + MZ_ZIP_CDH_FILE_DATE_OFS = 14, + MZ_ZIP_CDH_CRC32_OFS = 16, + MZ_ZIP_CDH_COMPRESSED_SIZE_OFS = 20, + MZ_ZIP_CDH_DECOMPRESSED_SIZE_OFS = 24, + MZ_ZIP_CDH_FILENAME_LEN_OFS = 28, + MZ_ZIP_CDH_EXTRA_LEN_OFS = 30, + MZ_ZIP_CDH_COMMENT_LEN_OFS = 32, + MZ_ZIP_CDH_DISK_START_OFS = 34, + MZ_ZIP_CDH_INTERNAL_ATTR_OFS = 36, + MZ_ZIP_CDH_EXTERNAL_ATTR_OFS = 38, + MZ_ZIP_CDH_LOCAL_HEADER_OFS = 42, + + /* Local directory header offsets */ + MZ_ZIP_LDH_SIG_OFS = 0, + MZ_ZIP_LDH_VERSION_NEEDED_OFS = 4, + MZ_ZIP_LDH_BIT_FLAG_OFS = 6, + MZ_ZIP_LDH_METHOD_OFS = 8, + MZ_ZIP_LDH_FILE_TIME_OFS = 10, + MZ_ZIP_LDH_FILE_DATE_OFS = 12, + MZ_ZIP_LDH_CRC32_OFS = 14, + MZ_ZIP_LDH_COMPRESSED_SIZE_OFS = 18, + MZ_ZIP_LDH_DECOMPRESSED_SIZE_OFS = 22, + MZ_ZIP_LDH_FILENAME_LEN_OFS = 26, + MZ_ZIP_LDH_EXTRA_LEN_OFS = 28, + MZ_ZIP_LDH_BIT_FLAG_HAS_LOCATOR = 1 << 3, + + /* End of central directory offsets */ + MZ_ZIP_ECDH_SIG_OFS = 0, + MZ_ZIP_ECDH_NUM_THIS_DISK_OFS = 4, + MZ_ZIP_ECDH_NUM_DISK_CDIR_OFS = 6, + MZ_ZIP_ECDH_CDIR_NUM_ENTRIES_ON_DISK_OFS = 8, + MZ_ZIP_ECDH_CDIR_TOTAL_ENTRIES_OFS = 10, + MZ_ZIP_ECDH_CDIR_SIZE_OFS = 12, + MZ_ZIP_ECDH_CDIR_OFS_OFS = 16, + MZ_ZIP_ECDH_COMMENT_SIZE_OFS = 20, + + /* ZIP64 End of central directory locator offsets */ + MZ_ZIP64_ECDL_SIG_OFS = 0, /* 4 bytes */ + MZ_ZIP64_ECDL_NUM_DISK_CDIR_OFS = 4, /* 4 bytes */ + MZ_ZIP64_ECDL_REL_OFS_TO_ZIP64_ECDR_OFS = 8, /* 8 bytes */ + MZ_ZIP64_ECDL_TOTAL_NUMBER_OF_DISKS_OFS = 16, /* 4 bytes */ + + /* ZIP64 End of central directory header offsets */ + MZ_ZIP64_ECDH_SIG_OFS = 0, /* 4 bytes */ + MZ_ZIP64_ECDH_SIZE_OF_RECORD_OFS = 4, /* 8 bytes */ + MZ_ZIP64_ECDH_VERSION_MADE_BY_OFS = 12, /* 2 bytes */ + MZ_ZIP64_ECDH_VERSION_NEEDED_OFS = 14, /* 2 bytes */ + MZ_ZIP64_ECDH_NUM_THIS_DISK_OFS = 16, /* 4 bytes */ + MZ_ZIP64_ECDH_NUM_DISK_CDIR_OFS = 20, /* 4 bytes */ + MZ_ZIP64_ECDH_CDIR_NUM_ENTRIES_ON_DISK_OFS = 24, /* 8 bytes */ + MZ_ZIP64_ECDH_CDIR_TOTAL_ENTRIES_OFS = 32, /* 8 bytes */ + MZ_ZIP64_ECDH_CDIR_SIZE_OFS = 40, /* 8 bytes */ + MZ_ZIP64_ECDH_CDIR_OFS_OFS = 48, /* 8 bytes */ + MZ_ZIP_VERSION_MADE_BY_DOS_FILESYSTEM_ID = 0, + MZ_ZIP_DOS_DIR_ATTRIBUTE_BITFLAG = 0x10, + MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_IS_ENCRYPTED = 1, + MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_COMPRESSED_PATCH_FLAG = 32, + MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_USES_STRONG_ENCRYPTION = 64, + MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_LOCAL_DIR_IS_MASKED = 8192, + MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_UTF8 = 1 << 11 +}; + +typedef struct +{ + void *m_p; + size_t m_size, m_capacity; + mz_uint m_element_size; +} mz_zip_array; + +struct mz_zip_internal_state_tag +{ + mz_zip_array m_central_dir; + mz_zip_array m_central_dir_offsets; + mz_zip_array m_sorted_central_dir_offsets; + + /* The flags passed in when the archive is initially opened. */ + uint32_t m_init_flags; + + /* MZ_TRUE if the archive has a zip64 end of central directory headers, etc. */ + mz_bool m_zip64; + + /* MZ_TRUE if we found zip64 extended info in the central directory (m_zip64 will also be slammed to true too, even if we didn't find a zip64 end of central dir header, etc.) */ + mz_bool m_zip64_has_extended_info_fields; + + /* These fields are used by the file, FILE, memory, and memory/heap read/write helpers. */ + MZ_FILE *m_pFile; + mz_uint64 m_file_archive_start_ofs; + + void *m_pMem; + size_t m_mem_size; + size_t m_mem_capacity; +}; + +#define MZ_ZIP_ARRAY_SET_ELEMENT_SIZE(array_ptr, element_size) (array_ptr)->m_element_size = element_size + +#if defined(DEBUG) || defined(_DEBUG) || defined(NDEBUG) +static MZ_FORCEINLINE mz_uint mz_zip_array_range_check(const mz_zip_array *pArray, mz_uint index) +{ + MZ_ASSERT(index < pArray->m_size); + return index; +} +#define MZ_ZIP_ARRAY_ELEMENT(array_ptr, element_type, index) ((element_type *)((array_ptr)->m_p))[mz_zip_array_range_check(array_ptr, index)] +#else +#define MZ_ZIP_ARRAY_ELEMENT(array_ptr, element_type, index) ((element_type *)((array_ptr)->m_p))[index] +#endif + +static MZ_FORCEINLINE void mz_zip_array_init(mz_zip_array *pArray, mz_uint32 element_size) +{ + memset(pArray, 0, sizeof(mz_zip_array)); + pArray->m_element_size = element_size; +} + +static MZ_FORCEINLINE void mz_zip_array_clear(mz_zip_archive *pZip, mz_zip_array *pArray) +{ + pZip->m_pFree(pZip->m_pAlloc_opaque, pArray->m_p); + memset(pArray, 0, sizeof(mz_zip_array)); +} + +static mz_bool mz_zip_array_ensure_capacity(mz_zip_archive *pZip, mz_zip_array *pArray, size_t min_new_capacity, mz_uint growing) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + void *pNew_p; + size_t new_capacity = min_new_capacity; + MZ_ASSERT(pArray->m_element_size); + if (pArray->m_capacity >= min_new_capacity) + return MZ_TRUE; + if (growing) + { + new_capacity = MZ_MAX(1, pArray->m_capacity); + while (new_capacity < min_new_capacity) + new_capacity *= 2; + } + if (NULL == (pNew_p = pZip->m_pRealloc(pZip->m_pAlloc_opaque, pArray->m_p, pArray->m_element_size, new_capacity))) + return MZ_FALSE; + pArray->m_p = pNew_p; + pArray->m_capacity = new_capacity; + return MZ_TRUE; +} + +static MZ_FORCEINLINE mz_bool mz_zip_array_reserve(mz_zip_archive *pZip, mz_zip_array *pArray, size_t new_capacity, mz_uint growing) +{ + if (new_capacity > pArray->m_capacity) + { + if (!mz_zip_array_ensure_capacity(pZip, pArray, new_capacity, growing)) + return MZ_FALSE; + } + return MZ_TRUE; +} + +static MZ_FORCEINLINE mz_bool mz_zip_array_resize(mz_zip_archive *pZip, mz_zip_array *pArray, size_t new_size, mz_uint growing) +{ + if (new_size > pArray->m_capacity) + { + if (!mz_zip_array_ensure_capacity(pZip, pArray, new_size, growing)) + return MZ_FALSE; + } + pArray->m_size = new_size; + return MZ_TRUE; +} + +static MZ_FORCEINLINE mz_bool mz_zip_array_ensure_room(mz_zip_archive *pZip, mz_zip_array *pArray, size_t n) +{ + return mz_zip_array_reserve(pZip, pArray, pArray->m_size + n, MZ_TRUE); +} + +static MZ_FORCEINLINE mz_bool mz_zip_array_push_back(mz_zip_archive *pZip, mz_zip_array *pArray, const void *pElements, size_t n) +{ + size_t orig_size = pArray->m_size; + if (!mz_zip_array_resize(pZip, pArray, orig_size + n, MZ_TRUE)) + return MZ_FALSE; + if (n > 0) + memcpy((mz_uint8 *)pArray->m_p + orig_size * pArray->m_element_size, pElements, n * pArray->m_element_size); + return MZ_TRUE; +} + +#ifndef MINIZ_NO_TIME +static MZ_TIME_T mz_zip_dos_to_time_t(int dos_time, int dos_date) +{ + struct tm tm; + memset(&tm, 0, sizeof(tm)); + tm.tm_isdst = -1; + tm.tm_year = ((dos_date >> 9) & 127) + 1980 - 1900; + tm.tm_mon = ((dos_date >> 5) & 15) - 1; + tm.tm_mday = dos_date & 31; + tm.tm_hour = (dos_time >> 11) & 31; + tm.tm_min = (dos_time >> 5) & 63; + tm.tm_sec = (dos_time << 1) & 62; + return mktime(&tm); +} + +#ifndef MINIZ_NO_ARCHIVE_WRITING_APIS +static void mz_zip_time_t_to_dos_time(MZ_TIME_T time, mz_uint16 *pDOS_time, mz_uint16 *pDOS_date) +{ +#ifdef _MSC_VER + struct tm tm_struct; + struct tm *tm = &tm_struct; + errno_t err = localtime_s(tm, &time); + if (err) + { + *pDOS_date = 0; + *pDOS_time = 0; + return; + } +#else + struct tm *tm = localtime(&time); +#endif /* #ifdef _MSC_VER */ + + *pDOS_time = (mz_uint16)(((tm->tm_hour) << 11) + ((tm->tm_min) << 5) + ((tm->tm_sec) >> 1)); + *pDOS_date = (mz_uint16)(((tm->tm_year + 1900 - 1980) << 9) + ((tm->tm_mon + 1) << 5) + tm->tm_mday); +} +#endif /* MINIZ_NO_ARCHIVE_WRITING_APIS */ + +#ifndef MINIZ_NO_STDIO +#ifndef MINIZ_NO_ARCHIVE_WRITING_APIS +static mz_bool mz_zip_get_file_modified_time(const char *pFilename, MZ_TIME_T *pTime) +{ + struct MZ_FILE_STAT_STRUCT file_stat; + + /* On Linux with x86 glibc, this call will fail on large files (I think >= 0x80000000 bytes) unless you compiled with _LARGEFILE64_SOURCE. Argh. */ + if (MZ_FILE_STAT(pFilename, &file_stat) != 0) + return MZ_FALSE; + + *pTime = file_stat.st_mtime; + + return MZ_TRUE; +} +#endif /* #ifndef MINIZ_NO_ARCHIVE_WRITING_APIS*/ + +static mz_bool mz_zip_set_file_times(const char *pFilename, MZ_TIME_T access_time, MZ_TIME_T modified_time) +{ + struct utimbuf t; + + memset(&t, 0, sizeof(t)); + t.actime = access_time; + t.modtime = modified_time; + + return !utime(pFilename, &t); +} +#endif /* #ifndef MINIZ_NO_STDIO */ +#endif /* #ifndef MINIZ_NO_TIME */ + +static MZ_FORCEINLINE mz_bool mz_zip_set_error(mz_zip_archive *pZip, mz_zip_error err_num) +{ + if (pZip) + pZip->m_last_error = err_num; + return MZ_FALSE; +} + +static mz_bool mz_zip_reader_init_internal(mz_zip_archive *pZip, mz_uint flags) +{ + (void)flags; + if ((!pZip) || (pZip->m_pState) || (pZip->m_zip_mode != MZ_ZIP_MODE_INVALID)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + if (!pZip->m_pAlloc) + pZip->m_pAlloc = miniz_def_alloc_func; + if (!pZip->m_pFree) + pZip->m_pFree = miniz_def_free_func; + if (!pZip->m_pRealloc) + pZip->m_pRealloc = miniz_def_realloc_func; + + pZip->m_archive_size = 0; + pZip->m_central_directory_file_ofs = 0; + pZip->m_total_files = 0; + pZip->m_last_error = MZ_ZIP_NO_ERROR; + + if (NULL == (pZip->m_pState = (mz_zip_internal_state *)pZip->m_pAlloc(pZip->m_pAlloc_opaque, 1, sizeof(mz_zip_internal_state)))) + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + + memset(pZip->m_pState, 0, sizeof(mz_zip_internal_state)); + MZ_ZIP_ARRAY_SET_ELEMENT_SIZE(&pZip->m_pState->m_central_dir, sizeof(mz_uint8)); + MZ_ZIP_ARRAY_SET_ELEMENT_SIZE(&pZip->m_pState->m_central_dir_offsets, sizeof(mz_uint32)); + MZ_ZIP_ARRAY_SET_ELEMENT_SIZE(&pZip->m_pState->m_sorted_central_dir_offsets, sizeof(mz_uint32)); + pZip->m_pState->m_init_flags = flags; + pZip->m_pState->m_zip64 = MZ_FALSE; + pZip->m_pState->m_zip64_has_extended_info_fields = MZ_FALSE; + + pZip->m_zip_mode = MZ_ZIP_MODE_READING; + + return MZ_TRUE; +} + +static MZ_FORCEINLINE mz_bool mz_zip_reader_filename_less(const mz_zip_array *pCentral_dir_array, const mz_zip_array *pCentral_dir_offsets, mz_uint l_index, mz_uint r_index) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + const mz_uint8 *pL = &MZ_ZIP_ARRAY_ELEMENT(pCentral_dir_array, mz_uint8, MZ_ZIP_ARRAY_ELEMENT(pCentral_dir_offsets, mz_uint32, l_index)), *pE; + const mz_uint8 *pR = &MZ_ZIP_ARRAY_ELEMENT(pCentral_dir_array, mz_uint8, MZ_ZIP_ARRAY_ELEMENT(pCentral_dir_offsets, mz_uint32, r_index)); + mz_uint l_len = MZ_READ_LE16(pL + MZ_ZIP_CDH_FILENAME_LEN_OFS), r_len = MZ_READ_LE16(pR + MZ_ZIP_CDH_FILENAME_LEN_OFS); + mz_uint8 l = 0, r = 0; + pL += MZ_ZIP_CENTRAL_DIR_HEADER_SIZE; + pR += MZ_ZIP_CENTRAL_DIR_HEADER_SIZE; + pE = pL + MZ_MIN(l_len, r_len); + while (pL < pE) + { + if ((l = MZ_TOLOWER(*pL)) != (r = MZ_TOLOWER(*pR))) + break; + pL++; + pR++; + } + return (pL == pE) ? (l_len < r_len) : (l < r); +} + +#define MZ_SWAP_UINT32(a, b) \ + do \ + { \ + mz_uint32 t = a; \ + a = b; \ + b = t; \ + } \ + MZ_MACRO_END + +/* Heap sort of lowercased filenames, used to help accelerate plain central directory searches by mz_zip_reader_locate_file(). (Could also use qsort(), but it could allocate memory.) */ +static void mz_zip_reader_sort_central_dir_offsets_by_filename(mz_zip_archive *pZip) +{ + mz_zip_internal_state *pState = pZip->m_pState; + const mz_zip_array *pCentral_dir_offsets = &pState->m_central_dir_offsets; + const mz_zip_array *pCentral_dir = &pState->m_central_dir; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 *pIndices; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 start, end; + const mz_uint32 size = pZip->m_total_files; + + if (size <= 1U) + return; + + pIndices = &MZ_ZIP_ARRAY_ELEMENT(&pState->m_sorted_central_dir_offsets, mz_uint32, 0); + + start = (size - 2U) >> 1U; + for (;;) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint64 child, root = start; + for (;;) + { + if ((child = (root << 1U) + 1U) >= size) + break; + child += (((child + 1U) < size) && (mz_zip_reader_filename_less(pCentral_dir, pCentral_dir_offsets, pIndices[child], pIndices[child + 1U]))); + if (!mz_zip_reader_filename_less(pCentral_dir, pCentral_dir_offsets, pIndices[root], pIndices[child])) + break; + MZ_SWAP_UINT32(pIndices[root], pIndices[child]); + root = child; + } + if (!start) + break; + start--; + } + + end = size - 1; + while (end > 0) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint64 child, root = 0; + MZ_SWAP_UINT32(pIndices[end], pIndices[0]); + for (;;) + { + if ((child = (root << 1U) + 1U) >= end) + break; + child += (((child + 1U) < end) && mz_zip_reader_filename_less(pCentral_dir, pCentral_dir_offsets, pIndices[child], pIndices[child + 1U])); + if (!mz_zip_reader_filename_less(pCentral_dir, pCentral_dir_offsets, pIndices[root], pIndices[child])) + break; + MZ_SWAP_UINT32(pIndices[root], pIndices[child]); + root = child; + } + end--; + } +} + +static mz_bool mz_zip_reader_locate_header_sig(mz_zip_archive *pZip, mz_uint32 record_sig, mz_uint32 record_size, mz_int64 *pOfs) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_int64 cur_file_ofs; + mz_uint32 buf_u32[4096 / sizeof(mz_uint32)]; + mz_uint8 *pBuf = (mz_uint8 *)buf_u32; + + /* Basic sanity checks - reject files which are too small */ + if (pZip->m_archive_size < record_size) + return MZ_FALSE; + + /* Find the record by scanning the file from the end towards the beginning. */ + cur_file_ofs = MZ_MAX((mz_int64)pZip->m_archive_size - (mz_int64)sizeof(buf_u32), 0); + for (;;) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int i, n = (int)MZ_MIN(sizeof(buf_u32), pZip->m_archive_size - cur_file_ofs); + + if (pZip->m_pRead(pZip->m_pIO_opaque, cur_file_ofs, pBuf, n) != (mz_uint)n) + return MZ_FALSE; + + for (i = n - 4; i >= 0; --i) + { + mz_uint s = MZ_READ_LE32(pBuf + i); + if (s == record_sig) + { + if ((pZip->m_archive_size - (cur_file_ofs + i)) >= record_size) + break; + } + } + + if (i >= 0) + { + cur_file_ofs += i; + break; + } + + /* Give up if we've searched the entire file, or we've gone back "too far" (~64kb) */ + if ((!cur_file_ofs) || ((pZip->m_archive_size - cur_file_ofs) >= (MZ_UINT16_MAX + record_size))) + return MZ_FALSE; + + cur_file_ofs = MZ_MAX(cur_file_ofs - (sizeof(buf_u32) - 3), 0); + } + + *pOfs = cur_file_ofs; + return MZ_TRUE; +} + +static mz_bool mz_zip_reader_read_central_dir(mz_zip_archive *pZip, mz_uint flags) +{ + mz_uint cdir_size = 0, cdir_entries_on_this_disk = 0, num_this_disk = 0, cdir_disk_index = 0; + mz_uint64 cdir_ofs = 0; + mz_int64 cur_file_ofs = 0; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + const mz_uint8 *p; + + mz_uint32 buf_u32[4096 / sizeof(mz_uint32)]; + mz_uint8 *pBuf = (mz_uint8 *)buf_u32; + mz_bool sort_central_dir = ((flags & MZ_ZIP_FLAG_DO_NOT_SORT_CENTRAL_DIRECTORY) == 0); + mz_uint32 zip64_end_of_central_dir_locator_u32[(MZ_ZIP64_END_OF_CENTRAL_DIR_LOCATOR_SIZE + sizeof(mz_uint32) - 1) / sizeof(mz_uint32)]; + mz_uint8 *pZip64_locator = (mz_uint8 *)zip64_end_of_central_dir_locator_u32; + + mz_uint32 zip64_end_of_central_dir_header_u32[(MZ_ZIP64_END_OF_CENTRAL_DIR_HEADER_SIZE + sizeof(mz_uint32) - 1) / sizeof(mz_uint32)]; + mz_uint8 *pZip64_end_of_central_dir = (mz_uint8 *)zip64_end_of_central_dir_header_u32; + + mz_uint64 zip64_end_of_central_dir_ofs = 0; + + /* Basic sanity checks - reject files which are too small, and check the first 4 bytes of the file to make sure a local header is there. */ + if (pZip->m_archive_size < MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIZE) + return mz_zip_set_error(pZip, MZ_ZIP_NOT_AN_ARCHIVE); + + if (!mz_zip_reader_locate_header_sig(pZip, MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIG, MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIZE, &cur_file_ofs)) + return mz_zip_set_error(pZip, MZ_ZIP_FAILED_FINDING_CENTRAL_DIR); + + /* Read and verify the end of central directory record. */ + if (pZip->m_pRead(pZip->m_pIO_opaque, cur_file_ofs, pBuf, MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIZE) != MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIZE) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); + + if (MZ_READ_LE32(pBuf + MZ_ZIP_ECDH_SIG_OFS) != MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIG) + return mz_zip_set_error(pZip, MZ_ZIP_NOT_AN_ARCHIVE); + + if (cur_file_ofs >= (MZ_ZIP64_END_OF_CENTRAL_DIR_LOCATOR_SIZE + MZ_ZIP64_END_OF_CENTRAL_DIR_HEADER_SIZE)) + { + if (pZip->m_pRead(pZip->m_pIO_opaque, cur_file_ofs - MZ_ZIP64_END_OF_CENTRAL_DIR_LOCATOR_SIZE, pZip64_locator, MZ_ZIP64_END_OF_CENTRAL_DIR_LOCATOR_SIZE) == MZ_ZIP64_END_OF_CENTRAL_DIR_LOCATOR_SIZE) + { + if (MZ_READ_LE32(pZip64_locator + MZ_ZIP64_ECDL_SIG_OFS) == MZ_ZIP64_END_OF_CENTRAL_DIR_LOCATOR_SIG) + { + zip64_end_of_central_dir_ofs = MZ_READ_LE64(pZip64_locator + MZ_ZIP64_ECDL_REL_OFS_TO_ZIP64_ECDR_OFS); + if (zip64_end_of_central_dir_ofs > (pZip->m_archive_size - MZ_ZIP64_END_OF_CENTRAL_DIR_HEADER_SIZE)) + return mz_zip_set_error(pZip, MZ_ZIP_NOT_AN_ARCHIVE); + + if (pZip->m_pRead(pZip->m_pIO_opaque, zip64_end_of_central_dir_ofs, pZip64_end_of_central_dir, MZ_ZIP64_END_OF_CENTRAL_DIR_HEADER_SIZE) == MZ_ZIP64_END_OF_CENTRAL_DIR_HEADER_SIZE) + { + if (MZ_READ_LE32(pZip64_end_of_central_dir + MZ_ZIP64_ECDH_SIG_OFS) == MZ_ZIP64_END_OF_CENTRAL_DIR_HEADER_SIG) + { + pZip->m_pState->m_zip64 = MZ_TRUE; + } + } + } + } + } + + pZip->m_total_files = MZ_READ_LE16(pBuf + MZ_ZIP_ECDH_CDIR_TOTAL_ENTRIES_OFS); + cdir_entries_on_this_disk = MZ_READ_LE16(pBuf + MZ_ZIP_ECDH_CDIR_NUM_ENTRIES_ON_DISK_OFS); + num_this_disk = MZ_READ_LE16(pBuf + MZ_ZIP_ECDH_NUM_THIS_DISK_OFS); + cdir_disk_index = MZ_READ_LE16(pBuf + MZ_ZIP_ECDH_NUM_DISK_CDIR_OFS); + cdir_size = MZ_READ_LE32(pBuf + MZ_ZIP_ECDH_CDIR_SIZE_OFS); + cdir_ofs = MZ_READ_LE32(pBuf + MZ_ZIP_ECDH_CDIR_OFS_OFS); + + if (pZip->m_pState->m_zip64) + { + // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult) + mz_uint32 zip64_total_num_of_disks = MZ_READ_LE32(pZip64_locator + MZ_ZIP64_ECDL_TOTAL_NUMBER_OF_DISKS_OFS); + // NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult) + mz_uint64 zip64_cdir_total_entries = MZ_READ_LE64(pZip64_end_of_central_dir + MZ_ZIP64_ECDH_CDIR_TOTAL_ENTRIES_OFS); + mz_uint64 zip64_cdir_total_entries_on_this_disk = MZ_READ_LE64(pZip64_end_of_central_dir + MZ_ZIP64_ECDH_CDIR_NUM_ENTRIES_ON_DISK_OFS); + mz_uint64 zip64_size_of_end_of_central_dir_record = MZ_READ_LE64(pZip64_end_of_central_dir + MZ_ZIP64_ECDH_SIZE_OF_RECORD_OFS); + mz_uint64 zip64_size_of_central_directory = MZ_READ_LE64(pZip64_end_of_central_dir + MZ_ZIP64_ECDH_CDIR_SIZE_OFS); + + if (zip64_size_of_end_of_central_dir_record < (MZ_ZIP64_END_OF_CENTRAL_DIR_HEADER_SIZE - 12)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + if (zip64_total_num_of_disks != 1U) + return mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_MULTIDISK); + + /* Check for miniz's practical limits */ + if (zip64_cdir_total_entries > MZ_UINT32_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_TOO_MANY_FILES); + + pZip->m_total_files = (mz_uint32)zip64_cdir_total_entries; + + if (zip64_cdir_total_entries_on_this_disk > MZ_UINT32_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_TOO_MANY_FILES); + + cdir_entries_on_this_disk = (mz_uint32)zip64_cdir_total_entries_on_this_disk; + + /* Check for miniz's current practical limits (sorry, this should be enough for millions of files) */ + if (zip64_size_of_central_directory > MZ_UINT32_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_CDIR_SIZE); + + cdir_size = (mz_uint32)zip64_size_of_central_directory; + + num_this_disk = MZ_READ_LE32(pZip64_end_of_central_dir + MZ_ZIP64_ECDH_NUM_THIS_DISK_OFS); + + cdir_disk_index = MZ_READ_LE32(pZip64_end_of_central_dir + MZ_ZIP64_ECDH_NUM_DISK_CDIR_OFS); + + cdir_ofs = MZ_READ_LE64(pZip64_end_of_central_dir + MZ_ZIP64_ECDH_CDIR_OFS_OFS); + } + + if (pZip->m_total_files != cdir_entries_on_this_disk) + return mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_MULTIDISK); + + if (((num_this_disk | cdir_disk_index) != 0) && ((num_this_disk != 1) || (cdir_disk_index != 1))) + return mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_MULTIDISK); + + if (cdir_size < pZip->m_total_files * MZ_ZIP_CENTRAL_DIR_HEADER_SIZE) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + if ((cdir_ofs + (mz_uint64)cdir_size) > pZip->m_archive_size) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + pZip->m_central_directory_file_ofs = cdir_ofs; + + if (pZip->m_total_files) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint i, n; + /* Read the entire central directory into a heap block, and allocate another heap block to hold the unsorted central dir file record offsets, and possibly another to hold the sorted indices. */ + if ((!mz_zip_array_resize(pZip, &pZip->m_pState->m_central_dir, cdir_size, MZ_FALSE)) || + (!mz_zip_array_resize(pZip, &pZip->m_pState->m_central_dir_offsets, pZip->m_total_files, MZ_FALSE))) + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + + if (sort_central_dir) + { + if (!mz_zip_array_resize(pZip, &pZip->m_pState->m_sorted_central_dir_offsets, pZip->m_total_files, MZ_FALSE)) + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + } + + if (pZip->m_pRead(pZip->m_pIO_opaque, cdir_ofs, pZip->m_pState->m_central_dir.m_p, cdir_size) != cdir_size) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); + + /* Now create an index into the central directory file records, do some basic sanity checking on each record */ + p = (const mz_uint8 *)pZip->m_pState->m_central_dir.m_p; + for (n = cdir_size, i = 0; i < pZip->m_total_files; ++i) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint total_header_size, disk_index, bit_flags, filename_size, ext_data_size; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint64 comp_size, decomp_size, local_header_ofs; + + if ((n < MZ_ZIP_CENTRAL_DIR_HEADER_SIZE) || (MZ_READ_LE32(p) != MZ_ZIP_CENTRAL_DIR_HEADER_SIG)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + MZ_ZIP_ARRAY_ELEMENT(&pZip->m_pState->m_central_dir_offsets, mz_uint32, i) = (mz_uint32)(p - (const mz_uint8 *)pZip->m_pState->m_central_dir.m_p); + + if (sort_central_dir) + MZ_ZIP_ARRAY_ELEMENT(&pZip->m_pState->m_sorted_central_dir_offsets, mz_uint32, i) = i; + + comp_size = MZ_READ_LE32(p + MZ_ZIP_CDH_COMPRESSED_SIZE_OFS); + decomp_size = MZ_READ_LE32(p + MZ_ZIP_CDH_DECOMPRESSED_SIZE_OFS); + local_header_ofs = MZ_READ_LE32(p + MZ_ZIP_CDH_LOCAL_HEADER_OFS); + filename_size = MZ_READ_LE16(p + MZ_ZIP_CDH_FILENAME_LEN_OFS); + ext_data_size = MZ_READ_LE16(p + MZ_ZIP_CDH_EXTRA_LEN_OFS); + + if ((!pZip->m_pState->m_zip64_has_extended_info_fields) && + (ext_data_size) && + (MZ_MAX(MZ_MAX(comp_size, decomp_size), local_header_ofs) == MZ_UINT32_MAX)) + { + /* Attempt to find zip64 extended information field in the entry's extra data */ + mz_uint32 extra_size_remaining = ext_data_size; + + if (extra_size_remaining) + { + const mz_uint8 *pExtra_data; + void* buf = NULL; + + if (MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + filename_size + ext_data_size > n) + { + buf = MZ_MALLOC(ext_data_size); + if(buf==NULL) + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + + if (pZip->m_pRead(pZip->m_pIO_opaque, cdir_ofs + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + filename_size, buf, ext_data_size) != ext_data_size) + { + MZ_FREE(buf); + return mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); + } + + pExtra_data = (mz_uint8*)buf; + } + else + { + pExtra_data = p + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + filename_size; + } + + do + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 field_id; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 field_data_size; + + if (extra_size_remaining < (sizeof(mz_uint16) * 2)) + { + MZ_FREE(buf); + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + } + + field_id = MZ_READ_LE16(pExtra_data); + field_data_size = MZ_READ_LE16(pExtra_data + sizeof(mz_uint16)); + + if ((field_data_size + sizeof(mz_uint16) * 2) > extra_size_remaining) + { + MZ_FREE(buf); + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + } + + if (field_id == MZ_ZIP64_EXTENDED_INFORMATION_FIELD_HEADER_ID) + { + /* Ok, the archive didn't have any zip64 headers but it uses a zip64 extended information field so mark it as zip64 anyway (this can occur with infozip's zip util when it reads compresses files from stdin). */ + pZip->m_pState->m_zip64 = MZ_TRUE; + pZip->m_pState->m_zip64_has_extended_info_fields = MZ_TRUE; + break; + } + + pExtra_data += sizeof(mz_uint16) * 2 + field_data_size; + extra_size_remaining = extra_size_remaining - sizeof(mz_uint16) * 2 - field_data_size; + } while (extra_size_remaining); + + MZ_FREE(buf); + } + } + + /* I've seen archives that aren't marked as zip64 that uses zip64 ext data, argh */ + if ((comp_size != MZ_UINT32_MAX) && (decomp_size != MZ_UINT32_MAX)) + { + if (((!MZ_READ_LE32(p + MZ_ZIP_CDH_METHOD_OFS)) && (decomp_size != comp_size)) || (decomp_size && !comp_size)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + } + + disk_index = MZ_READ_LE16(p + MZ_ZIP_CDH_DISK_START_OFS); + if ((disk_index == MZ_UINT16_MAX) || ((disk_index != num_this_disk) && (disk_index != 1))) + return mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_MULTIDISK); + + if (comp_size != MZ_UINT32_MAX) + { + if (((mz_uint64)MZ_READ_LE32(p + MZ_ZIP_CDH_LOCAL_HEADER_OFS) + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + comp_size) > pZip->m_archive_size) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + } + + bit_flags = MZ_READ_LE16(p + MZ_ZIP_CDH_BIT_FLAG_OFS); + if (bit_flags & MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_LOCAL_DIR_IS_MASKED) + return mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_ENCRYPTION); + + if ((total_header_size = MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + MZ_READ_LE16(p + MZ_ZIP_CDH_FILENAME_LEN_OFS) + MZ_READ_LE16(p + MZ_ZIP_CDH_EXTRA_LEN_OFS) + MZ_READ_LE16(p + MZ_ZIP_CDH_COMMENT_LEN_OFS)) > n) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + n -= total_header_size; + p += total_header_size; + } + } + + if (sort_central_dir) + mz_zip_reader_sort_central_dir_offsets_by_filename(pZip); + + return MZ_TRUE; +} + +void mz_zip_zero_struct(mz_zip_archive *pZip) +{ + if (pZip) + MZ_CLEAR_OBJ(*pZip); +} + +static mz_bool mz_zip_reader_end_internal(mz_zip_archive *pZip, mz_bool set_last_error) +{ + mz_bool status = MZ_TRUE; + + if (!pZip) + return MZ_FALSE; + + if ((!pZip->m_pState) || (!pZip->m_pAlloc) || (!pZip->m_pFree) || (pZip->m_zip_mode != MZ_ZIP_MODE_READING)) + { + if (set_last_error) + pZip->m_last_error = MZ_ZIP_INVALID_PARAMETER; + + return MZ_FALSE; + } + + if (pZip->m_pState) + { + mz_zip_internal_state *pState = pZip->m_pState; + pZip->m_pState = NULL; + + mz_zip_array_clear(pZip, &pState->m_central_dir); + mz_zip_array_clear(pZip, &pState->m_central_dir_offsets); + mz_zip_array_clear(pZip, &pState->m_sorted_central_dir_offsets); + +#ifndef MINIZ_NO_STDIO + if (pState->m_pFile) + { + if (pZip->m_zip_type == MZ_ZIP_TYPE_FILE) + { + if (MZ_FCLOSE(pState->m_pFile) == EOF) + { + if (set_last_error) + pZip->m_last_error = MZ_ZIP_FILE_CLOSE_FAILED; + status = MZ_FALSE; + } + } + pState->m_pFile = NULL; + } +#endif /* #ifndef MINIZ_NO_STDIO */ + + pZip->m_pFree(pZip->m_pAlloc_opaque, pState); + } + pZip->m_zip_mode = MZ_ZIP_MODE_INVALID; + + return status; +} + +mz_bool mz_zip_reader_end(mz_zip_archive *pZip) +{ + return mz_zip_reader_end_internal(pZip, MZ_TRUE); +} +mz_bool mz_zip_reader_init(mz_zip_archive *pZip, mz_uint64 size, mz_uint flags) +{ + if ((!pZip) || (!pZip->m_pRead)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + if (!mz_zip_reader_init_internal(pZip, flags)) + return MZ_FALSE; + + pZip->m_zip_type = MZ_ZIP_TYPE_USER; + pZip->m_archive_size = size; + + if (!mz_zip_reader_read_central_dir(pZip, flags)) + { + mz_zip_reader_end_internal(pZip, MZ_FALSE); + return MZ_FALSE; + } + + return MZ_TRUE; +} + +static size_t mz_zip_mem_read_func(void *pOpaque, mz_uint64 file_ofs, void *pBuf, size_t n) +{ + mz_zip_archive *pZip = (mz_zip_archive *)pOpaque; + size_t s = (file_ofs >= pZip->m_archive_size) ? 0 : (size_t)MZ_MIN(pZip->m_archive_size - file_ofs, n); + memcpy(pBuf, (const mz_uint8 *)pZip->m_pState->m_pMem + file_ofs, s); + return s; +} + +mz_bool mz_zip_reader_init_mem(mz_zip_archive *pZip, const void *pMem, size_t size, mz_uint flags) +{ + if (!pMem) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + if (size < MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIZE) + return mz_zip_set_error(pZip, MZ_ZIP_NOT_AN_ARCHIVE); + + if (!mz_zip_reader_init_internal(pZip, flags)) + return MZ_FALSE; + + pZip->m_zip_type = MZ_ZIP_TYPE_MEMORY; + pZip->m_archive_size = size; + pZip->m_pRead = mz_zip_mem_read_func; + pZip->m_pIO_opaque = pZip; + pZip->m_pNeeds_keepalive = NULL; + +#ifdef __cplusplus + pZip->m_pState->m_pMem = const_cast(pMem); +#else + pZip->m_pState->m_pMem = (void *)pMem; +#endif + + pZip->m_pState->m_mem_size = size; + + if (!mz_zip_reader_read_central_dir(pZip, flags)) + { + mz_zip_reader_end_internal(pZip, MZ_FALSE); + return MZ_FALSE; + } + + return MZ_TRUE; +} + +#ifndef MINIZ_NO_STDIO +static size_t mz_zip_file_read_func(void *pOpaque, mz_uint64 file_ofs, void *pBuf, size_t n) +{ + mz_zip_archive *pZip = (mz_zip_archive *)pOpaque; + mz_int64 cur_ofs = MZ_FTELL64(pZip->m_pState->m_pFile); + + file_ofs += pZip->m_pState->m_file_archive_start_ofs; + + if (((mz_int64)file_ofs < 0) || (((cur_ofs != (mz_int64)file_ofs)) && (MZ_FSEEK64(pZip->m_pState->m_pFile, (mz_int64)file_ofs, SEEK_SET)))) + return 0; + + return MZ_FREAD(pBuf, 1, n, pZip->m_pState->m_pFile); +} + +mz_bool mz_zip_reader_init_file(mz_zip_archive *pZip, const char *pFilename, mz_uint32 flags) +{ + return mz_zip_reader_init_file_v2(pZip, pFilename, flags, 0, 0); +} + +mz_bool mz_zip_reader_init_file_v2(mz_zip_archive *pZip, const char *pFilename, mz_uint flags, mz_uint64 file_start_ofs, mz_uint64 archive_size) +{ + mz_uint64 file_size; + MZ_FILE *pFile; + + if ((!pZip) || (!pFilename) || ((archive_size) && (archive_size < MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIZE))) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + pFile = MZ_FOPEN(pFilename, "rb"); + if (!pFile) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_OPEN_FAILED); + + file_size = archive_size; + if (!file_size) + { + if (MZ_FSEEK64(pFile, 0, SEEK_END)) + { + MZ_FCLOSE(pFile); + return mz_zip_set_error(pZip, MZ_ZIP_FILE_SEEK_FAILED); + } + + file_size = MZ_FTELL64(pFile); + } + + /* TODO: Better sanity check archive_size and the # of actual remaining bytes */ + + if (file_size < MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIZE) + { + MZ_FCLOSE(pFile); + return mz_zip_set_error(pZip, MZ_ZIP_NOT_AN_ARCHIVE); + } + + if (!mz_zip_reader_init_internal(pZip, flags)) + { + MZ_FCLOSE(pFile); + return MZ_FALSE; + } + + pZip->m_zip_type = MZ_ZIP_TYPE_FILE; + pZip->m_pRead = mz_zip_file_read_func; + pZip->m_pIO_opaque = pZip; + pZip->m_pState->m_pFile = pFile; + pZip->m_archive_size = file_size; + pZip->m_pState->m_file_archive_start_ofs = file_start_ofs; + + if (!mz_zip_reader_read_central_dir(pZip, flags)) + { + mz_zip_reader_end_internal(pZip, MZ_FALSE); + return MZ_FALSE; + } + + return MZ_TRUE; +} + +mz_bool mz_zip_reader_init_cfile(mz_zip_archive *pZip, MZ_FILE *pFile, mz_uint64 archive_size, mz_uint flags) +{ + mz_uint64 cur_file_ofs; + + if ((!pZip) || (!pFile)) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_OPEN_FAILED); + + cur_file_ofs = MZ_FTELL64(pFile); + + if (!archive_size) + { + if (MZ_FSEEK64(pFile, 0, SEEK_END)) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_SEEK_FAILED); + + archive_size = MZ_FTELL64(pFile) - cur_file_ofs; + + if (archive_size < MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIZE) + return mz_zip_set_error(pZip, MZ_ZIP_NOT_AN_ARCHIVE); + } + + if (!mz_zip_reader_init_internal(pZip, flags)) + return MZ_FALSE; + + pZip->m_zip_type = MZ_ZIP_TYPE_CFILE; + pZip->m_pRead = mz_zip_file_read_func; + + pZip->m_pIO_opaque = pZip; + pZip->m_pState->m_pFile = pFile; + pZip->m_archive_size = archive_size; + pZip->m_pState->m_file_archive_start_ofs = cur_file_ofs; + + if (!mz_zip_reader_read_central_dir(pZip, flags)) + { + mz_zip_reader_end_internal(pZip, MZ_FALSE); + return MZ_FALSE; + } + + return MZ_TRUE; +} + +#endif /* #ifndef MINIZ_NO_STDIO */ + +static MZ_FORCEINLINE const mz_uint8 *mz_zip_get_cdh(mz_zip_archive *pZip, mz_uint file_index) +{ + if ((!pZip) || (!pZip->m_pState) || (file_index >= pZip->m_total_files)) + return NULL; + return &MZ_ZIP_ARRAY_ELEMENT(&pZip->m_pState->m_central_dir, mz_uint8, MZ_ZIP_ARRAY_ELEMENT(&pZip->m_pState->m_central_dir_offsets, mz_uint32, file_index)); +} + +mz_bool mz_zip_reader_is_file_encrypted(mz_zip_archive *pZip, mz_uint file_index) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint m_bit_flag; + const mz_uint8 *p = mz_zip_get_cdh(pZip, file_index); + if (!p) + { + mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + return MZ_FALSE; + } + + m_bit_flag = MZ_READ_LE16(p + MZ_ZIP_CDH_BIT_FLAG_OFS); + return (m_bit_flag & (MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_IS_ENCRYPTED | MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_USES_STRONG_ENCRYPTION)) != 0; +} + +mz_bool mz_zip_reader_is_file_supported(mz_zip_archive *pZip, mz_uint file_index) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint bit_flag; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint method; + + const mz_uint8 *p = mz_zip_get_cdh(pZip, file_index); + if (!p) + { + mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + return MZ_FALSE; + } + + method = MZ_READ_LE16(p + MZ_ZIP_CDH_METHOD_OFS); + bit_flag = MZ_READ_LE16(p + MZ_ZIP_CDH_BIT_FLAG_OFS); + + if ((method != 0) && (method != MZ_DEFLATED)) + { + mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_METHOD); + return MZ_FALSE; + } + + if (bit_flag & (MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_IS_ENCRYPTED | MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_USES_STRONG_ENCRYPTION)) + { + mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_ENCRYPTION); + return MZ_FALSE; + } + + if (bit_flag & MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_COMPRESSED_PATCH_FLAG) + { + mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_FEATURE); + return MZ_FALSE; + } + + return MZ_TRUE; +} + +mz_bool mz_zip_reader_is_file_a_directory(mz_zip_archive *pZip, mz_uint file_index) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint filename_len, attribute_mapping_id, external_attr; + const mz_uint8 *p = mz_zip_get_cdh(pZip, file_index); + if (!p) + { + mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + return MZ_FALSE; + } + + filename_len = MZ_READ_LE16(p + MZ_ZIP_CDH_FILENAME_LEN_OFS); + if (filename_len) + { + if (*(p + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + filename_len - 1) == '/') + return MZ_TRUE; + } + + /* Bugfix: This code was also checking if the internal attribute was non-zero, which wasn't correct. */ + /* Most/all zip writers (hopefully) set DOS file/directory attributes in the low 16-bits, so check for the DOS directory flag and ignore the source OS ID in the created by field. */ + /* FIXME: Remove this check? Is it necessary - we already check the filename. */ + attribute_mapping_id = MZ_READ_LE16(p + MZ_ZIP_CDH_VERSION_MADE_BY_OFS) >> 8; + (void)attribute_mapping_id; + + external_attr = MZ_READ_LE32(p + MZ_ZIP_CDH_EXTERNAL_ATTR_OFS); + if ((external_attr & MZ_ZIP_DOS_DIR_ATTRIBUTE_BITFLAG) != 0) + { + return MZ_TRUE; + } + + return MZ_FALSE; +} + +static mz_bool mz_zip_file_stat_internal(mz_zip_archive *pZip, mz_uint file_index, const mz_uint8 *pCentral_dir_header, mz_zip_archive_file_stat *pStat, mz_bool *pFound_zip64_extra_data) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint n; + const mz_uint8 *p = pCentral_dir_header; + + if (pFound_zip64_extra_data) + *pFound_zip64_extra_data = MZ_FALSE; + + if ((!p) || (!pStat)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + /* Extract fields from the central directory record. */ + pStat->m_file_index = file_index; + pStat->m_central_dir_ofs = MZ_ZIP_ARRAY_ELEMENT(&pZip->m_pState->m_central_dir_offsets, mz_uint32, file_index); + pStat->m_version_made_by = MZ_READ_LE16(p + MZ_ZIP_CDH_VERSION_MADE_BY_OFS); + pStat->m_version_needed = MZ_READ_LE16(p + MZ_ZIP_CDH_VERSION_NEEDED_OFS); + pStat->m_bit_flag = MZ_READ_LE16(p + MZ_ZIP_CDH_BIT_FLAG_OFS); + pStat->m_method = MZ_READ_LE16(p + MZ_ZIP_CDH_METHOD_OFS); +#ifndef MINIZ_NO_TIME + pStat->m_time = mz_zip_dos_to_time_t(MZ_READ_LE16(p + MZ_ZIP_CDH_FILE_TIME_OFS), MZ_READ_LE16(p + MZ_ZIP_CDH_FILE_DATE_OFS)); +#endif + pStat->m_crc32 = MZ_READ_LE32(p + MZ_ZIP_CDH_CRC32_OFS); + pStat->m_comp_size = MZ_READ_LE32(p + MZ_ZIP_CDH_COMPRESSED_SIZE_OFS); + pStat->m_uncomp_size = MZ_READ_LE32(p + MZ_ZIP_CDH_DECOMPRESSED_SIZE_OFS); + pStat->m_internal_attr = MZ_READ_LE16(p + MZ_ZIP_CDH_INTERNAL_ATTR_OFS); + pStat->m_external_attr = MZ_READ_LE32(p + MZ_ZIP_CDH_EXTERNAL_ATTR_OFS); + pStat->m_local_header_ofs = MZ_READ_LE32(p + MZ_ZIP_CDH_LOCAL_HEADER_OFS); + + /* Copy as much of the filename and comment as possible. */ + n = MZ_READ_LE16(p + MZ_ZIP_CDH_FILENAME_LEN_OFS); + n = MZ_MIN(n, MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE - 1); + memcpy(pStat->m_filename, p + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE, n); + pStat->m_filename[n] = '\0'; + + n = MZ_READ_LE16(p + MZ_ZIP_CDH_COMMENT_LEN_OFS); + n = MZ_MIN(n, MZ_ZIP_MAX_ARCHIVE_FILE_COMMENT_SIZE - 1); + pStat->m_comment_size = n; + memcpy(pStat->m_comment, p + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + MZ_READ_LE16(p + MZ_ZIP_CDH_FILENAME_LEN_OFS) + MZ_READ_LE16(p + MZ_ZIP_CDH_EXTRA_LEN_OFS), n); + pStat->m_comment[n] = '\0'; + + /* Set some flags for convienance */ + pStat->m_is_directory = mz_zip_reader_is_file_a_directory(pZip, file_index); + pStat->m_is_encrypted = mz_zip_reader_is_file_encrypted(pZip, file_index); + pStat->m_is_supported = mz_zip_reader_is_file_supported(pZip, file_index); + + /* See if we need to read any zip64 extended information fields. */ + /* Confusingly, these zip64 fields can be present even on non-zip64 archives (Debian zip on a huge files from stdin piped to stdout creates them). */ + if (MZ_MAX(MZ_MAX(pStat->m_comp_size, pStat->m_uncomp_size), pStat->m_local_header_ofs) == MZ_UINT32_MAX) + { + /* Attempt to find zip64 extended information field in the entry's extra data */ + mz_uint32 extra_size_remaining = MZ_READ_LE16(p + MZ_ZIP_CDH_EXTRA_LEN_OFS); + + if (extra_size_remaining) + { + const mz_uint8 *pExtra_data = p + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + MZ_READ_LE16(p + MZ_ZIP_CDH_FILENAME_LEN_OFS); + + do + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 field_id; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 field_data_size; + + if (extra_size_remaining < (sizeof(mz_uint16) * 2)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + field_id = MZ_READ_LE16(pExtra_data); + field_data_size = MZ_READ_LE16(pExtra_data + sizeof(mz_uint16)); + + if ((field_data_size + sizeof(mz_uint16) * 2) > extra_size_remaining) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + if (field_id == MZ_ZIP64_EXTENDED_INFORMATION_FIELD_HEADER_ID) + { + const mz_uint8 *pField_data = pExtra_data + sizeof(mz_uint16) * 2; + mz_uint32 field_data_remaining = field_data_size; + + if (pFound_zip64_extra_data) + *pFound_zip64_extra_data = MZ_TRUE; + + if (pStat->m_uncomp_size == MZ_UINT32_MAX) + { + if (field_data_remaining < sizeof(mz_uint64)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + pStat->m_uncomp_size = MZ_READ_LE64(pField_data); + pField_data += sizeof(mz_uint64); + field_data_remaining -= sizeof(mz_uint64); + } + + if (pStat->m_comp_size == MZ_UINT32_MAX) + { + if (field_data_remaining < sizeof(mz_uint64)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + pStat->m_comp_size = MZ_READ_LE64(pField_data); + pField_data += sizeof(mz_uint64); + field_data_remaining -= sizeof(mz_uint64); + } + + if (pStat->m_local_header_ofs == MZ_UINT32_MAX) + { + if (field_data_remaining < sizeof(mz_uint64)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + pStat->m_local_header_ofs = MZ_READ_LE64(pField_data); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + pField_data += sizeof(mz_uint64); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + field_data_remaining -= sizeof(mz_uint64); + } + + break; + } + + pExtra_data += sizeof(mz_uint16) * 2 + field_data_size; + extra_size_remaining = extra_size_remaining - sizeof(mz_uint16) * 2 - field_data_size; + } while (extra_size_remaining); + } + } + + return MZ_TRUE; +} + +static MZ_FORCEINLINE mz_bool mz_zip_string_equal(const char *pA, const char *pB, mz_uint len, mz_uint flags) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint i; + if (flags & MZ_ZIP_FLAG_CASE_SENSITIVE) + return 0 == memcmp(pA, pB, len); + for (i = 0; i < len; ++i) + if (MZ_TOLOWER(pA[i]) != MZ_TOLOWER(pB[i])) + return MZ_FALSE; + return MZ_TRUE; +} + +static MZ_FORCEINLINE int mz_zip_filename_compare(const mz_zip_array *pCentral_dir_array, const mz_zip_array *pCentral_dir_offsets, mz_uint l_index, const char *pR, mz_uint r_len) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + const mz_uint8 *pL = &MZ_ZIP_ARRAY_ELEMENT(pCentral_dir_array, mz_uint8, MZ_ZIP_ARRAY_ELEMENT(pCentral_dir_offsets, mz_uint32, l_index)), *pE; + mz_uint l_len = MZ_READ_LE16(pL + MZ_ZIP_CDH_FILENAME_LEN_OFS); + mz_uint8 l = 0, r = 0; + pL += MZ_ZIP_CENTRAL_DIR_HEADER_SIZE; + pE = pL + MZ_MIN(l_len, r_len); + while (pL < pE) + { + if ((l = MZ_TOLOWER(*pL)) != (r = MZ_TOLOWER(*pR))) + break; + pL++; + pR++; + } + return (pL == pE) ? (int)(l_len - r_len) : (l - r); +} + +static mz_bool mz_zip_locate_file_binary_search(mz_zip_archive *pZip, const char *pFilename, mz_uint32 *pIndex) +{ + mz_zip_internal_state *pState = pZip->m_pState; + const mz_zip_array *pCentral_dir_offsets = &pState->m_central_dir_offsets; + const mz_zip_array *pCentral_dir = &pState->m_central_dir; + mz_uint32 *pIndices = &MZ_ZIP_ARRAY_ELEMENT(&pState->m_sorted_central_dir_offsets, mz_uint32, 0); + const uint32_t size = pZip->m_total_files; + const mz_uint filename_len = (mz_uint)strlen(pFilename); + + if (pIndex) + *pIndex = 0; + + if (size) + { + /* yes I could use uint32_t's, but then we would have to add some special case checks in the loop, argh, and */ + /* honestly the major expense here on 32-bit CPU's will still be the filename compare */ + mz_int64 l = 0, h = (mz_int64)size - 1; + + while (l <= h) + { + mz_int64 m = l + ((h - l) >> 1); + uint32_t file_index = pIndices[(uint32_t)m]; + + int comp = mz_zip_filename_compare(pCentral_dir, pCentral_dir_offsets, file_index, pFilename, filename_len); + if (!comp) + { + if (pIndex) + *pIndex = file_index; + return MZ_TRUE; + } + else if (comp < 0) + l = m + 1; + else + h = m - 1; + } + } + + return mz_zip_set_error(pZip, MZ_ZIP_FILE_NOT_FOUND); +} + +int mz_zip_reader_locate_file(mz_zip_archive *pZip, const char *pName, const char *pComment, mz_uint flags) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 index; + if (!mz_zip_reader_locate_file_v2(pZip, pName, pComment, flags, &index)) + return -1; + else + return (int)index; +} + +mz_bool mz_zip_reader_locate_file_v2(mz_zip_archive *pZip, const char *pName, const char *pComment, mz_uint flags, mz_uint32 *pIndex) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint file_index; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + size_t name_len, comment_len; + + if (pIndex) + *pIndex = 0; + + if ((!pZip) || (!pZip->m_pState) || (!pName)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + /* See if we can use a binary search */ + if (((pZip->m_pState->m_init_flags & MZ_ZIP_FLAG_DO_NOT_SORT_CENTRAL_DIRECTORY) == 0) && + (pZip->m_zip_mode == MZ_ZIP_MODE_READING) && + ((flags & (MZ_ZIP_FLAG_IGNORE_PATH | MZ_ZIP_FLAG_CASE_SENSITIVE)) == 0) && (!pComment) && (pZip->m_pState->m_sorted_central_dir_offsets.m_size)) + { + return mz_zip_locate_file_binary_search(pZip, pName, pIndex); + } + + /* Locate the entry by scanning the entire central directory */ + name_len = strlen(pName); + if (name_len > MZ_UINT16_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + comment_len = pComment ? strlen(pComment) : 0; + if (comment_len > MZ_UINT16_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + for (file_index = 0; file_index < pZip->m_total_files; file_index++) + { + const mz_uint8 *pHeader = &MZ_ZIP_ARRAY_ELEMENT(&pZip->m_pState->m_central_dir, mz_uint8, MZ_ZIP_ARRAY_ELEMENT(&pZip->m_pState->m_central_dir_offsets, mz_uint32, file_index)); + mz_uint filename_len = MZ_READ_LE16(pHeader + MZ_ZIP_CDH_FILENAME_LEN_OFS); + const char *pFilename = (const char *)pHeader + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE; + if (filename_len < name_len) + continue; + if (comment_len) + { + mz_uint file_extra_len = MZ_READ_LE16(pHeader + MZ_ZIP_CDH_EXTRA_LEN_OFS), file_comment_len = MZ_READ_LE16(pHeader + MZ_ZIP_CDH_COMMENT_LEN_OFS); + const char *pFile_comment = pFilename + filename_len + file_extra_len; + if ((file_comment_len != comment_len) || (!mz_zip_string_equal(pComment, pFile_comment, file_comment_len, flags))) + continue; + } + if ((flags & MZ_ZIP_FLAG_IGNORE_PATH) && (filename_len)) + { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) + int ofs = filename_len - 1; + do + { + if ((pFilename[ofs] == '/') || (pFilename[ofs] == '\\') || (pFilename[ofs] == ':')) + break; + } while (--ofs >= 0); + ofs++; + pFilename += ofs; + filename_len -= ofs; + } + if ((filename_len == name_len) && (mz_zip_string_equal(pName, pFilename, filename_len, flags))) + { + if (pIndex) + *pIndex = file_index; + return MZ_TRUE; + } + } + + return mz_zip_set_error(pZip, MZ_ZIP_FILE_NOT_FOUND); +} + +mz_bool mz_zip_reader_extract_to_mem_no_alloc(mz_zip_archive *pZip, mz_uint file_index, void *pBuf, size_t buf_size, mz_uint flags, void *pUser_read_buf, size_t user_read_buf_size) +{ + int status = TINFL_STATUS_DONE; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint64 needed_size, cur_file_ofs, comp_remaining, out_buf_ofs = 0, read_buf_size, read_buf_ofs = 0, read_buf_avail; + mz_zip_archive_file_stat file_stat; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + void *pRead_buf; + mz_uint32 local_header_u32[(MZ_ZIP_LOCAL_DIR_HEADER_SIZE + sizeof(mz_uint32) - 1) / sizeof(mz_uint32)]; + mz_uint8 *pLocal_header = (mz_uint8 *)local_header_u32; + tinfl_decompressor inflator; + + if ((!pZip) || (!pZip->m_pState) || ((buf_size) && (!pBuf)) || ((user_read_buf_size) && (!pUser_read_buf)) || (!pZip->m_pRead)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + if (!mz_zip_reader_file_stat(pZip, file_index, &file_stat)) + return MZ_FALSE; + + /* A directory or zero length file */ + if ((file_stat.m_is_directory) || (!file_stat.m_comp_size)) + return MZ_TRUE; + + /* Encryption and patch files are not supported. */ + if (file_stat.m_bit_flag & (MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_IS_ENCRYPTED | MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_USES_STRONG_ENCRYPTION | MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_COMPRESSED_PATCH_FLAG)) + return mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_ENCRYPTION); + + /* This function only supports decompressing stored and deflate. */ + if ((!(flags & MZ_ZIP_FLAG_COMPRESSED_DATA)) && (file_stat.m_method != 0) && (file_stat.m_method != MZ_DEFLATED)) + return mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_METHOD); + + /* Ensure supplied output buffer is large enough. */ + needed_size = (flags & MZ_ZIP_FLAG_COMPRESSED_DATA) ? file_stat.m_comp_size : file_stat.m_uncomp_size; + if (buf_size < needed_size) + return mz_zip_set_error(pZip, MZ_ZIP_BUF_TOO_SMALL); + + /* Read and parse the local directory entry. */ + cur_file_ofs = file_stat.m_local_header_ofs; + if (pZip->m_pRead(pZip->m_pIO_opaque, cur_file_ofs, pLocal_header, MZ_ZIP_LOCAL_DIR_HEADER_SIZE) != MZ_ZIP_LOCAL_DIR_HEADER_SIZE) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); + + if (MZ_READ_LE32(pLocal_header) != MZ_ZIP_LOCAL_DIR_HEADER_SIG) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + cur_file_ofs += MZ_ZIP_LOCAL_DIR_HEADER_SIZE + MZ_READ_LE16(pLocal_header + MZ_ZIP_LDH_FILENAME_LEN_OFS) + MZ_READ_LE16(pLocal_header + MZ_ZIP_LDH_EXTRA_LEN_OFS); + if ((cur_file_ofs + file_stat.m_comp_size) > pZip->m_archive_size) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + if ((flags & MZ_ZIP_FLAG_COMPRESSED_DATA) || (!file_stat.m_method)) + { + /* The file is stored or the caller has requested the compressed data. */ + if (pZip->m_pRead(pZip->m_pIO_opaque, cur_file_ofs, pBuf, (size_t)needed_size) != needed_size) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); + +#ifndef MINIZ_DISABLE_ZIP_READER_CRC32_CHECKS + if ((flags & MZ_ZIP_FLAG_COMPRESSED_DATA) == 0) + { + if (mz_crc32(MZ_CRC32_INIT, (const mz_uint8 *)pBuf, (size_t)file_stat.m_uncomp_size) != file_stat.m_crc32) + return mz_zip_set_error(pZip, MZ_ZIP_CRC_CHECK_FAILED); + } +#endif + + return MZ_TRUE; + } + + /* Decompress the file either directly from memory or from a file input buffer. */ + tinfl_init(&inflator); + + if (pZip->m_pState->m_pMem) + { + /* Read directly from the archive in memory. */ + pRead_buf = (mz_uint8 *)pZip->m_pState->m_pMem + cur_file_ofs; + read_buf_size = read_buf_avail = file_stat.m_comp_size; + comp_remaining = 0; + } + else if (pUser_read_buf) + { + /* Use a user provided read buffer. */ + if (!user_read_buf_size) + return MZ_FALSE; + pRead_buf = (mz_uint8 *)pUser_read_buf; + read_buf_size = user_read_buf_size; + read_buf_avail = 0; + comp_remaining = file_stat.m_comp_size; + } + else + { + /* Temporarily allocate a read buffer. */ + read_buf_size = MZ_MIN(file_stat.m_comp_size, (mz_uint64)MZ_ZIP_MAX_IO_BUF_SIZE); + if (((sizeof(size_t) == sizeof(mz_uint32))) && (read_buf_size > 0x7FFFFFFF)) + return mz_zip_set_error(pZip, MZ_ZIP_INTERNAL_ERROR); + + if (NULL == (pRead_buf = pZip->m_pAlloc(pZip->m_pAlloc_opaque, 1, (size_t)read_buf_size))) + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + + read_buf_avail = 0; + comp_remaining = file_stat.m_comp_size; + } + + do + { + /* The size_t cast here should be OK because we've verified that the output buffer is >= file_stat.m_uncomp_size above */ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + size_t in_buf_size, out_buf_size = (size_t)(file_stat.m_uncomp_size - out_buf_ofs); + if ((!read_buf_avail) && (!pZip->m_pState->m_pMem)) + { + read_buf_avail = MZ_MIN(read_buf_size, comp_remaining); + if (pZip->m_pRead(pZip->m_pIO_opaque, cur_file_ofs, pRead_buf, (size_t)read_buf_avail) != read_buf_avail) + { + status = TINFL_STATUS_FAILED; + mz_zip_set_error(pZip, MZ_ZIP_DECOMPRESSION_FAILED); + break; + } + cur_file_ofs += read_buf_avail; + comp_remaining -= read_buf_avail; + read_buf_ofs = 0; + } + in_buf_size = (size_t)read_buf_avail; + status = tinfl_decompress(&inflator, (mz_uint8 *)pRead_buf + read_buf_ofs, &in_buf_size, (mz_uint8 *)pBuf, (mz_uint8 *)pBuf + out_buf_ofs, &out_buf_size, TINFL_FLAG_USING_NON_WRAPPING_OUTPUT_BUF | (comp_remaining ? TINFL_FLAG_HAS_MORE_INPUT : 0)); + read_buf_avail -= in_buf_size; + read_buf_ofs += in_buf_size; + out_buf_ofs += out_buf_size; + } while (status == TINFL_STATUS_NEEDS_MORE_INPUT); + + if (status == TINFL_STATUS_DONE) + { + /* Make sure the entire file was decompressed, and check its CRC. */ + if (out_buf_ofs != file_stat.m_uncomp_size) + { + mz_zip_set_error(pZip, MZ_ZIP_UNEXPECTED_DECOMPRESSED_SIZE); + status = TINFL_STATUS_FAILED; + } +#ifndef MINIZ_DISABLE_ZIP_READER_CRC32_CHECKS + else if (mz_crc32(MZ_CRC32_INIT, (const mz_uint8 *)pBuf, (size_t)file_stat.m_uncomp_size) != file_stat.m_crc32) + { + mz_zip_set_error(pZip, MZ_ZIP_CRC_CHECK_FAILED); + status = TINFL_STATUS_FAILED; + } +#endif + } + + if ((!pZip->m_pState->m_pMem) && (!pUser_read_buf)) + pZip->m_pFree(pZip->m_pAlloc_opaque, pRead_buf); + + return status == TINFL_STATUS_DONE; +} + +mz_bool mz_zip_reader_extract_file_to_mem_no_alloc(mz_zip_archive *pZip, const char *pFilename, void *pBuf, size_t buf_size, mz_uint flags, void *pUser_read_buf, size_t user_read_buf_size) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 file_index; + if (!mz_zip_reader_locate_file_v2(pZip, pFilename, NULL, flags, &file_index)) + return MZ_FALSE; + return mz_zip_reader_extract_to_mem_no_alloc(pZip, file_index, pBuf, buf_size, flags, pUser_read_buf, user_read_buf_size); +} + +mz_bool mz_zip_reader_extract_to_mem(mz_zip_archive *pZip, mz_uint file_index, void *pBuf, size_t buf_size, mz_uint flags) +{ + return mz_zip_reader_extract_to_mem_no_alloc(pZip, file_index, pBuf, buf_size, flags, NULL, 0); +} + +mz_bool mz_zip_reader_extract_file_to_mem(mz_zip_archive *pZip, const char *pFilename, void *pBuf, size_t buf_size, mz_uint flags) +{ + return mz_zip_reader_extract_file_to_mem_no_alloc(pZip, pFilename, pBuf, buf_size, flags, NULL, 0); +} + +void *mz_zip_reader_extract_to_heap(mz_zip_archive *pZip, mz_uint file_index, size_t *pSize, mz_uint flags) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint64 comp_size, uncomp_size, alloc_size; + const mz_uint8 *p = mz_zip_get_cdh(pZip, file_index); + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + void *pBuf; + + if (pSize) + *pSize = 0; + + if (!p) + { + mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + return NULL; + } + + comp_size = MZ_READ_LE32(p + MZ_ZIP_CDH_COMPRESSED_SIZE_OFS); + uncomp_size = MZ_READ_LE32(p + MZ_ZIP_CDH_DECOMPRESSED_SIZE_OFS); + + alloc_size = (flags & MZ_ZIP_FLAG_COMPRESSED_DATA) ? comp_size : uncomp_size; + if (((sizeof(size_t) == sizeof(mz_uint32))) && (alloc_size > 0x7FFFFFFF)) + { + mz_zip_set_error(pZip, MZ_ZIP_INTERNAL_ERROR); + return NULL; + } + + if (NULL == (pBuf = pZip->m_pAlloc(pZip->m_pAlloc_opaque, 1, (size_t)alloc_size))) + { + mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + return NULL; + } + + if (!mz_zip_reader_extract_to_mem(pZip, file_index, pBuf, (size_t)alloc_size, flags)) + { + pZip->m_pFree(pZip->m_pAlloc_opaque, pBuf); + return NULL; + } + + if (pSize) + *pSize = (size_t)alloc_size; + return pBuf; +} + +void *mz_zip_reader_extract_file_to_heap(mz_zip_archive *pZip, const char *pFilename, size_t *pSize, mz_uint flags) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 file_index; + if (!mz_zip_reader_locate_file_v2(pZip, pFilename, NULL, flags, &file_index)) + { + if (pSize) + *pSize = 0; + return MZ_FALSE; + } + return mz_zip_reader_extract_to_heap(pZip, file_index, pSize, flags); +} + +mz_bool mz_zip_reader_extract_to_callback(mz_zip_archive *pZip, mz_uint file_index, mz_file_write_func pCallback, void *pOpaque, mz_uint flags) +{ + int status = TINFL_STATUS_DONE; +#ifndef MINIZ_DISABLE_ZIP_READER_CRC32_CHECKS + mz_uint file_crc32 = MZ_CRC32_INIT; +#endif + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint64 read_buf_size, read_buf_ofs = 0, read_buf_avail, comp_remaining, out_buf_ofs = 0, cur_file_ofs; + mz_zip_archive_file_stat file_stat; + void *pRead_buf = NULL; + void *pWrite_buf = NULL; + mz_uint32 local_header_u32[(MZ_ZIP_LOCAL_DIR_HEADER_SIZE + sizeof(mz_uint32) - 1) / sizeof(mz_uint32)]; + mz_uint8 *pLocal_header = (mz_uint8 *)local_header_u32; + + if ((!pZip) || (!pZip->m_pState) || (!pCallback) || (!pZip->m_pRead)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + if (!mz_zip_reader_file_stat(pZip, file_index, &file_stat)) + return MZ_FALSE; + + /* A directory or zero length file */ + if ((file_stat.m_is_directory) || (!file_stat.m_comp_size)) + return MZ_TRUE; + + /* Encryption and patch files are not supported. */ + if (file_stat.m_bit_flag & (MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_IS_ENCRYPTED | MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_USES_STRONG_ENCRYPTION | MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_COMPRESSED_PATCH_FLAG)) + return mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_ENCRYPTION); + + /* This function only supports decompressing stored and deflate. */ + if ((!(flags & MZ_ZIP_FLAG_COMPRESSED_DATA)) && (file_stat.m_method != 0) && (file_stat.m_method != MZ_DEFLATED)) + return mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_METHOD); + + /* Read and do some minimal validation of the local directory entry (this doesn't crack the zip64 stuff, which we already have from the central dir) */ + cur_file_ofs = file_stat.m_local_header_ofs; + if (pZip->m_pRead(pZip->m_pIO_opaque, cur_file_ofs, pLocal_header, MZ_ZIP_LOCAL_DIR_HEADER_SIZE) != MZ_ZIP_LOCAL_DIR_HEADER_SIZE) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); + + if (MZ_READ_LE32(pLocal_header) != MZ_ZIP_LOCAL_DIR_HEADER_SIG) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + cur_file_ofs += MZ_ZIP_LOCAL_DIR_HEADER_SIZE + MZ_READ_LE16(pLocal_header + MZ_ZIP_LDH_FILENAME_LEN_OFS) + MZ_READ_LE16(pLocal_header + MZ_ZIP_LDH_EXTRA_LEN_OFS); + if ((cur_file_ofs + file_stat.m_comp_size) > pZip->m_archive_size) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + /* Decompress the file either directly from memory or from a file input buffer. */ + if (pZip->m_pState->m_pMem) + { + pRead_buf = (mz_uint8 *)pZip->m_pState->m_pMem + cur_file_ofs; + read_buf_size = read_buf_avail = file_stat.m_comp_size; + comp_remaining = 0; + } + else + { + read_buf_size = MZ_MIN(file_stat.m_comp_size, (mz_uint64)MZ_ZIP_MAX_IO_BUF_SIZE); + if (NULL == (pRead_buf = pZip->m_pAlloc(pZip->m_pAlloc_opaque, 1, (size_t)read_buf_size))) + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + + read_buf_avail = 0; + comp_remaining = file_stat.m_comp_size; + } + + if ((flags & MZ_ZIP_FLAG_COMPRESSED_DATA) || (!file_stat.m_method)) + { + /* The file is stored or the caller has requested the compressed data. */ + if (pZip->m_pState->m_pMem) + { + if (((sizeof(size_t) == sizeof(mz_uint32))) && (file_stat.m_comp_size > MZ_UINT32_MAX)) + return mz_zip_set_error(pZip, MZ_ZIP_INTERNAL_ERROR); + + if (pCallback(pOpaque, out_buf_ofs, pRead_buf, (size_t)file_stat.m_comp_size) != file_stat.m_comp_size) + { + mz_zip_set_error(pZip, MZ_ZIP_WRITE_CALLBACK_FAILED); + status = TINFL_STATUS_FAILED; + } + else if (!(flags & MZ_ZIP_FLAG_COMPRESSED_DATA)) + { +#ifndef MINIZ_DISABLE_ZIP_READER_CRC32_CHECKS + file_crc32 = (mz_uint32)mz_crc32(file_crc32, (const mz_uint8 *)pRead_buf, (size_t)file_stat.m_comp_size); +#endif + } + + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + cur_file_ofs += file_stat.m_comp_size; + out_buf_ofs += file_stat.m_comp_size; + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + comp_remaining = 0; + } + else + { + while (comp_remaining) + { + read_buf_avail = MZ_MIN(read_buf_size, comp_remaining); + if (pZip->m_pRead(pZip->m_pIO_opaque, cur_file_ofs, pRead_buf, (size_t)read_buf_avail) != read_buf_avail) + { + mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); + status = TINFL_STATUS_FAILED; + break; + } + +#ifndef MINIZ_DISABLE_ZIP_READER_CRC32_CHECKS + if (!(flags & MZ_ZIP_FLAG_COMPRESSED_DATA)) + { + file_crc32 = (mz_uint32)mz_crc32(file_crc32, (const mz_uint8 *)pRead_buf, (size_t)read_buf_avail); + } +#endif + + if (pCallback(pOpaque, out_buf_ofs, pRead_buf, (size_t)read_buf_avail) != read_buf_avail) + { + mz_zip_set_error(pZip, MZ_ZIP_WRITE_CALLBACK_FAILED); + status = TINFL_STATUS_FAILED; + break; + } + + cur_file_ofs += read_buf_avail; + out_buf_ofs += read_buf_avail; + comp_remaining -= read_buf_avail; + } + } + } + else + { + tinfl_decompressor inflator; + tinfl_init(&inflator); + + if (NULL == (pWrite_buf = pZip->m_pAlloc(pZip->m_pAlloc_opaque, 1, TINFL_LZ_DICT_SIZE))) + { + mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + status = TINFL_STATUS_FAILED; + } + else + { + do + { + mz_uint8 *pWrite_buf_cur = (mz_uint8 *)pWrite_buf + (out_buf_ofs & (TINFL_LZ_DICT_SIZE - 1)); + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + size_t in_buf_size, out_buf_size = TINFL_LZ_DICT_SIZE - (out_buf_ofs & (TINFL_LZ_DICT_SIZE - 1)); + if ((!read_buf_avail) && (!pZip->m_pState->m_pMem)) + { + read_buf_avail = MZ_MIN(read_buf_size, comp_remaining); + if (pZip->m_pRead(pZip->m_pIO_opaque, cur_file_ofs, pRead_buf, (size_t)read_buf_avail) != read_buf_avail) + { + mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); + status = TINFL_STATUS_FAILED; + break; + } + cur_file_ofs += read_buf_avail; + comp_remaining -= read_buf_avail; + read_buf_ofs = 0; + } + + in_buf_size = (size_t)read_buf_avail; + status = tinfl_decompress(&inflator, (const mz_uint8 *)pRead_buf + read_buf_ofs, &in_buf_size, (mz_uint8 *)pWrite_buf, pWrite_buf_cur, &out_buf_size, comp_remaining ? TINFL_FLAG_HAS_MORE_INPUT : 0); + read_buf_avail -= in_buf_size; + read_buf_ofs += in_buf_size; + + if (out_buf_size) + { + if (pCallback(pOpaque, out_buf_ofs, pWrite_buf_cur, out_buf_size) != out_buf_size) + { + mz_zip_set_error(pZip, MZ_ZIP_WRITE_CALLBACK_FAILED); + status = TINFL_STATUS_FAILED; + break; + } + +#ifndef MINIZ_DISABLE_ZIP_READER_CRC32_CHECKS + file_crc32 = (mz_uint32)mz_crc32(file_crc32, pWrite_buf_cur, out_buf_size); +#endif + if ((out_buf_ofs += out_buf_size) > file_stat.m_uncomp_size) + { + mz_zip_set_error(pZip, MZ_ZIP_DECOMPRESSION_FAILED); + status = TINFL_STATUS_FAILED; + break; + } + } + } while ((status == TINFL_STATUS_NEEDS_MORE_INPUT) || (status == TINFL_STATUS_HAS_MORE_OUTPUT)); + } + } + + if ((status == TINFL_STATUS_DONE) && (!(flags & MZ_ZIP_FLAG_COMPRESSED_DATA))) + { + /* Make sure the entire file was decompressed, and check its CRC. */ + if (out_buf_ofs != file_stat.m_uncomp_size) + { + mz_zip_set_error(pZip, MZ_ZIP_UNEXPECTED_DECOMPRESSED_SIZE); + status = TINFL_STATUS_FAILED; + } +#ifndef MINIZ_DISABLE_ZIP_READER_CRC32_CHECKS + else if (file_crc32 != file_stat.m_crc32) + { + mz_zip_set_error(pZip, MZ_ZIP_DECOMPRESSION_FAILED); + status = TINFL_STATUS_FAILED; + } +#endif + } + + if (!pZip->m_pState->m_pMem) + pZip->m_pFree(pZip->m_pAlloc_opaque, pRead_buf); + + if (pWrite_buf) + pZip->m_pFree(pZip->m_pAlloc_opaque, pWrite_buf); + + return status == TINFL_STATUS_DONE; +} + +mz_bool mz_zip_reader_extract_file_to_callback(mz_zip_archive *pZip, const char *pFilename, mz_file_write_func pCallback, void *pOpaque, mz_uint flags) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 file_index; + if (!mz_zip_reader_locate_file_v2(pZip, pFilename, NULL, flags, &file_index)) + return MZ_FALSE; + + return mz_zip_reader_extract_to_callback(pZip, file_index, pCallback, pOpaque, flags); +} + +mz_zip_reader_extract_iter_state* mz_zip_reader_extract_iter_new(mz_zip_archive *pZip, mz_uint file_index, mz_uint flags) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_zip_reader_extract_iter_state *pState; + mz_uint32 local_header_u32[(MZ_ZIP_LOCAL_DIR_HEADER_SIZE + sizeof(mz_uint32) - 1) / sizeof(mz_uint32)]; + mz_uint8 *pLocal_header = (mz_uint8 *)local_header_u32; + + /* Argument sanity check */ + if ((!pZip) || (!pZip->m_pState)) + return NULL; + + /* Allocate an iterator status structure */ + pState = (mz_zip_reader_extract_iter_state*)pZip->m_pAlloc(pZip->m_pAlloc_opaque, 1, sizeof(mz_zip_reader_extract_iter_state)); + if (!pState) + { + mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + return NULL; + } + + /* Fetch file details */ + if (!mz_zip_reader_file_stat(pZip, file_index, &pState->file_stat)) + { + pZip->m_pFree(pZip->m_pAlloc_opaque, pState); + return NULL; + } + + /* Encryption and patch files are not supported. */ + if (pState->file_stat.m_bit_flag & (MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_IS_ENCRYPTED | MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_USES_STRONG_ENCRYPTION | MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_COMPRESSED_PATCH_FLAG)) + { + mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_ENCRYPTION); + pZip->m_pFree(pZip->m_pAlloc_opaque, pState); + return NULL; + } + + /* This function only supports decompressing stored and deflate. */ + if ((!(flags & MZ_ZIP_FLAG_COMPRESSED_DATA)) && (pState->file_stat.m_method != 0) && (pState->file_stat.m_method != MZ_DEFLATED)) + { + mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_METHOD); + pZip->m_pFree(pZip->m_pAlloc_opaque, pState); + return NULL; + } + + /* Init state - save args */ + pState->pZip = pZip; + pState->flags = flags; + + /* Init state - reset variables to defaults */ + pState->status = TINFL_STATUS_DONE; +#ifndef MINIZ_DISABLE_ZIP_READER_CRC32_CHECKS + pState->file_crc32 = MZ_CRC32_INIT; +#endif + pState->read_buf_ofs = 0; + pState->out_buf_ofs = 0; + pState->pRead_buf = NULL; + pState->pWrite_buf = NULL; + pState->out_blk_remain = 0; + + /* Read and parse the local directory entry. */ + pState->cur_file_ofs = pState->file_stat.m_local_header_ofs; + if (pZip->m_pRead(pZip->m_pIO_opaque, pState->cur_file_ofs, pLocal_header, MZ_ZIP_LOCAL_DIR_HEADER_SIZE) != MZ_ZIP_LOCAL_DIR_HEADER_SIZE) + { + mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); + pZip->m_pFree(pZip->m_pAlloc_opaque, pState); + return NULL; + } + + if (MZ_READ_LE32(pLocal_header) != MZ_ZIP_LOCAL_DIR_HEADER_SIG) + { + mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + pZip->m_pFree(pZip->m_pAlloc_opaque, pState); + return NULL; + } + + pState->cur_file_ofs += MZ_ZIP_LOCAL_DIR_HEADER_SIZE + MZ_READ_LE16(pLocal_header + MZ_ZIP_LDH_FILENAME_LEN_OFS) + MZ_READ_LE16(pLocal_header + MZ_ZIP_LDH_EXTRA_LEN_OFS); + if ((pState->cur_file_ofs + pState->file_stat.m_comp_size) > pZip->m_archive_size) + { + mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + pZip->m_pFree(pZip->m_pAlloc_opaque, pState); + return NULL; + } + + /* Decompress the file either directly from memory or from a file input buffer. */ + if (pZip->m_pState->m_pMem) + { + pState->pRead_buf = (mz_uint8 *)pZip->m_pState->m_pMem + pState->cur_file_ofs; + pState->read_buf_size = pState->read_buf_avail = pState->file_stat.m_comp_size; + pState->comp_remaining = pState->file_stat.m_comp_size; + } + else + { + if (!((flags & MZ_ZIP_FLAG_COMPRESSED_DATA) || (!pState->file_stat.m_method))) + { + /* Decompression required, therefore intermediate read buffer required */ + pState->read_buf_size = MZ_MIN(pState->file_stat.m_comp_size, MZ_ZIP_MAX_IO_BUF_SIZE); + if (NULL == (pState->pRead_buf = pZip->m_pAlloc(pZip->m_pAlloc_opaque, 1, (size_t)pState->read_buf_size))) + { + mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + pZip->m_pFree(pZip->m_pAlloc_opaque, pState); + return NULL; + } + } + else + { + /* Decompression not required - we will be reading directly into user buffer, no temp buf required */ + pState->read_buf_size = 0; + } + pState->read_buf_avail = 0; + pState->comp_remaining = pState->file_stat.m_comp_size; + } + + if (!((flags & MZ_ZIP_FLAG_COMPRESSED_DATA) || (!pState->file_stat.m_method))) + { + /* Decompression required, init decompressor */ + tinfl_init( &pState->inflator ); + + /* Allocate write buffer */ + if (NULL == (pState->pWrite_buf = pZip->m_pAlloc(pZip->m_pAlloc_opaque, 1, TINFL_LZ_DICT_SIZE))) + { + mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + if (pState->pRead_buf) + pZip->m_pFree(pZip->m_pAlloc_opaque, pState->pRead_buf); + pZip->m_pFree(pZip->m_pAlloc_opaque, pState); + return NULL; + } + } + + return pState; +} + +mz_zip_reader_extract_iter_state* mz_zip_reader_extract_file_iter_new(mz_zip_archive *pZip, const char *pFilename, mz_uint flags) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 file_index; + + /* Locate file index by name */ + if (!mz_zip_reader_locate_file_v2(pZip, pFilename, NULL, flags, &file_index)) + return NULL; + + /* Construct iterator */ + return mz_zip_reader_extract_iter_new(pZip, file_index, flags); +} + +size_t mz_zip_reader_extract_iter_read(mz_zip_reader_extract_iter_state* pState, void* pvBuf, size_t buf_size) +{ + size_t copied_to_caller = 0; + + /* Argument sanity check */ + if ((!pState) || (!pState->pZip) || (!pState->pZip->m_pState) || (!pvBuf)) + return 0; + + if ((pState->flags & MZ_ZIP_FLAG_COMPRESSED_DATA) || (!pState->file_stat.m_method)) + { + /* The file is stored or the caller has requested the compressed data, calc amount to return. */ + copied_to_caller = (size_t)MZ_MIN( buf_size, pState->comp_remaining ); + + /* Zip is in memory....or requires reading from a file? */ + if (pState->pZip->m_pState->m_pMem) + { + /* Copy data to caller's buffer */ + memcpy( pvBuf, pState->pRead_buf, copied_to_caller ); + pState->pRead_buf = ((mz_uint8*)pState->pRead_buf) + copied_to_caller; + } + else + { + /* Read directly into caller's buffer */ + if (pState->pZip->m_pRead(pState->pZip->m_pIO_opaque, pState->cur_file_ofs, pvBuf, copied_to_caller) != copied_to_caller) + { + /* Failed to read all that was asked for, flag failure and alert user */ + mz_zip_set_error(pState->pZip, MZ_ZIP_FILE_READ_FAILED); + pState->status = TINFL_STATUS_FAILED; + copied_to_caller = 0; + } + } + +#ifndef MINIZ_DISABLE_ZIP_READER_CRC32_CHECKS + /* Compute CRC if not returning compressed data only */ + if (!(pState->flags & MZ_ZIP_FLAG_COMPRESSED_DATA)) + pState->file_crc32 = (mz_uint32)mz_crc32(pState->file_crc32, (const mz_uint8 *)pvBuf, copied_to_caller); +#endif + + /* Advance offsets, dec counters */ + pState->cur_file_ofs += copied_to_caller; + pState->out_buf_ofs += copied_to_caller; + pState->comp_remaining -= copied_to_caller; + } + else + { + do + { + /* Calc ptr to write buffer - given current output pos and block size */ + mz_uint8 *pWrite_buf_cur = (mz_uint8 *)pState->pWrite_buf + (pState->out_buf_ofs & (TINFL_LZ_DICT_SIZE - 1)); + + /* Calc max output size - given current output pos and block size */ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + size_t in_buf_size, out_buf_size = TINFL_LZ_DICT_SIZE - (pState->out_buf_ofs & (TINFL_LZ_DICT_SIZE - 1)); + + if (!pState->out_blk_remain) + { + /* Read more data from file if none available (and reading from file) */ + if ((!pState->read_buf_avail) && (!pState->pZip->m_pState->m_pMem)) + { + /* Calc read size */ + pState->read_buf_avail = MZ_MIN(pState->read_buf_size, pState->comp_remaining); + if (pState->pZip->m_pRead(pState->pZip->m_pIO_opaque, pState->cur_file_ofs, pState->pRead_buf, (size_t)pState->read_buf_avail) != pState->read_buf_avail) + { + mz_zip_set_error(pState->pZip, MZ_ZIP_FILE_READ_FAILED); + pState->status = TINFL_STATUS_FAILED; + break; + } + + /* Advance offsets, dec counters */ + pState->cur_file_ofs += pState->read_buf_avail; + pState->comp_remaining -= pState->read_buf_avail; + pState->read_buf_ofs = 0; + } + + /* Perform decompression */ + in_buf_size = (size_t)pState->read_buf_avail; + pState->status = tinfl_decompress(&pState->inflator, (const mz_uint8 *)pState->pRead_buf + pState->read_buf_ofs, &in_buf_size, (mz_uint8 *)pState->pWrite_buf, pWrite_buf_cur, &out_buf_size, pState->comp_remaining ? TINFL_FLAG_HAS_MORE_INPUT : 0); + pState->read_buf_avail -= in_buf_size; + pState->read_buf_ofs += in_buf_size; + + /* Update current output block size remaining */ + pState->out_blk_remain = out_buf_size; + } + + if (pState->out_blk_remain) + { + /* Calc amount to return. */ + size_t to_copy = MZ_MIN( (buf_size - copied_to_caller), pState->out_blk_remain ); + + /* Copy data to caller's buffer */ + memcpy( (uint8_t*)pvBuf + copied_to_caller, pWrite_buf_cur, to_copy ); + +#ifndef MINIZ_DISABLE_ZIP_READER_CRC32_CHECKS + /* Perform CRC */ + pState->file_crc32 = (mz_uint32)mz_crc32(pState->file_crc32, pWrite_buf_cur, to_copy); +#endif + + /* Decrement data consumed from block */ + pState->out_blk_remain -= to_copy; + + /* Inc output offset, while performing sanity check */ + if ((pState->out_buf_ofs += to_copy) > pState->file_stat.m_uncomp_size) + { + mz_zip_set_error(pState->pZip, MZ_ZIP_DECOMPRESSION_FAILED); + pState->status = TINFL_STATUS_FAILED; + break; + } + + /* Increment counter of data copied to caller */ + copied_to_caller += to_copy; + } + } while ( (copied_to_caller < buf_size) && ((pState->status == TINFL_STATUS_NEEDS_MORE_INPUT) || (pState->status == TINFL_STATUS_HAS_MORE_OUTPUT)) ); + } + + /* Return how many bytes were copied into user buffer */ + return copied_to_caller; +} + +mz_bool mz_zip_reader_extract_iter_free(mz_zip_reader_extract_iter_state* pState) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int status; + + /* Argument sanity check */ + if ((!pState) || (!pState->pZip) || (!pState->pZip->m_pState)) + return MZ_FALSE; + + /* Was decompression completed and requested? */ + if ((pState->status == TINFL_STATUS_DONE) && (!(pState->flags & MZ_ZIP_FLAG_COMPRESSED_DATA))) + { + /* Make sure the entire file was decompressed, and check its CRC. */ + if (pState->out_buf_ofs != pState->file_stat.m_uncomp_size) + { + mz_zip_set_error(pState->pZip, MZ_ZIP_UNEXPECTED_DECOMPRESSED_SIZE); + pState->status = TINFL_STATUS_FAILED; + } +#ifndef MINIZ_DISABLE_ZIP_READER_CRC32_CHECKS + else if (pState->file_crc32 != pState->file_stat.m_crc32) + { + mz_zip_set_error(pState->pZip, MZ_ZIP_DECOMPRESSION_FAILED); + pState->status = TINFL_STATUS_FAILED; + } +#endif + } + + /* Free buffers */ + if (!pState->pZip->m_pState->m_pMem) + pState->pZip->m_pFree(pState->pZip->m_pAlloc_opaque, pState->pRead_buf); + if (pState->pWrite_buf) + pState->pZip->m_pFree(pState->pZip->m_pAlloc_opaque, pState->pWrite_buf); + + /* Save status */ + status = pState->status; + + /* Free context */ + pState->pZip->m_pFree(pState->pZip->m_pAlloc_opaque, pState); + + return status == TINFL_STATUS_DONE; +} + +#ifndef MINIZ_NO_STDIO +static size_t mz_zip_file_write_callback(void *pOpaque, mz_uint64 ofs, const void *pBuf, size_t n) +{ + (void)ofs; + + return MZ_FWRITE(pBuf, 1, n, (MZ_FILE *)pOpaque); +} + +mz_bool mz_zip_reader_extract_to_file(mz_zip_archive *pZip, mz_uint file_index, const char *pDst_filename, mz_uint flags) +{ + mz_bool status; + mz_zip_archive_file_stat file_stat; + MZ_FILE *pFile; + + if (!mz_zip_reader_file_stat(pZip, file_index, &file_stat)) + return MZ_FALSE; + + if ((file_stat.m_is_directory) || (!file_stat.m_is_supported)) + return mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_FEATURE); + + pFile = MZ_FOPEN(pDst_filename, "wb"); + if (!pFile) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_OPEN_FAILED); + + status = mz_zip_reader_extract_to_callback(pZip, file_index, mz_zip_file_write_callback, pFile, flags); + + if (MZ_FCLOSE(pFile) == EOF) + { + if (status) + mz_zip_set_error(pZip, MZ_ZIP_FILE_CLOSE_FAILED); + + status = MZ_FALSE; + } + +#if !defined(MINIZ_NO_TIME) && !defined(MINIZ_NO_STDIO) + if (status) + mz_zip_set_file_times(pDst_filename, file_stat.m_time, file_stat.m_time); +#endif + + return status; +} + +mz_bool mz_zip_reader_extract_file_to_file(mz_zip_archive *pZip, const char *pArchive_filename, const char *pDst_filename, mz_uint flags) +{ + mz_uint32 file_index; + if (!mz_zip_reader_locate_file_v2(pZip, pArchive_filename, NULL, flags, &file_index)) + return MZ_FALSE; + + return mz_zip_reader_extract_to_file(pZip, file_index, pDst_filename, flags); +} + +mz_bool mz_zip_reader_extract_to_cfile(mz_zip_archive *pZip, mz_uint file_index, MZ_FILE *pFile, mz_uint flags) +{ + mz_zip_archive_file_stat file_stat; + + if (!mz_zip_reader_file_stat(pZip, file_index, &file_stat)) + return MZ_FALSE; + + if ((file_stat.m_is_directory) || (!file_stat.m_is_supported)) + return mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_FEATURE); + + return mz_zip_reader_extract_to_callback(pZip, file_index, mz_zip_file_write_callback, pFile, flags); +} + +mz_bool mz_zip_reader_extract_file_to_cfile(mz_zip_archive *pZip, const char *pArchive_filename, MZ_FILE *pFile, mz_uint flags) +{ + mz_uint32 file_index; + if (!mz_zip_reader_locate_file_v2(pZip, pArchive_filename, NULL, flags, &file_index)) + return MZ_FALSE; + + return mz_zip_reader_extract_to_cfile(pZip, file_index, pFile, flags); +} +#endif /* #ifndef MINIZ_NO_STDIO */ + +static size_t mz_zip_compute_crc32_callback(void *pOpaque, mz_uint64 file_ofs, const void *pBuf, size_t n) +{ + mz_uint32 *p = (mz_uint32 *)pOpaque; + (void)file_ofs; + *p = (mz_uint32)mz_crc32(*p, (const mz_uint8 *)pBuf, n); + return n; +} + +mz_bool mz_zip_validate_file(mz_zip_archive *pZip, mz_uint file_index, mz_uint flags) +{ + mz_zip_archive_file_stat file_stat; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_zip_internal_state *pState; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + const mz_uint8 *pCentral_dir_header; + mz_bool found_zip64_ext_data_in_cdir = MZ_FALSE; + mz_bool found_zip64_ext_data_in_ldir = MZ_FALSE; + mz_uint32 local_header_u32[(MZ_ZIP_LOCAL_DIR_HEADER_SIZE + sizeof(mz_uint32) - 1) / sizeof(mz_uint32)]; + mz_uint8 *pLocal_header = (mz_uint8 *)local_header_u32; + mz_uint64 local_header_ofs = 0; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 local_header_filename_len, local_header_extra_len, local_header_crc32; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint64 local_header_comp_size, local_header_uncomp_size; + mz_uint32 uncomp_crc32 = MZ_CRC32_INIT; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_bool has_data_descriptor; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 local_header_bit_flags; + + mz_zip_array file_data_array; + mz_zip_array_init(&file_data_array, 1); + + if ((!pZip) || (!pZip->m_pState) || (!pZip->m_pAlloc) || (!pZip->m_pFree) || (!pZip->m_pRead)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + if (file_index > pZip->m_total_files) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + pState = pZip->m_pState; + + pCentral_dir_header = mz_zip_get_cdh(pZip, file_index); + + if (!mz_zip_file_stat_internal(pZip, file_index, pCentral_dir_header, &file_stat, &found_zip64_ext_data_in_cdir)) + return MZ_FALSE; + + /* A directory or zero length file */ + if ((file_stat.m_is_directory) || (!file_stat.m_uncomp_size)) + return MZ_TRUE; + + /* Encryption and patch files are not supported. */ + if (file_stat.m_is_encrypted) + return mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_ENCRYPTION); + + /* This function only supports stored and deflate. */ + if ((file_stat.m_method != 0) && (file_stat.m_method != MZ_DEFLATED)) + return mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_METHOD); + + if (!file_stat.m_is_supported) + return mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_FEATURE); + + /* Read and parse the local directory entry. */ + local_header_ofs = file_stat.m_local_header_ofs; + if (pZip->m_pRead(pZip->m_pIO_opaque, local_header_ofs, pLocal_header, MZ_ZIP_LOCAL_DIR_HEADER_SIZE) != MZ_ZIP_LOCAL_DIR_HEADER_SIZE) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); + + if (MZ_READ_LE32(pLocal_header) != MZ_ZIP_LOCAL_DIR_HEADER_SIG) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + local_header_filename_len = MZ_READ_LE16(pLocal_header + MZ_ZIP_LDH_FILENAME_LEN_OFS); + local_header_extra_len = MZ_READ_LE16(pLocal_header + MZ_ZIP_LDH_EXTRA_LEN_OFS); + local_header_comp_size = MZ_READ_LE32(pLocal_header + MZ_ZIP_LDH_COMPRESSED_SIZE_OFS); + local_header_uncomp_size = MZ_READ_LE32(pLocal_header + MZ_ZIP_LDH_DECOMPRESSED_SIZE_OFS); + local_header_crc32 = MZ_READ_LE32(pLocal_header + MZ_ZIP_LDH_CRC32_OFS); + local_header_bit_flags = MZ_READ_LE16(pLocal_header + MZ_ZIP_LDH_BIT_FLAG_OFS); + has_data_descriptor = (local_header_bit_flags & 8) != 0; + + if (local_header_filename_len != strlen(file_stat.m_filename)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + if ((local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + local_header_filename_len + local_header_extra_len + file_stat.m_comp_size) > pZip->m_archive_size) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + if (!mz_zip_array_resize(pZip, &file_data_array, MZ_MAX(local_header_filename_len, local_header_extra_len), MZ_FALSE)) + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + + if (local_header_filename_len) + { + if (pZip->m_pRead(pZip->m_pIO_opaque, local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE, file_data_array.m_p, local_header_filename_len) != local_header_filename_len) + { + mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); + goto handle_failure; + } + + /* I've seen 1 archive that had the same pathname, but used backslashes in the local dir and forward slashes in the central dir. Do we care about this? For now, this case will fail validation. */ + // NOLINTNEXTLINE(clang-analyzer-unix.cstring.NullArg) + if (memcmp(file_stat.m_filename, file_data_array.m_p, local_header_filename_len) != 0) + { + mz_zip_set_error(pZip, MZ_ZIP_VALIDATION_FAILED); + goto handle_failure; + } + } + + if ((local_header_extra_len) && ((local_header_comp_size == MZ_UINT32_MAX) || (local_header_uncomp_size == MZ_UINT32_MAX))) + { + mz_uint32 extra_size_remaining = local_header_extra_len; + const mz_uint8 *pExtra_data = (const mz_uint8 *)file_data_array.m_p; + + if (pZip->m_pRead(pZip->m_pIO_opaque, local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + local_header_filename_len, file_data_array.m_p, local_header_extra_len) != local_header_extra_len) + { + mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); + goto handle_failure; + } + + do + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 field_id, field_data_size, field_total_size; + + if (extra_size_remaining < (sizeof(mz_uint16) * 2)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + // NOLINTNEXTLINE(clang-analyzer-core.NullDereference) + field_id = MZ_READ_LE16(pExtra_data); + field_data_size = MZ_READ_LE16(pExtra_data + sizeof(mz_uint16)); + field_total_size = field_data_size + sizeof(mz_uint16) * 2; + + if (field_total_size > extra_size_remaining) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + if (field_id == MZ_ZIP64_EXTENDED_INFORMATION_FIELD_HEADER_ID) + { + const mz_uint8 *pSrc_field_data = pExtra_data + sizeof(mz_uint32); + + if (field_data_size < sizeof(mz_uint64) * 2) + { + mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + goto handle_failure; + } + + local_header_uncomp_size = MZ_READ_LE64(pSrc_field_data); + local_header_comp_size = MZ_READ_LE64(pSrc_field_data + sizeof(mz_uint64)); + + found_zip64_ext_data_in_ldir = MZ_TRUE; + break; + } + + pExtra_data += field_total_size; + extra_size_remaining -= field_total_size; + } while (extra_size_remaining); + } + + /* TODO: parse local header extra data when local_header_comp_size is 0xFFFFFFFF! (big_descriptor.zip) */ + /* I've seen zips in the wild with the data descriptor bit set, but proper local header values and bogus data descriptors */ + if ((has_data_descriptor) && (!local_header_comp_size) && (!local_header_crc32)) + { + mz_uint8 descriptor_buf[32]; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_bool has_id; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + const mz_uint8 *pSrc; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 file_crc32; + mz_uint64 comp_size = 0, uncomp_size = 0; + + mz_uint32 num_descriptor_uint32s = ((pState->m_zip64) || (found_zip64_ext_data_in_ldir)) ? 6 : 4; + + if (pZip->m_pRead(pZip->m_pIO_opaque, local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + local_header_filename_len + local_header_extra_len + file_stat.m_comp_size, descriptor_buf, sizeof(mz_uint32) * num_descriptor_uint32s) != (sizeof(mz_uint32) * num_descriptor_uint32s)) + { + mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); + goto handle_failure; + } + + has_id = (MZ_READ_LE32(descriptor_buf) == MZ_ZIP_DATA_DESCRIPTOR_ID); + pSrc = has_id ? (descriptor_buf + sizeof(mz_uint32)) : descriptor_buf; + + file_crc32 = MZ_READ_LE32(pSrc); + + if ((pState->m_zip64) || (found_zip64_ext_data_in_ldir)) + { + comp_size = MZ_READ_LE64(pSrc + sizeof(mz_uint32)); + uncomp_size = MZ_READ_LE64(pSrc + sizeof(mz_uint32) + sizeof(mz_uint64)); + } + else + { + comp_size = MZ_READ_LE32(pSrc + sizeof(mz_uint32)); + uncomp_size = MZ_READ_LE32(pSrc + sizeof(mz_uint32) + sizeof(mz_uint32)); + } + + if ((file_crc32 != file_stat.m_crc32) || (comp_size != file_stat.m_comp_size) || (uncomp_size != file_stat.m_uncomp_size)) + { + mz_zip_set_error(pZip, MZ_ZIP_VALIDATION_FAILED); + goto handle_failure; + } + } + else + { + if ((local_header_crc32 != file_stat.m_crc32) || (local_header_comp_size != file_stat.m_comp_size) || (local_header_uncomp_size != file_stat.m_uncomp_size)) + { + mz_zip_set_error(pZip, MZ_ZIP_VALIDATION_FAILED); + goto handle_failure; + } + } + + mz_zip_array_clear(pZip, &file_data_array); + + if ((flags & MZ_ZIP_FLAG_VALIDATE_HEADERS_ONLY) == 0) + { + if (!mz_zip_reader_extract_to_callback(pZip, file_index, mz_zip_compute_crc32_callback, &uncomp_crc32, 0)) + return MZ_FALSE; + + /* 1 more check to be sure, although the extract checks too. */ + if (uncomp_crc32 != file_stat.m_crc32) + { + mz_zip_set_error(pZip, MZ_ZIP_VALIDATION_FAILED); + return MZ_FALSE; + } + } + + return MZ_TRUE; + +handle_failure: + mz_zip_array_clear(pZip, &file_data_array); + return MZ_FALSE; +} + +mz_bool mz_zip_validate_archive(mz_zip_archive *pZip, mz_uint flags) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_zip_internal_state *pState; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + uint32_t i; + + if ((!pZip) || (!pZip->m_pState) || (!pZip->m_pAlloc) || (!pZip->m_pFree) || (!pZip->m_pRead)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + pState = pZip->m_pState; + + /* Basic sanity checks */ + if (!pState->m_zip64) + { + if (pZip->m_total_files > MZ_UINT16_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_ARCHIVE_TOO_LARGE); + + if (pZip->m_archive_size > MZ_UINT32_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_ARCHIVE_TOO_LARGE); + } + else + { + if (pZip->m_total_files >= MZ_UINT32_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_ARCHIVE_TOO_LARGE); + + if (pState->m_central_dir.m_size >= MZ_UINT32_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_ARCHIVE_TOO_LARGE); + } + + for (i = 0; i < pZip->m_total_files; i++) + { + if (MZ_ZIP_FLAG_VALIDATE_LOCATE_FILE_FLAG & flags) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 found_index; + mz_zip_archive_file_stat stat; + + if (!mz_zip_reader_file_stat(pZip, i, &stat)) + return MZ_FALSE; + + if (!mz_zip_reader_locate_file_v2(pZip, stat.m_filename, NULL, 0, &found_index)) + return MZ_FALSE; + + /* This check can fail if there are duplicate filenames in the archive (which we don't check for when writing - that's up to the user) */ + if (found_index != i) + return mz_zip_set_error(pZip, MZ_ZIP_VALIDATION_FAILED); + } + + if (!mz_zip_validate_file(pZip, i, flags)) + return MZ_FALSE; + } + + return MZ_TRUE; +} + +mz_bool mz_zip_validate_mem_archive(const void *pMem, size_t size, mz_uint flags, mz_zip_error *pErr) +{ + mz_bool success = MZ_TRUE; + mz_zip_archive zip; + mz_zip_error actual_err = MZ_ZIP_NO_ERROR; + + if ((!pMem) || (!size)) + { + if (pErr) + *pErr = MZ_ZIP_INVALID_PARAMETER; + return MZ_FALSE; + } + + mz_zip_zero_struct(&zip); + + if (!mz_zip_reader_init_mem(&zip, pMem, size, flags)) + { + if (pErr) + *pErr = zip.m_last_error; + return MZ_FALSE; + } + + if (!mz_zip_validate_archive(&zip, flags)) + { + actual_err = zip.m_last_error; + success = MZ_FALSE; + } + + if (!mz_zip_reader_end_internal(&zip, success)) + { + if (!actual_err) + actual_err = zip.m_last_error; + success = MZ_FALSE; + } + + if (pErr) + *pErr = actual_err; + + return success; +} + +#ifndef MINIZ_NO_STDIO +mz_bool mz_zip_validate_file_archive(const char *pFilename, mz_uint flags, mz_zip_error *pErr) +{ + mz_bool success = MZ_TRUE; + mz_zip_archive zip; + mz_zip_error actual_err = MZ_ZIP_NO_ERROR; + + if (!pFilename) + { + if (pErr) + *pErr = MZ_ZIP_INVALID_PARAMETER; + return MZ_FALSE; + } + + mz_zip_zero_struct(&zip); + + if (!mz_zip_reader_init_file_v2(&zip, pFilename, flags, 0, 0)) + { + if (pErr) + *pErr = zip.m_last_error; + return MZ_FALSE; + } + + if (!mz_zip_validate_archive(&zip, flags)) + { + actual_err = zip.m_last_error; + success = MZ_FALSE; + } + + if (!mz_zip_reader_end_internal(&zip, success)) + { + if (!actual_err) + actual_err = zip.m_last_error; + success = MZ_FALSE; + } + + if (pErr) + *pErr = actual_err; + + return success; +} +#endif /* #ifndef MINIZ_NO_STDIO */ + +/* ------------------- .ZIP archive writing */ + +#ifndef MINIZ_NO_ARCHIVE_WRITING_APIS + +static MZ_FORCEINLINE void mz_write_le16(mz_uint8 *p, mz_uint16 v) +{ + p[0] = (mz_uint8)v; + p[1] = (mz_uint8)(v >> 8); +} +static MZ_FORCEINLINE void mz_write_le32(mz_uint8 *p, mz_uint32 v) +{ + p[0] = (mz_uint8)v; + p[1] = (mz_uint8)(v >> 8); + p[2] = (mz_uint8)(v >> 16); + p[3] = (mz_uint8)(v >> 24); +} +static MZ_FORCEINLINE void mz_write_le64(mz_uint8 *p, mz_uint64 v) +{ + mz_write_le32(p, (mz_uint32)v); + mz_write_le32(p + sizeof(mz_uint32), (mz_uint32)(v >> 32)); +} + +#define MZ_WRITE_LE16(p, v) mz_write_le16((mz_uint8 *)(p), (mz_uint16)(v)) +#define MZ_WRITE_LE32(p, v) mz_write_le32((mz_uint8 *)(p), (mz_uint32)(v)) +#define MZ_WRITE_LE64(p, v) mz_write_le64((mz_uint8 *)(p), (mz_uint64)(v)) + +static size_t mz_zip_heap_write_func(void *pOpaque, mz_uint64 file_ofs, const void *pBuf, size_t n) +{ + mz_zip_archive *pZip = (mz_zip_archive *)pOpaque; + mz_zip_internal_state *pState = pZip->m_pState; + mz_uint64 new_size = MZ_MAX(file_ofs + n, pState->m_mem_size); + + if (!n) + return 0; + + /* An allocation this big is likely to just fail on 32-bit systems, so don't even go there. */ + if ((sizeof(size_t) == sizeof(mz_uint32)) && (new_size > 0x7FFFFFFF)) + { + mz_zip_set_error(pZip, MZ_ZIP_FILE_TOO_LARGE); + return 0; + } + + if (new_size > pState->m_mem_capacity) + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + void *pNew_block; + size_t new_capacity = MZ_MAX(64, pState->m_mem_capacity); + + while (new_capacity < new_size) + new_capacity *= 2; + + if (NULL == (pNew_block = pZip->m_pRealloc(pZip->m_pAlloc_opaque, pState->m_pMem, 1, new_capacity))) + { + mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + return 0; + } + + pState->m_pMem = pNew_block; + pState->m_mem_capacity = new_capacity; + } + memcpy((mz_uint8 *)pState->m_pMem + file_ofs, pBuf, n); + pState->m_mem_size = (size_t)new_size; + return n; +} + +static mz_bool mz_zip_writer_end_internal(mz_zip_archive *pZip, mz_bool set_last_error) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_zip_internal_state *pState; + mz_bool status = MZ_TRUE; + + if ((!pZip) || (!pZip->m_pState) || (!pZip->m_pAlloc) || (!pZip->m_pFree) || ((pZip->m_zip_mode != MZ_ZIP_MODE_WRITING) && (pZip->m_zip_mode != MZ_ZIP_MODE_WRITING_HAS_BEEN_FINALIZED))) + { + if (set_last_error) + mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + return MZ_FALSE; + } + + pState = pZip->m_pState; + pZip->m_pState = NULL; + mz_zip_array_clear(pZip, &pState->m_central_dir); + mz_zip_array_clear(pZip, &pState->m_central_dir_offsets); + mz_zip_array_clear(pZip, &pState->m_sorted_central_dir_offsets); + +#ifndef MINIZ_NO_STDIO + if (pState->m_pFile) + { + if (pZip->m_zip_type == MZ_ZIP_TYPE_FILE) + { + if (MZ_FCLOSE(pState->m_pFile) == EOF) + { + if (set_last_error) + mz_zip_set_error(pZip, MZ_ZIP_FILE_CLOSE_FAILED); + status = MZ_FALSE; + } + } + + pState->m_pFile = NULL; + } +#endif /* #ifndef MINIZ_NO_STDIO */ + + if ((pZip->m_pWrite == mz_zip_heap_write_func) && (pState->m_pMem)) + { + pZip->m_pFree(pZip->m_pAlloc_opaque, pState->m_pMem); + pState->m_pMem = NULL; + } + + pZip->m_pFree(pZip->m_pAlloc_opaque, pState); + pZip->m_zip_mode = MZ_ZIP_MODE_INVALID; + return status; +} + +mz_bool mz_zip_writer_init_v2(mz_zip_archive *pZip, mz_uint64 existing_size, mz_uint flags) +{ + mz_bool zip64 = (flags & MZ_ZIP_FLAG_WRITE_ZIP64) != 0; + + if ((!pZip) || (pZip->m_pState) || (!pZip->m_pWrite) || (pZip->m_zip_mode != MZ_ZIP_MODE_INVALID)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + if (flags & MZ_ZIP_FLAG_WRITE_ALLOW_READING) + { + if (!pZip->m_pRead) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + } + + if (pZip->m_file_offset_alignment) + { + /* Ensure user specified file offset alignment is a power of 2. */ + if (pZip->m_file_offset_alignment & (pZip->m_file_offset_alignment - 1)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + } + + if (!pZip->m_pAlloc) + pZip->m_pAlloc = miniz_def_alloc_func; + if (!pZip->m_pFree) + pZip->m_pFree = miniz_def_free_func; + if (!pZip->m_pRealloc) + pZip->m_pRealloc = miniz_def_realloc_func; + + pZip->m_archive_size = existing_size; + pZip->m_central_directory_file_ofs = 0; + pZip->m_total_files = 0; + + if (NULL == (pZip->m_pState = (mz_zip_internal_state *)pZip->m_pAlloc(pZip->m_pAlloc_opaque, 1, sizeof(mz_zip_internal_state)))) + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + + memset(pZip->m_pState, 0, sizeof(mz_zip_internal_state)); + + MZ_ZIP_ARRAY_SET_ELEMENT_SIZE(&pZip->m_pState->m_central_dir, sizeof(mz_uint8)); + MZ_ZIP_ARRAY_SET_ELEMENT_SIZE(&pZip->m_pState->m_central_dir_offsets, sizeof(mz_uint32)); + MZ_ZIP_ARRAY_SET_ELEMENT_SIZE(&pZip->m_pState->m_sorted_central_dir_offsets, sizeof(mz_uint32)); + + pZip->m_pState->m_zip64 = zip64; + pZip->m_pState->m_zip64_has_extended_info_fields = zip64; + + pZip->m_zip_type = MZ_ZIP_TYPE_USER; + pZip->m_zip_mode = MZ_ZIP_MODE_WRITING; + + return MZ_TRUE; +} + +mz_bool mz_zip_writer_init(mz_zip_archive *pZip, mz_uint64 existing_size) +{ + return mz_zip_writer_init_v2(pZip, existing_size, 0); +} + +mz_bool mz_zip_writer_init_heap_v2(mz_zip_archive *pZip, size_t size_to_reserve_at_beginning, size_t initial_allocation_size, mz_uint flags) +{ + pZip->m_pWrite = mz_zip_heap_write_func; + pZip->m_pNeeds_keepalive = NULL; + + if (flags & MZ_ZIP_FLAG_WRITE_ALLOW_READING) + pZip->m_pRead = mz_zip_mem_read_func; + + pZip->m_pIO_opaque = pZip; + + if (!mz_zip_writer_init_v2(pZip, size_to_reserve_at_beginning, flags)) + return MZ_FALSE; + + pZip->m_zip_type = MZ_ZIP_TYPE_HEAP; + + if (0 != (initial_allocation_size = MZ_MAX(initial_allocation_size, size_to_reserve_at_beginning))) + { + if (NULL == (pZip->m_pState->m_pMem = pZip->m_pAlloc(pZip->m_pAlloc_opaque, 1, initial_allocation_size))) + { + mz_zip_writer_end_internal(pZip, MZ_FALSE); + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + } + pZip->m_pState->m_mem_capacity = initial_allocation_size; + } + + return MZ_TRUE; +} + +mz_bool mz_zip_writer_init_heap(mz_zip_archive *pZip, size_t size_to_reserve_at_beginning, size_t initial_allocation_size) +{ + return mz_zip_writer_init_heap_v2(pZip, size_to_reserve_at_beginning, initial_allocation_size, 0); +} + +#ifndef MINIZ_NO_STDIO +static size_t mz_zip_file_write_func(void *pOpaque, mz_uint64 file_ofs, const void *pBuf, size_t n) +{ + mz_zip_archive *pZip = (mz_zip_archive *)pOpaque; + mz_int64 cur_ofs = MZ_FTELL64(pZip->m_pState->m_pFile); + + file_ofs += pZip->m_pState->m_file_archive_start_ofs; + + if (((mz_int64)file_ofs < 0) || (((cur_ofs != (mz_int64)file_ofs)) && (MZ_FSEEK64(pZip->m_pState->m_pFile, (mz_int64)file_ofs, SEEK_SET)))) + { + mz_zip_set_error(pZip, MZ_ZIP_FILE_SEEK_FAILED); + return 0; + } + + return MZ_FWRITE(pBuf, 1, n, pZip->m_pState->m_pFile); +} + +mz_bool mz_zip_writer_init_file(mz_zip_archive *pZip, const char *pFilename, mz_uint64 size_to_reserve_at_beginning) +{ + return mz_zip_writer_init_file_v2(pZip, pFilename, size_to_reserve_at_beginning, 0); +} + +mz_bool mz_zip_writer_init_file_v2(mz_zip_archive *pZip, const char *pFilename, mz_uint64 size_to_reserve_at_beginning, mz_uint flags) +{ + MZ_FILE *pFile; + + pZip->m_pWrite = mz_zip_file_write_func; + pZip->m_pNeeds_keepalive = NULL; + + if (flags & MZ_ZIP_FLAG_WRITE_ALLOW_READING) + pZip->m_pRead = mz_zip_file_read_func; + + pZip->m_pIO_opaque = pZip; + + if (!mz_zip_writer_init_v2(pZip, size_to_reserve_at_beginning, flags)) + return MZ_FALSE; + + if (NULL == (pFile = MZ_FOPEN(pFilename, (flags & MZ_ZIP_FLAG_WRITE_ALLOW_READING) ? "w+b" : "wb"))) + { + mz_zip_writer_end(pZip); + return mz_zip_set_error(pZip, MZ_ZIP_FILE_OPEN_FAILED); + } + + pZip->m_pState->m_pFile = pFile; + pZip->m_zip_type = MZ_ZIP_TYPE_FILE; + + if (size_to_reserve_at_beginning) + { + mz_uint64 cur_ofs = 0; + char buf[4096]; + + MZ_CLEAR_OBJ(buf); + + do + { + size_t n = (size_t)MZ_MIN(sizeof(buf), size_to_reserve_at_beginning); + if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_ofs, buf, n) != n) + { + mz_zip_writer_end(pZip); + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + } + cur_ofs += n; + size_to_reserve_at_beginning -= n; + } while (size_to_reserve_at_beginning); + } + + return MZ_TRUE; +} + +mz_bool mz_zip_writer_init_cfile(mz_zip_archive *pZip, MZ_FILE *pFile, mz_uint flags) +{ + pZip->m_pWrite = mz_zip_file_write_func; + pZip->m_pNeeds_keepalive = NULL; + + if (flags & MZ_ZIP_FLAG_WRITE_ALLOW_READING) + pZip->m_pRead = mz_zip_file_read_func; + + pZip->m_pIO_opaque = pZip; + + if (!mz_zip_writer_init_v2(pZip, 0, flags)) + return MZ_FALSE; + + pZip->m_pState->m_pFile = pFile; + pZip->m_pState->m_file_archive_start_ofs = MZ_FTELL64(pZip->m_pState->m_pFile); + pZip->m_zip_type = MZ_ZIP_TYPE_CFILE; + + return MZ_TRUE; +} +#endif /* #ifndef MINIZ_NO_STDIO */ + +mz_bool mz_zip_writer_init_from_reader_v2(mz_zip_archive *pZip, const char *pFilename, mz_uint flags) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_zip_internal_state *pState; + + if ((!pZip) || (!pZip->m_pState) || (pZip->m_zip_mode != MZ_ZIP_MODE_READING)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + if (flags & MZ_ZIP_FLAG_WRITE_ZIP64) + { + /* We don't support converting a non-zip64 file to zip64 - this seems like more trouble than it's worth. (What about the existing 32-bit data descriptors that could follow the compressed data?) */ + if (!pZip->m_pState->m_zip64) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + } + + /* No sense in trying to write to an archive that's already at the support max size */ + if (pZip->m_pState->m_zip64) + { + if (pZip->m_total_files == MZ_UINT32_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_TOO_MANY_FILES); + } + else + { + if (pZip->m_total_files == MZ_UINT16_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_TOO_MANY_FILES); + + if ((pZip->m_archive_size + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + MZ_ZIP_LOCAL_DIR_HEADER_SIZE) > MZ_UINT32_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_TOO_LARGE); + } + + pState = pZip->m_pState; + + if (pState->m_pFile) + { +#ifdef MINIZ_NO_STDIO + (void)pFilename; + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); +#else + if (pZip->m_pIO_opaque != pZip) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + if (pZip->m_zip_type == MZ_ZIP_TYPE_FILE) + { + if (!pFilename) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + /* Archive is being read from stdio and was originally opened only for reading. Try to reopen as writable. */ + if (NULL == (pState->m_pFile = MZ_FREOPEN(pFilename, "r+b", pState->m_pFile))) + { + /* The mz_zip_archive is now in a bogus state because pState->m_pFile is NULL, so just close it. */ + mz_zip_reader_end_internal(pZip, MZ_FALSE); + return mz_zip_set_error(pZip, MZ_ZIP_FILE_OPEN_FAILED); + } + } + + pZip->m_pWrite = mz_zip_file_write_func; + pZip->m_pNeeds_keepalive = NULL; +#endif /* #ifdef MINIZ_NO_STDIO */ + } + else if (pState->m_pMem) + { + /* Archive lives in a memory block. Assume it's from the heap that we can resize using the realloc callback. */ + if (pZip->m_pIO_opaque != pZip) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + pState->m_mem_capacity = pState->m_mem_size; + pZip->m_pWrite = mz_zip_heap_write_func; + pZip->m_pNeeds_keepalive = NULL; + } + /* Archive is being read via a user provided read function - make sure the user has specified a write function too. */ + else if (!pZip->m_pWrite) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + /* Start writing new files at the archive's current central directory location. */ + /* TODO: We could add a flag that lets the user start writing immediately AFTER the existing central dir - this would be safer. */ + pZip->m_archive_size = pZip->m_central_directory_file_ofs; + pZip->m_central_directory_file_ofs = 0; + + /* Clear the sorted central dir offsets, they aren't useful or maintained now. */ + /* Even though we're now in write mode, files can still be extracted and verified, but file locates will be slow. */ + /* TODO: We could easily maintain the sorted central directory offsets. */ + mz_zip_array_clear(pZip, &pZip->m_pState->m_sorted_central_dir_offsets); + + pZip->m_zip_mode = MZ_ZIP_MODE_WRITING; + + return MZ_TRUE; +} + +mz_bool mz_zip_writer_init_from_reader(mz_zip_archive *pZip, const char *pFilename) +{ + return mz_zip_writer_init_from_reader_v2(pZip, pFilename, 0); +} + +/* TODO: pArchive_name is a terrible name here! */ +mz_bool mz_zip_writer_add_mem(mz_zip_archive *pZip, const char *pArchive_name, const void *pBuf, size_t buf_size, mz_uint level_and_flags) +{ + return mz_zip_writer_add_mem_ex(pZip, pArchive_name, pBuf, buf_size, NULL, 0, level_and_flags, 0, 0); +} + +typedef struct +{ + mz_zip_archive *m_pZip; + mz_uint64 m_cur_archive_file_ofs; + mz_uint64 m_comp_size; +} mz_zip_writer_add_state; + +static mz_bool mz_zip_writer_add_put_buf_callback(const void *pBuf, int len, void *pUser) +{ + mz_zip_writer_add_state *pState = (mz_zip_writer_add_state *)pUser; + if ((int)pState->m_pZip->m_pWrite(pState->m_pZip->m_pIO_opaque, pState->m_cur_archive_file_ofs, pBuf, len) != len) + return MZ_FALSE; + + pState->m_cur_archive_file_ofs += len; + pState->m_comp_size += len; + return MZ_TRUE; +} + +#define MZ_ZIP64_MAX_LOCAL_EXTRA_FIELD_SIZE (sizeof(mz_uint16) * 2 + sizeof(mz_uint64) * 2) +#define MZ_ZIP64_MAX_CENTRAL_EXTRA_FIELD_SIZE (sizeof(mz_uint16) * 2 + sizeof(mz_uint64) * 3) +static mz_uint32 mz_zip_writer_create_zip64_extra_data(mz_uint8 *pBuf, mz_uint64 *pUncomp_size, mz_uint64 *pComp_size, mz_uint64 *pLocal_header_ofs) +{ + mz_uint8 *pDst = pBuf; + mz_uint32 field_size = 0; + + MZ_WRITE_LE16(pDst + 0, MZ_ZIP64_EXTENDED_INFORMATION_FIELD_HEADER_ID); + MZ_WRITE_LE16(pDst + 2, 0); + pDst += sizeof(mz_uint16) * 2; + + if (pUncomp_size) + { + MZ_WRITE_LE64(pDst, *pUncomp_size); + pDst += sizeof(mz_uint64); + field_size += sizeof(mz_uint64); + } + + if (pComp_size) + { + MZ_WRITE_LE64(pDst, *pComp_size); + pDst += sizeof(mz_uint64); + field_size += sizeof(mz_uint64); + } + + if (pLocal_header_ofs) + { + MZ_WRITE_LE64(pDst, *pLocal_header_ofs); + pDst += sizeof(mz_uint64); + field_size += sizeof(mz_uint64); + } + + MZ_WRITE_LE16(pBuf + 2, field_size); + + return (mz_uint32)(pDst - pBuf); +} + +static mz_bool mz_zip_writer_create_local_dir_header(mz_zip_archive *pZip, mz_uint8 *pDst, mz_uint16 filename_size, mz_uint16 extra_size, mz_uint64 uncomp_size, mz_uint64 comp_size, mz_uint32 uncomp_crc32, mz_uint16 method, mz_uint16 bit_flags, mz_uint16 dos_time, mz_uint16 dos_date) +{ + (void)pZip; + memset(pDst, 0, MZ_ZIP_LOCAL_DIR_HEADER_SIZE); + MZ_WRITE_LE32(pDst + MZ_ZIP_LDH_SIG_OFS, MZ_ZIP_LOCAL_DIR_HEADER_SIG); + MZ_WRITE_LE16(pDst + MZ_ZIP_LDH_VERSION_NEEDED_OFS, method ? 20 : 0); + MZ_WRITE_LE16(pDst + MZ_ZIP_LDH_BIT_FLAG_OFS, bit_flags); + MZ_WRITE_LE16(pDst + MZ_ZIP_LDH_METHOD_OFS, method); + MZ_WRITE_LE16(pDst + MZ_ZIP_LDH_FILE_TIME_OFS, dos_time); + MZ_WRITE_LE16(pDst + MZ_ZIP_LDH_FILE_DATE_OFS, dos_date); + MZ_WRITE_LE32(pDst + MZ_ZIP_LDH_CRC32_OFS, uncomp_crc32); + MZ_WRITE_LE32(pDst + MZ_ZIP_LDH_COMPRESSED_SIZE_OFS, MZ_MIN(comp_size, MZ_UINT32_MAX)); + MZ_WRITE_LE32(pDst + MZ_ZIP_LDH_DECOMPRESSED_SIZE_OFS, MZ_MIN(uncomp_size, MZ_UINT32_MAX)); + MZ_WRITE_LE16(pDst + MZ_ZIP_LDH_FILENAME_LEN_OFS, filename_size); + MZ_WRITE_LE16(pDst + MZ_ZIP_LDH_EXTRA_LEN_OFS, extra_size); + return MZ_TRUE; +} + +static mz_bool mz_zip_writer_create_central_dir_header(mz_zip_archive *pZip, mz_uint8 *pDst, + mz_uint16 filename_size, mz_uint16 extra_size, mz_uint16 comment_size, + mz_uint64 uncomp_size, mz_uint64 comp_size, mz_uint32 uncomp_crc32, + mz_uint16 method, mz_uint16 bit_flags, mz_uint16 dos_time, mz_uint16 dos_date, + mz_uint64 local_header_ofs, mz_uint32 ext_attributes) +{ + (void)pZip; + memset(pDst, 0, MZ_ZIP_CENTRAL_DIR_HEADER_SIZE); + MZ_WRITE_LE32(pDst + MZ_ZIP_CDH_SIG_OFS, MZ_ZIP_CENTRAL_DIR_HEADER_SIG); + MZ_WRITE_LE16(pDst + MZ_ZIP_CDH_VERSION_NEEDED_OFS, method ? 20 : 0); + MZ_WRITE_LE16(pDst + MZ_ZIP_CDH_BIT_FLAG_OFS, bit_flags); + MZ_WRITE_LE16(pDst + MZ_ZIP_CDH_METHOD_OFS, method); + MZ_WRITE_LE16(pDst + MZ_ZIP_CDH_FILE_TIME_OFS, dos_time); + MZ_WRITE_LE16(pDst + MZ_ZIP_CDH_FILE_DATE_OFS, dos_date); + MZ_WRITE_LE32(pDst + MZ_ZIP_CDH_CRC32_OFS, uncomp_crc32); + MZ_WRITE_LE32(pDst + MZ_ZIP_CDH_COMPRESSED_SIZE_OFS, MZ_MIN(comp_size, MZ_UINT32_MAX)); + MZ_WRITE_LE32(pDst + MZ_ZIP_CDH_DECOMPRESSED_SIZE_OFS, MZ_MIN(uncomp_size, MZ_UINT32_MAX)); + MZ_WRITE_LE16(pDst + MZ_ZIP_CDH_FILENAME_LEN_OFS, filename_size); + MZ_WRITE_LE16(pDst + MZ_ZIP_CDH_EXTRA_LEN_OFS, extra_size); + MZ_WRITE_LE16(pDst + MZ_ZIP_CDH_COMMENT_LEN_OFS, comment_size); + MZ_WRITE_LE32(pDst + MZ_ZIP_CDH_EXTERNAL_ATTR_OFS, ext_attributes); + MZ_WRITE_LE32(pDst + MZ_ZIP_CDH_LOCAL_HEADER_OFS, MZ_MIN(local_header_ofs, MZ_UINT32_MAX)); + return MZ_TRUE; +} + +static mz_bool mz_zip_writer_add_to_central_dir(mz_zip_archive *pZip, const char *pFilename, mz_uint16 filename_size, + const void *pExtra, mz_uint16 extra_size, const void *pComment, mz_uint16 comment_size, + mz_uint64 uncomp_size, mz_uint64 comp_size, mz_uint32 uncomp_crc32, + mz_uint16 method, mz_uint16 bit_flags, mz_uint16 dos_time, mz_uint16 dos_date, + mz_uint64 local_header_ofs, mz_uint32 ext_attributes, + const char *user_extra_data, mz_uint user_extra_data_len) +{ + mz_zip_internal_state *pState = pZip->m_pState; + mz_uint32 central_dir_ofs = (mz_uint32)pState->m_central_dir.m_size; + size_t orig_central_dir_size = pState->m_central_dir.m_size; + mz_uint8 central_dir_header[MZ_ZIP_CENTRAL_DIR_HEADER_SIZE]; + + if (!pZip->m_pState->m_zip64) + { + if (local_header_ofs > 0xFFFFFFFF) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_TOO_LARGE); + } + + /* miniz doesn't support central dirs >= MZ_UINT32_MAX bytes yet */ + if (((mz_uint64)pState->m_central_dir.m_size + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + filename_size + extra_size + user_extra_data_len + comment_size) >= MZ_UINT32_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_CDIR_SIZE); + + if (!mz_zip_writer_create_central_dir_header(pZip, central_dir_header, filename_size, (mz_uint16)(extra_size + user_extra_data_len), comment_size, uncomp_size, comp_size, uncomp_crc32, method, bit_flags, dos_time, dos_date, local_header_ofs, ext_attributes)) + return mz_zip_set_error(pZip, MZ_ZIP_INTERNAL_ERROR); + + if ((!mz_zip_array_push_back(pZip, &pState->m_central_dir, central_dir_header, MZ_ZIP_CENTRAL_DIR_HEADER_SIZE)) || + (!mz_zip_array_push_back(pZip, &pState->m_central_dir, pFilename, filename_size)) || + (!mz_zip_array_push_back(pZip, &pState->m_central_dir, pExtra, extra_size)) || + (!mz_zip_array_push_back(pZip, &pState->m_central_dir, user_extra_data, user_extra_data_len)) || + (!mz_zip_array_push_back(pZip, &pState->m_central_dir, pComment, comment_size)) || + (!mz_zip_array_push_back(pZip, &pState->m_central_dir_offsets, ¢ral_dir_ofs, 1))) + { + /* Try to resize the central directory array back into its original state. */ + mz_zip_array_resize(pZip, &pState->m_central_dir, orig_central_dir_size, MZ_FALSE); + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + } + + return MZ_TRUE; +} + +static mz_bool mz_zip_writer_validate_archive_name(const char *pArchive_name) +{ + /* Basic ZIP archive filename validity checks: Valid filenames cannot start with a forward slash, cannot contain a drive letter, and cannot use DOS-style backward slashes. */ + if (*pArchive_name == '/') + return MZ_FALSE; + + /* Making sure the name does not contain drive letters or DOS style backward slashes is the responsibility of the program using miniz*/ + + return MZ_TRUE; +} + +static mz_uint mz_zip_writer_compute_padding_needed_for_file_alignment(mz_zip_archive *pZip) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 n; + if (!pZip->m_file_offset_alignment) + return 0; + n = (mz_uint32)(pZip->m_archive_size & (pZip->m_file_offset_alignment - 1)); + return (mz_uint)((pZip->m_file_offset_alignment - n) & (pZip->m_file_offset_alignment - 1)); +} + +static mz_bool mz_zip_writer_write_zeros(mz_zip_archive *pZip, mz_uint64 cur_file_ofs, mz_uint32 n) +{ + char buf[4096]; + memset(buf, 0, MZ_MIN(sizeof(buf), n)); + while (n) + { + mz_uint32 s = MZ_MIN(sizeof(buf), n); + if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_file_ofs, buf, s) != s) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + + cur_file_ofs += s; + n -= s; + } + return MZ_TRUE; +} + +mz_bool mz_zip_writer_add_mem_ex(mz_zip_archive *pZip, const char *pArchive_name, const void *pBuf, size_t buf_size, const void *pComment, mz_uint16 comment_size, mz_uint level_and_flags, + mz_uint64 uncomp_size, mz_uint32 uncomp_crc32) +{ + return mz_zip_writer_add_mem_ex_v2(pZip, pArchive_name, pBuf, buf_size, pComment, comment_size, level_and_flags, uncomp_size, uncomp_crc32, NULL, NULL, 0, NULL, 0); +} + +mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_name, const void *pBuf, size_t buf_size, const void *pComment, mz_uint16 comment_size, + mz_uint level_and_flags, mz_uint64 uncomp_size, mz_uint32 uncomp_crc32, MZ_TIME_T *last_modified, + const char *user_extra_data, mz_uint user_extra_data_len, const char *user_extra_data_central, mz_uint user_extra_data_central_len) +{ + mz_uint16 method = 0, dos_time = 0, dos_date = 0; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint level, ext_attributes = 0, num_alignment_padding_bytes; + mz_uint64 local_dir_header_ofs = pZip->m_archive_size, cur_archive_file_ofs = pZip->m_archive_size, comp_size = 0; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + size_t archive_name_size; + mz_uint8 local_dir_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE]; + tdefl_compressor *pComp = NULL; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_bool store_data_uncompressed; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_zip_internal_state *pState; + mz_uint8 *pExtra_data = NULL; + mz_uint32 extra_size = 0; + mz_uint8 extra_data[MZ_ZIP64_MAX_CENTRAL_EXTRA_FIELD_SIZE]; + mz_uint16 bit_flags = 0; + + if ((int)level_and_flags < 0) + level_and_flags = MZ_DEFAULT_LEVEL; + + if (uncomp_size || (buf_size && !(level_and_flags & MZ_ZIP_FLAG_COMPRESSED_DATA))) + bit_flags |= MZ_ZIP_LDH_BIT_FLAG_HAS_LOCATOR; + + if (!(level_and_flags & MZ_ZIP_FLAG_ASCII_FILENAME)) + bit_flags |= MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_UTF8; + + level = level_and_flags & 0xF; + store_data_uncompressed = ((!level) || (level_and_flags & MZ_ZIP_FLAG_COMPRESSED_DATA)); + + if ((!pZip) || (!pZip->m_pState) || (pZip->m_zip_mode != MZ_ZIP_MODE_WRITING) || ((buf_size) && (!pBuf)) || (!pArchive_name) || ((comment_size) && (!pComment)) || (level > MZ_UBER_COMPRESSION)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + pState = pZip->m_pState; + + if (pState->m_zip64) + { + if (pZip->m_total_files == MZ_UINT32_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_TOO_MANY_FILES); + } + else + { + if (pZip->m_total_files == MZ_UINT16_MAX) + { + pState->m_zip64 = MZ_TRUE; + /*return mz_zip_set_error(pZip, MZ_ZIP_TOO_MANY_FILES); */ + } + if ((buf_size > 0xFFFFFFFF) || (uncomp_size > 0xFFFFFFFF)) + { + pState->m_zip64 = MZ_TRUE; + /*return mz_zip_set_error(pZip, MZ_ZIP_ARCHIVE_TOO_LARGE); */ + } + } + + if ((!(level_and_flags & MZ_ZIP_FLAG_COMPRESSED_DATA)) && (uncomp_size)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + if (!mz_zip_writer_validate_archive_name(pArchive_name)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_FILENAME); + +#ifndef MINIZ_NO_TIME + if (last_modified != NULL) + { + mz_zip_time_t_to_dos_time(*last_modified, &dos_time, &dos_date); + } + else + { + MZ_TIME_T cur_time; + time(&cur_time); + mz_zip_time_t_to_dos_time(cur_time, &dos_time, &dos_date); + } +#endif /* #ifndef MINIZ_NO_TIME */ + + if (!(level_and_flags & MZ_ZIP_FLAG_COMPRESSED_DATA)) + { + // uncomp_crc32 = (mz_uint32)mz_crc32(MZ_CRC32_INIT, (const mz_uint8 *)pBuf, buf_size); + uncomp_size = buf_size; + if (uncomp_size <= 3) + { + level = 0; + store_data_uncompressed = MZ_TRUE; + } + } + + archive_name_size = strlen(pArchive_name); + if (archive_name_size > MZ_UINT16_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_FILENAME); + + num_alignment_padding_bytes = mz_zip_writer_compute_padding_needed_for_file_alignment(pZip); + + /* miniz doesn't support central dirs >= MZ_UINT32_MAX bytes yet */ + if (((mz_uint64)pState->m_central_dir.m_size + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + archive_name_size + MZ_ZIP64_MAX_CENTRAL_EXTRA_FIELD_SIZE + comment_size) >= MZ_UINT32_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_CDIR_SIZE); + + if (!pState->m_zip64) + { + /* Bail early if the archive would obviously become too large */ + if ((pZip->m_archive_size + num_alignment_padding_bytes + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + archive_name_size + + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + archive_name_size + comment_size + user_extra_data_len + + pState->m_central_dir.m_size + MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIZE + user_extra_data_central_len + + MZ_ZIP_DATA_DESCRIPTER_SIZE32) > 0xFFFFFFFF) + { + pState->m_zip64 = MZ_TRUE; + /*return mz_zip_set_error(pZip, MZ_ZIP_ARCHIVE_TOO_LARGE); */ + } + } + + if ((archive_name_size) && (pArchive_name[archive_name_size - 1] == '/')) + { + /* Set DOS Subdirectory attribute bit. */ + ext_attributes |= MZ_ZIP_DOS_DIR_ATTRIBUTE_BITFLAG; + + /* Subdirectories cannot contain data. */ + if ((buf_size) || (uncomp_size)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + } + + /* Try to do any allocations before writing to the archive, so if an allocation fails the file remains unmodified. (A good idea if we're doing an in-place modification.) */ + if ((!mz_zip_array_ensure_room(pZip, &pState->m_central_dir, MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + archive_name_size + comment_size + (pState->m_zip64 ? MZ_ZIP64_MAX_CENTRAL_EXTRA_FIELD_SIZE : 0))) || (!mz_zip_array_ensure_room(pZip, &pState->m_central_dir_offsets, 1))) + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + + if ((!store_data_uncompressed) && (buf_size)) + { + if (NULL == (pComp = (tdefl_compressor *)pZip->m_pAlloc(pZip->m_pAlloc_opaque, 1, sizeof(tdefl_compressor)))) + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + } + + if (!mz_zip_writer_write_zeros(pZip, cur_archive_file_ofs, num_alignment_padding_bytes)) + { + pZip->m_pFree(pZip->m_pAlloc_opaque, pComp); + return MZ_FALSE; + } + + local_dir_header_ofs += num_alignment_padding_bytes; + if (pZip->m_file_offset_alignment) + { + MZ_ASSERT((local_dir_header_ofs & (pZip->m_file_offset_alignment - 1)) == 0); + } + cur_archive_file_ofs += num_alignment_padding_bytes; + + MZ_CLEAR_OBJ(local_dir_header); + + if (!store_data_uncompressed || (level_and_flags & MZ_ZIP_FLAG_COMPRESSED_DATA)) + { + method = MZ_DEFLATED; + } + + if (pState->m_zip64) + { + if (uncomp_size >= MZ_UINT32_MAX || local_dir_header_ofs >= MZ_UINT32_MAX) + { + pExtra_data = extra_data; + extra_size = mz_zip_writer_create_zip64_extra_data(extra_data, (uncomp_size >= MZ_UINT32_MAX) ? &uncomp_size : NULL, + (uncomp_size >= MZ_UINT32_MAX) ? &comp_size : NULL, (local_dir_header_ofs >= MZ_UINT32_MAX) ? &local_dir_header_ofs : NULL); + } + + if (!mz_zip_writer_create_local_dir_header(pZip, local_dir_header, (mz_uint16)archive_name_size, (mz_uint16)(extra_size + user_extra_data_len), 0, 0, 0, method, bit_flags, dos_time, dos_date)) + return mz_zip_set_error(pZip, MZ_ZIP_INTERNAL_ERROR); + + if (pZip->m_pWrite(pZip->m_pIO_opaque, local_dir_header_ofs, local_dir_header, sizeof(local_dir_header)) != sizeof(local_dir_header)) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + + cur_archive_file_ofs += sizeof(local_dir_header); + + if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, pArchive_name, archive_name_size) != archive_name_size) + { + pZip->m_pFree(pZip->m_pAlloc_opaque, pComp); + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + } + cur_archive_file_ofs += archive_name_size; + + if (pExtra_data != NULL) + { + if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, extra_data, extra_size) != extra_size) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + + cur_archive_file_ofs += extra_size; + } + } + else + { + if ((comp_size > MZ_UINT32_MAX) || (cur_archive_file_ofs > MZ_UINT32_MAX)) + return mz_zip_set_error(pZip, MZ_ZIP_ARCHIVE_TOO_LARGE); + if (!mz_zip_writer_create_local_dir_header(pZip, local_dir_header, (mz_uint16)archive_name_size, (mz_uint16)user_extra_data_len, 0, 0, 0, method, bit_flags, dos_time, dos_date)) + return mz_zip_set_error(pZip, MZ_ZIP_INTERNAL_ERROR); + + if (pZip->m_pWrite(pZip->m_pIO_opaque, local_dir_header_ofs, local_dir_header, sizeof(local_dir_header)) != sizeof(local_dir_header)) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + + cur_archive_file_ofs += sizeof(local_dir_header); + + if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, pArchive_name, archive_name_size) != archive_name_size) + { + pZip->m_pFree(pZip->m_pAlloc_opaque, pComp); + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + } + cur_archive_file_ofs += archive_name_size; + } + + if (user_extra_data_len > 0) + { + if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, user_extra_data, user_extra_data_len) != user_extra_data_len) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + + cur_archive_file_ofs += user_extra_data_len; + } + + if (store_data_uncompressed) + { + if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, pBuf, buf_size) != buf_size) + { + pZip->m_pFree(pZip->m_pAlloc_opaque, pComp); + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + } + + cur_archive_file_ofs += buf_size; + comp_size = buf_size; + } + else if (buf_size) + { + mz_zip_writer_add_state state; + + state.m_pZip = pZip; + state.m_cur_archive_file_ofs = cur_archive_file_ofs; + state.m_comp_size = 0; + + if ((tdefl_init(pComp, mz_zip_writer_add_put_buf_callback, &state, tdefl_create_comp_flags_from_zip_params(level, -15, MZ_DEFAULT_STRATEGY)) != TDEFL_STATUS_OKAY) || + (tdefl_compress_buffer(pComp, pBuf, buf_size, TDEFL_FINISH) != TDEFL_STATUS_DONE)) + { + pZip->m_pFree(pZip->m_pAlloc_opaque, pComp); + return mz_zip_set_error(pZip, MZ_ZIP_COMPRESSION_FAILED); + } + + comp_size = state.m_comp_size; + cur_archive_file_ofs = state.m_cur_archive_file_ofs; + } + + pZip->m_pFree(pZip->m_pAlloc_opaque, pComp); + pComp = NULL; + + if (uncomp_size) + { + mz_uint8 local_dir_footer[MZ_ZIP_DATA_DESCRIPTER_SIZE64]; + mz_uint32 local_dir_footer_size = MZ_ZIP_DATA_DESCRIPTER_SIZE32; + + MZ_ASSERT(bit_flags & MZ_ZIP_LDH_BIT_FLAG_HAS_LOCATOR); + + MZ_WRITE_LE32(local_dir_footer + 0, MZ_ZIP_DATA_DESCRIPTOR_ID); + MZ_WRITE_LE32(local_dir_footer + 4, uncomp_crc32); + if (pExtra_data == NULL) + { + if (comp_size > MZ_UINT32_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_ARCHIVE_TOO_LARGE); + + MZ_WRITE_LE32(local_dir_footer + 8, comp_size); + MZ_WRITE_LE32(local_dir_footer + 12, uncomp_size); + } + else + { + MZ_WRITE_LE64(local_dir_footer + 8, comp_size); + MZ_WRITE_LE64(local_dir_footer + 16, uncomp_size); + local_dir_footer_size = MZ_ZIP_DATA_DESCRIPTER_SIZE64; + } + + if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, local_dir_footer, local_dir_footer_size) != local_dir_footer_size) + return MZ_FALSE; + + cur_archive_file_ofs += local_dir_footer_size; + } + + if (pExtra_data != NULL) + { + extra_size = mz_zip_writer_create_zip64_extra_data(extra_data, (uncomp_size >= MZ_UINT32_MAX) ? &uncomp_size : NULL, + (uncomp_size >= MZ_UINT32_MAX) ? &comp_size : NULL, (local_dir_header_ofs >= MZ_UINT32_MAX) ? &local_dir_header_ofs : NULL); + } + + if (!mz_zip_writer_add_to_central_dir(pZip, pArchive_name, (mz_uint16)archive_name_size, pExtra_data, (mz_uint16)extra_size, pComment, + comment_size, uncomp_size, comp_size, uncomp_crc32, method, bit_flags, dos_time, dos_date, local_dir_header_ofs, ext_attributes, + user_extra_data_central, user_extra_data_central_len)) + return MZ_FALSE; + + pZip->m_total_files++; + pZip->m_archive_size = cur_archive_file_ofs; + + return MZ_TRUE; +} + +mz_bool mz_zip_writer_add_read_buf_callback(mz_zip_archive *pZip, const char *pArchive_name, mz_file_read_func read_callback, void* callback_opaque, mz_uint64 size_to_add, const MZ_TIME_T *pFile_time, const void *pComment, mz_uint16 comment_size, mz_uint level_and_flags, + const char *user_extra_data, mz_uint user_extra_data_len, const char *user_extra_data_central, mz_uint user_extra_data_central_len) +{ + mz_uint16 gen_flags = MZ_ZIP_LDH_BIT_FLAG_HAS_LOCATOR; + mz_uint uncomp_crc32 = MZ_CRC32_INIT, level, num_alignment_padding_bytes; + mz_uint16 method = 0, dos_time = 0, dos_date = 0, ext_attributes = 0; + mz_uint64 local_dir_header_ofs, cur_archive_file_ofs = pZip->m_archive_size, uncomp_size = size_to_add, comp_size = 0; + size_t archive_name_size; + mz_uint8 local_dir_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE]; + mz_uint8 *pExtra_data = NULL; + mz_uint32 extra_size = 0; + mz_uint8 extra_data[MZ_ZIP64_MAX_CENTRAL_EXTRA_FIELD_SIZE]; + mz_zip_internal_state *pState; + mz_uint64 file_ofs = 0; + + if (!(level_and_flags & MZ_ZIP_FLAG_ASCII_FILENAME)) + gen_flags |= MZ_ZIP_GENERAL_PURPOSE_BIT_FLAG_UTF8; + + if ((int)level_and_flags < 0) + level_and_flags = MZ_DEFAULT_LEVEL; + level = level_and_flags & 0xF; + + /* Sanity checks */ + if ((!pZip) || (!pZip->m_pState) || (pZip->m_zip_mode != MZ_ZIP_MODE_WRITING) || (!pArchive_name) || ((comment_size) && (!pComment)) || (level > MZ_UBER_COMPRESSION)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + pState = pZip->m_pState; + + if ((!pState->m_zip64) && (uncomp_size > MZ_UINT32_MAX)) + { + /* Source file is too large for non-zip64 */ + /*return mz_zip_set_error(pZip, MZ_ZIP_ARCHIVE_TOO_LARGE); */ + pState->m_zip64 = MZ_TRUE; + } + + /* We could support this, but why? */ + if (level_and_flags & MZ_ZIP_FLAG_COMPRESSED_DATA) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + if (!mz_zip_writer_validate_archive_name(pArchive_name)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_FILENAME); + + if (pState->m_zip64) + { + if (pZip->m_total_files == MZ_UINT32_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_TOO_MANY_FILES); + } + else + { + if (pZip->m_total_files == MZ_UINT16_MAX) + { + pState->m_zip64 = MZ_TRUE; + /*return mz_zip_set_error(pZip, MZ_ZIP_TOO_MANY_FILES); */ + } + } + + archive_name_size = strlen(pArchive_name); + if (archive_name_size > MZ_UINT16_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_FILENAME); + + num_alignment_padding_bytes = mz_zip_writer_compute_padding_needed_for_file_alignment(pZip); + + /* miniz doesn't support central dirs >= MZ_UINT32_MAX bytes yet */ + if (((mz_uint64)pState->m_central_dir.m_size + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + archive_name_size + MZ_ZIP64_MAX_CENTRAL_EXTRA_FIELD_SIZE + comment_size) >= MZ_UINT32_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_CDIR_SIZE); + + if (!pState->m_zip64) + { + /* Bail early if the archive would obviously become too large */ + if ((pZip->m_archive_size + num_alignment_padding_bytes + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + archive_name_size + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + + archive_name_size + comment_size + user_extra_data_len + pState->m_central_dir.m_size + MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIZE + 1024 + + MZ_ZIP_DATA_DESCRIPTER_SIZE32 + user_extra_data_central_len) > 0xFFFFFFFF) + { + pState->m_zip64 = MZ_TRUE; + /*return mz_zip_set_error(pZip, MZ_ZIP_ARCHIVE_TOO_LARGE); */ + } + } + +#ifndef MINIZ_NO_TIME + if (pFile_time) + { + mz_zip_time_t_to_dos_time(*pFile_time, &dos_time, &dos_date); + } +#endif + + if (uncomp_size <= 3) + level = 0; + + if (!mz_zip_writer_write_zeros(pZip, cur_archive_file_ofs, num_alignment_padding_bytes)) + { + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + } + + cur_archive_file_ofs += num_alignment_padding_bytes; + local_dir_header_ofs = cur_archive_file_ofs; + + if (pZip->m_file_offset_alignment) + { + MZ_ASSERT((cur_archive_file_ofs & (pZip->m_file_offset_alignment - 1)) == 0); + } + + if (uncomp_size && level) + { + method = MZ_DEFLATED; + } + + MZ_CLEAR_OBJ(local_dir_header); + if (pState->m_zip64) + { + if (uncomp_size >= MZ_UINT32_MAX || local_dir_header_ofs >= MZ_UINT32_MAX) + { + pExtra_data = extra_data; + extra_size = mz_zip_writer_create_zip64_extra_data(extra_data, (uncomp_size >= MZ_UINT32_MAX) ? &uncomp_size : NULL, + (uncomp_size >= MZ_UINT32_MAX) ? &comp_size : NULL, (local_dir_header_ofs >= MZ_UINT32_MAX) ? &local_dir_header_ofs : NULL); + } + + if (!mz_zip_writer_create_local_dir_header(pZip, local_dir_header, (mz_uint16)archive_name_size, (mz_uint16)(extra_size + user_extra_data_len), 0, 0, 0, method, gen_flags, dos_time, dos_date)) + return mz_zip_set_error(pZip, MZ_ZIP_INTERNAL_ERROR); + + if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, local_dir_header, sizeof(local_dir_header)) != sizeof(local_dir_header)) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + + cur_archive_file_ofs += sizeof(local_dir_header); + + if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, pArchive_name, archive_name_size) != archive_name_size) + { + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + } + + cur_archive_file_ofs += archive_name_size; + + if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, extra_data, extra_size) != extra_size) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + + cur_archive_file_ofs += extra_size; + } + else + { + if ((comp_size > MZ_UINT32_MAX) || (cur_archive_file_ofs > MZ_UINT32_MAX)) + return mz_zip_set_error(pZip, MZ_ZIP_ARCHIVE_TOO_LARGE); + if (!mz_zip_writer_create_local_dir_header(pZip, local_dir_header, (mz_uint16)archive_name_size, (mz_uint16)user_extra_data_len, 0, 0, 0, method, gen_flags, dos_time, dos_date)) + return mz_zip_set_error(pZip, MZ_ZIP_INTERNAL_ERROR); + + if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, local_dir_header, sizeof(local_dir_header)) != sizeof(local_dir_header)) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + + cur_archive_file_ofs += sizeof(local_dir_header); + + if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, pArchive_name, archive_name_size) != archive_name_size) + { + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + } + + cur_archive_file_ofs += archive_name_size; + } + + if (user_extra_data_len > 0) + { + if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, user_extra_data, user_extra_data_len) != user_extra_data_len) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + + cur_archive_file_ofs += user_extra_data_len; + } + + if (uncomp_size) + { + mz_uint64 uncomp_remaining = uncomp_size; + void *pRead_buf = pZip->m_pAlloc(pZip->m_pAlloc_opaque, 1, MZ_ZIP_MAX_IO_BUF_SIZE); + if (!pRead_buf) + { + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + } + + if (!level) + { + while (uncomp_remaining) + { + mz_uint n = (mz_uint)MZ_MIN((mz_uint64)MZ_ZIP_MAX_IO_BUF_SIZE, uncomp_remaining); + if ((read_callback(callback_opaque, file_ofs, pRead_buf, n) != n) || (pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, pRead_buf, n) != n)) + { + pZip->m_pFree(pZip->m_pAlloc_opaque, pRead_buf); + return mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); + } + file_ofs += n; + uncomp_crc32 = (mz_uint32)mz_crc32(uncomp_crc32, (const mz_uint8 *)pRead_buf, n); + uncomp_remaining -= n; + cur_archive_file_ofs += n; + } + comp_size = uncomp_size; + } + else + { + mz_bool result = MZ_FALSE; + mz_zip_writer_add_state state; + tdefl_compressor *pComp = (tdefl_compressor *)pZip->m_pAlloc(pZip->m_pAlloc_opaque, 1, sizeof(tdefl_compressor)); + if (!pComp) + { + pZip->m_pFree(pZip->m_pAlloc_opaque, pRead_buf); + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + } + + state.m_pZip = pZip; + state.m_cur_archive_file_ofs = cur_archive_file_ofs; + state.m_comp_size = 0; + + if (tdefl_init(pComp, mz_zip_writer_add_put_buf_callback, &state, tdefl_create_comp_flags_from_zip_params(level, -15, MZ_DEFAULT_STRATEGY)) != TDEFL_STATUS_OKAY) + { + pZip->m_pFree(pZip->m_pAlloc_opaque, pComp); + pZip->m_pFree(pZip->m_pAlloc_opaque, pRead_buf); + return mz_zip_set_error(pZip, MZ_ZIP_INTERNAL_ERROR); + } + + for (;;) + { + size_t in_buf_size = (mz_uint32)MZ_MIN(uncomp_remaining, (mz_uint64)MZ_ZIP_MAX_IO_BUF_SIZE); + tdefl_status status; + tdefl_flush flush = TDEFL_NO_FLUSH; + + if (read_callback(callback_opaque, file_ofs, pRead_buf, in_buf_size)!= in_buf_size) + { + mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); + break; + } + + file_ofs += in_buf_size; + uncomp_crc32 = (mz_uint32)mz_crc32(uncomp_crc32, (const mz_uint8 *)pRead_buf, in_buf_size); + uncomp_remaining -= in_buf_size; + + if (pZip->m_pNeeds_keepalive != NULL && pZip->m_pNeeds_keepalive(pZip->m_pIO_opaque)) + flush = TDEFL_FULL_FLUSH; + + status = tdefl_compress_buffer(pComp, pRead_buf, in_buf_size, uncomp_remaining ? flush : TDEFL_FINISH); + if (status == TDEFL_STATUS_DONE) + { + result = MZ_TRUE; + break; + } + else if (status != TDEFL_STATUS_OKAY) + { + mz_zip_set_error(pZip, MZ_ZIP_COMPRESSION_FAILED); + break; + } + } + + pZip->m_pFree(pZip->m_pAlloc_opaque, pComp); + + if (!result) + { + pZip->m_pFree(pZip->m_pAlloc_opaque, pRead_buf); + return MZ_FALSE; + } + + comp_size = state.m_comp_size; + cur_archive_file_ofs = state.m_cur_archive_file_ofs; + } + + pZip->m_pFree(pZip->m_pAlloc_opaque, pRead_buf); + } + + { + mz_uint8 local_dir_footer[MZ_ZIP_DATA_DESCRIPTER_SIZE64]; + mz_uint32 local_dir_footer_size = MZ_ZIP_DATA_DESCRIPTER_SIZE32; + + MZ_WRITE_LE32(local_dir_footer + 0, MZ_ZIP_DATA_DESCRIPTOR_ID); + MZ_WRITE_LE32(local_dir_footer + 4, uncomp_crc32); + if (pExtra_data == NULL) + { + if (comp_size > MZ_UINT32_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_ARCHIVE_TOO_LARGE); + + MZ_WRITE_LE32(local_dir_footer + 8, comp_size); + MZ_WRITE_LE32(local_dir_footer + 12, uncomp_size); + } + else + { + MZ_WRITE_LE64(local_dir_footer + 8, comp_size); + MZ_WRITE_LE64(local_dir_footer + 16, uncomp_size); + local_dir_footer_size = MZ_ZIP_DATA_DESCRIPTER_SIZE64; + } + + if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_archive_file_ofs, local_dir_footer, local_dir_footer_size) != local_dir_footer_size) + return MZ_FALSE; + + cur_archive_file_ofs += local_dir_footer_size; + } + + if (pExtra_data != NULL) + { + extra_size = mz_zip_writer_create_zip64_extra_data(extra_data, (uncomp_size >= MZ_UINT32_MAX) ? &uncomp_size : NULL, + (uncomp_size >= MZ_UINT32_MAX) ? &comp_size : NULL, (local_dir_header_ofs >= MZ_UINT32_MAX) ? &local_dir_header_ofs : NULL); + } + + if (!mz_zip_writer_add_to_central_dir(pZip, pArchive_name, (mz_uint16)archive_name_size, pExtra_data, (mz_uint16)extra_size, pComment, comment_size, + uncomp_size, comp_size, uncomp_crc32, method, gen_flags, dos_time, dos_date, local_dir_header_ofs, ext_attributes, + user_extra_data_central, user_extra_data_central_len)) + return MZ_FALSE; + + pZip->m_total_files++; + pZip->m_archive_size = cur_archive_file_ofs; + + return MZ_TRUE; +} + +#ifndef MINIZ_NO_STDIO + +static size_t mz_file_read_func_stdio(void *pOpaque, mz_uint64 file_ofs, void *pBuf, size_t n) +{ + MZ_FILE *pSrc_file = (MZ_FILE *)pOpaque; + mz_int64 cur_ofs = MZ_FTELL64(pSrc_file); + + if (((mz_int64)file_ofs < 0) || (((cur_ofs != (mz_int64)file_ofs)) && (MZ_FSEEK64(pSrc_file, (mz_int64)file_ofs, SEEK_SET)))) + return 0; + + return MZ_FREAD(pBuf, 1, n, pSrc_file); +} + +mz_bool mz_zip_writer_add_cfile(mz_zip_archive *pZip, const char *pArchive_name, MZ_FILE *pSrc_file, mz_uint64 size_to_add, const MZ_TIME_T *pFile_time, const void *pComment, mz_uint16 comment_size, mz_uint level_and_flags, + const char *user_extra_data, mz_uint user_extra_data_len, const char *user_extra_data_central, mz_uint user_extra_data_central_len) +{ + return mz_zip_writer_add_read_buf_callback(pZip, pArchive_name, mz_file_read_func_stdio, pSrc_file, size_to_add, pFile_time, pComment, comment_size, level_and_flags, + user_extra_data, user_extra_data_len, user_extra_data_central, user_extra_data_central_len); +} + +mz_bool mz_zip_writer_add_file(mz_zip_archive *pZip, const char *pArchive_name, const char *pSrc_filename, const void *pComment, mz_uint16 comment_size, mz_uint level_and_flags) +{ + MZ_FILE *pSrc_file = NULL; + mz_uint64 uncomp_size = 0; + MZ_TIME_T file_modified_time; + MZ_TIME_T *pFile_time = NULL; + mz_bool status; + + memset(&file_modified_time, 0, sizeof(file_modified_time)); + +#if !defined(MINIZ_NO_TIME) && !defined(MINIZ_NO_STDIO) + pFile_time = &file_modified_time; + if (!mz_zip_get_file_modified_time(pSrc_filename, &file_modified_time)) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_STAT_FAILED); +#endif + + pSrc_file = MZ_FOPEN(pSrc_filename, "rb"); + if (!pSrc_file) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_OPEN_FAILED); + + MZ_FSEEK64(pSrc_file, 0, SEEK_END); + uncomp_size = MZ_FTELL64(pSrc_file); + MZ_FSEEK64(pSrc_file, 0, SEEK_SET); + + status = mz_zip_writer_add_cfile(pZip, pArchive_name, pSrc_file, uncomp_size, pFile_time, pComment, comment_size, level_and_flags, NULL, 0, NULL, 0); + + MZ_FCLOSE(pSrc_file); + + return status; +} +#endif /* #ifndef MINIZ_NO_STDIO */ + +static mz_bool mz_zip_writer_update_zip64_extension_block(mz_zip_array *pNew_ext, mz_zip_archive *pZip, const mz_uint8 *pExt, uint32_t ext_len, mz_uint64 *pComp_size, mz_uint64 *pUncomp_size, mz_uint64 *pLocal_header_ofs, mz_uint32 *pDisk_start) +{ + /* + 64 should be enough for any new zip64 data */ + if (!mz_zip_array_reserve(pZip, pNew_ext, ext_len + 64, MZ_FALSE)) + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + + mz_zip_array_resize(pZip, pNew_ext, 0, MZ_FALSE); + + if ((pUncomp_size) || (pComp_size) || (pLocal_header_ofs) || (pDisk_start)) + { + mz_uint8 new_ext_block[64]; + mz_uint8 *pDst = new_ext_block; + mz_write_le16(pDst, MZ_ZIP64_EXTENDED_INFORMATION_FIELD_HEADER_ID); + mz_write_le16(pDst + sizeof(mz_uint16), 0); + pDst += sizeof(mz_uint16) * 2; + + if (pUncomp_size) + { + mz_write_le64(pDst, *pUncomp_size); + pDst += sizeof(mz_uint64); + } + + if (pComp_size) + { + mz_write_le64(pDst, *pComp_size); + pDst += sizeof(mz_uint64); + } + + if (pLocal_header_ofs) + { + mz_write_le64(pDst, *pLocal_header_ofs); + pDst += sizeof(mz_uint64); + } + + if (pDisk_start) + { + mz_write_le32(pDst, *pDisk_start); + pDst += sizeof(mz_uint32); + } + + mz_write_le16(new_ext_block + sizeof(mz_uint16), (mz_uint16)((pDst - new_ext_block) - sizeof(mz_uint16) * 2)); + + if (!mz_zip_array_push_back(pZip, pNew_ext, new_ext_block, pDst - new_ext_block)) + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + } + + if ((pExt) && (ext_len)) + { + mz_uint32 extra_size_remaining = ext_len; + const mz_uint8 *pExtra_data = pExt; + + do + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 field_id, field_data_size, field_total_size; + + if (extra_size_remaining < (sizeof(mz_uint16) * 2)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + field_id = MZ_READ_LE16(pExtra_data); + field_data_size = MZ_READ_LE16(pExtra_data + sizeof(mz_uint16)); + field_total_size = field_data_size + sizeof(mz_uint16) * 2; + + if (field_total_size > extra_size_remaining) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + if (field_id != MZ_ZIP64_EXTENDED_INFORMATION_FIELD_HEADER_ID) + { + if (!mz_zip_array_push_back(pZip, pNew_ext, pExtra_data, field_total_size)) + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + } + + pExtra_data += field_total_size; + extra_size_remaining -= field_total_size; + } while (extra_size_remaining); + } + + return MZ_TRUE; +} + +/* TODO: This func is now pretty freakin complex due to zip64, split it up? */ +mz_bool mz_zip_writer_add_from_zip_reader(mz_zip_archive *pZip, mz_zip_archive *pSource_zip, mz_uint src_file_index) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint n, bit_flags, num_alignment_padding_bytes, src_central_dir_following_data_size; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint64 src_archive_bytes_remaining, local_dir_header_ofs; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint64 cur_src_file_ofs, cur_dst_file_ofs; + mz_uint32 local_header_u32[(MZ_ZIP_LOCAL_DIR_HEADER_SIZE + sizeof(mz_uint32) - 1) / sizeof(mz_uint32)]; + mz_uint8 *pLocal_header = (mz_uint8 *)local_header_u32; + mz_uint8 new_central_header[MZ_ZIP_CENTRAL_DIR_HEADER_SIZE]; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + size_t orig_central_dir_size; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_zip_internal_state *pState; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + void *pBuf; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + const mz_uint8 *pSrc_central_header; + mz_zip_archive_file_stat src_file_stat; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 src_filename_len, src_comment_len, src_ext_len; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 local_header_filename_size, local_header_extra_len; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint64 local_header_comp_size, local_header_uncomp_size; + mz_bool found_zip64_ext_data_in_ldir = MZ_FALSE; + + /* Sanity checks */ + if ((!pZip) || (!pZip->m_pState) || (pZip->m_zip_mode != MZ_ZIP_MODE_WRITING) || (!pSource_zip->m_pRead)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + pState = pZip->m_pState; + + /* Don't support copying files from zip64 archives to non-zip64, even though in some cases this is possible */ + if ((pSource_zip->m_pState->m_zip64) && (!pZip->m_pState->m_zip64)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + /* Get pointer to the source central dir header and crack it */ + if (NULL == (pSrc_central_header = mz_zip_get_cdh(pSource_zip, src_file_index))) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + if (MZ_READ_LE32(pSrc_central_header + MZ_ZIP_CDH_SIG_OFS) != MZ_ZIP_CENTRAL_DIR_HEADER_SIG) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + src_filename_len = MZ_READ_LE16(pSrc_central_header + MZ_ZIP_CDH_FILENAME_LEN_OFS); + src_comment_len = MZ_READ_LE16(pSrc_central_header + MZ_ZIP_CDH_COMMENT_LEN_OFS); + src_ext_len = MZ_READ_LE16(pSrc_central_header + MZ_ZIP_CDH_EXTRA_LEN_OFS); + src_central_dir_following_data_size = src_filename_len + src_ext_len + src_comment_len; + + /* TODO: We don't support central dir's >= MZ_UINT32_MAX bytes right now (+32 fudge factor in case we need to add more extra data) */ + if ((pState->m_central_dir.m_size + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + src_central_dir_following_data_size + 32) >= MZ_UINT32_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_CDIR_SIZE); + + num_alignment_padding_bytes = mz_zip_writer_compute_padding_needed_for_file_alignment(pZip); + + if (!pState->m_zip64) + { + if (pZip->m_total_files == MZ_UINT16_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_TOO_MANY_FILES); + } + else + { + /* TODO: Our zip64 support still has some 32-bit limits that may not be worth fixing. */ + if (pZip->m_total_files == MZ_UINT32_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_TOO_MANY_FILES); + } + + if (!mz_zip_file_stat_internal(pSource_zip, src_file_index, pSrc_central_header, &src_file_stat, NULL)) + return MZ_FALSE; + + cur_src_file_ofs = src_file_stat.m_local_header_ofs; + cur_dst_file_ofs = pZip->m_archive_size; + + /* Read the source archive's local dir header */ + if (pSource_zip->m_pRead(pSource_zip->m_pIO_opaque, cur_src_file_ofs, pLocal_header, MZ_ZIP_LOCAL_DIR_HEADER_SIZE) != MZ_ZIP_LOCAL_DIR_HEADER_SIZE) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); + + if (MZ_READ_LE32(pLocal_header) != MZ_ZIP_LOCAL_DIR_HEADER_SIG) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + + cur_src_file_ofs += MZ_ZIP_LOCAL_DIR_HEADER_SIZE; + + /* Compute the total size we need to copy (filename+extra data+compressed data) */ + local_header_filename_size = MZ_READ_LE16(pLocal_header + MZ_ZIP_LDH_FILENAME_LEN_OFS); + local_header_extra_len = MZ_READ_LE16(pLocal_header + MZ_ZIP_LDH_EXTRA_LEN_OFS); + local_header_comp_size = MZ_READ_LE32(pLocal_header + MZ_ZIP_LDH_COMPRESSED_SIZE_OFS); + local_header_uncomp_size = MZ_READ_LE32(pLocal_header + MZ_ZIP_LDH_DECOMPRESSED_SIZE_OFS); + src_archive_bytes_remaining = local_header_filename_size + local_header_extra_len + src_file_stat.m_comp_size; + + /* Try to find a zip64 extended information field */ + if ((local_header_extra_len) && ((local_header_comp_size == MZ_UINT32_MAX) || (local_header_uncomp_size == MZ_UINT32_MAX))) + { + mz_zip_array file_data_array; + const mz_uint8 *pExtra_data; + mz_uint32 extra_size_remaining = local_header_extra_len; + + mz_zip_array_init(&file_data_array, 1); + if (!mz_zip_array_resize(pZip, &file_data_array, local_header_extra_len, MZ_FALSE)) + { + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + } + + if (pSource_zip->m_pRead(pSource_zip->m_pIO_opaque, src_file_stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + local_header_filename_size, file_data_array.m_p, local_header_extra_len) != local_header_extra_len) + { + mz_zip_array_clear(pZip, &file_data_array); + return mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); + } + + pExtra_data = (const mz_uint8 *)file_data_array.m_p; + + do + { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint32 field_id, field_data_size, field_total_size; + + if (extra_size_remaining < (sizeof(mz_uint16) * 2)) + { + mz_zip_array_clear(pZip, &file_data_array); + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + } + + field_id = MZ_READ_LE16(pExtra_data); + field_data_size = MZ_READ_LE16(pExtra_data + sizeof(mz_uint16)); + field_total_size = field_data_size + sizeof(mz_uint16) * 2; + + if (field_total_size > extra_size_remaining) + { + mz_zip_array_clear(pZip, &file_data_array); + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + } + + if (field_id == MZ_ZIP64_EXTENDED_INFORMATION_FIELD_HEADER_ID) + { + const mz_uint8 *pSrc_field_data = pExtra_data + sizeof(mz_uint32); + + if (field_data_size < sizeof(mz_uint64) * 2) + { + mz_zip_array_clear(pZip, &file_data_array); + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_HEADER_OR_CORRUPTED); + } + + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + local_header_uncomp_size = MZ_READ_LE64(pSrc_field_data); + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + local_header_comp_size = MZ_READ_LE64(pSrc_field_data + sizeof(mz_uint64)); /* may be 0 if there's a descriptor */ + + found_zip64_ext_data_in_ldir = MZ_TRUE; + break; + } + + pExtra_data += field_total_size; + extra_size_remaining -= field_total_size; + } while (extra_size_remaining); + + mz_zip_array_clear(pZip, &file_data_array); + } + + if (!pState->m_zip64) + { + /* Try to detect if the new archive will most likely wind up too big and bail early (+(sizeof(mz_uint32) * 4) is for the optional descriptor which could be present, +64 is a fudge factor). */ + /* We also check when the archive is finalized so this doesn't need to be perfect. */ + mz_uint64 approx_new_archive_size = cur_dst_file_ofs + num_alignment_padding_bytes + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + src_archive_bytes_remaining + (sizeof(mz_uint32) * 4) + + pState->m_central_dir.m_size + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + src_central_dir_following_data_size + MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIZE + 64; + + if (approx_new_archive_size >= MZ_UINT32_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_ARCHIVE_TOO_LARGE); + } + + /* Write dest archive padding */ + if (!mz_zip_writer_write_zeros(pZip, cur_dst_file_ofs, num_alignment_padding_bytes)) + return MZ_FALSE; + + cur_dst_file_ofs += num_alignment_padding_bytes; + + local_dir_header_ofs = cur_dst_file_ofs; + if (pZip->m_file_offset_alignment) + { + MZ_ASSERT((local_dir_header_ofs & (pZip->m_file_offset_alignment - 1)) == 0); + } + + /* The original zip's local header+ext block doesn't change, even with zip64, so we can just copy it over to the dest zip */ + if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_dst_file_ofs, pLocal_header, MZ_ZIP_LOCAL_DIR_HEADER_SIZE) != MZ_ZIP_LOCAL_DIR_HEADER_SIZE) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + + cur_dst_file_ofs += MZ_ZIP_LOCAL_DIR_HEADER_SIZE; + + /* Copy over the source archive bytes to the dest archive, also ensure we have enough buf space to handle optional data descriptor */ + if (NULL == (pBuf = pZip->m_pAlloc(pZip->m_pAlloc_opaque, 1, (size_t)MZ_MAX(32U, MZ_MIN((mz_uint64)MZ_ZIP_MAX_IO_BUF_SIZE, src_archive_bytes_remaining))))) + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + + while (src_archive_bytes_remaining) + { + n = (mz_uint)MZ_MIN((mz_uint64)MZ_ZIP_MAX_IO_BUF_SIZE, src_archive_bytes_remaining); + if (pSource_zip->m_pRead(pSource_zip->m_pIO_opaque, cur_src_file_ofs, pBuf, n) != n) + { + pZip->m_pFree(pZip->m_pAlloc_opaque, pBuf); + return mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); + } + cur_src_file_ofs += n; + + if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_dst_file_ofs, pBuf, n) != n) + { + pZip->m_pFree(pZip->m_pAlloc_opaque, pBuf); + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + } + cur_dst_file_ofs += n; + + src_archive_bytes_remaining -= n; + } + + /* Now deal with the optional data descriptor */ + bit_flags = MZ_READ_LE16(pLocal_header + MZ_ZIP_LDH_BIT_FLAG_OFS); + if (bit_flags & 8) + { + /* Copy data descriptor */ + if ((pSource_zip->m_pState->m_zip64) || (found_zip64_ext_data_in_ldir)) + { + /* src is zip64, dest must be zip64 */ + + /* name uint32_t's */ + /* id 1 (optional in zip64?) */ + /* crc 1 */ + /* comp_size 2 */ + /* uncomp_size 2 */ + if (pSource_zip->m_pRead(pSource_zip->m_pIO_opaque, cur_src_file_ofs, pBuf, (sizeof(mz_uint32) * 6)) != (sizeof(mz_uint32) * 6)) + { + pZip->m_pFree(pZip->m_pAlloc_opaque, pBuf); + return mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); + } + + n = sizeof(mz_uint32) * ((MZ_READ_LE32(pBuf) == MZ_ZIP_DATA_DESCRIPTOR_ID) ? 6 : 5); + } + else + { + /* src is NOT zip64 */ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_bool has_id; + + if (pSource_zip->m_pRead(pSource_zip->m_pIO_opaque, cur_src_file_ofs, pBuf, sizeof(mz_uint32) * 4) != sizeof(mz_uint32) * 4) + { + pZip->m_pFree(pZip->m_pAlloc_opaque, pBuf); + return mz_zip_set_error(pZip, MZ_ZIP_FILE_READ_FAILED); + } + + has_id = (MZ_READ_LE32(pBuf) == MZ_ZIP_DATA_DESCRIPTOR_ID); + + if (pZip->m_pState->m_zip64) + { + /* dest is zip64, so upgrade the data descriptor */ + const mz_uint32 *pSrc_descriptor = (const mz_uint32 *)((const mz_uint8 *)pBuf + (has_id ? sizeof(mz_uint32) : 0)); + const mz_uint32 src_crc32 = pSrc_descriptor[0]; + const mz_uint64 src_comp_size = pSrc_descriptor[1]; + const mz_uint64 src_uncomp_size = pSrc_descriptor[2]; + + mz_write_le32((mz_uint8 *)pBuf, MZ_ZIP_DATA_DESCRIPTOR_ID); + mz_write_le32((mz_uint8 *)pBuf + sizeof(mz_uint32) * 1, src_crc32); + mz_write_le64((mz_uint8 *)pBuf + sizeof(mz_uint32) * 2, src_comp_size); + mz_write_le64((mz_uint8 *)pBuf + sizeof(mz_uint32) * 4, src_uncomp_size); + + n = sizeof(mz_uint32) * 6; + } + else + { + /* dest is NOT zip64, just copy it as-is */ + n = sizeof(mz_uint32) * (has_id ? 4 : 3); + } + } + + if (pZip->m_pWrite(pZip->m_pIO_opaque, cur_dst_file_ofs, pBuf, n) != n) + { + pZip->m_pFree(pZip->m_pAlloc_opaque, pBuf); + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + } + + // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) + cur_src_file_ofs += n; + cur_dst_file_ofs += n; + } + pZip->m_pFree(pZip->m_pAlloc_opaque, pBuf); + + /* Finally, add the new central dir header */ + orig_central_dir_size = pState->m_central_dir.m_size; + + memcpy(new_central_header, pSrc_central_header, MZ_ZIP_CENTRAL_DIR_HEADER_SIZE); + + if (pState->m_zip64) + { + /* This is the painful part: We need to write a new central dir header + ext block with updated zip64 fields, and ensure the old fields (if any) are not included. */ + const mz_uint8 *pSrc_ext = pSrc_central_header + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + src_filename_len; + mz_zip_array new_ext_block; + + mz_zip_array_init(&new_ext_block, sizeof(mz_uint8)); + + MZ_WRITE_LE32(new_central_header + MZ_ZIP_CDH_COMPRESSED_SIZE_OFS, MZ_UINT32_MAX); + MZ_WRITE_LE32(new_central_header + MZ_ZIP_CDH_DECOMPRESSED_SIZE_OFS, MZ_UINT32_MAX); + MZ_WRITE_LE32(new_central_header + MZ_ZIP_CDH_LOCAL_HEADER_OFS, MZ_UINT32_MAX); + + if (!mz_zip_writer_update_zip64_extension_block(&new_ext_block, pZip, pSrc_ext, src_ext_len, &src_file_stat.m_comp_size, &src_file_stat.m_uncomp_size, &local_dir_header_ofs, NULL)) + { + mz_zip_array_clear(pZip, &new_ext_block); + return MZ_FALSE; + } + + MZ_WRITE_LE16(new_central_header + MZ_ZIP_CDH_EXTRA_LEN_OFS, new_ext_block.m_size); + + if (!mz_zip_array_push_back(pZip, &pState->m_central_dir, new_central_header, MZ_ZIP_CENTRAL_DIR_HEADER_SIZE)) + { + mz_zip_array_clear(pZip, &new_ext_block); + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + } + + if (!mz_zip_array_push_back(pZip, &pState->m_central_dir, pSrc_central_header + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE, src_filename_len)) + { + mz_zip_array_clear(pZip, &new_ext_block); + mz_zip_array_resize(pZip, &pState->m_central_dir, orig_central_dir_size, MZ_FALSE); + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + } + + if (!mz_zip_array_push_back(pZip, &pState->m_central_dir, new_ext_block.m_p, new_ext_block.m_size)) + { + mz_zip_array_clear(pZip, &new_ext_block); + mz_zip_array_resize(pZip, &pState->m_central_dir, orig_central_dir_size, MZ_FALSE); + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + } + + if (!mz_zip_array_push_back(pZip, &pState->m_central_dir, pSrc_central_header + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE + src_filename_len + src_ext_len, src_comment_len)) + { + mz_zip_array_clear(pZip, &new_ext_block); + mz_zip_array_resize(pZip, &pState->m_central_dir, orig_central_dir_size, MZ_FALSE); + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + } + + mz_zip_array_clear(pZip, &new_ext_block); + } + else + { + /* sanity checks */ + if (cur_dst_file_ofs > MZ_UINT32_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_ARCHIVE_TOO_LARGE); + + if (local_dir_header_ofs >= MZ_UINT32_MAX) + return mz_zip_set_error(pZip, MZ_ZIP_ARCHIVE_TOO_LARGE); + + MZ_WRITE_LE32(new_central_header + MZ_ZIP_CDH_LOCAL_HEADER_OFS, local_dir_header_ofs); + + if (!mz_zip_array_push_back(pZip, &pState->m_central_dir, new_central_header, MZ_ZIP_CENTRAL_DIR_HEADER_SIZE)) + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + + if (!mz_zip_array_push_back(pZip, &pState->m_central_dir, pSrc_central_header + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE, src_central_dir_following_data_size)) + { + mz_zip_array_resize(pZip, &pState->m_central_dir, orig_central_dir_size, MZ_FALSE); + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + } + } + + /* This shouldn't trigger unless we screwed up during the initial sanity checks */ + if (pState->m_central_dir.m_size >= MZ_UINT32_MAX) + { + /* TODO: Support central dirs >= 32-bits in size */ + mz_zip_array_resize(pZip, &pState->m_central_dir, orig_central_dir_size, MZ_FALSE); + return mz_zip_set_error(pZip, MZ_ZIP_UNSUPPORTED_CDIR_SIZE); + } + + n = (mz_uint32)orig_central_dir_size; + if (!mz_zip_array_push_back(pZip, &pState->m_central_dir_offsets, &n, 1)) + { + mz_zip_array_resize(pZip, &pState->m_central_dir, orig_central_dir_size, MZ_FALSE); + return mz_zip_set_error(pZip, MZ_ZIP_ALLOC_FAILED); + } + + pZip->m_total_files++; + pZip->m_archive_size = cur_dst_file_ofs; + + return MZ_TRUE; +} + +mz_bool mz_zip_writer_finalize_archive(mz_zip_archive *pZip) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_zip_internal_state *pState; + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint64 central_dir_ofs, central_dir_size; + mz_uint8 hdr[256]; + + if ((!pZip) || (!pZip->m_pState) || (pZip->m_zip_mode != MZ_ZIP_MODE_WRITING)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + pState = pZip->m_pState; + + if (pState->m_zip64) + { + if ((pZip->m_total_files > MZ_UINT32_MAX) || (pState->m_central_dir.m_size >= MZ_UINT32_MAX)) + return mz_zip_set_error(pZip, MZ_ZIP_TOO_MANY_FILES); + } + else + { + if ((pZip->m_total_files > MZ_UINT16_MAX) || ((pZip->m_archive_size + pState->m_central_dir.m_size + MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIZE) > MZ_UINT32_MAX)) + return mz_zip_set_error(pZip, MZ_ZIP_TOO_MANY_FILES); + } + + central_dir_ofs = 0; + central_dir_size = 0; + if (pZip->m_total_files) + { + /* Write central directory */ + central_dir_ofs = pZip->m_archive_size; + central_dir_size = pState->m_central_dir.m_size; + pZip->m_central_directory_file_ofs = central_dir_ofs; + if (pZip->m_pWrite(pZip->m_pIO_opaque, central_dir_ofs, pState->m_central_dir.m_p, (size_t)central_dir_size) != central_dir_size) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + + pZip->m_archive_size += central_dir_size; + } + + if (pState->m_zip64) + { + /* Write zip64 end of central directory header */ + mz_uint64 rel_ofs_to_zip64_ecdr = pZip->m_archive_size; + + MZ_CLEAR_OBJ(hdr); + MZ_WRITE_LE32(hdr + MZ_ZIP64_ECDH_SIG_OFS, MZ_ZIP64_END_OF_CENTRAL_DIR_HEADER_SIG); + MZ_WRITE_LE64(hdr + MZ_ZIP64_ECDH_SIZE_OF_RECORD_OFS, MZ_ZIP64_END_OF_CENTRAL_DIR_HEADER_SIZE - sizeof(mz_uint32) - sizeof(mz_uint64)); + MZ_WRITE_LE16(hdr + MZ_ZIP64_ECDH_VERSION_MADE_BY_OFS, 0x031E); /* TODO: always Unix */ + MZ_WRITE_LE16(hdr + MZ_ZIP64_ECDH_VERSION_NEEDED_OFS, 0x002D); + MZ_WRITE_LE64(hdr + MZ_ZIP64_ECDH_CDIR_NUM_ENTRIES_ON_DISK_OFS, pZip->m_total_files); + MZ_WRITE_LE64(hdr + MZ_ZIP64_ECDH_CDIR_TOTAL_ENTRIES_OFS, pZip->m_total_files); + MZ_WRITE_LE64(hdr + MZ_ZIP64_ECDH_CDIR_SIZE_OFS, central_dir_size); + MZ_WRITE_LE64(hdr + MZ_ZIP64_ECDH_CDIR_OFS_OFS, central_dir_ofs); + if (pZip->m_pWrite(pZip->m_pIO_opaque, pZip->m_archive_size, hdr, MZ_ZIP64_END_OF_CENTRAL_DIR_HEADER_SIZE) != MZ_ZIP64_END_OF_CENTRAL_DIR_HEADER_SIZE) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + + pZip->m_archive_size += MZ_ZIP64_END_OF_CENTRAL_DIR_HEADER_SIZE; + + /* Write zip64 end of central directory locator */ + MZ_CLEAR_OBJ(hdr); + MZ_WRITE_LE32(hdr + MZ_ZIP64_ECDL_SIG_OFS, MZ_ZIP64_END_OF_CENTRAL_DIR_LOCATOR_SIG); + MZ_WRITE_LE64(hdr + MZ_ZIP64_ECDL_REL_OFS_TO_ZIP64_ECDR_OFS, rel_ofs_to_zip64_ecdr); + MZ_WRITE_LE32(hdr + MZ_ZIP64_ECDL_TOTAL_NUMBER_OF_DISKS_OFS, 1); + if (pZip->m_pWrite(pZip->m_pIO_opaque, pZip->m_archive_size, hdr, MZ_ZIP64_END_OF_CENTRAL_DIR_LOCATOR_SIZE) != MZ_ZIP64_END_OF_CENTRAL_DIR_LOCATOR_SIZE) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + + pZip->m_archive_size += MZ_ZIP64_END_OF_CENTRAL_DIR_LOCATOR_SIZE; + } + + /* Write end of central directory record */ + MZ_CLEAR_OBJ(hdr); + MZ_WRITE_LE32(hdr + MZ_ZIP_ECDH_SIG_OFS, MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIG); + MZ_WRITE_LE16(hdr + MZ_ZIP_ECDH_CDIR_NUM_ENTRIES_ON_DISK_OFS, MZ_MIN(MZ_UINT16_MAX, pZip->m_total_files)); + MZ_WRITE_LE16(hdr + MZ_ZIP_ECDH_CDIR_TOTAL_ENTRIES_OFS, MZ_MIN(MZ_UINT16_MAX, pZip->m_total_files)); + MZ_WRITE_LE32(hdr + MZ_ZIP_ECDH_CDIR_SIZE_OFS, MZ_MIN(MZ_UINT32_MAX, central_dir_size)); + MZ_WRITE_LE32(hdr + MZ_ZIP_ECDH_CDIR_OFS_OFS, MZ_MIN(MZ_UINT32_MAX, central_dir_ofs)); + + if (pZip->m_pWrite(pZip->m_pIO_opaque, pZip->m_archive_size, hdr, MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIZE) != MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIZE) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_WRITE_FAILED); + +#ifndef MINIZ_NO_STDIO + if ((pState->m_pFile) && (MZ_FFLUSH(pState->m_pFile) == EOF)) + return mz_zip_set_error(pZip, MZ_ZIP_FILE_CLOSE_FAILED); +#endif /* #ifndef MINIZ_NO_STDIO */ + + pZip->m_archive_size += MZ_ZIP_END_OF_CENTRAL_DIR_HEADER_SIZE; + + pZip->m_zip_mode = MZ_ZIP_MODE_WRITING_HAS_BEEN_FINALIZED; + return MZ_TRUE; +} + +mz_bool mz_zip_writer_finalize_heap_archive(mz_zip_archive *pZip, void **ppBuf, size_t *pSize) +{ + if ((!ppBuf) || (!pSize)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + *ppBuf = NULL; + *pSize = 0; + + if ((!pZip) || (!pZip->m_pState)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + if (pZip->m_pWrite != mz_zip_heap_write_func) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + if (!mz_zip_writer_finalize_archive(pZip)) + return MZ_FALSE; + + *ppBuf = pZip->m_pState->m_pMem; + *pSize = pZip->m_pState->m_mem_size; + pZip->m_pState->m_pMem = NULL; + pZip->m_pState->m_mem_size = pZip->m_pState->m_mem_capacity = 0; + + return MZ_TRUE; +} + +mz_bool mz_zip_writer_end(mz_zip_archive *pZip) +{ + return mz_zip_writer_end_internal(pZip, MZ_TRUE); +} + +#ifndef MINIZ_NO_STDIO +mz_bool mz_zip_add_mem_to_archive_file_in_place(const char *pZip_filename, const char *pArchive_name, const void *pBuf, size_t buf_size, const void *pComment, mz_uint16 comment_size, mz_uint level_and_flags) +{ + return mz_zip_add_mem_to_archive_file_in_place_v2(pZip_filename, pArchive_name, pBuf, buf_size, pComment, comment_size, level_and_flags, NULL); +} + +mz_bool mz_zip_add_mem_to_archive_file_in_place_v2(const char *pZip_filename, const char *pArchive_name, const void *pBuf, size_t buf_size, const void *pComment, mz_uint16 comment_size, mz_uint level_and_flags, mz_zip_error *pErr) +{ + mz_bool status, created_new_archive = MZ_FALSE; + mz_zip_archive zip_archive; + struct MZ_FILE_STAT_STRUCT file_stat; + mz_zip_error actual_err = MZ_ZIP_NO_ERROR; + + mz_zip_zero_struct(&zip_archive); + if ((int)level_and_flags < 0) + level_and_flags = MZ_DEFAULT_LEVEL; + + if ((!pZip_filename) || (!pArchive_name) || ((buf_size) && (!pBuf)) || ((comment_size) && (!pComment)) || ((level_and_flags & 0xF) > MZ_UBER_COMPRESSION)) + { + if (pErr) + *pErr = MZ_ZIP_INVALID_PARAMETER; + return MZ_FALSE; + } + + if (!mz_zip_writer_validate_archive_name(pArchive_name)) + { + if (pErr) + *pErr = MZ_ZIP_INVALID_FILENAME; + return MZ_FALSE; + } + + /* Important: The regular non-64 bit version of stat() can fail here if the file is very large, which could cause the archive to be overwritten. */ + /* So be sure to compile with _LARGEFILE64_SOURCE 1 */ + if (MZ_FILE_STAT(pZip_filename, &file_stat) != 0) + { + /* Create a new archive. */ + if (!mz_zip_writer_init_file_v2(&zip_archive, pZip_filename, 0, level_and_flags)) + { + if (pErr) + *pErr = zip_archive.m_last_error; + return MZ_FALSE; + } + + created_new_archive = MZ_TRUE; + } + else + { + /* Append to an existing archive. */ + if (!mz_zip_reader_init_file_v2(&zip_archive, pZip_filename, level_and_flags | MZ_ZIP_FLAG_DO_NOT_SORT_CENTRAL_DIRECTORY, 0, 0)) + { + if (pErr) + *pErr = zip_archive.m_last_error; + return MZ_FALSE; + } + + if (!mz_zip_writer_init_from_reader_v2(&zip_archive, pZip_filename, level_and_flags)) + { + if (pErr) + *pErr = zip_archive.m_last_error; + + mz_zip_reader_end_internal(&zip_archive, MZ_FALSE); + + return MZ_FALSE; + } + } + + status = mz_zip_writer_add_mem_ex(&zip_archive, pArchive_name, pBuf, buf_size, pComment, comment_size, level_and_flags, 0, 0); + actual_err = zip_archive.m_last_error; + + /* Always finalize, even if adding failed for some reason, so we have a valid central directory. (This may not always succeed, but we can try.) */ + if (!mz_zip_writer_finalize_archive(&zip_archive)) + { + if (!actual_err) + actual_err = zip_archive.m_last_error; + + status = MZ_FALSE; + } + + if (!mz_zip_writer_end_internal(&zip_archive, status)) + { + if (!actual_err) + actual_err = zip_archive.m_last_error; + + status = MZ_FALSE; + } + + if ((!status) && (created_new_archive)) + { + /* It's a new archive and something went wrong, so just delete it. */ + int ignoredStatus = MZ_DELETE_FILE(pZip_filename); + (void)ignoredStatus; + } + + if (pErr) + *pErr = actual_err; + + return status; +} + +void *mz_zip_extract_archive_file_to_heap_v2(const char *pZip_filename, const char *pArchive_name, const char *pComment, size_t *pSize, mz_uint flags, mz_zip_error *pErr) +{ + mz_uint32 file_index; + mz_zip_archive zip_archive; + void *p = NULL; + + if (pSize) + *pSize = 0; + + if ((!pZip_filename) || (!pArchive_name)) + { + if (pErr) + *pErr = MZ_ZIP_INVALID_PARAMETER; + + return NULL; + } + + mz_zip_zero_struct(&zip_archive); + if (!mz_zip_reader_init_file_v2(&zip_archive, pZip_filename, flags | MZ_ZIP_FLAG_DO_NOT_SORT_CENTRAL_DIRECTORY, 0, 0)) + { + if (pErr) + *pErr = zip_archive.m_last_error; + + return NULL; + } + + if (mz_zip_reader_locate_file_v2(&zip_archive, pArchive_name, pComment, flags, &file_index)) + { + p = mz_zip_reader_extract_to_heap(&zip_archive, file_index, pSize, flags); + } + + mz_zip_reader_end_internal(&zip_archive, p != NULL); + + if (pErr) + *pErr = zip_archive.m_last_error; + + return p; +} + +void *mz_zip_extract_archive_file_to_heap(const char *pZip_filename, const char *pArchive_name, size_t *pSize, mz_uint flags) +{ + return mz_zip_extract_archive_file_to_heap_v2(pZip_filename, pArchive_name, NULL, pSize, flags, NULL); +} + +#endif /* #ifndef MINIZ_NO_STDIO */ + +#endif /* #ifndef MINIZ_NO_ARCHIVE_WRITING_APIS */ + +/* ------------------- Misc utils */ + +mz_zip_mode mz_zip_get_mode(mz_zip_archive *pZip) +{ + return pZip ? pZip->m_zip_mode : MZ_ZIP_MODE_INVALID; +} + +mz_zip_type mz_zip_get_type(mz_zip_archive *pZip) +{ + return pZip ? pZip->m_zip_type : MZ_ZIP_TYPE_INVALID; +} + +mz_zip_error mz_zip_set_last_error(mz_zip_archive *pZip, mz_zip_error err_num) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_zip_error prev_err; + + if (!pZip) + return MZ_ZIP_INVALID_PARAMETER; + + prev_err = pZip->m_last_error; + + pZip->m_last_error = err_num; + return prev_err; +} + +mz_zip_error mz_zip_peek_last_error(mz_zip_archive *pZip) +{ + if (!pZip) + return MZ_ZIP_INVALID_PARAMETER; + + return pZip->m_last_error; +} + +mz_zip_error mz_zip_clear_last_error(mz_zip_archive *pZip) +{ + return mz_zip_set_last_error(pZip, MZ_ZIP_NO_ERROR); +} + +mz_zip_error mz_zip_get_last_error(mz_zip_archive *pZip) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_zip_error prev_err; + + if (!pZip) + return MZ_ZIP_INVALID_PARAMETER; + + prev_err = pZip->m_last_error; + + pZip->m_last_error = MZ_ZIP_NO_ERROR; + return prev_err; +} + +const char *mz_zip_get_error_string(mz_zip_error mz_err) +{ + switch (mz_err) + { + case MZ_ZIP_NO_ERROR: + return "no error"; + case MZ_ZIP_UNDEFINED_ERROR: + return "undefined error"; + case MZ_ZIP_TOO_MANY_FILES: + return "too many files"; + case MZ_ZIP_FILE_TOO_LARGE: + return "file too large"; + case MZ_ZIP_UNSUPPORTED_METHOD: + return "unsupported method"; + case MZ_ZIP_UNSUPPORTED_ENCRYPTION: + return "unsupported encryption"; + case MZ_ZIP_UNSUPPORTED_FEATURE: + return "unsupported feature"; + case MZ_ZIP_FAILED_FINDING_CENTRAL_DIR: + return "failed finding central directory"; + case MZ_ZIP_NOT_AN_ARCHIVE: + return "not a ZIP archive"; + case MZ_ZIP_INVALID_HEADER_OR_CORRUPTED: + return "invalid header or archive is corrupted"; + case MZ_ZIP_UNSUPPORTED_MULTIDISK: + return "unsupported multidisk archive"; + case MZ_ZIP_DECOMPRESSION_FAILED: + return "decompression failed or archive is corrupted"; + case MZ_ZIP_COMPRESSION_FAILED: + return "compression failed"; + case MZ_ZIP_UNEXPECTED_DECOMPRESSED_SIZE: + return "unexpected decompressed size"; + case MZ_ZIP_CRC_CHECK_FAILED: + return "CRC-32 check failed"; + case MZ_ZIP_UNSUPPORTED_CDIR_SIZE: + return "unsupported central directory size"; + case MZ_ZIP_ALLOC_FAILED: + return "allocation failed"; + case MZ_ZIP_FILE_OPEN_FAILED: + return "file open failed"; + case MZ_ZIP_FILE_CREATE_FAILED: + return "file create failed"; + case MZ_ZIP_FILE_WRITE_FAILED: + return "file write failed"; + case MZ_ZIP_FILE_READ_FAILED: + return "file read failed"; + case MZ_ZIP_FILE_CLOSE_FAILED: + return "file close failed"; + case MZ_ZIP_FILE_SEEK_FAILED: + return "file seek failed"; + case MZ_ZIP_FILE_STAT_FAILED: + return "file stat failed"; + case MZ_ZIP_INVALID_PARAMETER: + return "invalid parameter"; + case MZ_ZIP_INVALID_FILENAME: + return "invalid filename"; + case MZ_ZIP_BUF_TOO_SMALL: + return "buffer too small"; + case MZ_ZIP_INTERNAL_ERROR: + return "internal error"; + case MZ_ZIP_FILE_NOT_FOUND: + return "file not found"; + case MZ_ZIP_ARCHIVE_TOO_LARGE: + return "archive is too large"; + case MZ_ZIP_VALIDATION_FAILED: + return "validation failed"; + case MZ_ZIP_WRITE_CALLBACK_FAILED: + return "write calledback failed"; + default: + break; + } + + return "unknown error"; +} + +/* Note: Just because the archive is not zip64 doesn't necessarily mean it doesn't have Zip64 extended information extra field, argh. */ +mz_bool mz_zip_is_zip64(mz_zip_archive *pZip) +{ + if ((!pZip) || (!pZip->m_pState)) + return MZ_FALSE; + + return pZip->m_pState->m_zip64; +} + +size_t mz_zip_get_central_dir_size(mz_zip_archive *pZip) +{ + if ((!pZip) || (!pZip->m_pState)) + return 0; + + return pZip->m_pState->m_central_dir.m_size; +} + +mz_uint mz_zip_reader_get_num_files(mz_zip_archive *pZip) +{ + return pZip ? pZip->m_total_files : 0; +} + +mz_uint64 mz_zip_get_archive_size(mz_zip_archive *pZip) +{ + if (!pZip) + return 0; + return pZip->m_archive_size; +} + +mz_uint64 mz_zip_get_archive_file_start_offset(mz_zip_archive *pZip) +{ + if ((!pZip) || (!pZip->m_pState)) + return 0; + return pZip->m_pState->m_file_archive_start_ofs; +} + +MZ_FILE *mz_zip_get_cfile(mz_zip_archive *pZip) +{ + if ((!pZip) || (!pZip->m_pState)) + return 0; + return pZip->m_pState->m_pFile; +} + +size_t mz_zip_read_archive_data(mz_zip_archive *pZip, mz_uint64 file_ofs, void *pBuf, size_t n) +{ + if ((!pZip) || (!pZip->m_pState) || (!pBuf) || (!pZip->m_pRead)) + return mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + + return pZip->m_pRead(pZip->m_pIO_opaque, file_ofs, pBuf, n); +} + +mz_uint mz_zip_reader_get_filename(mz_zip_archive *pZip, mz_uint file_index, char *pFilename, mz_uint filename_buf_size) +{ + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + mz_uint n; + const mz_uint8 *p = mz_zip_get_cdh(pZip, file_index); + if (!p) + { + if (filename_buf_size) + pFilename[0] = '\0'; + mz_zip_set_error(pZip, MZ_ZIP_INVALID_PARAMETER); + return 0; + } + n = MZ_READ_LE16(p + MZ_ZIP_CDH_FILENAME_LEN_OFS); + if (filename_buf_size) + { + n = MZ_MIN(n, filename_buf_size - 1); + memcpy(pFilename, p + MZ_ZIP_CENTRAL_DIR_HEADER_SIZE, n); + pFilename[n] = '\0'; + } + return n + 1; +} + +mz_bool mz_zip_reader_file_stat(mz_zip_archive *pZip, mz_uint file_index, mz_zip_archive_file_stat *pStat) +{ + return mz_zip_file_stat_internal(pZip, file_index, mz_zip_get_cdh(pZip, file_index), pStat, NULL); +} + +mz_bool mz_zip_end(mz_zip_archive *pZip) +{ + if (!pZip) + return MZ_FALSE; + + if (pZip->m_zip_mode == MZ_ZIP_MODE_READING) + return mz_zip_reader_end(pZip); +#ifndef MINIZ_NO_ARCHIVE_WRITING_APIS + else if ((pZip->m_zip_mode == MZ_ZIP_MODE_WRITING) || (pZip->m_zip_mode == MZ_ZIP_MODE_WRITING_HAS_BEEN_FINALIZED)) + return mz_zip_writer_end(pZip); +#endif + + return MZ_FALSE; +} + +#endif /*#ifndef MINIZ_NO_ARCHIVE_APIS*/ + +} \ No newline at end of file diff --git a/python/jittor/src/misc/miniz.h b/python/jittor/src/misc/miniz.h new file mode 100755 index 00000000..9f8fd724 --- /dev/null +++ b/python/jittor/src/misc/miniz.h @@ -0,0 +1,1415 @@ +/* miniz.c 2.1.0 - public domain deflate/inflate, zlib-subset, ZIP reading/writing/appending, PNG writing + See "unlicense" statement at the end of this file. + Rich Geldreich , last updated Oct. 13, 2013 + Implements RFC 1950: http://www.ietf.org/rfc/rfc1950.txt and RFC 1951: http://www.ietf.org/rfc/rfc1951.txt + + Most API's defined in miniz.c are optional. For example, to disable the archive related functions just define + MINIZ_NO_ARCHIVE_APIS, or to get rid of all stdio usage define MINIZ_NO_STDIO (see the list below for more macros). + + * Low-level Deflate/Inflate implementation notes: + + Compression: Use the "tdefl" API's. The compressor supports raw, static, and dynamic blocks, lazy or + greedy parsing, match length filtering, RLE-only, and Huffman-only streams. It performs and compresses + approximately as well as zlib. + + Decompression: Use the "tinfl" API's. The entire decompressor is implemented as a single function + coroutine: see tinfl_decompress(). It supports decompression into a 32KB (or larger power of 2) wrapping buffer, or into a memory + block large enough to hold the entire file. + + The low-level tdefl/tinfl API's do not make any use of dynamic memory allocation. + + * zlib-style API notes: + + miniz.c implements a fairly large subset of zlib. There's enough functionality present for it to be a drop-in + zlib replacement in many apps: + The z_stream struct, optional memory allocation callbacks + deflateInit/deflateInit2/deflate/deflateReset/deflateEnd/deflateBound + inflateInit/inflateInit2/inflate/inflateReset/inflateEnd + compress, compress2, compressBound, uncompress + CRC-32, Adler-32 - Using modern, minimal code size, CPU cache friendly routines. + Supports raw deflate streams or standard zlib streams with adler-32 checking. + + Limitations: + The callback API's are not implemented yet. No support for gzip headers or zlib static dictionaries. + I've tried to closely emulate zlib's various flavors of stream flushing and return status codes, but + there are no guarantees that miniz.c pulls this off perfectly. + + * PNG writing: See the tdefl_write_image_to_png_file_in_memory() function, originally written by + Alex Evans. Supports 1-4 bytes/pixel images. + + * ZIP archive API notes: + + The ZIP archive API's where designed with simplicity and efficiency in mind, with just enough abstraction to + get the job done with minimal fuss. There are simple API's to retrieve file information, read files from + existing archives, create new archives, append new files to existing archives, or clone archive data from + one archive to another. It supports archives located in memory or the heap, on disk (using stdio.h), + or you can specify custom file read/write callbacks. + + - Archive reading: Just call this function to read a single file from a disk archive: + + void *mz_zip_extract_archive_file_to_heap(const char *pZip_filename, const char *pArchive_name, + size_t *pSize, mz_uint zip_flags); + + For more complex cases, use the "mz_zip_reader" functions. Upon opening an archive, the entire central + directory is located and read as-is into memory, and subsequent file access only occurs when reading individual files. + + - Archives file scanning: The simple way is to use this function to scan a loaded archive for a specific file: + + int mz_zip_reader_locate_file(mz_zip_archive *pZip, const char *pName, const char *pComment, mz_uint flags); + + The locate operation can optionally check file comments too, which (as one example) can be used to identify + multiple versions of the same file in an archive. This function uses a simple linear search through the central + directory, so it's not very fast. + + Alternately, you can iterate through all the files in an archive (using mz_zip_reader_get_num_files()) and + retrieve detailed info on each file by calling mz_zip_reader_file_stat(). + + - Archive creation: Use the "mz_zip_writer" functions. The ZIP writer immediately writes compressed file data + to disk and builds an exact image of the central directory in memory. The central directory image is written + all at once at the end of the archive file when the archive is finalized. + + The archive writer can optionally align each file's local header and file data to any power of 2 alignment, + which can be useful when the archive will be read from optical media. Also, the writer supports placing + arbitrary data blobs at the very beginning of ZIP archives. Archives written using either feature are still + readable by any ZIP tool. + + - Archive appending: The simple way to add a single file to an archive is to call this function: + + mz_bool mz_zip_add_mem_to_archive_file_in_place(const char *pZip_filename, const char *pArchive_name, + const void *pBuf, size_t buf_size, const void *pComment, mz_uint16 comment_size, mz_uint level_and_flags); + + The archive will be created if it doesn't already exist, otherwise it'll be appended to. + Note the appending is done in-place and is not an atomic operation, so if something goes wrong + during the operation it's possible the archive could be left without a central directory (although the local + file headers and file data will be fine, so the archive will be recoverable). + + For more complex archive modification scenarios: + 1. The safest way is to use a mz_zip_reader to read the existing archive, cloning only those bits you want to + preserve into a new archive using using the mz_zip_writer_add_from_zip_reader() function (which compiles the + compressed file data as-is). When you're done, delete the old archive and rename the newly written archive, and + you're done. This is safe but requires a bunch of temporary disk space or heap memory. + + 2. Or, you can convert an mz_zip_reader in-place to an mz_zip_writer using mz_zip_writer_init_from_reader(), + append new files as needed, then finalize the archive which will write an updated central directory to the + original archive. (This is basically what mz_zip_add_mem_to_archive_file_in_place() does.) There's a + possibility that the archive's central directory could be lost with this method if anything goes wrong, though. + + - ZIP archive support limitations: + No zip64 or spanning support. Extraction functions can only handle unencrypted, stored or deflated files. + Requires streams capable of seeking. + + * This is a header file library, like stb_image.c. To get only a header file, either cut and paste the + below header, or create miniz.h, #define MINIZ_HEADER_FILE_ONLY, and then include miniz.c from it. + + * Important: For best perf. be sure to customize the below macros for your target platform: + #define MINIZ_USE_UNALIGNED_LOADS_AND_STORES 1 + #define MINIZ_LITTLE_ENDIAN 1 + #define MINIZ_HAS_64BIT_REGISTERS 1 + + * On platforms using glibc, Be sure to "#define _LARGEFILE64_SOURCE 1" before including miniz.c to ensure miniz + uses the 64-bit variants: fopen64(), stat64(), etc. Otherwise you won't be able to process large files + (i.e. 32-bit stat() fails for me on files > 0x7FFFFFFF bytes). +*/ +#pragma once + + + + + +/* Defines to completely disable specific portions of miniz.c: + If all macros here are defined the only functionality remaining will be CRC-32, adler-32, tinfl, and tdefl. */ + +/* Define MINIZ_NO_STDIO to disable all usage and any functions which rely on stdio for file I/O. */ +/*#define MINIZ_NO_STDIO */ + +/* If MINIZ_NO_TIME is specified then the ZIP archive functions will not be able to get the current time, or */ +/* get/set file times, and the C run-time funcs that get/set times won't be called. */ +/* The current downside is the times written to your archives will be from 1979. */ +#define MINIZ_NO_TIME + +/* Define MINIZ_NO_ARCHIVE_APIS to disable all ZIP archive API's. */ +/*#define MINIZ_NO_ARCHIVE_APIS */ + +/* Define MINIZ_NO_ARCHIVE_WRITING_APIS to disable all writing related ZIP archive API's. */ +/*#define MINIZ_NO_ARCHIVE_WRITING_APIS */ + +/* Define MINIZ_NO_ZLIB_APIS to remove all ZLIB-style compression/decompression API's. */ +/*#define MINIZ_NO_ZLIB_APIS */ + +/* Define MINIZ_NO_ZLIB_COMPATIBLE_NAME to disable zlib names, to prevent conflicts against stock zlib. */ +#define MINIZ_NO_ZLIB_COMPATIBLE_NAMES + +/* Define MINIZ_NO_MALLOC to disable all calls to malloc, free, and realloc. + Note if MINIZ_NO_MALLOC is defined then the user must always provide custom user alloc/free/realloc + callbacks to the zlib and archive API's, and a few stand-alone helper API's which don't provide custom user + functions (such as tdefl_compress_mem_to_heap() and tinfl_decompress_mem_to_heap()) won't work. */ +/*#define MINIZ_NO_MALLOC */ + +#if defined(__TINYC__) && (defined(__linux) || defined(__linux__)) +/* TODO: Work around "error: include file 'sys\utime.h' when compiling with tcc on Linux */ +#define MINIZ_NO_TIME +#endif + +#define MINIZ_DISABLE_ZIP_READER_CRC32_CHECKS + +#include + +#include +#include +#include +#include +#include + +#include "common.h" +#include "ops/op_register.h" +#include "var_holder.h" +#include "profiler/simple_profiler.h" + +#if !defined(MINIZ_NO_TIME) && !defined(MINIZ_NO_ARCHIVE_APIS) +#include +#endif + +#if defined(_M_IX86) || defined(_M_X64) || defined(__i386__) || defined(__i386) || defined(__i486__) || defined(__i486) || defined(i386) || defined(__ia64__) || defined(__x86_64__) +/* MINIZ_X86_OR_X64_CPU is only used to help set the below macros. */ +#define MINIZ_X86_OR_X64_CPU 1 +#else +#define MINIZ_X86_OR_X64_CPU 0 +#endif + +#if (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__) || MINIZ_X86_OR_X64_CPU +/* Set MINIZ_LITTLE_ENDIAN to 1 if the processor is little endian. */ +#define MINIZ_LITTLE_ENDIAN 1 +#else +#define MINIZ_LITTLE_ENDIAN 0 +#endif + +/* Set MINIZ_USE_UNALIGNED_LOADS_AND_STORES only if not set */ +#if !defined(MINIZ_USE_UNALIGNED_LOADS_AND_STORES) +#if MINIZ_X86_OR_X64_CPU +/* Set MINIZ_USE_UNALIGNED_LOADS_AND_STORES to 1 on CPU's that permit efficient integer loads and stores from unaligned addresses. */ +/* zdevito: ASAN doesn't like unligned loads and stores, and -O3 optimizes the unoptimized code pattern away anyawy */ +#define MINIZ_USE_UNALIGNED_LOADS_AND_STORES 0 +/* zdevito: ASAN doesn't like unligned loads and stores, and -O3 optimizes the unoptimized code pattern away anyawy */ +/*#define MINIZ_UNALIGNED_USE_MEMCPY*/ +#else +#define MINIZ_USE_UNALIGNED_LOADS_AND_STORES 0 +#endif +#endif + +#if defined(_M_X64) || defined(_WIN64) || defined(__MINGW64__) || defined(_LP64) || defined(__LP64__) || defined(__ia64__) || defined(__x86_64__) +/* Set MINIZ_HAS_64BIT_REGISTERS to 1 if operations on 64-bit integers are reasonably fast (and don't involve compiler generated calls to helper functions). */ +#define MINIZ_HAS_64BIT_REGISTERS 1 +#else +#define MINIZ_HAS_64BIT_REGISTERS 0 +#endif + +namespace jittor { + +/* ------------------- zlib-style API Definitions. */ + +/* For more compatibility with zlib, miniz.c uses unsigned long for some parameters/struct members. Beware: mz_ulong can be either 32 or 64-bits! */ +typedef unsigned long mz_ulong; + +/* mz_free() internally uses the MZ_FREE() macro (which by default calls free() unless you've modified the MZ_MALLOC macro) to release a block allocated from the heap. */ +void mz_free(void *p); + +#define MZ_ADLER32_INIT (1) +/* mz_adler32() returns the initial adler-32 value to use when called with ptr==NULL. */ +mz_ulong mz_adler32(mz_ulong adler, const unsigned char *ptr, size_t buf_len); + +#define MZ_CRC32_INIT (0) +/* mz_crc32() returns the initial CRC-32 value to use when called with ptr==NULL. */ +mz_ulong mz_crc32(mz_ulong crc, const unsigned char *ptr, size_t buf_len); + +/* Compression strategies. */ +enum +{ + MZ_DEFAULT_STRATEGY = 0, + MZ_FILTERED = 1, + MZ_HUFFMAN_ONLY = 2, + MZ_RLE = 3, + MZ_FIXED = 4 +}; + +/* Method */ +#define MZ_DEFLATED 8 + +/* Heap allocation callbacks. +Note that mz_alloc_func parameter types purpsosely differ from zlib's: items/size is size_t, not unsigned long. */ +typedef void *(*mz_alloc_func)(void *opaque, size_t items, size_t size); +typedef void (*mz_free_func)(void *opaque, void *address); +typedef void *(*mz_realloc_func)(void *opaque, void *address, size_t items, size_t size); + +/* Compression levels: 0-9 are the standard zlib-style levels, 10 is best possible compression (not zlib compatible, and may be very slow), MZ_DEFAULT_COMPRESSION=MZ_DEFAULT_LEVEL. */ +enum +{ + MZ_NO_COMPRESSION = 0, + MZ_BEST_SPEED = 1, + MZ_BEST_COMPRESSION = 9, + MZ_UBER_COMPRESSION = 10, + MZ_DEFAULT_LEVEL = 6, + MZ_DEFAULT_COMPRESSION = -1 +}; + +#define MZ_VERSION "10.1.0" +#define MZ_VERNUM 0xA100 +#define MZ_VER_MAJOR 10 +#define MZ_VER_MINOR 1 +#define MZ_VER_REVISION 0 +#define MZ_VER_SUBREVISION 0 + +#ifndef MINIZ_NO_ZLIB_APIS + +/* Flush values. For typical usage you only need MZ_NO_FLUSH and MZ_FINISH. The other values are for advanced use (refer to the zlib docs). */ +enum +{ + MZ_NO_FLUSH = 0, + MZ_PARTIAL_FLUSH = 1, + MZ_SYNC_FLUSH = 2, + MZ_FULL_FLUSH = 3, + MZ_FINISH = 4, + MZ_BLOCK = 5 +}; + +/* Return status codes. MZ_PARAM_ERROR is non-standard. */ +enum +{ + MZ_OK = 0, + MZ_STREAM_END = 1, + MZ_NEED_DICT = 2, + MZ_ERRNO = -1, + MZ_STREAM_ERROR = -2, + MZ_DATA_ERROR = -3, + MZ_MEM_ERROR = -4, + MZ_BUF_ERROR = -5, + MZ_VERSION_ERROR = -6, + MZ_PARAM_ERROR = -10000 +}; + +/* Window bits */ +#define MZ_DEFAULT_WINDOW_BITS 15 + +struct mz_internal_state; + +/* Compression/decompression stream struct. */ +typedef struct mz_stream_s +{ + const unsigned char *next_in; /* pointer to next byte to read */ + unsigned int avail_in; /* number of bytes available at next_in */ + mz_ulong total_in; /* total number of bytes consumed so far */ + + unsigned char *next_out; /* pointer to next byte to write */ + unsigned int avail_out; /* number of bytes that can be written to next_out */ + mz_ulong total_out; /* total number of bytes produced so far */ + + char *msg; /* error msg (unused) */ + struct mz_internal_state *state; /* internal state, allocated by zalloc/zfree */ + + mz_alloc_func zalloc; /* optional heap allocation function (defaults to malloc) */ + mz_free_func zfree; /* optional heap free function (defaults to free) */ + void *opaque; /* heap alloc function user pointer */ + + int data_type; /* data_type (unused) */ + mz_ulong adler; /* adler32 of the source or uncompressed data */ + mz_ulong reserved; /* not used */ +} mz_stream; + +typedef mz_stream *mz_streamp; + +/* Returns the version string of miniz.c. */ +const char *mz_version(void); + +/* mz_deflateInit() initializes a compressor with default options: */ +/* Parameters: */ +/* pStream must point to an initialized mz_stream struct. */ +/* level must be between [MZ_NO_COMPRESSION, MZ_BEST_COMPRESSION]. */ +/* level 1 enables a specially optimized compression function that's been optimized purely for performance, not ratio. */ +/* (This special func. is currently only enabled when MINIZ_USE_UNALIGNED_LOADS_AND_STORES and MINIZ_LITTLE_ENDIAN are defined.) */ +/* Return values: */ +/* MZ_OK on success. */ +/* MZ_STREAM_ERROR if the stream is bogus. */ +/* MZ_PARAM_ERROR if the input parameters are bogus. */ +/* MZ_MEM_ERROR on out of memory. */ +int mz_deflateInit(mz_streamp pStream, int level); + +/* mz_deflateInit2() is like mz_deflate(), except with more control: */ +/* Additional parameters: */ +/* method must be MZ_DEFLATED */ +/* window_bits must be MZ_DEFAULT_WINDOW_BITS (to wrap the deflate stream with zlib header/adler-32 footer) or -MZ_DEFAULT_WINDOW_BITS (raw deflate/no header or footer) */ +/* mem_level must be between [1, 9] (it's checked but ignored by miniz.c) */ +int mz_deflateInit2(mz_streamp pStream, int level, int method, int window_bits, int mem_level, int strategy); + +/* Quickly resets a compressor without having to reallocate anything. Same as calling mz_deflateEnd() followed by mz_deflateInit()/mz_deflateInit2(). */ +int mz_deflateReset(mz_streamp pStream); + +/* mz_deflate() compresses the input to output, consuming as much of the input and producing as much output as possible. */ +/* Parameters: */ +/* pStream is the stream to read from and write to. You must initialize/update the next_in, avail_in, next_out, and avail_out members. */ +/* flush may be MZ_NO_FLUSH, MZ_PARTIAL_FLUSH/MZ_SYNC_FLUSH, MZ_FULL_FLUSH, or MZ_FINISH. */ +/* Return values: */ +/* MZ_OK on success (when flushing, or if more input is needed but not available, and/or there's more output to be written but the output buffer is full). */ +/* MZ_STREAM_END if all input has been consumed and all output bytes have been written. Don't call mz_deflate() on the stream anymore. */ +/* MZ_STREAM_ERROR if the stream is bogus. */ +/* MZ_PARAM_ERROR if one of the parameters is invalid. */ +/* MZ_BUF_ERROR if no forward progress is possible because the input and/or output buffers are empty. (Fill up the input buffer or free up some output space and try again.) */ +int mz_deflate(mz_streamp pStream, int flush); + +/* mz_deflateEnd() deinitializes a compressor: */ +/* Return values: */ +/* MZ_OK on success. */ +/* MZ_STREAM_ERROR if the stream is bogus. */ +int mz_deflateEnd(mz_streamp pStream); + +/* mz_deflateBound() returns a (very) conservative upper bound on the amount of data that could be generated by deflate(), assuming flush is set to only MZ_NO_FLUSH or MZ_FINISH. */ +mz_ulong mz_deflateBound(mz_streamp pStream, mz_ulong source_len); + +/* Single-call compression functions mz_compress() and mz_compress2(): */ +/* Returns MZ_OK on success, or one of the error codes from mz_deflate() on failure. */ +int mz_compress(unsigned char *pDest, mz_ulong *pDest_len, const unsigned char *pSource, mz_ulong source_len); +int mz_compress2(unsigned char *pDest, mz_ulong *pDest_len, const unsigned char *pSource, mz_ulong source_len, int level); + +/* mz_compressBound() returns a (very) conservative upper bound on the amount of data that could be generated by calling mz_compress(). */ +mz_ulong mz_compressBound(mz_ulong source_len); + +/* Initializes a decompressor. */ +int mz_inflateInit(mz_streamp pStream); + +/* mz_inflateInit2() is like mz_inflateInit() with an additional option that controls the window size and whether or not the stream has been wrapped with a zlib header/footer: */ +/* window_bits must be MZ_DEFAULT_WINDOW_BITS (to parse zlib header/footer) or -MZ_DEFAULT_WINDOW_BITS (raw deflate). */ +int mz_inflateInit2(mz_streamp pStream, int window_bits); + +/* Quickly resets a compressor without having to reallocate anything. Same as calling mz_inflateEnd() followed by mz_inflateInit()/mz_inflateInit2(). */ +int mz_inflateReset(mz_streamp pStream); + +/* Decompresses the input stream to the output, consuming only as much of the input as needed, and writing as much to the output as possible. */ +/* Parameters: */ +/* pStream is the stream to read from and write to. You must initialize/update the next_in, avail_in, next_out, and avail_out members. */ +/* flush may be MZ_NO_FLUSH, MZ_SYNC_FLUSH, or MZ_FINISH. */ +/* On the first call, if flush is MZ_FINISH it's assumed the input and output buffers are both sized large enough to decompress the entire stream in a single call (this is slightly faster). */ +/* MZ_FINISH implies that there are no more source bytes available beside what's already in the input buffer, and that the output buffer is large enough to hold the rest of the decompressed data. */ +/* Return values: */ +/* MZ_OK on success. Either more input is needed but not available, and/or there's more output to be written but the output buffer is full. */ +/* MZ_STREAM_END if all needed input has been consumed and all output bytes have been written. For zlib streams, the adler-32 of the decompressed data has also been verified. */ +/* MZ_STREAM_ERROR if the stream is bogus. */ +/* MZ_DATA_ERROR if the deflate stream is invalid. */ +/* MZ_PARAM_ERROR if one of the parameters is invalid. */ +/* MZ_BUF_ERROR if no forward progress is possible because the input buffer is empty but the inflater needs more input to continue, or if the output buffer is not large enough. Call mz_inflate() again */ +/* with more input data, or with more room in the output buffer (except when using single call decompression, described above). */ +int mz_inflate(mz_streamp pStream, int flush); + +/* Deinitializes a decompressor. */ +int mz_inflateEnd(mz_streamp pStream); + +/* Single-call decompression. */ +/* Returns MZ_OK on success, or one of the error codes from mz_inflate() on failure. */ +int mz_uncompress(unsigned char *pDest, mz_ulong *pDest_len, const unsigned char *pSource, mz_ulong source_len); + +/* Returns a string description of the specified error code, or NULL if the error code is invalid. */ +const char *mz_error(int err); + +/* Redefine zlib-compatible names to miniz equivalents, so miniz.c can be used as a drop-in replacement for the subset of zlib that miniz.c supports. */ +/* Define MINIZ_NO_ZLIB_COMPATIBLE_NAMES to disable zlib-compatibility if you use zlib in the same project. */ +#ifndef MINIZ_NO_ZLIB_COMPATIBLE_NAMES +typedef unsigned char Byte; +typedef unsigned int uInt; +typedef mz_ulong uLong; +typedef Byte Bytef; +typedef uInt uIntf; +typedef char charf; +typedef int intf; +typedef void *voidpf; +typedef uLong uLongf; +typedef void *voidp; +typedef void *const voidpc; +#define Z_NULL 0 +#define Z_NO_FLUSH MZ_NO_FLUSH +#define Z_PARTIAL_FLUSH MZ_PARTIAL_FLUSH +#define Z_SYNC_FLUSH MZ_SYNC_FLUSH +#define Z_FULL_FLUSH MZ_FULL_FLUSH +#define Z_FINISH MZ_FINISH +#define Z_BLOCK MZ_BLOCK +#define Z_OK MZ_OK +#define Z_STREAM_END MZ_STREAM_END +#define Z_NEED_DICT MZ_NEED_DICT +#define Z_ERRNO MZ_ERRNO +#define Z_STREAM_ERROR MZ_STREAM_ERROR +#define Z_DATA_ERROR MZ_DATA_ERROR +#define Z_MEM_ERROR MZ_MEM_ERROR +#define Z_BUF_ERROR MZ_BUF_ERROR +#define Z_VERSION_ERROR MZ_VERSION_ERROR +#define Z_PARAM_ERROR MZ_PARAM_ERROR +#define Z_NO_COMPRESSION MZ_NO_COMPRESSION +#define Z_BEST_SPEED MZ_BEST_SPEED +#define Z_BEST_COMPRESSION MZ_BEST_COMPRESSION +#define Z_DEFAULT_COMPRESSION MZ_DEFAULT_COMPRESSION +#define Z_DEFAULT_STRATEGY MZ_DEFAULT_STRATEGY +#define Z_FILTERED MZ_FILTERED +#define Z_HUFFMAN_ONLY MZ_HUFFMAN_ONLY +#define Z_RLE MZ_RLE +#define Z_FIXED MZ_FIXED +#define Z_DEFLATED MZ_DEFLATED +#define Z_DEFAULT_WINDOW_BITS MZ_DEFAULT_WINDOW_BITS +#define alloc_func mz_alloc_func +#define free_func mz_free_func +#define internal_state mz_internal_state +#define z_stream mz_stream +#define deflateInit mz_deflateInit +#define deflateInit2 mz_deflateInit2 +#define deflateReset mz_deflateReset +#define deflate mz_deflate +#define deflateEnd mz_deflateEnd +#define deflateBound mz_deflateBound +#define compress mz_compress +#define compress2 mz_compress2 +#define compressBound mz_compressBound +#define inflateInit mz_inflateInit +#define inflateInit2 mz_inflateInit2 +#define inflateReset mz_inflateReset +#define inflate mz_inflate +#define inflateEnd mz_inflateEnd +#define uncompress mz_uncompress +#define crc32 mz_crc32 +#define adler32 mz_adler32 +#define MAX_WBITS 15 +#define MAX_MEM_LEVEL 9 +#define zError mz_error +#define ZLIB_VERSION MZ_VERSION +#define ZLIB_VERNUM MZ_VERNUM +#define ZLIB_VER_MAJOR MZ_VER_MAJOR +#define ZLIB_VER_MINOR MZ_VER_MINOR +#define ZLIB_VER_REVISION MZ_VER_REVISION +#define ZLIB_VER_SUBREVISION MZ_VER_SUBREVISION +#define zlibVersion mz_version +#define zlib_version mz_version() +#endif /* #ifndef MINIZ_NO_ZLIB_COMPATIBLE_NAMES */ + +#endif /* MINIZ_NO_ZLIB_APIS */ + +#pragma once + +/* ------------------- Types and macros */ +typedef unsigned char mz_uint8; +typedef signed short mz_int16; +typedef unsigned short mz_uint16; +typedef unsigned int mz_uint32; +typedef unsigned int mz_uint; +typedef int64_t mz_int64; +typedef uint64_t mz_uint64; +typedef int mz_bool; + +#define MZ_FALSE (0) +#define MZ_TRUE (1) + +/* Works around MSVC's spammy "warning C4127: conditional expression is constant" message. */ +#ifdef _MSC_VER +#define MZ_MACRO_END while (0, 0) +#else +#define MZ_MACRO_END while (0) +#endif + +#ifdef MINIZ_NO_STDIO +#define MZ_FILE void * +#else +#define MZ_FILE FILE +#endif /* #ifdef MINIZ_NO_STDIO */ + +#ifdef MINIZ_NO_TIME +typedef struct mz_dummy_time_t_tag +{ + int m_dummy; +} mz_dummy_time_t; +#define MZ_TIME_T mz_dummy_time_t +#else +#define MZ_TIME_T time_t +#endif + +#define MZ_ASSERT(x) assert(x) + +#ifdef MINIZ_NO_MALLOC +#define MZ_MALLOC(x) NULL +#define MZ_FREE(x) (void)x, ((void)0) +#define MZ_REALLOC(p, x) NULL +#else +#define MZ_MALLOC(x) malloc(x) +#define MZ_FREE(x) free(x) +#define MZ_REALLOC(p, x) realloc(p, x) +#endif + +#define MZ_MAX(a, b) (((a) > (b)) ? (a) : (b)) +#define MZ_MIN(a, b) (((a) < (b)) ? (a) : (b)) +#define MZ_CLEAR_OBJ(obj) memset(&(obj), 0, sizeof(obj)) + +#if MINIZ_USE_UNALIGNED_LOADS_AND_STORES && MINIZ_LITTLE_ENDIAN +#define MZ_READ_LE16(p) *((const mz_uint16 *)(p)) +#define MZ_READ_LE32(p) *((const mz_uint32 *)(p)) +#else +#define MZ_READ_LE16(p) ((mz_uint32)(((const mz_uint8 *)(p))[0]) | ((mz_uint32)(((const mz_uint8 *)(p))[1]) << 8U)) +#define MZ_READ_LE32(p) ((mz_uint32)(((const mz_uint8 *)(p))[0]) | ((mz_uint32)(((const mz_uint8 *)(p))[1]) << 8U) | ((mz_uint32)(((const mz_uint8 *)(p))[2]) << 16U) | ((mz_uint32)(((const mz_uint8 *)(p))[3]) << 24U)) +#endif + +#define MZ_READ_LE64(p) (((mz_uint64)MZ_READ_LE32(p)) | (((mz_uint64)MZ_READ_LE32((const mz_uint8 *)(p) + sizeof(mz_uint32))) << 32U)) + +#ifdef _MSC_VER +#define MZ_FORCEINLINE __forceinline +#elif defined(__GNUC__) +#define MZ_FORCEINLINE __inline__ __attribute__((__always_inline__)) +#else +#define MZ_FORCEINLINE inline +#endif + +extern void *miniz_def_alloc_func(void *opaque, size_t items, size_t size); +extern void miniz_def_free_func(void *opaque, void *address); +extern void *miniz_def_realloc_func(void *opaque, void *address, size_t items, size_t size); + +#define MZ_UINT16_MAX (0xFFFFU) +#define MZ_UINT32_MAX (0xFFFFFFFFU) + +#pragma once + + +/* ------------------- Low-level Compression API Definitions */ + +/* Set TDEFL_LESS_MEMORY to 1 to use less memory (compression will be slightly slower, and raw/dynamic blocks will be output more frequently). */ +#define TDEFL_LESS_MEMORY 0 + +/* tdefl_init() compression flags logically OR'd together (low 12 bits contain the max. number of probes per dictionary search): */ +/* TDEFL_DEFAULT_MAX_PROBES: The compressor defaults to 128 dictionary probes per dictionary search. 0=Huffman only, 1=Huffman+LZ (fastest/crap compression), 4095=Huffman+LZ (slowest/best compression). */ +enum +{ + TDEFL_HUFFMAN_ONLY = 0, + TDEFL_DEFAULT_MAX_PROBES = 128, + TDEFL_MAX_PROBES_MASK = 0xFFF +}; + +/* TDEFL_WRITE_ZLIB_HEADER: If set, the compressor outputs a zlib header before the deflate data, and the Adler-32 of the source data at the end. Otherwise, you'll get raw deflate data. */ +/* TDEFL_COMPUTE_ADLER32: Always compute the adler-32 of the input data (even when not writing zlib headers). */ +/* TDEFL_GREEDY_PARSING_FLAG: Set to use faster greedy parsing, instead of more efficient lazy parsing. */ +/* TDEFL_NONDETERMINISTIC_PARSING_FLAG: Enable to decrease the compressor's initialization time to the minimum, but the output may vary from run to run given the same input (depending on the contents of memory). */ +/* TDEFL_RLE_MATCHES: Only look for RLE matches (matches with a distance of 1) */ +/* TDEFL_FILTER_MATCHES: Discards matches <= 5 chars if enabled. */ +/* TDEFL_FORCE_ALL_STATIC_BLOCKS: Disable usage of optimized Huffman tables. */ +/* TDEFL_FORCE_ALL_RAW_BLOCKS: Only use raw (uncompressed) deflate blocks. */ +/* The low 12 bits are reserved to control the max # of hash probes per dictionary lookup (see TDEFL_MAX_PROBES_MASK). */ +enum +{ + TDEFL_WRITE_ZLIB_HEADER = 0x01000, + TDEFL_COMPUTE_ADLER32 = 0x02000, + TDEFL_GREEDY_PARSING_FLAG = 0x04000, + TDEFL_NONDETERMINISTIC_PARSING_FLAG = 0x08000, + TDEFL_RLE_MATCHES = 0x10000, + TDEFL_FILTER_MATCHES = 0x20000, + TDEFL_FORCE_ALL_STATIC_BLOCKS = 0x40000, + TDEFL_FORCE_ALL_RAW_BLOCKS = 0x80000 +}; + +/* High level compression functions: */ +/* tdefl_compress_mem_to_heap() compresses a block in memory to a heap block allocated via malloc(). */ +/* On entry: */ +/* pSrc_buf, src_buf_len: Pointer and size of source block to compress. */ +/* flags: The max match finder probes (default is 128) logically OR'd against the above flags. Higher probes are slower but improve compression. */ +/* On return: */ +/* Function returns a pointer to the compressed data, or NULL on failure. */ +/* *pOut_len will be set to the compressed data's size, which could be larger than src_buf_len on uncompressible data. */ +/* The caller must free() the returned block when it's no longer needed. */ +void *tdefl_compress_mem_to_heap(const void *pSrc_buf, size_t src_buf_len, size_t *pOut_len, int flags); + +/* tdefl_compress_mem_to_mem() compresses a block in memory to another block in memory. */ +/* Returns 0 on failure. */ +size_t tdefl_compress_mem_to_mem(void *pOut_buf, size_t out_buf_len, const void *pSrc_buf, size_t src_buf_len, int flags); + +/* Compresses an image to a compressed PNG file in memory. */ +/* On entry: */ +/* pImage, w, h, and num_chans describe the image to compress. num_chans may be 1, 2, 3, or 4. */ +/* The image pitch in bytes per scanline will be w*num_chans. The leftmost pixel on the top scanline is stored first in memory. */ +/* level may range from [0,10], use MZ_NO_COMPRESSION, MZ_BEST_SPEED, MZ_BEST_COMPRESSION, etc. or a decent default is MZ_DEFAULT_LEVEL */ +/* If flip is true, the image will be flipped on the Y axis (useful for OpenGL apps). */ +/* On return: */ +/* Function returns a pointer to the compressed data, or NULL on failure. */ +/* *pLen_out will be set to the size of the PNG image file. */ +/* The caller must mz_free() the returned heap block (which will typically be larger than *pLen_out) when it's no longer needed. */ +void *tdefl_write_image_to_png_file_in_memory_ex(const void *pImage, int w, int h, int num_chans, size_t *pLen_out, mz_uint level, mz_bool flip); +void *tdefl_write_image_to_png_file_in_memory(const void *pImage, int w, int h, int num_chans, size_t *pLen_out); + +/* Output stream interface. The compressor uses this interface to write compressed data. It'll typically be called TDEFL_OUT_BUF_SIZE at a time. */ +typedef mz_bool (*tdefl_put_buf_func_ptr)(const void *pBuf, int len, void *pUser); + +/* tdefl_compress_mem_to_output() compresses a block to an output stream. The above helpers use this function internally. */ +mz_bool tdefl_compress_mem_to_output(const void *pBuf, size_t buf_len, tdefl_put_buf_func_ptr pPut_buf_func, void *pPut_buf_user, int flags); + +enum +{ + TDEFL_MAX_HUFF_TABLES = 3, + TDEFL_MAX_HUFF_SYMBOLS_0 = 288, + TDEFL_MAX_HUFF_SYMBOLS_1 = 32, + TDEFL_MAX_HUFF_SYMBOLS_2 = 19, + TDEFL_LZ_DICT_SIZE = 32768, + TDEFL_LZ_DICT_SIZE_MASK = TDEFL_LZ_DICT_SIZE - 1, + TDEFL_MIN_MATCH_LEN = 3, + TDEFL_MAX_MATCH_LEN = 258 +}; + +/* TDEFL_OUT_BUF_SIZE MUST be large enough to hold a single entire compressed output block (using static/fixed Huffman codes). */ +#if TDEFL_LESS_MEMORY +enum +{ + TDEFL_LZ_CODE_BUF_SIZE = 24 * 1024, + TDEFL_OUT_BUF_SIZE = (TDEFL_LZ_CODE_BUF_SIZE * 13) / 10, + TDEFL_MAX_HUFF_SYMBOLS = 288, + TDEFL_LZ_HASH_BITS = 12, + TDEFL_LEVEL1_HASH_SIZE_MASK = 4095, + TDEFL_LZ_HASH_SHIFT = (TDEFL_LZ_HASH_BITS + 2) / 3, + TDEFL_LZ_HASH_SIZE = 1 << TDEFL_LZ_HASH_BITS +}; +#else +enum +{ + TDEFL_LZ_CODE_BUF_SIZE = 64 * 1024, + TDEFL_OUT_BUF_SIZE = (TDEFL_LZ_CODE_BUF_SIZE * 13) / 10, + TDEFL_MAX_HUFF_SYMBOLS = 288, + TDEFL_LZ_HASH_BITS = 15, + TDEFL_LEVEL1_HASH_SIZE_MASK = 4095, + TDEFL_LZ_HASH_SHIFT = (TDEFL_LZ_HASH_BITS + 2) / 3, + TDEFL_LZ_HASH_SIZE = 1 << TDEFL_LZ_HASH_BITS +}; +#endif + +/* The low-level tdefl functions below may be used directly if the above helper functions aren't flexible enough. The low-level functions don't make any heap allocations, unlike the above helper functions. */ +typedef enum { + TDEFL_STATUS_BAD_PARAM = -2, + TDEFL_STATUS_PUT_BUF_FAILED = -1, + TDEFL_STATUS_OKAY = 0, + TDEFL_STATUS_DONE = 1 +} tdefl_status; + +/* Must map to MZ_NO_FLUSH, MZ_SYNC_FLUSH, etc. enums */ +typedef enum { + TDEFL_NO_FLUSH = 0, + TDEFL_SYNC_FLUSH = 2, + TDEFL_FULL_FLUSH = 3, + TDEFL_FINISH = 4 +} tdefl_flush; + +/* tdefl's compression state structure. */ +typedef struct +{ + tdefl_put_buf_func_ptr m_pPut_buf_func; + void *m_pPut_buf_user; + mz_uint m_flags, m_max_probes[2]; + int m_greedy_parsing; + mz_uint m_adler32, m_lookahead_pos, m_lookahead_size, m_dict_size; + mz_uint8 *m_pLZ_code_buf, *m_pLZ_flags, *m_pOutput_buf, *m_pOutput_buf_end; + mz_uint m_num_flags_left, m_total_lz_bytes, m_lz_code_buf_dict_pos, m_bits_in, m_bit_buffer; + mz_uint m_saved_match_dist, m_saved_match_len, m_saved_lit, m_output_flush_ofs, m_output_flush_remaining, m_finished, m_block_index, m_wants_to_finish; + tdefl_status m_prev_return_status; + const void *m_pIn_buf; + void *m_pOut_buf; + size_t *m_pIn_buf_size, *m_pOut_buf_size; + tdefl_flush m_flush; + const mz_uint8 *m_pSrc; + size_t m_src_buf_left, m_out_buf_ofs; + mz_uint8 m_dict[TDEFL_LZ_DICT_SIZE + TDEFL_MAX_MATCH_LEN - 1]; + mz_uint16 m_huff_count[TDEFL_MAX_HUFF_TABLES][TDEFL_MAX_HUFF_SYMBOLS]; + mz_uint16 m_huff_codes[TDEFL_MAX_HUFF_TABLES][TDEFL_MAX_HUFF_SYMBOLS]; + mz_uint8 m_huff_code_sizes[TDEFL_MAX_HUFF_TABLES][TDEFL_MAX_HUFF_SYMBOLS]; + mz_uint8 m_lz_code_buf[TDEFL_LZ_CODE_BUF_SIZE]; + mz_uint16 m_next[TDEFL_LZ_DICT_SIZE]; + mz_uint16 m_hash[TDEFL_LZ_HASH_SIZE]; + mz_uint8 m_output_buf[TDEFL_OUT_BUF_SIZE]; +} tdefl_compressor; + +/* Initializes the compressor. */ +/* There is no corresponding deinit() function because the tdefl API's do not dynamically allocate memory. */ +/* pBut_buf_func: If NULL, output data will be supplied to the specified callback. In this case, the user should call the tdefl_compress_buffer() API for compression. */ +/* If pBut_buf_func is NULL the user should always call the tdefl_compress() API. */ +/* flags: See the above enums (TDEFL_HUFFMAN_ONLY, TDEFL_WRITE_ZLIB_HEADER, etc.) */ +tdefl_status tdefl_init(tdefl_compressor *d, tdefl_put_buf_func_ptr pPut_buf_func, void *pPut_buf_user, int flags); + +/* Compresses a block of data, consuming as much of the specified input buffer as possible, and writing as much compressed data to the specified output buffer as possible. */ +tdefl_status tdefl_compress(tdefl_compressor *d, const void *pIn_buf, size_t *pIn_buf_size, void *pOut_buf, size_t *pOut_buf_size, tdefl_flush flush); + +/* tdefl_compress_buffer() is only usable when the tdefl_init() is called with a non-NULL tdefl_put_buf_func_ptr. */ +/* tdefl_compress_buffer() always consumes the entire input buffer. */ +tdefl_status tdefl_compress_buffer(tdefl_compressor *d, const void *pIn_buf, size_t in_buf_size, tdefl_flush flush); + +tdefl_status tdefl_get_prev_return_status(tdefl_compressor *d); +mz_uint32 tdefl_get_adler32(tdefl_compressor *d); + +/* Create tdefl_compress() flags given zlib-style compression parameters. */ +/* level may range from [0,10] (where 10 is absolute max compression, but may be much slower on some files) */ +/* window_bits may be -15 (raw deflate) or 15 (zlib) */ +/* strategy may be either MZ_DEFAULT_STRATEGY, MZ_FILTERED, MZ_HUFFMAN_ONLY, MZ_RLE, or MZ_FIXED */ +mz_uint tdefl_create_comp_flags_from_zip_params(int level, int window_bits, int strategy); + +#ifndef MINIZ_NO_MALLOC +/* Allocate the tdefl_compressor structure in C so that */ +/* non-C language bindings to tdefl_ API don't need to worry about */ +/* structure size and allocation mechanism. */ +tdefl_compressor *tdefl_compressor_alloc(void); +void tdefl_compressor_free(tdefl_compressor *pComp); +#endif + +#pragma once + +/* ------------------- Low-level Decompression API Definitions */ + +/* Decompression flags used by tinfl_decompress(). */ +/* TINFL_FLAG_PARSE_ZLIB_HEADER: If set, the input has a valid zlib header and ends with an adler32 checksum (it's a valid zlib stream). Otherwise, the input is a raw deflate stream. */ +/* TINFL_FLAG_HAS_MORE_INPUT: If set, there are more input bytes available beyond the end of the supplied input buffer. If clear, the input buffer contains all remaining input. */ +/* TINFL_FLAG_USING_NON_WRAPPING_OUTPUT_BUF: If set, the output buffer is large enough to hold the entire decompressed stream. If clear, the output buffer is at least the size of the dictionary (typically 32KB). */ +/* TINFL_FLAG_COMPUTE_ADLER32: Force adler-32 checksum computation of the decompressed bytes. */ +enum +{ + TINFL_FLAG_PARSE_ZLIB_HEADER = 1, + TINFL_FLAG_HAS_MORE_INPUT = 2, + TINFL_FLAG_USING_NON_WRAPPING_OUTPUT_BUF = 4, + TINFL_FLAG_COMPUTE_ADLER32 = 8 +}; + +/* High level decompression functions: */ +/* tinfl_decompress_mem_to_heap() decompresses a block in memory to a heap block allocated via malloc(). */ +/* On entry: */ +/* pSrc_buf, src_buf_len: Pointer and size of the Deflate or zlib source data to decompress. */ +/* On return: */ +/* Function returns a pointer to the decompressed data, or NULL on failure. */ +/* *pOut_len will be set to the decompressed data's size, which could be larger than src_buf_len on uncompressible data. */ +/* The caller must call mz_free() on the returned block when it's no longer needed. */ +void *tinfl_decompress_mem_to_heap(const void *pSrc_buf, size_t src_buf_len, size_t *pOut_len, int flags); + +/* tinfl_decompress_mem_to_mem() decompresses a block in memory to another block in memory. */ +/* Returns TINFL_DECOMPRESS_MEM_TO_MEM_FAILED on failure, or the number of bytes written on success. */ +#define TINFL_DECOMPRESS_MEM_TO_MEM_FAILED ((size_t)(-1)) +size_t tinfl_decompress_mem_to_mem(void *pOut_buf, size_t out_buf_len, const void *pSrc_buf, size_t src_buf_len, int flags); + +/* tinfl_decompress_mem_to_callback() decompresses a block in memory to an internal 32KB buffer, and a user provided callback function will be called to flush the buffer. */ +/* Returns 1 on success or 0 on failure. */ +typedef int (*tinfl_put_buf_func_ptr)(const void *pBuf, int len, void *pUser); +int tinfl_decompress_mem_to_callback(const void *pIn_buf, size_t *pIn_buf_size, tinfl_put_buf_func_ptr pPut_buf_func, void *pPut_buf_user, int flags); + +struct tinfl_decompressor_tag; +typedef struct tinfl_decompressor_tag tinfl_decompressor; + +#ifndef MINIZ_NO_MALLOC +/* Allocate the tinfl_decompressor structure in C so that */ +/* non-C language bindings to tinfl_ API don't need to worry about */ +/* structure size and allocation mechanism. */ +tinfl_decompressor *tinfl_decompressor_alloc(void); +void tinfl_decompressor_free(tinfl_decompressor *pDecomp); +#endif + +/* Max size of LZ dictionary. */ +#define TINFL_LZ_DICT_SIZE 32768 + +/* Return status. */ +typedef enum { + /* This flags indicates the inflator needs 1 or more input bytes to make forward progress, but the caller is indicating that no more are available. The compressed data */ + /* is probably corrupted. If you call the inflator again with more bytes it'll try to continue processing the input but this is a BAD sign (either the data is corrupted or you called it incorrectly). */ + /* If you call it again with no input you'll just get TINFL_STATUS_FAILED_CANNOT_MAKE_PROGRESS again. */ + TINFL_STATUS_FAILED_CANNOT_MAKE_PROGRESS = -4, + + /* This flag indicates that one or more of the input parameters was obviously bogus. (You can try calling it again, but if you get this error the calling code is wrong.) */ + TINFL_STATUS_BAD_PARAM = -3, + + /* This flags indicate the inflator is finished but the adler32 check of the uncompressed data didn't match. If you call it again it'll return TINFL_STATUS_DONE. */ + TINFL_STATUS_ADLER32_MISMATCH = -2, + + /* This flags indicate the inflator has somehow failed (bad code, corrupted input, etc.). If you call it again without resetting via tinfl_init() it it'll just keep on returning the same status failure code. */ + TINFL_STATUS_FAILED = -1, + + /* Any status code less than TINFL_STATUS_DONE must indicate a failure. */ + + /* This flag indicates the inflator has returned every byte of uncompressed data that it can, has consumed every byte that it needed, has successfully reached the end of the deflate stream, and */ + /* if zlib headers and adler32 checking enabled that it has successfully checked the uncompressed data's adler32. If you call it again you'll just get TINFL_STATUS_DONE over and over again. */ + TINFL_STATUS_DONE = 0, + + /* This flag indicates the inflator MUST have more input data (even 1 byte) before it can make any more forward progress, or you need to clear the TINFL_FLAG_HAS_MORE_INPUT */ + /* flag on the next call if you don't have any more source data. If the source data was somehow corrupted it's also possible (but unlikely) for the inflator to keep on demanding input to */ + /* proceed, so be sure to properly set the TINFL_FLAG_HAS_MORE_INPUT flag. */ + TINFL_STATUS_NEEDS_MORE_INPUT = 1, + + /* This flag indicates the inflator definitely has 1 or more bytes of uncompressed data available, but it cannot write this data into the output buffer. */ + /* Note if the source compressed data was corrupted it's possible for the inflator to return a lot of uncompressed data to the caller. I've been assuming you know how much uncompressed data to expect */ + /* (either exact or worst case) and will stop calling the inflator and fail after receiving too much. In pure streaming scenarios where you have no idea how many bytes to expect this may not be possible */ + /* so I may need to add some code to address this. */ + TINFL_STATUS_HAS_MORE_OUTPUT = 2 +} tinfl_status; + +/* Initializes the decompressor to its initial state. */ +#define tinfl_init(r) \ + do \ + { \ + (r)->m_state = 0; \ + } \ + MZ_MACRO_END +#define tinfl_get_adler32(r) (r)->m_check_adler32 + +/* Main low-level decompressor coroutine function. This is the only function actually needed for decompression. All the other functions are just high-level helpers for improved usability. */ +/* This is a universal API, i.e. it can be used as a building block to build any desired higher level decompression API. In the limit case, it can be called once per every byte input or output. */ +tinfl_status tinfl_decompress(tinfl_decompressor *r, const mz_uint8 *pIn_buf_next, size_t *pIn_buf_size, mz_uint8 *pOut_buf_start, mz_uint8 *pOut_buf_next, size_t *pOut_buf_size, const mz_uint32 decomp_flags); + +/* Internal/private bits follow. */ +enum +{ + TINFL_MAX_HUFF_TABLES = 3, + TINFL_MAX_HUFF_SYMBOLS_0 = 288, + TINFL_MAX_HUFF_SYMBOLS_1 = 32, + TINFL_MAX_HUFF_SYMBOLS_2 = 19, + TINFL_FAST_LOOKUP_BITS = 10, + TINFL_FAST_LOOKUP_SIZE = 1 << TINFL_FAST_LOOKUP_BITS +}; + +typedef struct +{ + mz_uint8 m_code_size[TINFL_MAX_HUFF_SYMBOLS_0]; + mz_int16 m_look_up[TINFL_FAST_LOOKUP_SIZE], m_tree[TINFL_MAX_HUFF_SYMBOLS_0 * 2]; +} tinfl_huff_table; + +#if MINIZ_HAS_64BIT_REGISTERS +#define TINFL_USE_64BIT_BITBUF 1 +#else +#define TINFL_USE_64BIT_BITBUF 0 +#endif + +#if TINFL_USE_64BIT_BITBUF +typedef mz_uint64 tinfl_bit_buf_t; +#define TINFL_BITBUF_SIZE (64) +#else +typedef mz_uint32 tinfl_bit_buf_t; +#define TINFL_BITBUF_SIZE (32) +#endif + +struct tinfl_decompressor_tag +{ + mz_uint32 m_state, m_num_bits, m_zhdr0, m_zhdr1, m_z_adler32, m_final, m_type, m_check_adler32, m_dist, m_counter, m_num_extra, m_table_sizes[TINFL_MAX_HUFF_TABLES]; + tinfl_bit_buf_t m_bit_buf; + size_t m_dist_from_out_buf_start; + tinfl_huff_table m_tables[TINFL_MAX_HUFF_TABLES]; + mz_uint8 m_raw_header[4], m_len_codes[TINFL_MAX_HUFF_SYMBOLS_0 + TINFL_MAX_HUFF_SYMBOLS_1 + 137]; +}; + + +#pragma once + + +/* ------------------- ZIP archive reading/writing */ + +#ifndef MINIZ_NO_ARCHIVE_APIS + +enum +{ + /* Note: These enums can be reduced as needed to save memory or stack space - they are pretty conservative. */ + MZ_ZIP_MAX_IO_BUF_SIZE = 64 * 1024, + MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE = 512, + MZ_ZIP_MAX_ARCHIVE_FILE_COMMENT_SIZE = 512 +}; + +typedef struct +{ + /* Central directory file index. */ + mz_uint32 m_file_index; + + /* Byte offset of this entry in the archive's central directory. Note we currently only support up to UINT_MAX or less bytes in the central dir. */ + mz_uint64 m_central_dir_ofs; + + /* These fields are copied directly from the zip's central dir. */ + mz_uint16 m_version_made_by; + mz_uint16 m_version_needed; + mz_uint16 m_bit_flag; + mz_uint16 m_method; + +#ifndef MINIZ_NO_TIME + MZ_TIME_T m_time; +#endif + + /* CRC-32 of uncompressed data. */ + mz_uint32 m_crc32; + + /* File's compressed size. */ + mz_uint64 m_comp_size; + + /* File's uncompressed size. Note, I've seen some old archives where directory entries had 512 bytes for their uncompressed sizes, but when you try to unpack them you actually get 0 bytes. */ + mz_uint64 m_uncomp_size; + + /* Zip internal and external file attributes. */ + mz_uint16 m_internal_attr; + mz_uint32 m_external_attr; + + /* Entry's local header file offset in bytes. */ + mz_uint64 m_local_header_ofs; + + /* Size of comment in bytes. */ + mz_uint32 m_comment_size; + + /* MZ_TRUE if the entry appears to be a directory. */ + mz_bool m_is_directory; + + /* MZ_TRUE if the entry uses encryption/strong encryption (which miniz_zip doesn't support) */ + mz_bool m_is_encrypted; + + /* MZ_TRUE if the file is not encrypted, a patch file, and if it uses a compression method we support. */ + mz_bool m_is_supported; + + /* Filename. If string ends in '/' it's a subdirectory entry. */ + /* Guaranteed to be zero terminated, may be truncated to fit. */ + char m_filename[MZ_ZIP_MAX_ARCHIVE_FILENAME_SIZE]; + + /* Comment field. */ + /* Guaranteed to be zero terminated, may be truncated to fit. */ + char m_comment[MZ_ZIP_MAX_ARCHIVE_FILE_COMMENT_SIZE]; + +} mz_zip_archive_file_stat; + +typedef size_t (*mz_file_read_func)(void *pOpaque, mz_uint64 file_ofs, void *pBuf, size_t n); +typedef size_t (*mz_file_write_func)(void *pOpaque, mz_uint64 file_ofs, const void *pBuf, size_t n); +typedef mz_bool (*mz_file_needs_keepalive)(void *pOpaque); + +struct mz_zip_internal_state_tag; +typedef struct mz_zip_internal_state_tag mz_zip_internal_state; + +typedef enum { + MZ_ZIP_MODE_INVALID = 0, + MZ_ZIP_MODE_READING = 1, + MZ_ZIP_MODE_WRITING = 2, + MZ_ZIP_MODE_WRITING_HAS_BEEN_FINALIZED = 3 +} mz_zip_mode; + +typedef enum { + MZ_ZIP_FLAG_CASE_SENSITIVE = 0x0100, + MZ_ZIP_FLAG_IGNORE_PATH = 0x0200, + MZ_ZIP_FLAG_COMPRESSED_DATA = 0x0400, + MZ_ZIP_FLAG_DO_NOT_SORT_CENTRAL_DIRECTORY = 0x0800, + MZ_ZIP_FLAG_VALIDATE_LOCATE_FILE_FLAG = 0x1000, /* if enabled, mz_zip_reader_locate_file() will be called on each file as its validated to ensure the func finds the file in the central dir (intended for testing) */ + MZ_ZIP_FLAG_VALIDATE_HEADERS_ONLY = 0x2000, /* validate the local headers, but don't decompress the entire file and check the crc32 */ + MZ_ZIP_FLAG_WRITE_ZIP64 = 0x4000, /* always use the zip64 file format, instead of the original zip file format with automatic switch to zip64. Use as flags parameter with mz_zip_writer_init*_v2 */ + MZ_ZIP_FLAG_WRITE_ALLOW_READING = 0x8000, + MZ_ZIP_FLAG_ASCII_FILENAME = 0x10000 +} mz_zip_flags; + +typedef enum { + MZ_ZIP_TYPE_INVALID = 0, + MZ_ZIP_TYPE_USER, + MZ_ZIP_TYPE_MEMORY, + MZ_ZIP_TYPE_HEAP, + MZ_ZIP_TYPE_FILE, + MZ_ZIP_TYPE_CFILE, + MZ_ZIP_TOTAL_TYPES +} mz_zip_type; + +/* miniz error codes. Be sure to update mz_zip_get_error_string() if you add or modify this enum. */ +typedef enum { + MZ_ZIP_NO_ERROR = 0, + MZ_ZIP_UNDEFINED_ERROR, + MZ_ZIP_TOO_MANY_FILES, + MZ_ZIP_FILE_TOO_LARGE, + MZ_ZIP_UNSUPPORTED_METHOD, + MZ_ZIP_UNSUPPORTED_ENCRYPTION, + MZ_ZIP_UNSUPPORTED_FEATURE, + MZ_ZIP_FAILED_FINDING_CENTRAL_DIR, + MZ_ZIP_NOT_AN_ARCHIVE, + MZ_ZIP_INVALID_HEADER_OR_CORRUPTED, + MZ_ZIP_UNSUPPORTED_MULTIDISK, + MZ_ZIP_DECOMPRESSION_FAILED, + MZ_ZIP_COMPRESSION_FAILED, + MZ_ZIP_UNEXPECTED_DECOMPRESSED_SIZE, + MZ_ZIP_CRC_CHECK_FAILED, + MZ_ZIP_UNSUPPORTED_CDIR_SIZE, + MZ_ZIP_ALLOC_FAILED, + MZ_ZIP_FILE_OPEN_FAILED, + MZ_ZIP_FILE_CREATE_FAILED, + MZ_ZIP_FILE_WRITE_FAILED, + MZ_ZIP_FILE_READ_FAILED, + MZ_ZIP_FILE_CLOSE_FAILED, + MZ_ZIP_FILE_SEEK_FAILED, + MZ_ZIP_FILE_STAT_FAILED, + MZ_ZIP_INVALID_PARAMETER, + MZ_ZIP_INVALID_FILENAME, + MZ_ZIP_BUF_TOO_SMALL, + MZ_ZIP_INTERNAL_ERROR, + MZ_ZIP_FILE_NOT_FOUND, + MZ_ZIP_ARCHIVE_TOO_LARGE, + MZ_ZIP_VALIDATION_FAILED, + MZ_ZIP_WRITE_CALLBACK_FAILED, + MZ_ZIP_TOTAL_ERRORS +} mz_zip_error; + +typedef struct mz_zip_archive /* note: added name so it can be forward declared */ +{ + mz_uint64 m_archive_size; + mz_uint64 m_central_directory_file_ofs; + + /* We only support up to UINT32_MAX files in zip64 mode. */ + mz_uint32 m_total_files; + mz_zip_mode m_zip_mode; + mz_zip_type m_zip_type; + mz_zip_error m_last_error; + + mz_uint64 m_file_offset_alignment; + + mz_alloc_func m_pAlloc; + mz_free_func m_pFree; + mz_realloc_func m_pRealloc; + void *m_pAlloc_opaque; + + mz_file_read_func m_pRead; + mz_file_write_func m_pWrite; + mz_file_needs_keepalive m_pNeeds_keepalive; + void *m_pIO_opaque; + + mz_zip_internal_state *m_pState; + +} mz_zip_archive; + +typedef struct +{ + mz_zip_archive *pZip; + mz_uint flags; + + int status; +#ifndef MINIZ_DISABLE_ZIP_READER_CRC32_CHECKS + mz_uint file_crc32; +#endif + mz_uint64 read_buf_size, read_buf_ofs, read_buf_avail, comp_remaining, out_buf_ofs, cur_file_ofs; + mz_zip_archive_file_stat file_stat; + void *pRead_buf; + void *pWrite_buf; + + size_t out_blk_remain; + + tinfl_decompressor inflator; + +} mz_zip_reader_extract_iter_state; + +/* -------- ZIP reading */ + +/* Inits a ZIP archive reader. */ +/* These functions read and validate the archive's central directory. */ +mz_bool mz_zip_reader_init(mz_zip_archive *pZip, mz_uint64 size, mz_uint flags); + +mz_bool mz_zip_reader_init_mem(mz_zip_archive *pZip, const void *pMem, size_t size, mz_uint flags); + +#ifndef MINIZ_NO_STDIO +/* Read a archive from a disk file. */ +/* file_start_ofs is the file offset where the archive actually begins, or 0. */ +/* actual_archive_size is the true total size of the archive, which may be smaller than the file's actual size on disk. If zero the entire file is treated as the archive. */ +mz_bool mz_zip_reader_init_file(mz_zip_archive *pZip, const char *pFilename, mz_uint32 flags); +mz_bool mz_zip_reader_init_file_v2(mz_zip_archive *pZip, const char *pFilename, mz_uint flags, mz_uint64 file_start_ofs, mz_uint64 archive_size); + +/* Read an archive from an already opened FILE, beginning at the current file position. */ +/* The archive is assumed to be archive_size bytes long. If archive_size is < 0, then the entire rest of the file is assumed to contain the archive. */ +/* The FILE will NOT be closed when mz_zip_reader_end() is called. */ +mz_bool mz_zip_reader_init_cfile(mz_zip_archive *pZip, MZ_FILE *pFile, mz_uint64 archive_size, mz_uint flags); +#endif + +/* Ends archive reading, freeing all allocations, and closing the input archive file if mz_zip_reader_init_file() was used. */ +mz_bool mz_zip_reader_end(mz_zip_archive *pZip); + +/* -------- ZIP reading or writing */ + +/* Clears a mz_zip_archive struct to all zeros. */ +/* Important: This must be done before passing the struct to any mz_zip functions. */ +void mz_zip_zero_struct(mz_zip_archive *pZip); + +mz_zip_mode mz_zip_get_mode(mz_zip_archive *pZip); +mz_zip_type mz_zip_get_type(mz_zip_archive *pZip); + +/* Returns the total number of files in the archive. */ +mz_uint mz_zip_reader_get_num_files(mz_zip_archive *pZip); + +mz_uint64 mz_zip_get_archive_size(mz_zip_archive *pZip); +mz_uint64 mz_zip_get_archive_file_start_offset(mz_zip_archive *pZip); +MZ_FILE *mz_zip_get_cfile(mz_zip_archive *pZip); + +/* Reads n bytes of raw archive data, starting at file offset file_ofs, to pBuf. */ +size_t mz_zip_read_archive_data(mz_zip_archive *pZip, mz_uint64 file_ofs, void *pBuf, size_t n); + +/* All mz_zip funcs set the m_last_error field in the mz_zip_archive struct. These functions retrieve/manipulate this field. */ +/* Note that the m_last_error functionality is not thread safe. */ +mz_zip_error mz_zip_set_last_error(mz_zip_archive *pZip, mz_zip_error err_num); +mz_zip_error mz_zip_peek_last_error(mz_zip_archive *pZip); +mz_zip_error mz_zip_clear_last_error(mz_zip_archive *pZip); +mz_zip_error mz_zip_get_last_error(mz_zip_archive *pZip); +const char *mz_zip_get_error_string(mz_zip_error mz_err); + +/* MZ_TRUE if the archive file entry is a directory entry. */ +mz_bool mz_zip_reader_is_file_a_directory(mz_zip_archive *pZip, mz_uint file_index); + +/* MZ_TRUE if the file is encrypted/strong encrypted. */ +mz_bool mz_zip_reader_is_file_encrypted(mz_zip_archive *pZip, mz_uint file_index); + +/* MZ_TRUE if the compression method is supported, and the file is not encrypted, and the file is not a compressed patch file. */ +mz_bool mz_zip_reader_is_file_supported(mz_zip_archive *pZip, mz_uint file_index); + +/* Retrieves the filename of an archive file entry. */ +/* Returns the number of bytes written to pFilename, or if filename_buf_size is 0 this function returns the number of bytes needed to fully store the filename. */ +mz_uint mz_zip_reader_get_filename(mz_zip_archive *pZip, mz_uint file_index, char *pFilename, mz_uint filename_buf_size); + +/* Attempts to locates a file in the archive's central directory. */ +/* Valid flags: MZ_ZIP_FLAG_CASE_SENSITIVE, MZ_ZIP_FLAG_IGNORE_PATH */ +/* Returns -1 if the file cannot be found. */ +int mz_zip_reader_locate_file(mz_zip_archive *pZip, const char *pName, const char *pComment, mz_uint flags); +int mz_zip_reader_locate_file_v2(mz_zip_archive *pZip, const char *pName, const char *pComment, mz_uint flags, mz_uint32 *file_index); + +/* Returns detailed information about an archive file entry. */ +mz_bool mz_zip_reader_file_stat(mz_zip_archive *pZip, mz_uint file_index, mz_zip_archive_file_stat *pStat); + +/* MZ_TRUE if the file is in zip64 format. */ +/* A file is considered zip64 if it contained a zip64 end of central directory marker, or if it contained any zip64 extended file information fields in the central directory. */ +mz_bool mz_zip_is_zip64(mz_zip_archive *pZip); + +/* Returns the total central directory size in bytes. */ +/* The current max supported size is <= MZ_UINT32_MAX. */ +size_t mz_zip_get_central_dir_size(mz_zip_archive *pZip); + +/* Extracts a archive file to a memory buffer using no memory allocation. */ +/* There must be at least enough room on the stack to store the inflator's state (~34KB or so). */ +mz_bool mz_zip_reader_extract_to_mem_no_alloc(mz_zip_archive *pZip, mz_uint file_index, void *pBuf, size_t buf_size, mz_uint flags, void *pUser_read_buf, size_t user_read_buf_size); +mz_bool mz_zip_reader_extract_file_to_mem_no_alloc(mz_zip_archive *pZip, const char *pFilename, void *pBuf, size_t buf_size, mz_uint flags, void *pUser_read_buf, size_t user_read_buf_size); + +/* Extracts a archive file to a memory buffer. */ +mz_bool mz_zip_reader_extract_to_mem(mz_zip_archive *pZip, mz_uint file_index, void *pBuf, size_t buf_size, mz_uint flags); +mz_bool mz_zip_reader_extract_file_to_mem(mz_zip_archive *pZip, const char *pFilename, void *pBuf, size_t buf_size, mz_uint flags); + +/* Extracts a archive file to a dynamically allocated heap buffer. */ +/* The memory will be allocated via the mz_zip_archive's alloc/realloc functions. */ +/* Returns NULL and sets the last error on failure. */ +void *mz_zip_reader_extract_to_heap(mz_zip_archive *pZip, mz_uint file_index, size_t *pSize, mz_uint flags); +void *mz_zip_reader_extract_file_to_heap(mz_zip_archive *pZip, const char *pFilename, size_t *pSize, mz_uint flags); + +/* Extracts a archive file using a callback function to output the file's data. */ +mz_bool mz_zip_reader_extract_to_callback(mz_zip_archive *pZip, mz_uint file_index, mz_file_write_func pCallback, void *pOpaque, mz_uint flags); +mz_bool mz_zip_reader_extract_file_to_callback(mz_zip_archive *pZip, const char *pFilename, mz_file_write_func pCallback, void *pOpaque, mz_uint flags); + +/* Extract a file iteratively */ +mz_zip_reader_extract_iter_state* mz_zip_reader_extract_iter_new(mz_zip_archive *pZip, mz_uint file_index, mz_uint flags); +mz_zip_reader_extract_iter_state* mz_zip_reader_extract_file_iter_new(mz_zip_archive *pZip, const char *pFilename, mz_uint flags); +size_t mz_zip_reader_extract_iter_read(mz_zip_reader_extract_iter_state* pState, void* pvBuf, size_t buf_size); +mz_bool mz_zip_reader_extract_iter_free(mz_zip_reader_extract_iter_state* pState); + +#ifndef MINIZ_NO_STDIO +/* Extracts a archive file to a disk file and sets its last accessed and modified times. */ +/* This function only extracts files, not archive directory records. */ +mz_bool mz_zip_reader_extract_to_file(mz_zip_archive *pZip, mz_uint file_index, const char *pDst_filename, mz_uint flags); +mz_bool mz_zip_reader_extract_file_to_file(mz_zip_archive *pZip, const char *pArchive_filename, const char *pDst_filename, mz_uint flags); + +/* Extracts a archive file starting at the current position in the destination FILE stream. */ +mz_bool mz_zip_reader_extract_to_cfile(mz_zip_archive *pZip, mz_uint file_index, MZ_FILE *File, mz_uint flags); +mz_bool mz_zip_reader_extract_file_to_cfile(mz_zip_archive *pZip, const char *pArchive_filename, MZ_FILE *pFile, mz_uint flags); +#endif + +#if 0 +/* TODO */ + typedef void *mz_zip_streaming_extract_state_ptr; + mz_zip_streaming_extract_state_ptr mz_zip_streaming_extract_begin(mz_zip_archive *pZip, mz_uint file_index, mz_uint flags); + uint64_t mz_zip_streaming_extract_get_size(mz_zip_archive *pZip, mz_zip_streaming_extract_state_ptr pState); + uint64_t mz_zip_streaming_extract_get_cur_ofs(mz_zip_archive *pZip, mz_zip_streaming_extract_state_ptr pState); + mz_bool mz_zip_streaming_extract_seek(mz_zip_archive *pZip, mz_zip_streaming_extract_state_ptr pState, uint64_t new_ofs); + size_t mz_zip_streaming_extract_read(mz_zip_archive *pZip, mz_zip_streaming_extract_state_ptr pState, void *pBuf, size_t buf_size); + mz_bool mz_zip_streaming_extract_end(mz_zip_archive *pZip, mz_zip_streaming_extract_state_ptr pState); +#endif + +/* This function compares the archive's local headers, the optional local zip64 extended information block, and the optional descriptor following the compressed data vs. the data in the central directory. */ +/* It also validates that each file can be successfully uncompressed unless the MZ_ZIP_FLAG_VALIDATE_HEADERS_ONLY is specified. */ +mz_bool mz_zip_validate_file(mz_zip_archive *pZip, mz_uint file_index, mz_uint flags); + +/* Validates an entire archive by calling mz_zip_validate_file() on each file. */ +mz_bool mz_zip_validate_archive(mz_zip_archive *pZip, mz_uint flags); + +/* Misc utils/helpers, valid for ZIP reading or writing */ +mz_bool mz_zip_validate_mem_archive(const void *pMem, size_t size, mz_uint flags, mz_zip_error *pErr); +mz_bool mz_zip_validate_file_archive(const char *pFilename, mz_uint flags, mz_zip_error *pErr); + +/* Universal end function - calls either mz_zip_reader_end() or mz_zip_writer_end(). */ +mz_bool mz_zip_end(mz_zip_archive *pZip); + +/* -------- ZIP writing */ + +#ifndef MINIZ_NO_ARCHIVE_WRITING_APIS + +/* Inits a ZIP archive writer. */ +/*Set pZip->m_pWrite (and pZip->m_pIO_opaque) before calling mz_zip_writer_init or mz_zip_writer_init_v2*/ +/*The output is streamable, i.e. file_ofs in mz_file_write_func always increases only by n*/ +mz_bool mz_zip_writer_init(mz_zip_archive *pZip, mz_uint64 existing_size); +mz_bool mz_zip_writer_init_v2(mz_zip_archive *pZip, mz_uint64 existing_size, mz_uint flags); + +mz_bool mz_zip_writer_init_heap(mz_zip_archive *pZip, size_t size_to_reserve_at_beginning, size_t initial_allocation_size); +mz_bool mz_zip_writer_init_heap_v2(mz_zip_archive *pZip, size_t size_to_reserve_at_beginning, size_t initial_allocation_size, mz_uint flags); + +#ifndef MINIZ_NO_STDIO +mz_bool mz_zip_writer_init_file(mz_zip_archive *pZip, const char *pFilename, mz_uint64 size_to_reserve_at_beginning); +mz_bool mz_zip_writer_init_file_v2(mz_zip_archive *pZip, const char *pFilename, mz_uint64 size_to_reserve_at_beginning, mz_uint flags); +mz_bool mz_zip_writer_init_cfile(mz_zip_archive *pZip, MZ_FILE *pFile, mz_uint flags); +#endif + +/* Converts a ZIP archive reader object into a writer object, to allow efficient in-place file appends to occur on an existing archive. */ +/* For archives opened using mz_zip_reader_init_file, pFilename must be the archive's filename so it can be reopened for writing. If the file can't be reopened, mz_zip_reader_end() will be called. */ +/* For archives opened using mz_zip_reader_init_mem, the memory block must be growable using the realloc callback (which defaults to realloc unless you've overridden it). */ +/* Finally, for archives opened using mz_zip_reader_init, the mz_zip_archive's user provided m_pWrite function cannot be NULL. */ +/* Note: In-place archive modification is not recommended unless you know what you're doing, because if execution stops or something goes wrong before */ +/* the archive is finalized the file's central directory will be hosed. */ +mz_bool mz_zip_writer_init_from_reader(mz_zip_archive *pZip, const char *pFilename); +mz_bool mz_zip_writer_init_from_reader_v2(mz_zip_archive *pZip, const char *pFilename, mz_uint flags); + +/* Adds the contents of a memory buffer to an archive. These functions record the current local time into the archive. */ +/* To add a directory entry, call this method with an archive name ending in a forwardslash with an empty buffer. */ +/* level_and_flags - compression level (0-10, see MZ_BEST_SPEED, MZ_BEST_COMPRESSION, etc.) logically OR'd with zero or more mz_zip_flags, or just set to MZ_DEFAULT_COMPRESSION. */ +mz_bool mz_zip_writer_add_mem(mz_zip_archive *pZip, const char *pArchive_name, const void *pBuf, size_t buf_size, mz_uint level_and_flags); + +/* Like mz_zip_writer_add_mem(), except you can specify a file comment field, and optionally supply the function with already compressed data. */ +/* uncomp_size/uncomp_crc32 are only used if the MZ_ZIP_FLAG_COMPRESSED_DATA flag is specified. */ +mz_bool mz_zip_writer_add_mem_ex(mz_zip_archive *pZip, const char *pArchive_name, const void *pBuf, size_t buf_size, const void *pComment, mz_uint16 comment_size, mz_uint level_and_flags, + mz_uint64 uncomp_size, mz_uint32 uncomp_crc32); + +mz_bool mz_zip_writer_add_mem_ex_v2(mz_zip_archive *pZip, const char *pArchive_name, const void *pBuf, size_t buf_size, const void *pComment, mz_uint16 comment_size, mz_uint level_and_flags, + mz_uint64 uncomp_size, mz_uint32 uncomp_crc32, MZ_TIME_T *last_modified, const char *user_extra_data_local, mz_uint user_extra_data_local_len, + const char *user_extra_data_central, mz_uint user_extra_data_central_len); + +/* Adds the contents of a file to an archive. This function also records the disk file's modified time into the archive. */ +/* File data is supplied via a read callback function. User mz_zip_writer_add_(c)file to add a file directly.*/ +mz_bool mz_zip_writer_add_read_buf_callback(mz_zip_archive *pZip, const char *pArchive_name, mz_file_read_func read_callback, void* callback_opaque, mz_uint64 size_to_add, + const MZ_TIME_T *pFile_time, const void *pComment, mz_uint16 comment_size, mz_uint level_and_flags, const char *user_extra_data_local, mz_uint user_extra_data_local_len, + const char *user_extra_data_central, mz_uint user_extra_data_central_len); + +#ifndef MINIZ_NO_STDIO +/* Adds the contents of a disk file to an archive. This function also records the disk file's modified time into the archive. */ +/* level_and_flags - compression level (0-10, see MZ_BEST_SPEED, MZ_BEST_COMPRESSION, etc.) logically OR'd with zero or more mz_zip_flags, or just set to MZ_DEFAULT_COMPRESSION. */ +mz_bool mz_zip_writer_add_file(mz_zip_archive *pZip, const char *pArchive_name, const char *pSrc_filename, const void *pComment, mz_uint16 comment_size, mz_uint level_and_flags); + +/* Like mz_zip_writer_add_file(), except the file data is read from the specified FILE stream. */ +mz_bool mz_zip_writer_add_cfile(mz_zip_archive *pZip, const char *pArchive_name, MZ_FILE *pSrc_file, mz_uint64 size_to_add, + const MZ_TIME_T *pFile_time, const void *pComment, mz_uint16 comment_size, mz_uint level_and_flags, const char *user_extra_data_local, mz_uint user_extra_data_local_len, + const char *user_extra_data_central, mz_uint user_extra_data_central_len); +#endif + +/* Adds a file to an archive by fully cloning the data from another archive. */ +/* This function fully clones the source file's compressed data (no recompression), along with its full filename, extra data (it may add or modify the zip64 local header extra data field), and the optional descriptor following the compressed data. */ +mz_bool mz_zip_writer_add_from_zip_reader(mz_zip_archive *pZip, mz_zip_archive *pSource_zip, mz_uint src_file_index); + +/* Finalizes the archive by writing the central directory records followed by the end of central directory record. */ +/* After an archive is finalized, the only valid call on the mz_zip_archive struct is mz_zip_writer_end(). */ +/* An archive must be manually finalized by calling this function for it to be valid. */ +mz_bool mz_zip_writer_finalize_archive(mz_zip_archive *pZip); + +/* Finalizes a heap archive, returning a poiner to the heap block and its size. */ +/* The heap block will be allocated using the mz_zip_archive's alloc/realloc callbacks. */ +mz_bool mz_zip_writer_finalize_heap_archive(mz_zip_archive *pZip, void **ppBuf, size_t *pSize); + +/* Ends archive writing, freeing all allocations, and closing the output file if mz_zip_writer_init_file() was used. */ +/* Note for the archive to be valid, it *must* have been finalized before ending (this function will not do it for you). */ +mz_bool mz_zip_writer_end(mz_zip_archive *pZip); + +/* -------- Misc. high-level helper functions: */ + +/* mz_zip_add_mem_to_archive_file_in_place() efficiently (but not atomically) appends a memory blob to a ZIP archive. */ +/* Note this is NOT a fully safe operation. If it crashes or dies in some way your archive can be left in a screwed up state (without a central directory). */ +/* level_and_flags - compression level (0-10, see MZ_BEST_SPEED, MZ_BEST_COMPRESSION, etc.) logically OR'd with zero or more mz_zip_flags, or just set to MZ_DEFAULT_COMPRESSION. */ +/* TODO: Perhaps add an option to leave the existing central dir in place in case the add dies? We could then truncate the file (so the old central dir would be at the end) if something goes wrong. */ +mz_bool mz_zip_add_mem_to_archive_file_in_place(const char *pZip_filename, const char *pArchive_name, const void *pBuf, size_t buf_size, const void *pComment, mz_uint16 comment_size, mz_uint level_and_flags); +mz_bool mz_zip_add_mem_to_archive_file_in_place_v2(const char *pZip_filename, const char *pArchive_name, const void *pBuf, size_t buf_size, const void *pComment, mz_uint16 comment_size, mz_uint level_and_flags, mz_zip_error *pErr); + +/* Reads a single file from an archive into a heap block. */ +/* If pComment is not NULL, only the file with the specified comment will be extracted. */ +/* Returns NULL on failure. */ +void *mz_zip_extract_archive_file_to_heap(const char *pZip_filename, const char *pArchive_name, size_t *pSize, mz_uint flags); +void *mz_zip_extract_archive_file_to_heap_v2(const char *pZip_filename, const char *pArchive_name, const char *pComment, size_t *pSize, mz_uint flags, mz_zip_error *pErr); + +#endif /* #ifndef MINIZ_NO_ARCHIVE_WRITING_APIS */ + +#endif /* MINIZ_NO_ARCHIVE_APIS */ + +} // namespace jittor + +#include "common.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" + +namespace jittor { + +// @pyjt(ZipFile) +struct ZipFile { + std::unique_ptr zip_archive; + char mode; + // @pyjt(__init__) + inline ZipFile(const string& filename, const string& mode="r") { + zip_archive = std::make_unique(); + memset(zip_archive.get(), 0, sizeof(mz_zip_archive)); + if (mode == "r") { + this->mode = 'r'; + if (!mz_zip_reader_init_file(zip_archive.get(), filename.c_str(), 0)) + zip_archive = nullptr; + } else if (mode == "w") { + this->mode = 'w'; + if (!mz_zip_writer_init_file_v2(zip_archive.get(), filename.c_str(), 0, MZ_ZIP_FLAG_WRITE_ZIP64)) { + zip_archive = nullptr; + } + } + if (!zip_archive) + throw std::runtime_error("Failed to open zip file: " + filename); + } + // @pyjt(__dealloc__) + inline ~ZipFile() { + if (zip_archive) { + if (mode == 'w') { + mz_zip_writer_finalize_archive(zip_archive.get()); + mz_zip_writer_end(zip_archive.get()); + } else { + mz_zip_reader_end(zip_archive.get()); + } + } + } + + // @pyjt(valid) + inline int valid() { return !!zip_archive; } + + // @pyjt(list) + inline map list() { + map files; + int n = mz_zip_reader_get_num_files(zip_archive.get()); + for (int i=0; i(); + size_t key = mz_zip_reader_locate_file(zip_archive.get(), filename.c_str(), nullptr, 0); + mz_zip_archive_file_stat stat; + CHECK(mz_zip_reader_file_stat(zip_archive.get(), key, &stat)); + auto var = make_empty({stat.m_uncomp_size >> dtype.dsize_()}, dtype); + auto vh = std::make_unique(var); + void* memptr = (void*)vh->raw_ptr(); + mz_zip_reader_extract_to_mem(zip_archive.get(), key, memptr, stat.m_uncomp_size, 0); + #if HAS_CUDA + if (use_cuda_bk) { + use_cuda = use_cuda_bk; + migrate_to_gpu(vh->var, get_allocator()); + } + #endif + return vh.release(); + } +}; + + +} \ No newline at end of file diff --git a/python/jittor/src/misc/nan_checker.cc b/python/jittor/src/misc/nan_checker.cc new file mode 100644 index 00000000..00fb34aa --- /dev/null +++ b/python/jittor/src/misc/nan_checker.cc @@ -0,0 +1,176 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#include +#include "misc/nan_checker.h" +#ifdef IS_CUDA +#include "misc/cuda_flags.h" +#include +#include +#ifndef IS_ROCM +#include +#endif +#include "helper_cuda.h" +#endif +#include "mem/allocator.h" +#include "op.h" + +namespace jittor { + + +#ifdef IS_CUDA +EXTERN_LIB vector check_nan_float16(__half* ptr, int64 num); +#ifndef IS_ROCM +EXTERN_LIB vector check_nan_bfloat16(__nv_bfloat16* ptr, int64 num); +#endif +EXTERN_LIB vector check_nan_float32(float32* ptr, int64 num); +EXTERN_LIB vector check_nan_float64(float64* ptr, int64 num); +#endif + +void dump_var(Var* v, string name) { + std::stringstream ss; + ss << name << v->id << v->dtype() << v->shape << ".bin"; + name = ss.str(); + LOGe << "dump" << v << "to" << name; + char* buffer = new char[v->size]; + #ifdef IS_ROCM + hipMemcpy(buffer, v->mem_ptr, v->size, hipMemcpyDefault); + #elif IS_CUDA + cudaMemcpy(buffer, v->mem_ptr, v->size, cudaMemcpyDefault); + #else + std::memcpy(buffer, v->mem_ptr, v->size); + #endif + std::fstream file(name, std::ios::out | std::ios::binary); + file.write(buffer, v->size); + file.close(); + delete[] buffer; +} + + +bool check_nan(Var* v, Op* op) { + if (!v->dtype().is_float() || v->num == 0) return true; + if (v->input() && ( + v->input()->name() == string("empty") || + v->input()->name() == string("setitem"))) + return true; + #ifdef IS_CUDA + if (v->allocator->is_cuda()) { + vector nan_index; + if (v->dtype() == ns_float16) { + nan_index = check_nan_float16((__half*)v->mem_ptr, v->num); + } + #ifndef IS_ROCM + if (v->dtype() == ns_bfloat16) { + nan_index = check_nan_bfloat16((__nv_bfloat16*)v->mem_ptr, v->num); + } + #endif + if (v->dtype() == ns_float32) { + nan_index = check_nan_float32((float32*)v->mem_ptr, v->num); + } else + if (v->dtype() == ns_float64) { + nan_index = check_nan_float64((float64*)v->mem_ptr, v->num); + } + if (nan_index[0]) { + LOGe << "detect nan count:" << nan_index[0]; + + /* dump nan var for analysis + python code for parse dump file: + + import numpy as np + + def load_var(filename): + dtype = "float16" + shape = filename.split('[')[1].split(']')[0] + shape = tuple(int(s) for s in shape.split(',')[:-1]) + with open(filename, 'rb') as f: + array = np.fromfile(f, dtype=dtype) + return array.reshape(shape) + + in0 = load_var("/tmp/input13736float16[4096,11008,].bin") + in1 = load_var("/tmp/input26930float16[32768,11008,].bin") + out0 = load_var("/tmp/output26938float16[32768,4096,].bin") + + */ + if (getenv("DUMP_NAN_INPUT") && getenv("DUMP_NAN_INPUT") == string("1")) { + for (Var* v : op->inputs()) + dump_var(v, "/tmp/input"); + for (Var* v : op->outputs()) + dump_var(v, "/tmp/output"); + } + + for (int i=0; iinputs()) { + icnt ++; + if (index >= input->num) continue; + if (input->dtype() == ns_float16) { + auto* ptr = input->ptr<__half>(); + __half value; + cudaMemcpy(&value, ptr+index, sizeof(__half), cudaMemcpyDeviceToHost); + // LOGe << "input" << icnt << "dtype" << input->dtype() << "index" << index << "value" << (float)value; + } else + #ifndef IS_ROCM + if (input->dtype() == ns_bfloat16) { + auto* ptr = input->ptr<__nv_bfloat16>(); + __nv_bfloat16 value; + cudaMemcpy(&value, ptr+index, sizeof(__nv_bfloat16), cudaMemcpyDeviceToHost); + LOGe << "input" << icnt << "dtype" << input->dtype() << "index" << index << "value" << (float)value; + } else + #endif + if (input->dtype() == ns_float32) { + auto* ptr = input->ptr(); + float32 value; + cudaMemcpy(&value, ptr+index, sizeof(float32), cudaMemcpyDeviceToHost); + LOGe << "input" << icnt << "dtype" << input->dtype() << "index" << index << "value" << value; + } else + if (input->dtype() == ns_float64) { + auto* ptr = input->ptr(); + float64 value; + cudaMemcpy(&value, ptr+index, sizeof(float64), cudaMemcpyDeviceToHost); + LOGe << "input" << icnt << "dtype" << input->dtype() << "index" << index << "value" << value; + } + } + LOGf << "detect nan count:" << nan_index[0]; + } + } + ASSERT(cudaDeviceSynchronize()==0) << "detect nan or inf at" << v; + ASSERT(cudaGetLastError() == 0); + } else + #endif + { + if (v->dtype() == ns_float32) { + auto* __restrict__ ptr = v->ptr(); + auto num = v->num; + bool ok = true; + int64 i=0; + for (; idtype() == ns_float64) { + auto* __restrict__ ptr = v->ptr(); + auto num = v->num; + bool ok = true; + int64 i=0; + for (; i +#include + +#include "helper_cuda.h" +#include +//TODO:FIX in ROCM +#ifndef IS_ROCM +#include +#endif + +namespace jittor { + +#define MAX_NAN_REPORT 10 + +inline __device__ void print_nan(float v, int64 i, int* cnt) { + auto x = atomicAdd(cnt, 1); + if (x 60000.f + ) + #endif + print_nan(float(ptr[i]), i, cnt); + } +} +#ifndef IS_ROCM +__global__ void _check_nan_bfloat16(__nv_bfloat16* __restrict__ ptr, int64 num, int* cnt) { + int64 i = threadIdx.x + blockIdx.x * (int64)blockDim.x; + if (i 60000.f + ) + #endif + print_nan(float(ptr[i]), i, cnt); + } +} +#endif + +__global__ void _check_nan_float32(float32* __restrict__ ptr, int64 num, int* cnt) { + int64 i = threadIdx.x + blockIdx.x * (int64)blockDim.x; + if (i report_nan() { + vector buffer(MAX_NAN_REPORT+1); + auto ptr = check_nan_get_device_ptr(); + cudaMemcpy(buffer.data(), ptr, 4+4*MAX_NAN_REPORT, cudaMemcpyDeviceToHost); + cudaMemset(ptr, 0, 4); + return buffer; +} + +vector check_nan_float64(float64* ptr, int64 num) { + int block_num = std::max((int64)1, (num-1)/1024+1); + int thread_num = std::min((int64)1024, num); + _check_nan_float64<<>>(ptr, num, check_nan_get_device_ptr()); + return report_nan(); +} + +vector check_nan_float32(float32* ptr, int64 num) { + int block_num = std::max((int64)1, (num-1)/1024+1); + int thread_num = std::min((int64)1024, num); + _check_nan_float32<<>>(ptr, num, check_nan_get_device_ptr()); + return report_nan(); +} + +vector check_nan_float16(__half* ptr, int64 num) { + int block_num = std::max((int64)1, (num-1)/1024+1); + int thread_num = std::min((int64)1024, num); + _check_nan_float16<<>>(ptr, num, check_nan_get_device_ptr()); + return report_nan(); +} +#ifndef IS_ROCM +vector check_nan_bfloat16(__nv_bfloat16* ptr, int64 num) { + int block_num = std::max((int64)1, (num-1)/1024+1); + int thread_num = std::min((int64)1024, num); + _check_nan_bfloat16<<>>(ptr, num, check_nan_get_device_ptr()); + return report_nan(); +} +#endif +#endif + +} \ No newline at end of file diff --git a/python/jittor/src/misc/nan_checker.h b/python/jittor/src/misc/nan_checker.h new file mode 100644 index 00000000..e05a4bc2 --- /dev/null +++ b/python/jittor/src/misc/nan_checker.h @@ -0,0 +1,14 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" +#include "var.h" + +namespace jittor { + +bool check_nan(Var* v, Op* op); + +} \ No newline at end of file diff --git a/python/jittor/src/misc/nano_string.cc b/python/jittor/src/misc/nano_string.cc new file mode 100644 index 00000000..6aa041fa --- /dev/null +++ b/python/jittor/src/misc/nano_string.cc @@ -0,0 +1,241 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "misc/nano_string.h" + +namespace jittor { + +#define FOR_ALL_TYPES(m) \ + m(bool) \ + m(int8) \ + m(int16) \ + m(int32) \ + m(int64) \ + m(uint8) \ + m(uint16) \ + m(uint32) \ + m(uint64) \ + m(float32) \ + m(float64) + +#ifdef _MSC_VER +inline int ffs(int i) { + int j=0; + while (i) j++,i/=2; + return j; +} +#define map_size(T) {#T, ffs(sizeof(T))-1}, +#else +#define map_size(T) {#T, __builtin_ffs(sizeof(T))-1}, +#endif + +unordered_map dsize_map = {FOR_ALL_TYPES(map_size)}; + +// TODO: make all static + +#define map_is_float(T) {#T, std::is_floating_point::value}, +static unordered_map is_float_map = {FOR_ALL_TYPES(map_is_float)}; + +#define map_is_unsigned(T) {#T, std::is_unsigned::value}, +static unordered_map is_unsigned = {FOR_ALL_TYPES(map_is_unsigned)}; + +static unordered_set is_bool = { + "bool", + "logical_not", + "less", + "less_equal", + "greater", + "greater_equal", + "equal", + "not_equal", + "logical_and", + "logical_or", + "logical_xor", +}; + +static unordered_set unary_ops = { + "abs", + "negative", + "logical_not", + "bitwise_not", + "log", + "exp", + "sqrt", + "round", + "floor", + "ceil", + "round_int", + "floor_int", + "ceil_int", + "cast", + "sin", + "asin", + "sinh", + "asinh", + "tan", + "atan", + "tanh", + "atanh", + "cos", + "acos", + "cosh", + "acosh", + "sigmoid", + "erf", + "erfinv" +}; + +static unordered_set float_ops = { + "log", + "exp", + "sqrt", + "mean", + "divide", + "sin", + "asin", + "sinh", + "asinh", + "tan", + "atan", + "tanh", + "atanh", + "cos", + "acos", + "cosh", + "acosh", + "sigmoid", + "erf", + "erfinv" +}; +static unordered_set int_ops = { + "round_int", + "floor_int", + "ceil_int", + "floor_divide", +}; + +static unordered_set binary_ops = { + "pow", + "maximum", + "minimum", + "add", + "subtract", + "multiply", + "divide", + "floor_divide", + "mod", + "less", + "less_equal", + "greater", + "greater_equal", + "equal", + "not_equal", + "left_shift", + "right_shift", + "logical_and", + "logical_or", + "logical_xor", + "bitwise_and", + "bitwise_or", + "bitwise_xor", + "mean", +}; + + +static unordered_set white_ops = { + // "log", + "exp", + "pow", +}; + +static unordered_set no_need_back_in = { + "void", + "cast", + "negative", + "add", + "subtract", + "mean", +}; + +static unordered_set no_need_back_out = { + "void", + "cast", + "negative", + "add", + "subtract", + "multiply", + "divide", +}; + +#define DEFINE_NS(T) NanoString ns_##T; +FOR_ALL_NS(DEFINE_NS); + +unordered_map __string_to_ns; +char __ns_to_string[ns_max_size*ns_max_len]; +int __ns_len[ns_max_size]; + +static void init_ns() { + dsize_map["float16"] = 1; + is_float_map["float16"] = 1; + is_unsigned["float16"] = 0; + dsize_map["bfloat16"] = 1; + is_float_map["bfloat16"] = 1; + is_unsigned["bfloat16"] = 0; + NanoString::ns_t i=0; + auto func = [&](const char* name, NanoString& ns) { + ns.set(NanoString::_index, i++, NanoString::_index_nbits); + if (dsize_map.count(name)) { + ns.set(NanoString::_type, NanoString::_dtype, NanoString::_type_nbits); + ns.set(NanoString::_bool, is_bool.count(name)); + ns.set(NanoString::_int, !is_float_map.at(name)); + ns.set(NanoString::_unsigned, is_unsigned.count(name)); + ns.set(NanoString::_float, is_float_map.at(name)); + ns.set(NanoString::_dsize, dsize_map.at(name), NanoString::_dsize_nbits); + } else + if (unary_ops.count(name)) { + ns.set(NanoString::_type, NanoString::_unary, NanoString::_type_nbits); + ns.set(NanoString::_bool, is_bool.count(name)); + ns.set(NanoString::_int, int_ops.count(name)); + ns.set(NanoString::_float, float_ops.count(name)); + } else + if (binary_ops.count(name)) { + ns.set(NanoString::_type, NanoString::_binary, NanoString::_type_nbits); + ns.set(NanoString::_bool, is_bool.count(name)); + ns.set(NanoString::_int, int_ops.count(name)); + ns.set(NanoString::_float, float_ops.count(name)); + } + ns.set(NanoString::_white_list, white_ops.count(name)); + ns.set(NanoString::_no_need_back_in, no_need_back_in.count(name)); + ns.set(NanoString::_no_need_back_out, no_need_back_out.count(name)); + __string_to_ns[name] = ns; + auto name2 = ns.to_cstring(); + int len=0; + for (;;len++) { + name2[len] = name[len]; + if (!name[len]) break; + } + __ns_len[i-1] = len; + }; + #define INIT_NS(T) func(#T, ns_##T); + FOR_ALL_NS(INIT_NS); + ASSERT(i<=(1<. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +namespace jittor { + +constexpr int ns_max_size = 256; +constexpr int ns_max_len = 16; + +#define FOR_ALL_NS(m) \ +\ + m(void) \ + m(bool) \ + m(int8) \ + m(int16) \ + m(int32) \ + m(int64) \ + m(uint8) \ + m(uint16) \ + m(uint32) \ + m(uint64) \ + m(float16) \ + m(float32) \ + m(float64) \ + m(bfloat16) \ +\ + m(pow) \ + m(maximum) \ + m(minimum) \ + m(add) \ + m(subtract) \ + m(multiply) \ + m(divide) \ + m(floor_divide) \ + m(mod) \ + m(less) \ + m(less_equal) \ + m(greater) \ + m(greater_equal) \ + m(equal) \ + m(not_equal) \ + m(left_shift) \ + m(right_shift) \ + m(logical_and) \ + m(logical_or) \ + m(logical_xor) \ + m(bitwise_and) \ + m(bitwise_or) \ + m(bitwise_xor) \ + m(mean) \ +\ + m(abs) \ + m(negative) \ + m(logical_not) \ + m(bitwise_not) \ + m(log) \ + m(exp) \ + m(sqrt) \ + m(round) \ + m(floor) \ + m(ceil) \ + m(round_int) \ + m(floor_int) \ + m(ceil_int) \ + m(cast) \ + \ + m(sin) \ + m(asin) \ + m(sinh) \ + m(asinh) \ + m(tan) \ + m(atan) \ + m(tanh) \ + m(atanh) \ + m(cos) \ + m(acos) \ + m(cosh) \ + m(acosh) \ + m(erf) \ + m(erfinv) \ + m(sigmoid) \ + \ + m(uniform) \ + m(normal) \ + +struct NanoString; +#define DECLEAR_NS(T) EXTERN_LIB NanoString ns_##T; +FOR_ALL_NS(DECLEAR_NS); + + +EXTERN_LIB unordered_map __string_to_ns; +EXTERN_LIB char __ns_to_string[]; +EXTERN_LIB int __ns_len[]; + +// @pyjt(NanoString) +struct NanoString { + typedef uint32 ns_t; + enum Flags { + // bit0~7: index + _index=0, _index_nbits=7, + _n=_index_nbits, + + // bit0-1: type + _type=_n, _type_nbits=2, + _other=0, _dtype=1, _unary=2, _binary=3, + // bit2: is bool + _bool=_n+2, + // bit3: is int + _int=_n+3, + // bit4: is unsigned + _unsigned=_n+4, + // bit5: is float + _float=_n+5, + // bit6-7: dsize(1,2,4,8 byte) + _dsize=_n+6, _dsize_nbits=2, + // bit8: white list + _white_list=_n+8, + // bit9: backward opt + _no_need_back_in=_n+9, + _no_need_back_out=_n+10, + }; + ns_t data=0; + + inline void set(Flags f, ns_t a=1, ns_t nbits=1) { + ns_t mask = (((1u<>f) & ((1u<second.data; + } + // @pyjt(__init__) + inline NanoString(const NanoString& other) : data(other.data) {} + inline NanoString(const string& s) : NanoString(s.c_str()) {} + // @pyjt(__repr__) + inline const char* to_cstring() const + { return __ns_to_string+index()*ns_max_len; } + inline char* to_cstring() + { return __ns_to_string+index()*ns_max_len; } + operator uint32() const { return data; } +}; + +// @pyjt(NanoString.__eq__) +inline bool eq(const NanoString& a, const NanoString& b) { + return a.data == b.data; +} + +// @pyjt(NanoString.__ne__) +inline bool ne(const NanoString& a, const NanoString& b) { + return a.data != b.data; +} + +inline bool operator==(const NanoString& a, const NanoString& b) { + return a.data == b.data; +} +inline bool operator!=(const NanoString& a, const NanoString& b) { + return a.data != b.data; +} + +inline std::ostream& operator<<(std::ostream& os, const NanoString& v) { + return os << v.to_cstring(); +} + +EXTERN_LIB int amp_reg; +constexpr int amp_prefer32 = 1; +constexpr int amp_prefer16 = 2; +constexpr int amp_keep_reduce = 4; +constexpr int amp_keep_white = 8; +constexpr int amp_array_prefer = 16; + +inline NanoString float_dtype(int dsize_, bool has_scalar=false, bool has_bf16=false) { + if (!has_scalar) { + if (amp_reg & amp_prefer32) + return ns_float32; + if (amp_reg & amp_prefer16) + return has_bf16 ? ns_bfloat16 : ns_float16; + } + return (dsize_ == 3) ? ns_float64 : + (dsize_ == 2 ) ? ns_float32 : + has_bf16 ? ns_bfloat16 : ns_float16; +} + +inline NanoString int_dtype(int dsize_) { + return (dsize_ == 3) ? ns_int64 : + (dsize_ == 2) ? ns_int32 : + (dsize_ == 1) ? ns_int16 : ns_int8; +} + +inline NanoString dtype_infer(NanoString x, NanoString y, bool xscalar=false, bool yscalar=false) { + int dsize_ = std::max(x.dsize_(), y.dsize_()); + if (xscalar) dsize_ = y.dsize_(); + if (yscalar) dsize_ = x.dsize_(); + bool is_float = x.is_float() || y.is_float(); + bool has_bf16 = x==ns_bfloat16 || y==ns_bfloat16; + if (is_float) + return float_dtype(dsize_, xscalar||yscalar, has_bf16); + else { + return int_dtype(dsize_); + } +} + +// @pyjt(binary_dtype_infer) +inline NanoString binary_dtype_infer(NanoString op, NanoString x, NanoString y, bool xscalar=false, bool yscalar=false) { + if (op.is_bool()) return ns_bool; + int dsize_ = std::max(x.dsize_(), y.dsize_()); + if (xscalar) dsize_ = y.dsize_(); + if (yscalar) dsize_ = x.dsize_(); + bool is_float = !op.is_int() && + (x.is_float() || y.is_float() || op.is_float()); + bool has_bf16 = x==ns_bfloat16 || y==ns_bfloat16; + if (is_float) { + if (op.is_white() && !(amp_reg & amp_keep_white)) + return (dsize_ == 3) ? ns_float64 : ns_float32; + return float_dtype(dsize_, xscalar||yscalar, has_bf16); + } else { + if (x.is_bool() && y.is_bool()) return ns_bool; + return int_dtype(dsize_); + } +} + +inline NanoString unary_dtype_infer(NanoString op, NanoString x) { + if (op.is_bool()) return ns_bool; + int dsize_ = x.dsize_(); + if (op.is_float()) { + if (op.is_white() && !(amp_reg & amp_keep_white)) + return (dsize_ == 3) ? ns_float64 : ns_float32; + return float_dtype(dsize_, false, x==ns_bfloat16); + } + if (op.is_int()) return int_dtype(dsize_); + return x; +} + +inline NanoString reduce_dtype_infer(NanoString op, NanoString x) { + bool is_float = x.is_float() || op.is_float(); + int dsize_ = x.dsize_(); + if (is_float) { + if (amp_reg & amp_keep_reduce) + return float_dtype(dsize_, false, x==ns_bfloat16); + return (dsize_ == 3) ? ns_float64 : ns_float32; + } else { + return x; + } +} + +} diff --git a/python/jittor/src/misc/nano_vector.h b/python/jittor/src/misc/nano_vector.h new file mode 100644 index 00000000..815856e5 --- /dev/null +++ b/python/jittor/src/misc/nano_vector.h @@ -0,0 +1,306 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include "misc/intrin.h" + +namespace jittor { + +struct Slice { + int64 start, stop, step, mask; + inline void fill(int64 size) { + if (step>0) { + if (mask&2) + stop = size; + else if (stop<0) + stop += size; + else + stop = std::min(size, stop); + } else { + if (mask&1) start = size-1; + if (mask&2) + stop = -1; + else if (stop<0) + stop = std::max((int64)0, stop+size); + } + if (start<0) start += size; + mask = 0; + ASSERT(start==stop || (start>=0 && stop>=-1 && start> (64 - lzcnt(b)); +} + +// @pyjt(NanoVector) +struct NanoVector { + int64 data=0, offset=0; + + enum { + size_nbits=4, + offset_nbits=6, + }; + + // @pyjt(__init__) + inline NanoVector() {} + inline NanoVector(std::nullptr_t) {} + // @pyjt(__init__) + inline NanoVector(const NanoVector& nv) : data(nv.data), offset(nv.offset) {} + + inline void clear() { data = offset = 0; } + + // @pyjt(__len__, __map_len__) + inline int size() const { + return offset & ((1<> (size_nbits+i*offset_nbits)) + & ((1<=0 && i> ((64-nbits)&63); + } + + // @pyjt(__map_getitem__) + inline NanoVector slice(Slice slice) { + slice.fill(size()); + NanoVector v; + if (slice.step>0) { + for (int i=slice.start; ioperator[](i)); + } else { + for (int i=slice.start; i>slice.stop; i+=slice.step) + v.push_back(this->operator[](i)); + } + return v; + } + + inline int64 operator[](int i) const { + int pre_offset = i ? get_offset(i-1) : 0; + int next_offset = get_offset(i); + int nbits = next_offset - pre_offset; + return (data << ((64-next_offset)&63)) + >> ((64-nbits)&63); + } + + // @pyjt(__init__) + inline NanoVector(const vector& v) { + for (auto a : v) push_back_check_overflow(a); + } + +#ifdef __linux__ + inline NanoVector(const vector& v) { + for (auto a : v) push_back_check_overflow((int64)a); + } +#endif + + template + inline static NanoVector make(const TMakeV* v, int n) { + NanoVector nv; + for (int i=0; i + NanoVector(Args... args) { + auto f = [&](int64 c) { push_back(c); }; + // Brace-enclosed initializers + int dummy[] = {(f(args), 0)...}; + (void)dummy; + } + + struct Iter { + const NanoVector* self; + int index; + inline int64 operator*() { return self->at(index); } + inline Iter& operator++() { index++; return *this; } + inline Iter operator+(int i) { return {self, i+index}; } + inline bool operator!=(Iter& other) { return index != other.index; } + }; + + inline Iter begin() { return {this, 0}; } + inline Iter end() { return {this, size()}; } + + inline void pop_back() { offset--; data &= (1ll<=get_nbits(v)); + set_data(v, nbits, pre_offset); + } + + inline vector to_vector() const { + vector v(size()); + for (int i=0; i + void _unpack(int i, int& x, Args&&... args) { + x = this->operator[](i); + _unpack(i+1, std::forward(args)...); + } + + template + void unpack(Args&&... args) { + _unpack(0, std::forward(args)...); + } +}; + + +// @pyjt(NanoVector.__add__) +inline NanoVector add(NanoVector self, NanoVector other) { + for (int i=0; ipush_back_check_overflow(other[i]); + return self; +} + +inline std::ostream& operator<<(std::ostream& os, const NanoVector& v) { + os << '['; + for (int i=0; i struct hash { +inline std::size_t operator()(jittor::NanoVector const& s) const noexcept { + std::size_t h1 = std::hash{}(s.data); + std::size_t h2 = std::hash{}(s.offset); + return h1 ^ (h2 << 1); +} +}; +} + diff --git a/python/jittor/src/misc/ring_buffer.cc b/python/jittor/src/misc/ring_buffer.cc new file mode 100644 index 00000000..3406058c --- /dev/null +++ b/python/jittor/src/misc/ring_buffer.cc @@ -0,0 +1,114 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#ifndef _WIN32 +#include +#endif +#include "common.h" +#include "misc/ring_buffer.h" + +namespace jittor { + +RingBuffer::RingBuffer(uint64 size, bool multiprocess) : m(multiprocess), cv(multiprocess) { + int i=0; + for (;(1ll<size = size_mask+1; + size_bit = i; + l = r = is_wait = is_stop = 0; + is_multiprocess = multiprocess; +} + +void RingBuffer::stop() { + MutexScope _(m); + is_stop = 1; + cv.notify(); +} + +RingBuffer::~RingBuffer() { + stop(); +} + + +RingBuffer* RingBuffer::make_ring_buffer(uint64 size, bool multiprocess, uint64 buffer, bool init) { + int i=0; + for (;(1ll<size; + auto is_multiprocess = rb->is_multiprocess; + if (init) + rb->~RingBuffer(); + if (is_multiprocess) { + #ifndef _WIN32 + munmap(rb, total_size); + #else + if (!buffer) + free((void*)rb); + // this buffer is not owned by this obj + #endif + (void)total_size; + } else { + free((void*)rb); + } +} + +// test + +JIT_TEST(ring_buffer_benchmark) { + size_t n = 1ll << 20; + size_t size = 1<<15; + // size_t n = 1ll << 30; + // size_t size = 1<<20; + // size_t n = 1ll << 10; + // size_t size = 1<<5; + RingBuffer* rb = RingBuffer::make_ring_buffer(size, 0); + std::thread p([&]() { + for (size_t i=0; ipush_t(i); + } + }); + auto start = std::chrono::high_resolution_clock::now(); + size_t s = 0; + for (size_t i=0; ipop_t(); + s += x; + } + auto finish = std::chrono::high_resolution_clock::now(); + auto tt = std::chrono::duration_cast(finish-start).count(); + p.join(); + expect_error([&]() { rb->push(size+1); }); + RingBuffer::free_ring_buffer(rb); + + LOGi << tt << tt*1.0/n; + LOGi << s << (n*(n-1)/2); + ASSERTop(s,==,(n*(n-1)/2)); + ASSERTop(tt*1.0/n,<=,100); +} + +} diff --git a/python/jittor/src/misc/ring_buffer.h b/python/jittor/src/misc/ring_buffer.h new file mode 100644 index 00000000..c9605c7e --- /dev/null +++ b/python/jittor/src/misc/ring_buffer.h @@ -0,0 +1,255 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#ifdef _MSC_VER +#include +#else +#include +#endif +#include +#include "common.h" + +namespace jittor { + +struct RingBuffer { + +#ifdef _MSC_VER + struct Mutex { + HANDLE handle; + inline Mutex(bool multiprocess=0) { + } + + inline void lock() { + } + + inline void unlock() { + } + inline ~Mutex() { + } + }; + struct MutexScope { + Mutex* m; + inline MutexScope(Mutex& m) : m(&m) { m.lock(); } + inline ~MutexScope() { m->unlock(); } + }; + + struct Cond { + inline Cond(bool multiprocess=0) { + } + + inline void wait(MutexScope& m) { + } + + inline void notify() { + } + }; +#else + struct Mutex { + pthread_mutex_t m; + inline Mutex(bool multiprocess=0) { + pthread_mutexattr_t attr; + pthread_mutexattr_init(&attr); + if (multiprocess) + pthread_mutexattr_setpshared(&attr, PTHREAD_PROCESS_SHARED); + ASSERT(0 == pthread_mutex_init((pthread_mutex_t*)&m, &attr)); + } + + inline ~Mutex() { + pthread_mutex_destroy(&m); + } + + inline void lock() { + pthread_mutex_lock(&m); + } + + inline void unlock() { + pthread_mutex_unlock(&m); + } + }; + struct MutexScope { + Mutex* m; + inline MutexScope(Mutex& m) : m(&m) { m.lock(); } + inline ~MutexScope() { m->unlock(); } + }; + + struct Cond { + pthread_cond_t cv; + inline Cond(bool multiprocess=0) { + pthread_condattr_t attr; + pthread_condattr_init(&attr); + if (multiprocess) + pthread_condattr_setpshared(&attr, PTHREAD_PROCESS_SHARED); + ASSERT(0 == pthread_cond_init((pthread_cond_t*)&cv, &attr)); + } + + inline ~Cond() { + // a dirty hack + // ref: https://stackoverflow.com/questions/20439404/pthread-conditions-and-process-termination + // cv.__data.__wrefs = 0; + #ifdef __linux__ + cv.__data = {0}; + #endif + pthread_cond_destroy(&cv); + } + + inline void wait(MutexScope& m) { + pthread_cond_wait(&cv, &m.m->m); + } + + inline void notify() { + pthread_cond_signal(&cv); + } + }; +#endif + + uint64 size; + uint64 size_mask; + uint64 size_bit; + volatile uint64 l; + volatile uint64 r; + volatile bool is_wait; + volatile bool is_stop; + bool is_multiprocess; + Mutex m; + Cond cv; + char _ptr; + + RingBuffer(uint64 size, bool multiprocess=false); + ~RingBuffer(); + void stop(); + static RingBuffer* make_ring_buffer(uint64 size, bool multiprocess, uint64 buffer=0, bool init=true); + static void free_ring_buffer(RingBuffer* rb, uint64 buffer=0, bool init=true); + + inline void clear() { l = r = is_stop = 0; } + + inline void wait() { + if (is_stop) { + throw std::runtime_error("stop"); + } + { + MutexScope _(m); + if (is_wait) { + cv.notify(); + is_wait = 0; + } + is_wait = 1; + cv.wait(_); + } + } + + inline void notify() { + MutexScope _(m); + cv.notify(); + is_wait = 0; + } + + inline void push(uint64 size, uint64& __restrict__ offset) { + auto rr = offset; + auto rr_next = rr + size; + auto c1 = rr >> size_bit; + auto c2 = (rr_next-1) >> size_bit; + if (c1 != c2) { + // if cross boundary + rr = c2 << size_bit; + rr_next = rr + size; + } + CHECK(rr_next <= r+this->size) << "Buffer size too small, please increase buffer size. Current size:" + << this->size << "Required size:" << rr_next - r; + while (rr_next > l + this->size) { + wait(); + } + offset = rr_next; + } + + inline void commit_push(uint64 offset) { + r = offset; + if (is_wait) + notify(); + } + + inline void pop(uint64 size, uint64& __restrict__ offset) { + auto ll = offset; + auto ll_next = ll + size; + auto c1 = ll >> size_bit; + auto c2 = (ll_next-1) >> size_bit; + if (c1 != c2) { + // if cross boundary + ll = c2 << size_bit; + ll_next = ll + size; + } + while (ll_next > r) { + ASSERT(size<=this->size); + wait(); + } + offset = ll_next; + } + + inline void commit_pop(uint64 offset) { + l = offset; + if (is_wait) + notify(); + } + + inline uint64 push(uint64 size) { + auto offset = r; + push(size, offset); + return offset; + } + inline uint64 pop(uint64 size) { + auto offset = l; + pop(size, offset); + return offset; + } + + inline char* get_ptr(uint64 size, uint64 offset) { return ((&_ptr)+((offset-size)&size_mask)); } + + template + inline T& get(uint64 offset) { return *(T*)((&_ptr)+((offset-sizeof(T))&size_mask)); } + + template + inline void push_t(const T& data, uint64& __restrict__ offset) { + push(sizeof(T), offset); + get(offset) = data; + } + + template + inline T& pop_t(uint64& __restrict__ offset) { + pop(sizeof(T), offset); + return get(offset); + } + + inline void push_string(const string& data, uint64& __restrict__ offset) { + push_t(data.size(), offset); + push(data.size(), offset); + auto ptr = get_ptr(data.size(), offset); + std::memcpy(ptr, data.c_str(), data.size()); + } + + inline string pop_string(uint64& __restrict__ offset) { + auto size = pop_t(offset); + pop(size, offset); + auto ptr = get_ptr(size, offset); + return string(ptr, size); + } + + template + inline void push_t(const T& data) { + auto offset = push(sizeof(T)); + get(offset) = data; + commit_push(offset); + } + + template + inline T pop_t() { + auto offset = pop(sizeof(T)); + T data = get(offset); + commit_pop(offset); + return data; + } +}; + +} diff --git a/python/jittor/src/misc/stack_vector.h b/python/jittor/src/misc/stack_vector.h new file mode 100644 index 00000000..62280842 --- /dev/null +++ b/python/jittor/src/misc/stack_vector.h @@ -0,0 +1,56 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include "misc/nano_vector.h" + +namespace jittor { + +template +struct StackVector { + int n; + T a[N+1]; + inline T& front() { return a[0]; } + inline T& back() { return a[n-1]; } + inline int size() { return n;} + inline T* data() { return a;} + inline StackVector(int n=0) : n(n) {} + + struct Iter { + const StackVector* self; + int index; + inline T operator*() { return self->at(index); } + inline Iter& operator++() { index++; return *this; } + inline Iter operator+(int i) { return {self, i+index}; } + inline bool operator!=(Iter& other) { return index != other.index; } + }; + + inline Iter begin() { return {this, 0}; } + inline Iter end() { return {this, size()}; } + inline T& operator[](int i) { return a[i]; } + + inline void pop_back() { n--; } + inline void push_back(T v) { if (n +inline std::ostream& operator<<(std::ostream& os, const StackVector& v) { + os << '['; + for (int i=0; i. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once + +#if defined(__clang__) +#include +#elif defined(__GNUC__) +#include +#endif + +#include "common.h" + +namespace jittor { + +#if __cplusplus < 201400L || defined(IS_ACL) +using string_view = string; +#elif defined(__clang__) +// using std::string_view; +using string_view = string; +#elif defined(__GNUC__) +using std::experimental::string_view; +#else +using std::string_view; +#endif + +template +struct string_view_map { + typedef typename std::unordered_map umap_t; + typedef typename umap_t::iterator iter_t; + umap_t umap; + vector holder; + + iter_t find(string_view sv) { + return umap.find(sv); + } + + iter_t begin() { return umap.begin(); } + iter_t end() { return umap.end(); } + + const T& at(string_view sv) { return umap.at(sv); } + size_t size() { return umap.size(); } + + T& operator[](string_view sv) { + auto iter = find(sv); + if (iter != end()) return iter->second; + holder.emplace_back(sv); + string_view nsv = holder.back(); + return umap[nsv]; + } +}; + + +} // jittor diff --git a/python/jittor/src/node.h b/python/jittor/src/node.h new file mode 100644 index 00000000..669ae85d --- /dev/null +++ b/python/jittor/src/node.h @@ -0,0 +1,242 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include "misc/nano_string.h" +#include "misc/nano_vector.h" +#include "pybind/py_var_tracer.h" + +namespace jittor { + +EXTERN_LIB unordered_map lived_nodes; +EXTERN_LIB unordered_map lived_nodes_id; +EXTERN_LIB int64 total_node; +EXTERN_LIB int64 nt; +EXTERN_LIB vector free_buffer; +EXTERN_LIB uint8 node_order; + +inline static Node* get_node(int64 id) +{ return lived_nodes_id.count(id) ? lived_nodes_id[id] : nullptr; } + +struct NodeFlags { + typedef uint32 nf_t; + nf_t flags=0; + enum Flags { + // bit0: is_var + _var=0, + // bit1: state + _finished=1, + // bit2: stop grad + _stop_grad=2, + // bit3: is fetch + _fetch=3, + // bit4: node order low + _node_order_low=4, + _node_order_high=5, + _n=6, + + // var related flags + _force_fuse=_n+0, + _stop_fuse=_n+1, + _needed_by_backward=_n+3, + _out_hint=_n+4, + _th_require_grad=_n+5, + _is_scalar=_n+5, + _is_swapped=_n+6, + + // op related flags + // bit0: support cpu + _cpu=_n+0, + // bit1: support cuda + _cuda=_n+1, + // bit2: forward op + _forwarded=_n+2, + // bit3: vary shape op + _vary_shape=_n+3, + // bit4~5: op type + _op_type=_n+4, _op_type_nbits=2, + // bit6: backprop grad at ones + _grads=_n+6, + // bit7: has graph optimize + _has_gopt=_n+7, + // bit8: has vary input + _has_vary_input=_n+8, + _manual_set_vnbb = _n+9, + // bit9: prefer 32 bit + _prefer_32=_n+10, + // force 16 bit + _prefer_16=_prefer_32+1, + // reduce keep type unchange + _reduce_keep=_prefer_32+2, + _custom_flag = _prefer_32+6, + }; + + inline void set(Flags f, int a=1, int nbits=1) { + nf_t mask = (((1u<>f) & ((1u<::iterator back; + input_t(Node* n) : node(n) {} + operator Node*() { return node; } + operator Op*() { return (Op*)node; } + operator Var*() { return (Var*)node; } + }; + struct output_t { + Node* node; + list::iterator back; + int index; + output_t(Node* n, int i) : node(n), index(i) {} + operator Node*() { return node; } + operator Op*() { return (Op*)node; } + operator Var*() { return (Var*)node; } + operator var_output_t() { return {(Op*)node, index}; } + }; + NodeFlags flags; + NanoString ns; + inline bool is_var() const { return flags.get(NodeFlags::_var); } + inline bool is_stop_grad() const { return flags.get(NodeFlags::_stop_grad); } + inline bool is_finished() const { return flags.get(NodeFlags::_finished); } + // forward_liveness can propergate forward(from input to output) + // f1. var_holder contrib one forward_liveness + // f2. var ptr contrib one forward_liveness + // f3. input(has_grad and f>0) contrib one forward_liveness + int forward_liveness = 0; + // forward_liveness can propergate backward(from output to input) + // b1. var ptr contrib one backward_liveness + // b2. var holder contrib one backward_liveness + // b3. output(b>0) contrib one backward_liveness + int backward_liveness = 0; + // pending liveness can propergate backward(from output to input) + // p1: pending and f>0 and b>0 contrib pending_liveness + // p2: output(p>0 and pending) contrib pending_liveness + int pending_liveness = 0; + inline bool need_free() + { return !pending_liveness && (!forward_liveness || !backward_liveness); } + + int custom_data; + int64 tflag = 0; + int64 id; + list _inputs; + list _outputs; + + int64 order() { + if (flags.get(NodeFlags::_node_order_low)) return 0; + if (flags.get(NodeFlags::_node_order_high)) return 1ll<<60; + return id; + } + + inline Node() { + id = ++total_node; + #ifdef NODE_MEMCHECK + lived_nodes_id[id] = this; + lived_nodes[(void*)this] = id; + #endif + flags.set(NodeFlags::_node_order_low, node_order, 2); + } + inline virtual ~Node() { + #ifdef NODE_MEMCHECK + lived_nodes_id.erase(id); + lived_nodes.erase((void*)this); + #endif + if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.release_node(this); + } + + inline Var* var() { return (Var*)this; } + inline Op* op() { return (Op*)this; } + inline Node* node() { return this; } + void free(); + // this function is used for debug memory + inline bool exist() const { + #ifdef NODE_MEMCHECK + return lived_nodes.count((void*)this); + #else + return true; + #endif + } + void memcheck_all_exist() const; + // release from counter and memory checker + void __release(); + #define CHECK_NODE_EXIST(node) \ + ASSERT(node->exist()) << "Node("#node")" << (void*)node << "not exist." + #define CHECK_EXIST CHECK_NODE_EXIST(this) + #define CHECK_NODE_EXIST2(a,b) \ + CHECK_NODE_EXIST(a); CHECK_NODE_EXIST(b); + #define CHECK_NODE_EXIST3(a,b,c) \ + CHECK_NODE_EXIST2(a,b); CHECK_NODE_EXIST(c); + + inline Caster inputs() { CHECK_EXIST; return &_inputs; } + inline Caster outputs() { CHECK_EXIST; return &_outputs; } + inline Node* input(uint i) { + CHECK_EXIST; + auto iter = _inputs.begin(); + while (i--) iter++; + return iter->node; + } + inline Node* output(uint i) { + CHECK_EXIST; + auto iter = _outputs.begin(); + while (i--) iter++; + return iter->node; + } + + void release_inputs(); + void set_inputs(list nodes); + void add_inputs(const vector& nodes); + void add_inputs(const vector& nodes); + void release_forward_liveness(); + void own_forward_liveness(); + void release_backward_liveness(); + void own_backward_liveness(); + void release_pending_liveness(); + void own_pending_liveness(); + void release_both_liveness(); + void own_both_liveness(); + void finish_pending_liveness(); + void set_stop_grad(); +}; + +struct SetupFreeBuffer { + +bool outside; +inline SetupFreeBuffer() { + outside = !nt; + if (outside) { + nt = ++tflag_count; + } +} + +inline ~SetupFreeBuffer() { + if (outside) { + for (int i=0; i +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include "common.h" +#include "var_holder.h" +#include "ops/array_op.h" + +namespace jittor { + +struct NumpyResult; + +struct NumpyFunc { + typedef NumpyResult R; + std::function callback; + std::function deleter; + std::function inc_ref; + NumpyFunc() = default; + NumpyFunc(NumpyFunc&& other) : callback(other.callback), deleter(other.deleter), inc_ref(other.inc_ref) { + other.callback = nullptr; + other.deleter = nullptr; + other.inc_ref = nullptr; + }; + NumpyFunc(const NumpyFunc& other) : callback(other.callback), deleter(other.deleter), inc_ref(other.inc_ref) { + inc_ref(); + }; + NumpyFunc(std::function&& callback) : callback(move(callback)) {} + NumpyFunc(std::function&& callback, std::function&& deleter) + : callback(move(callback)), deleter(move(deleter)) {}; + NumpyFunc(std::function&& callback, std::function&& deleter, std::function&& inc_ref) + : callback(move(callback)), deleter(move(deleter)), inc_ref(move(inc_ref)) {}; + ~NumpyFunc() { + if (deleter) { + deleter(); + } + } + void operator =(NumpyFunc&& other) { this->~NumpyFunc(); new (this) NumpyFunc(move(other)); } +}; + +struct NumpyResult { + map> varrays; + map ints; + map arrays; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/op.cc b/python/jittor/src/op.cc new file mode 100644 index 00000000..b430149b --- /dev/null +++ b/python/jittor/src/op.cc @@ -0,0 +1,331 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include + +#include "node.h" +#include "op.h" +#include "var.h" +#include "op_compiler.h" +#include "profiler/profiler.h" +#include "mem/allocator.h" +#include "misc/cuda_flags.h" +#include "pybind/py_var_tracer.h" +#include "executor.h" +#include "var_holder.h" +#include "fused_op.h" + +namespace jittor { + +DECLARE_FLAG(string, cache_path); +// DECLARE_FLAG(uint8, th_mode); +extern uint8 th_mode; + +DEFINE_FLAG(int, try_use_32bit_index, 0, + "If not overflow, try to use 32 bit type as index type."); + +string_view_map jit_ops; +string_view_map jit_key_mapper; + +int64 Op::number_of_lived_ops = 0; + +Op::Op() { + flags.set(NodeFlags::_var, 0); + flags.set(NodeFlags::_cpu, 1); + flags.flags |= ((amp_reg & 63) << NodeFlags::_prefer_32); + number_of_lived_ops++; + if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.record_node(this); +} + +Op::~Op() { + number_of_lived_ops--; +} + +void Op::forward(Var* input) { + flags.set(NodeFlags::_forwarded); + outputs_holder.emplace_back(input); +} + +VarPtr Op::duplicate() { + return nullptr; +} + +VarPtr Op::grad(Var* out, Var* dout, Var* v, int v_index) { + LOGw << "Grad of" << name() << "return zeros"; + return nullptr; +} + +void Op::grads(Var** douts, VarPtr* dins) { + LOGw << "Grads of" << name() << "return zeros"; +} + +Var* Op::create_output(NanoVector shape, NanoString dtype) { + VarPtr vp(shape, dtype); + Var* output = vp.ptr; + outputs_holder.emplace_back(move(vp)); + return output; +} + +void Op::init() { + infer_shape(); + bool manual_set_vnbb = flags.get(NodeFlags::_manual_set_vnbb) + || _inputs.size()==0 + || (_outputs.size()==1 && _outputs.front().node->is_stop_grad()); + for (Var* v : inputs()) { + if (!manual_set_vnbb) { + v->flags.set(NodeFlags::_needed_by_backward); + } + } + Var* need_sync = nullptr; + for (Var* v : outputs()) { + if (!manual_set_vnbb) + v->flags.set(NodeFlags::_needed_by_backward); + if (v->num < 0) + need_sync = v; + } + if (need_sync) { + exe.run_sync(vector({need_sync}), false); + CHECK(need_sync->num >= 0) << need_sync << "'s shape is error"; + } + if (th_mode) { + bool stop_grad = true; + for (Var* v : inputs()) { + if (!v->is_stop_grad()) { + stop_grad = false; + break; + } + } + if (stop_grad) + for (Var* v : outputs()) { + v->set_stop_grad(); + } + } +} + +void Op::compile_optimize(string& src) {} + +void Op::infer_shape() {} +void Op::run() {} +void Op::jit_prepare(JK& jk) {} +void Op::graph_optimize() {} + +string Op::name_ex() const { + string a=name(); + if (ns.data) { + a += '.'; + a += ns.to_cstring(); + } + return a; +} + +string Op::get_jit_key(JK& jk) { + jk.clear(); + do_jit_prepare(jk); + return jk.to_string(); +} + +vector> Op::get_jit_define() { + return parse_jit_keys(get_jit_key(get_jk())); +} + +string Op::get_hash_name() { + string hash_name; + std::stringstream ss; + JK& jk = get_jk(); + do_prepare(jk); + ss << std::hex << std::hash()(jk.to_string()); + hash_name = ss.str(); + return hash_name; +} + +void Op::do_jit_prepare(JK& jk) { + memcheck_all_exist(); + jk << name(); + auto pre_size = jk.size; + jit_prepare(jk); + if (jk.size == pre_size) { + // not a jit op + bool has_cuda = flags.get(NodeFlags::_cuda); + bool has_cpu = flags.get(NodeFlags::_cpu); + CHECK(has_cuda || has_cpu); + if (has_cuda && has_cpu && !use_cuda) + flags.set(NodeFlags::_cuda, 0); + jk.clear(); + } else { + bool use_int64_t = false; + // TODO: fused op do not have inputs, + // check use_cuda_op from outputs may not be enough + bool use_cuda_op = use_cuda; + for (Var* var : inputs()) { + if (var->num >= std::numeric_limits::max()) + use_int64_t = true; + } + for (Var* var : outputs()) { + if (var->num >= std::numeric_limits::max()) + use_int64_t = true; + } + jk << "«JIT:1"; + if (use_cuda_op && flags.get(NodeFlags::_cuda)) { + jk << "«JIT_cuda:1"; + flags.set(NodeFlags::_cpu, 0); + // TODO: 64bit index in CUDA + // use_int64_t = false; + } else { + if (use_cuda==2) { + if (flags.get(NodeFlags::_cuda)) + LOGf << "Op" << name() >> "'s vars are not allocated in cuda"; + else + LOGf << "Op" << name() << "doesn't have cuda version"; + } + ASSERT(flags.get(NodeFlags::_cpu)) + << "Op" << name() << "doesn't have cpu version"; + jk << "«JIT_cpu:1"; + flags.set(NodeFlags::_cuda, 0); + } + if (try_use_32bit_index) use_int64_t = false; + if (use_int64_t) + jk << "«index_t:int64"; + else + jk << "«index_t:int32"; + } + jk.finilize(); +} + +void Op::do_prepare(JK& jk){ + jk.clear(); + do_jit_prepare(jk); +} + +void Op::do_run_after_prepare(JK& jk) { + if (!jk.empty()) + jit_run(jk); + else + run(); +} + +void Op::do_run() { + JK& jk = get_jk(); + do_prepare(jk); + do_run_after_prepare(jk); +} + +string Op::get_filename_from_jit_key(const string& jit_key, const string& suffix) { + auto iter = jit_key_mapper.find(jit_key); + string s = iter==jit_key_mapper.end() ? jit_key : iter->second; + std::stringstream ss; + if (s.size() > 100) { + ss << s.substr(0, 90) << "...hash_" + << std::hex << std::hash()(s); + } else { + ss << s << "_hash_" << + std::hex << std::hash()(s); + } + s = ss.str(); + for (char& c : s) { + if (!((c>='a' && c<='z') || (c>='A' && c<='Z') || (c>='0' && c<='9'))) + c = '_'; + } + #ifndef _WIN32 + string filename = cache_path + "/jit/"; + #else + string filename = cache_path + "\\jit\\"; + #endif + filename += s; + filename += "_op"; + filename += suffix; + return filename; +} + +// convert xxx.yyy -> xxx +string Op::op_name_to_file_name(const string& s) { + auto pos = s.find('.'); + return pos == string::npos ? s : s.substr(0, pos); +} +// convert xxx_xxx -> XxxXxx +string Op::file_name_to_class_name(const string& s) { + char prev = '_'; + string res; + res.reserve(s.size()); + for (char c : s) { + if (c != '_') { + if (prev == '_') + res += c-'a'+'A'; + else + res += c; + } + prev = c; + } + return res; +} + +void Op::jit_run(JK& jk) { + const char* jit_key = jk.to_cstring(); + auto iter = jit_ops.find(jit_key); + if (iter != jit_ops.end()) { + LOGvvv << "Jit op key found:" << jit_key << "jit op entry:" << (void*)iter->second; + Profiler::record_and_run(iter->second, this, jit_key); + return; + } + LOGvv << "Jit op key not found:" << jit_key; + // compile JIT op + string prev_jit_key = jit_key; + auto op_entry = OpCompiler::do_compile(this); + string new_jit_key = get_jit_key(jk); + jit_ops[new_jit_key] = jit_ops[prev_jit_key] = op_entry; + jit_key_mapper[prev_jit_key] = new_jit_key; + LOGvv << "Get jit op entry:" << (void*)op_entry; + Profiler::record_and_run(op_entry, this, new_jit_key.c_str()); +} + +void Op::statistics(uint64_t& in, uint64_t& out, uint64_t& compute) { + in = out = compute = 0; + for (auto& e : _inputs) { + auto var = e.node->var(); + if (e.back->index<0) continue; + in += var->size; + compute = std::max(compute, (uint64_t)var->num); + } + for (auto& e : _outputs) { + auto var = e.node->var(); + if (e.index<0) continue; + out += var->size; + compute = std::max(compute, (uint64_t)var->num); + } +} + +std::ostream& operator<<(std::ostream& os, const Op* op) { + if (!op) return os << "Op(0)"; + os << "Op(" << op->id + << ':' << op->forward_liveness + << ':' << op->backward_liveness + << ':' << op->pending_liveness + << ":i" << op->_inputs.size() + << ":o" << op->_outputs.size() + << ":s" << op->is_finished() + << ":g" << !op->is_stop_grad() + << "," << op->name_ex(); + if (op->_outputs.size()>1) + os << "->..."; + else if (op->_outputs.size() == 1) { + auto v = (Var*)op->_outputs.front().node; + if (v->name.size()) + os << "->" << v->name; + else + os << "->" << v->id; + } + os << ')'; + if (trace_py_var) { + os << '{'; + print_node_trace(op, os); + os << '}'; + } + if (op->name_ex() == "fused") { + os << ((FusedOp*)op)->ops; + } + return os; +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/op.h b/python/jittor/src/op.h new file mode 100644 index 00000000..32d0116f --- /dev/null +++ b/python/jittor/src/op.h @@ -0,0 +1,73 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include "node.h" +#include "jit_key.h" +#include "misc/string_view_map.h" + +namespace jittor { + +enum OpType {other=0, element=1, broadcast=2, reduce=3}; +struct Op : Node { + vector outputs_holder; + static int64 number_of_lived_ops; + + inline Caster inputs() { CHECK_EXIST; return &_inputs; } + inline Caster outputs() { CHECK_EXIST; return &_outputs; } + inline Var* input(uint i) { return Node::input(i)->var(); } + inline Var* output(uint i) { return Node::output(i)->var(); } + inline uint type() const { CHECK_EXIST; return flags.get(NodeFlags::_op_type, NodeFlags::_op_type_nbits); } + inline void set_type(OpType t) { CHECK_EXIST; flags.set(NodeFlags::_op_type, t, NodeFlags::_op_type_nbits); } + + Var* create_output(NanoVector shape, NanoString dtype); + void init(); + + // Op::forward should be call in constructor + // A forwarded operator will suicide in after constructor + void forward(Var* input); + static string get_filename_from_jit_key(const string& jit_key, const string& suffix); + static string op_name_to_file_name(const string& s); + static string file_name_to_class_name(const string& s); + Op(); + ~Op(); + + virtual VarPtr grad(Var* out, Var* dout, Var* v, int v_index); + virtual void grads(Var** douts, VarPtr* dins); + virtual void infer_shape(); + virtual void run(); + virtual void jit_prepare(JK& jk); + virtual void do_jit_prepare(JK& jk); + virtual const char* name() const = 0; + virtual void statistics(uint64_t& in, uint64_t& out, uint64_t& compute); + virtual void do_prepare(JK& jk); + virtual void do_run_after_prepare(JK& jk); + virtual void do_run(); + virtual VarPtr duplicate(); + virtual void compile_optimize(string& src); + virtual void graph_optimize(); + void jit_run(JK& jk); + + string name_ex() const; + string get_jit_key(JK& jk); + vector> get_jit_define(); + string get_hash_name(); +}; + +std::ostream& operator<<(std::ostream& os, const Op* var); + +EXTERN_LIB string_view_map jit_ops; +// jit_key_mapper: map origin jit_key -> tuned jit_key +EXTERN_LIB string_view_map jit_key_mapper; + +#ifdef JIT + #define DECLARE_jit_run void jit_run(); +#else + #define DECLARE_jit_run void jit_prepare(JK& jk) override; +#endif + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/op_compiler.cc b/python/jittor/src/op_compiler.cc new file mode 100644 index 00000000..f2f15a9f --- /dev/null +++ b/python/jittor/src/op_compiler.cc @@ -0,0 +1,1102 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#include "op.h" +#include "fused_op.h" +#include "op_compiler.h" +#include "jit_compiler.h" +#include "utils/cache_compile.h" +#include "opt/tuner_manager.h" +#include "utils/str_utils.h" +#include "ops/op_register.h" +#include "ops/array_op.h" +#include "lock.h" +#include "opt/expr.h" +#include "pyjt/py_caller.h" + +namespace jittor { + +DECLARE_FLAG(string, jittor_path); + +using namespace jit_compiler; + +static bool isvar(char x) { return isalnum(x) || x == '_' || x == ':'; } + +void OpCompiler::get_op_var_by_name(const string& name, uint& op_id, uint& opvar_id, Op*& op, Var*& var) { + // name: op{id}_{varname} + ASSERT(name.size()>3 && name[0]=='o' && name[1]=='p'); + uint j=2; + while (j2); + op_id = std::stoi(name.substr(2, j-2)); + ASSERT(op_members.size() > op_id); + bool found = false; + for (opvar_id=0 ;opvar_id < op_members[op_id].size(); opvar_id++) { + if (op_members[op_id][opvar_id] == name) { + found = true; + break; + } + } + op = this->op->ops[op_id]; + ASSERT(found && opvar_id < op->inputs().size() + op->outputs().size()); + if (opvar_id >= op->inputs().size()) { + auto iter = op->outputs().begin(); + for (uint t=op->inputs().size(); tinputs().begin(); + for (uint t=0; tinputs()) { + if (i==var) { + found = 1; + break; + } + var_id++; + } + if (!found) + for (Var* o : op->outputs()) { + if (o==var) { + found = 1; + break; + } + var_id++; + } + ASSERT(found); + ASSERT(this->op); + ASSERT(this->op->context); + auto opid = this->op->context->node_id.at(op); + ASSERT(opid<(int)op_members.size()); + auto& v = op_members[opid]; + ASSERT(var_id < v.size()); + return v[var_id]; +} + +string OpCompiler::get_name_by_op_input(Op* op, uint i) { + return op_members.at(this->op->get_node_id(op)).at(i); +} + +string OpCompiler::get_name_by_op_output(Op* op, uint i) { + return op_members.at(this->op->get_node_id(op)).at(i+op->inputs().size()); +} + +bool OpCompiler::op_exist(Op* op) { + return op_members.at(this->op->get_node_id(op)).size(); +} + +int OpCompiler::total_member_count() { + int member_count=0; + int i = 0; + for (auto& v : op_members) { + // array need a extra local var + if (op->ops[i]->name()==string("array")) + member_count += 1; + if (op->ops[i]->name()==string("safe_clip")) + member_count += 2; + member_count += v.size(); + i += 1; + } + return member_count; +} + +int64 OpCompiler::eval(const string& expr, const unordered_map& vars) { + if (expr.find("@") != string::npos) { + string new_expr; + for (size_t i=0; isecond; + i = k-1; + } + } + } + return eval(new_expr, vars); + } + auto e = expr::make(expr); + e->dfs([&](expr::Expr* s) { + if (s->is_sym()) { + auto iter = vars.find(s->str); + ASSERT(iter!=vars.end()) << "Jit var " << s->str << " not found."; + auto e = expr::make(iter->second); + s->swap(e.get()); + } + }); + e = e->eval(); + ASSERT(e->is(expr::_int)); + return e->as_int(); +} + +void load_macros(const string& src, unordered_map& macros) { + LOGvvvv << "load_macros" << src; + for (size_t i=0; ip ? src.substr(p,q-p) : ""; + auto args = "<"+ (r+1"; + // header body + body = args + body; + auto header = src.substr(k,r-k); + LOGvvvv << "header:" << header << "body:" << body; + macros[header] = body; + i = q; + } + } +} + +string expand_op_search(const vector& args) { + for (auto op_type : op_types) { + string ret = op_type->expand_op(args); + if (ret.size()) + return ret; + } + LOGf << "No expand op pattern found for args:" << args; + return ""; +} + +void expand_macro(const string& macro, const vector& args, string& new_src) { + LOGvvvv << "expand_macro" << macro << "args:" << args; + if (macro.size() == 0 || macro[0] != '<') { + new_src += macro; + return; + } + auto i = macro.find(">"); + ASSERT(i != string::npos); + // body + // j k i + unordered_map args_map; + for (uint j=1, l=0; jsecond]; + } + i = j-1; + continue; + } + new_src += macro[i]; + } +} + +string precompile(unordered_map defs, string src, unordered_map& macros) { + string new_src; + new_src.reserve(src.size()); + // dirty fix windows \r\n change line + for (auto& c : src) + if (c == '\r') c = '\n'; + for (size_t i=0; i=6 && inc.substr(inc.size()-6) == "defs.h") { + LOGvvvv << "Found defs include" << inc; + auto src_path = join(jittor_path, "src"); + src_path = join(src_path, inc); + auto inc_src = read_all(_to_winstr(src_path)); + // load_macros from include src + precompile(defs, inc_src, macros); + // we do not include defs.h + i = l; + continue; + } + } else + if (j-i==7 && src.substr(i,j-i) == "#define") { + load_macros(src.substr(i,l-i), macros); + } else + // #ifdef JITxxx + // #else + // #endif + if (((j-i==6 && src.substr(i,j-i) == "#ifdef") || + (j-i==7 && src.substr(i,j-i) == "#ifndef")) && startswith(src, "JIT", k)) { + bool is_ifndef = j-i==7; + string key = src.substr(k, l-k); + // find pair #endif and #else + int presum = 1; + size_t prev = l+1, ii = prev; + string block, else_block; + while (ii < src.size()) { + if (startswith(src, "#if", ii)) { + presum++; + ii += 3; + continue; + } + if (startswith(src, "#else", ii)) { + auto next_ii = ii+5; + // remove ' ' or '\n' after #else + if (next_ii comma; + vector args; + size_t l = k+1; + if (expr == "for" || expr == "if" || expr == "expand_macro" || + expr == "expand_op" || + expr == "is_def" || expr == "python" || + (k=,4u) << "Jit error: for missing arguments."; + string vi = args[0]; + string vl = args[1]; + string vr = args[2]; + string vs = args[3]; + auto vil = OpCompiler::eval(vl, defs); + auto vir = OpCompiler::eval(vr, defs); + int step = 1; + if (args.size() >= 5) { + step = OpCompiler::eval(vs, defs); + vs = args[4]; + for (int i=5; i> "[" >> vil >> "," >> vir >> "," >> step >> "]"; + int total_step = 0; + for (auto vii=vil; vii!=vir; vii+=step) { + total_step ++; + ASSERT(total_step < 1000) << "Too much step."; + new_defs[vi] = S(vii); + new_src += precompile(new_defs, vs, macros); + } + i = l-1; + continue; + } else + if (expr == "if") { + // syntax: @if(cond, true[, false]) + // ij k l + ASSERT(args.size()>=2u && args.size()<=3u) + << "Jit error: if wrong arguments."; + string vcond = args[0]; + string vtrue = args[1]; + string vfalse = args.size() == 3u ? args[2] : ""; + int cond = OpCompiler::eval(vcond, defs); + new_src += precompile(defs, cond?vtrue:vfalse, macros); + i = l-1; + continue; + } else + if (expr == "is_def") { + ASSERT(args.size()==1) + << "Jit error: is_def wrong arguments."; + string vdef = args[0]; + vdef = precompile(defs, vdef, macros); + if (defs.count(vdef) || macros.count(vdef)) + new_src += "1"; + else + new_src += "0"; + i = l-1; + continue; + } else + if (expr == "expand_macro") { + // syntax: @expand_macro(macro, args) + // ij k l + for (auto& arg : args) { + uint p=0; + while (psecond, args, ns); + } + new_src += precompile(defs, ns, macros); + i = l-1; + continue; + } else + if (expr == "expand_op") { + // syntax: @expand_op(args) + for (auto& arg : args) { + uint p=0; + while (p=1u) + << "Jit error: define wrong arguments."; + new_src += "#define "; + auto key = precompile(defs, args[0], macros); + string value, src; + new_src += key; + if (args.size()>=2) { + new_src += " "; + string all_args = args[1]; + for (int i=2; ib + // a_type->b_type + // a_dim -> b_dim + // for i in a_dim: + // a_shapei -> b_shapei + // a_stridei -> b_stridei + CHECK(args.size()==2u) + << "Jit error: alias wrong arguments."; + auto key = strip(precompile(defs, args[0], macros)); + auto value = strip(precompile(defs, args[1], macros)); + CHECK(defs.count(value+"_dim")) << '"' >> value >> '"' << "not exsit"; + int dim = std::stoi(defs.at(value+"_dim")); + vector keys = {"", "p", "dim", "type"}; + for (int i=0; i e0_p[i0*e0_stride0+i1*e0_stride1+...] + ASSERT(expr.size()); + + int nid=(int)expr.size(); + while (nid && isdigit(expr[nid-1])) nid--; + string prefix = expr.substr(0, nid); + string suffix = expr.substr(nid); + string dim; + if (expr == "x" && defs.count("XDIM")) { + dim = "XDIM"; + prefix = "x"; + } else + if (prefix == "e") { + // TODO: unify interface + prefix = "extras" + suffix; + dim = "EDIM" + suffix; + } else { + prefix = expr+"_"; + dim = prefix + "dim"; + } + CHECK(macros.count(dim)) << expr << "not exsit" << macros; + CHECKop(macros.at(dim),==,S(args.size())) << expr << "dimension not matched"; + std::stringstream ss; + ss << prefix << "p["; + for (uint ii=0; iisecond, macros); + i = k-1; + continue; + } else if (src[j]=='@') { + // seperater syntex: @@ + i++; + continue; + } else + LOGf << "Jit error: Invalid syntax."; + } else + new_src += src[i]; + } catch (std::exception& e) { + int il = i, ir = i; + while (il>0 && src[il-1] != '\n') il--; + while (ir+1> "\nJit compiler error:\n" >> this_line; + } + } + return new_src; +} + +string OpCompiler::precompile(const unordered_map& defs, const string& src) { + unordered_map macros = defs; + return jittor::precompile(defs, src, macros); +} + +string OpCompiler::get_jit_src(Op* op) { + string name = op->name(); + string name2 = Op::op_name_to_file_name(name); + string name3 = Op::file_name_to_class_name(name2); + if (name == "fused") { + string src = get_fused_src((FusedOp*)op); + ASSERT(src.size()); + return src; + } + auto op_info = get_op_info(name); + auto& src_path = op_info.source_path; + + string begin_src = "", end_src = ""; + // source that need to be added after the last #include statement + string after_include_src = ""; + auto jit_define = op->get_jit_define(); + for (auto &t : jit_define) { + // don't add CODE in define + // this allowed comment exsit in CODE + if (t.first == "CODE" || t.first == "HEADER") + continue; + string src = "#define " + t.first + " "; + for (char c : t.second) { + if (c=='\n') src += '\\'; + src += c; + } + src += '\n'; + if (startswith(t.first, "JIT")) + begin_src += src; + else + after_include_src += src; + } + ASSERT(file_exist(_to_winstr(src_path))) << src_path; + LOGvvv << "Read from" << src_path; + string src = read_all(_to_winstr(src_path)); + ASSERT(src.size()) << "Source read failed:" << src_path; + + unordered_map defs(jit_define.begin(), jit_define.end()); + LOGvvv << "Precompile with key:" << defs; + src = precompile(defs, src); + + // find the last occur of #include "..."\n + auto pos = src.rfind("#include"); + if (pos == string::npos) pos=0; + else { + // find \n + pos = src.find("\n", pos); + if (pos == string::npos) + pos = src.size(); + else + pos++; + } + + string new_src = begin_src + src.substr(0, pos) + + after_include_src + src.substr(pos) + "\n" + end_src; + return new_src; +} + +string OpCompiler::get_fused_src(FusedOp* op) { + vector op_srcs; + vector relay_switch(op->context->vrm.relay_groups.size()); + for (uint i=0; iloop_options->count(relay_key) && + op->loop_options->at(relay_key) == 1) + relay_switch[i] = 1; + } + auto relay_source = op->context->vrm.get_op_relay_info(relay_switch); + std::set> relayed; + for (uint oi=0; oiops.size(); oi++) { + // relay group id, pair id + auto p = relay_source[oi]; + if (p.first != -1) { + if (relayed.count(p)) { + op_srcs.push_back(""); + continue; + } + relayed.insert(p); + auto src = op->context->vrm.get_relay_src(p.first, p.second); + op_srcs.push_back(src); + // op_srcs.push_back(get_relayed_src(src)); + continue; + } + Op* opi = op->ops[oi]; + string src = get_jit_src(opi); + op_srcs.push_back(move(src)); + } + return OpCompiler::__get_fused_src(op->ops, op_srcs, op_members); +} + +static void fix_op_member( + const vector& ops, + vector>& op_members +) { + // fill op member: [in0, in1, ... inN, fill, fill, out0, out1, ...] + for (int i=0; iinputs().size() + op->outputs().size(); + auto& member = op_members.at(i); + if (!member.size()) { + continue; + } + ASSERT(member.size() <= var_num); + while (member.size() < var_num) { + member.insert(member.end() - op->outputs().size(), "__fill__"); + } + } +} + +string OpCompiler::__get_fused_src( + const vector& ops, + const vector& op_srcs, + vector>& op_members +) { + string fused_begin; + string fused_includes; + string fused_defines; + string fused_kernel_args; + string fused_kernel; + // definitions of fused_begin + map defs; + unordered_set kernel_args; + op_members = vector>(op_srcs.size()); + fused_begin += "#define JIT 1\n"; + defs["JIT"] = "1"; + const string pattern = "::jit_run() {"; + // TODO: better check member + const unordered_set members = { + "x", "y", "z", "cond", "output", "extras" + }; + const unordered_set scalar_members = { + "left", "right" + }; + const unordered_set unchanged = { + "for", "const", "auto", "get_random_engine", + "int", "float", "bool", "CHECK", "STRINGIZE", + "void", "__restrict__", "if", "true", "false", + "Op", "Var", "Node", "itof", "assert", "ASSERT", + "float64" + }; + auto not_change = [&](const string& s) -> bool { + if (unchanged.count(s)) return true; + for (auto op_type : op_types) + if (op_type->types.count(s)) + return true; + return (s.find("::") != string::npos) || (s.find("LOG") != string::npos); + }; + // regex find XxxXxxOp::jit_run + std::regex e(R"([^]*\s(\S*)Op::jit_run[^]*)"); + for (uint oi=0; oiname()==string("array")) { + string op_name = "op" + S(oi); + string arg_name = op_name + "_output"; + string argp_name = op_name + "_outputp"; + string T = ((ArrayOp*)ops[oi])->output->dtype().to_cstring(); + + fused_kernel_args += precompile({{"oi",S(oi)}, {"T", T}}, R"( + Var* op@oi@@_output = ((ArrayOp*)(ops[@oi]))->output; + @T op@oi@@_outputv = ((ArrayOp*)(ops[@oi]))->ptr<@T>()[0]; + )"); + + + fused_kernel += precompile({{"oi",S(oi)}, {"T", T}}, R"( + @T* op@oi@@_outputp = op@oi@@_output->ptr<@T>(); + op@oi@@_outputp[0] = op@oi@@_outputv; + )"); + + + + fused_includes += "#include \"ops/array_op.h\"\n"; + op_members[oi].push_back(arg_name); + // auto opi = (ArrayOp*)(ops[i]); + // auto opi_output = opi->output; + // auto* opi_outputp = opi_output->ptr(); + // opi_outputp[0] = ((T*)(opi->buffer.get()))[0]; + continue; + } + std::smatch cm; + std::regex_match(src, cm, e); + ASSERT(cm.size()>=2) << src; + string name3 = cm[1]; + for (uint i=0; i') {} else + if (members.count(var) || scalar_members.count(var)) { + bool is_member = members.count(var); + string arg_name = "op" + S(oi) + "_" + var; + if (l" + var; + fused_kernel_args += ";\n"; + kernel_args.insert(arg_name); + if (is_member) + op_members[oi].push_back(arg_name); + } + fused_kernel += arg_name; + j = l-1; + continue; + } else + fused_kernel += "op" + S(oi) + "_"; + for (uint p=j; p> fused_kernel; + + auto fused_src = fused_begin + fused_includes + + "\n#include \n" + + "\n#include \"fused_op.h\"\n" + + fused_defines + '\n' + + "void jittor::FusedOp::jit_run() {\n" + fused_kernel + "\n}\n"; + + // we assume the member name is in lexicographical order + // for (auto& v : op_members) std::sort(v.begin(), v.end()); + + return fused_src; +} + +string OpCompiler::get_src() { + if (op==nullptr) return src; + for (const auto& p : *op->loop_options) + if (startswith(p.first, "relay")) { + // return get jit src if has relay op + return get_jit_src(op); + } + return src; +} + +OpCompiler::OpCompiler(Op* op) { + _op = op; + this->op = op->name()==string("fused") ? (FusedOp*)op : nullptr; + src = get_jit_src(op); +} + +jit_op_entry_t OpCompiler::compile(const string& jit_key, const string& src) { + // add extra flags for custom ops + bool is_cuda = _op->flags.get(NodeFlags::_cuda); + auto op_info = get_op_info(_op->name()); + string extra_flags = op_info.extra_flags; + for (auto v : _op->outputs()) + if (v->loop_options) + for (auto& kv : v->loop_options.data()) { + if (kv.second && startswith(kv.first, "FLAGS:")) + extra_flags += " " + kv.first.substr(6) + " "; + } + return jit_compiler::compile(jit_key, src, is_cuda, extra_flags); +} + +jit_op_entry_t (*do_compile_hook)(Op*) = nullptr; + +jit_op_entry_t do_compile_inner(Op* op) { + OpCompiler oc(op); + string* src = &oc.src; + for (auto op_type : op_types) + op_type->post_pass(&oc); + string src_after_passes; + // if is fused op + if (oc.op) { + TunerManager tm(&oc); + src_after_passes = tm.tune(); + src = &src_after_passes; + } + op->compile_optimize(*src); + auto ret = oc.compile(op->get_jit_key(get_jk()), *src); + return ret; +} + +jit_op_entry_t OpCompiler::do_compile(Op* op) { + jittor::lock_guard lg; + if (do_compile_hook) return do_compile_hook(op); + return do_compile_inner(op); +} + +} diff --git a/python/jittor/src/op_compiler.h b/python/jittor/src/op_compiler.h new file mode 100644 index 00000000..8d6a9013 --- /dev/null +++ b/python/jittor/src/op_compiler.h @@ -0,0 +1,50 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +namespace jittor { + +// @pyjt(op_compiler) +// @attrs(submodule) +struct OpCompiler { + // origin op ptr + Op* _op; + // if _op is a fused_op then op==_op, else op==nullptr + FusedOp* op; + // op source + string src; + // only available when op is fused op + // op_members[i][j] represents i-th op's j-th member + vector> op_members; + + OpCompiler(Op*); + string get_src(); + void get_op_var_by_name(const string& name, uint& op_id, uint& opvar_id, Op*& op, Var*& var); + string get_name_by_op_var(Op* op, Var* var); + string get_name_by_op_input(Op* op, uint i); + string get_name_by_op_output(Op* op, uint i); + // op may be relay and not exist + bool op_exist(Op* op); + int total_member_count(); + + string get_jit_src(Op* op); + string get_fused_src(FusedOp* op); + jit_op_entry_t compile(const string& jit_key, const string& src); + static string __get_fused_src( + const vector& ops, + const vector& op_srcs, + vector>& op_members + ); + // @pyjt(eval) + static int64 eval(const string& expr, const unordered_map& vars); + // @pyjt(precompile) + static string precompile(const unordered_map& defs, const string& src); + static jit_op_entry_t do_compile(Op* op); +}; + +} \ No newline at end of file diff --git a/python/jittor/src/ops/arg_reduce_op.cc b/python/jittor/src/ops/arg_reduce_op.cc new file mode 100644 index 00000000..5f735b96 --- /dev/null +++ b/python/jittor/src/ops/arg_reduce_op.cc @@ -0,0 +1,212 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "ops/arg_reduce_op.h" +#include +#include "executor.h" +#include "misc/cuda_flags.h" +#include "ops/op_register.h" + +namespace jittor { + +#ifndef JIT + +#ifdef HAS_CUDA +static auto make_array = get_op_info("array") + .get_constructor(); +static auto make_binary = get_op_info("binary") + .get_constructor(); +static auto make_transpose = get_op_info("transpose") + .get_constructor(); +#endif + +static auto make_index = get_op_info("index") + .get_constructor(); +static auto make_reshape = get_op_info("reshape") + .get_constructor(); +static auto make_reindex_reduce = get_op_info("reindex_reduce") + .get_constructor&&, vector&&, vector&&>(); + +ArgReduceOp::ArgReduceOp(Var* x, NanoString op, int dim, bool keepdims) + : x(x), op(op), dim(dim), keepdims(keepdims) { + if (this->dim == -1) + this->dim = x->shape.size() - 1; + dim = this->dim; + #ifdef HAS_CUDA + if (use_cuda) { + static auto cub_arg_reduce = has_op("cub_arg_reduce") ? + get_op_info("cub_arg_reduce").get_constructor, Var*, Var*, NanoString, bool>() + : nullptr; + if (cub_arg_reduce) { + int dims = x->shape.size(); + vector axes; + axes.reserve(dims); + for (int i = 0; i < dims; ++i) + if (i != dim) + axes.push_back(i); + axes.push_back(dim); + auto tranpose1 = make_transpose(x, axes); + + int m = 1; + for (int i = 0; i < dims - 1; ++i) { + m *= tranpose1->shape[i]; + } + int n = tranpose1->shape[dims - 1]; + auto one = make_array(&n, 1, ns_int32); + auto offsets1 = make_index({m+1}, 0, ns_int32); + auto offsets = make_binary(one, offsets1, ns_multiply); + auto var = cub_arg_reduce(tranpose1, offsets, op, keepdims); + if (keepdims) { + vector axes2; + axes2.reserve(dims); + for (int i = 0; i < dims; ++i) { + if (i == dim) axes2.push_back(dims - 1); + if (i < dims - 1) axes2.push_back(i); + } + auto tranpose2_0 = make_transpose(var[0], axes2); + auto tranpose2_1 = make_transpose(var[1], axes2); + forward(tranpose2_0); + forward(tranpose2_1); + } else { + auto tranpose2_0 = var[0]; + auto tranpose2_1 = var[1]; + forward(tranpose2_0); + forward(tranpose2_1); + } + return; + } + } + #endif + y = create_output(nullptr, ns_int32); + y_key = create_output(nullptr, x->dtype()); + flags.set(NodeFlags::_manual_set_vnbb); + y->flags.set(NodeFlags::_needed_by_backward); +} +VarPtr ArgReduceOp::get_grad(Var* out, Var* dout, Var* v, int v_index, int dim, Var* y) { + // Do not have grad to extras input + if (v_index) return nullptr; + vector shape; + shape.reserve(v->shape.size()); + for (int i = 0; i < v->shape.size(); ++i) + if (i == dim) { + shape.push_back(1); + } else { + shape.push_back(v->shape[i]); + } + auto reshape1 = make_reshape(dout, NanoVector(shape)); + auto reshapey = make_reshape(y, shape); + + vector indexes; + vector indexes_; + vector indexes__; + // auto& shape = v->shape; + for (int i = 0; i < shape.size(); ++i) { + if (i == dim) { + indexes.push_back(reshapey); + } else { + indexes.push_back(make_index(shape, i, ns_int32)); + } + indexes_.push_back(indexes.back()); + } + + string temp; + temp.reserve(6+3*shape.size()); // @e0(i0,i1) + temp += "@e0("; + for (uint i=0; ishape, move(indexes__), {}, move(indexes_)); +} + +VarPtr ArgReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) { + return get_grad(out, dout, v, v_index, dim, y); +} + +void ArgReduceOp::infer_shape() { + ASSERTop(dim,>=,0); + ASSERTop(dim,<,(int)x->shape.size()); + NanoVector shape; + for (int i = 0; i < x->shape.size(); ++i) { + if (i == dim) { + if (keepdims) + shape.push_back(1); + } else { + shape.push_back(x->shape[i]); + } + } + if (shape.size() == 0) + shape.push_back(1); + y->set_shape(shape); + y_key->set_shape(shape); +} + +void ArgReduceOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype(); + jk << "«Ty:" << y->dtype(); + jk << "«XDIM=" << JK::hex1(x->shape.size()); + jk << "«YDIM=" << JK::hex1(y->shape.size()); + jk << "«KEEPDIMS:" << (keepdims ? '1' : '0'); + jk << "«DIM=" << JK::hex1(dim); + jk << "«CMP:" << (op==ns_minimum ? "<" : ">"); +} + +#else // JIT +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma GCC diagnostic ignored "-Wunused-variable" +void ArgReduceOp::jit_run() { + auto* __restrict__ xp = x->ptr(); + // define x shape + @for(i, 0, XDIM, index_t xshape@i = x->shape[@i];) + // define x stride + index_t xstride@{XDIM-1} = 1; + @for(i, XDIM-2, -1, -1, auto xstride@i = xstride@{i+1} * xshape@{i+1};) + + // define y shape + @for(i, 0, YDIM, index_t yshape@i = y->shape[@i];) + // define y stride + index_t ystride@{YDIM-1} = 1; + @for(i, YDIM-2, -1, -1, auto ystride@i = ystride@{i+1} * yshape@{i+1};) + + auto* __restrict__ yp = y->ptr(); + auto* __restrict__ y_keyp = y_key->ptr(); + + @for(d, 0, DIM, for (index_t i@d=0; i@d < xshape@d; i@d++)) + @for(d, DIM+1, XDIM, for (index_t i@d=0; i@d < xshape@d; i@d++)) { + auto yid = 0@for(d, 0, DIM, + i@d * ystride@d); + @if(KEEPDIMS, yid += 0 @for(d, DIM + 1, XDIM, + i@d * ystride@d), yid += 0 @for(d, DIM + 1, XDIM, + i@d * ystride@{d-1})); + + auto x0id = 0@for(d, 0, DIM, + i@d * xstride@d); + x0id += 0 @for(d, DIM + 1, XDIM, + i@d * xstride@d); + + y_keyp[yid] = xp[x0id]; + yp[yid] = 0; + + for (index_t i@DIM=0; i@DIM < xshape@DIM; i@DIM++){ + auto xid = @for(d, 0, XDIM, + i@d * xstride@d); + if (xp[xid]@CMP@@y_keyp[yid]) { + y_keyp[yid] = xp[xid]; + yp[yid] = i@DIM; + } + } + } +} +#endif // JIT + +} // jittor diff --git a/python/jittor/src/ops/arg_reduce_op.h b/python/jittor/src/ops/arg_reduce_op.h new file mode 100644 index 00000000..fb1de253 --- /dev/null +++ b/python/jittor/src/ops/arg_reduce_op.h @@ -0,0 +1,57 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + + +namespace jittor { + +struct ArgReduceOp : Op { + Var* x, * y, * y_key; + NanoString op; + int dim; + bool keepdims; + + /** + Returns the indices of the maximum / minimum of the input across a dimension. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] op: "max" or "min". + + * [in] dim: int. Specifies which dimension to be reduced. + + * [in] keepdims: bool. Whether the output has ``dim`` retained or not. + + ---------------- + + Example-1:: + >>> x = jt.randint(0, 10, shape=(2, 3)) + >>> x + jt.Var([[4 2 5] + [6 7 1]], dtype=int32) + >>> jt.arg_reduce(x, 'max', dim=1, keepdims=False) + [jt.Var([2 1], dtype=int32), jt.Var([5 7], dtype=int32)] + >>> jt.arg_reduce(x, 'min', dim=1, keepdims=False) + [jt.Var([1 2], dtype=int32), jt.Var([2 1], dtype=int32)] + */ + // @attrs(multiple_outputs) + ArgReduceOp(Var* x, NanoString op, int dim, bool keepdims); + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + static VarPtr get_grad(Var* out, Var* dout, Var* v, int v_index, int dim, Var* y); + void infer_shape() override; + + const char* name() const override { return "arg_reduce"; } + DECLARE_jit_run; +}; + +} // jittor diff --git a/python/jittor/src/ops/argsort_op.cc b/python/jittor/src/ops/argsort_op.cc new file mode 100644 index 00000000..4fd440d0 --- /dev/null +++ b/python/jittor/src/ops/argsort_op.cc @@ -0,0 +1,174 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "ops/argsort_op.h" +#include +#include "executor.h" +#include "misc/cuda_flags.h" +#include "ops/op_register.h" +namespace jittor { + +#ifndef JIT + +static auto make_index = get_op_info("index") + .get_constructor(); +static auto make_reindex_reduce = get_op_info("reindex_reduce") + .get_constructor&&, vector&&, vector&&>(); + +#ifdef HAS_CUDA +static auto make_array = get_op_info("array") + .get_constructor(); +static auto make_binary = get_op_info("binary") + .get_constructor(); +static auto make_transpose = get_op_info("transpose") + .get_constructor(); +#endif + +ArgsortOp::ArgsortOp(Var* x, int dim, bool descending, NanoString dtype) + : x(x), dim(dim), descending(descending) { + if (this->dim == -1) + this->dim = x->shape.size() - 1; + dim = this->dim; + #ifdef HAS_CUDA + if (use_cuda) { + static std::vector(*cub_argsort)(Var*, Var*, Var*, bool, NanoString) = nullptr; + if (!cub_argsort && has_op("cub_argsort")) { + cub_argsort = get_op_info("cub_argsort") + .get_constructor, Var*, Var*, Var*, bool, NanoString>(); + } + if (cub_argsort) { + int dims = x->shape.size(); + vector axes; + axes.reserve(dims); + for (int i = 0; i < dims; ++i) + if (i != dim) + axes.push_back(i); + axes.push_back(dim); + auto tranpose1 = make_transpose(x, axes); + + auto indexes = make_index(tranpose1->shape, dims - 1, ns_int32); + int m = 1; + for (int i = 0; i < dims - 1; ++i) { + m *= tranpose1->shape[i]; + } + int n = tranpose1->shape[dims - 1]; + auto one = make_array(&n, 1, ns_int32); + auto offsets1 = make_index({m+1}, 0, ns_int32); + auto offsets = make_binary(one, offsets1, ns_multiply); + auto var = cub_argsort(tranpose1, indexes, offsets, descending, dtype); + vector axes2; + axes2.reserve(dims); + for (int i = 0; i < dims; ++i) { + if (i == dim) axes2.push_back(dims - 1); + if (i < dims - 1) axes2.push_back(i); + } + auto tranpose2_0 = make_transpose(var[0], axes2); + auto tranpose2_1 = make_transpose(var[1], axes2); + forward(tranpose2_0); + forward(tranpose2_1); + return; + } + } + #endif + y = create_output(nullptr, dtype); + y_key = create_output(nullptr, x->dtype()); + flags.set(NodeFlags::_manual_set_vnbb); + y->flags.set(NodeFlags::_needed_by_backward); +} + +VarPtr ArgsortOp::get_grad(Var* out, Var* dout, Var* v, int v_index, int dim, Var* y) { + // Do not have grad to extras input + if (v_index) return nullptr; + vector indexes; + vector indexes_; + vector indexes__; + auto& shape = v->shape; + for (int i = 0; i < v->shape.size(); ++i) { + if (i == dim) { + indexes.push_back(y); + } else { + indexes.push_back(make_index(v->shape, i, ns_int32)); + } + indexes_.push_back(indexes.back()); + } + + string temp; + temp.reserve(6+3*shape.size()); // @e0(i0,i1) + temp += "@e0("; + for (uint i=0; ishape, move(indexes__), {}, move(indexes_)); +} + +VarPtr ArgsortOp::grad(Var* out, Var* dout, Var* v, int v_index) { + return get_grad(out, dout, v, v_index, dim, y); +} + +void ArgsortOp::infer_shape() { + ASSERTop(dim,>=,0); + ASSERTop(dim,<,(int)x->shape.size()); + y->set_shape(x->shape); + y_key->set_shape(x->shape); +} + +void ArgsortOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype(); + jk << "«Ty:" << y->dtype(); + jk << "«XDIM=" << JK::hex1(x->shape.size()); + jk << "«DIM=" << JK::hex1(dim); + jk << "«CMP:" << (descending ? '>' : '<'); +} + +#else // JIT +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma GCC diagnostic ignored "-Wunused-variable" +void ArgsortOp::jit_run() { + auto* __restrict__ xp = x->ptr(); + // define x shape + @for(i, 0, XDIM, index_t xshape@i = x->shape[@i];) + // define x stride + index_t xstride@{XDIM-1} = 1; + @for(i, XDIM-2, -1, -1, auto xstride@i = xstride@{i+1} * xshape@{i+1};) + + auto* __restrict__ yp = y->ptr(); + auto* __restrict__ y_keyp = y_key->ptr(); + std::vector tempx(xshape@DIM); + std::vector tempy(xshape@DIM); + + @for(d, 0, DIM, for (index_t i@d=0; i@d < xshape@d; i@d++)) + @for(d, DIM+1, XDIM, for (index_t i@d=0; i@d < xshape@d; i@d++)) { + for (index_t i@DIM=0; i@DIM < xshape@DIM; i@DIM++){ + auto xid = @for(d, 0, XDIM, + i@d * xstride@d); + tempx[i@DIM] = xp[xid]; + tempy[i@DIM] = i@DIM; + } + std::sort(tempy.begin(), tempy.end(), [&](Ty i, Ty j) -> bool { return tempx[i]@CMP@@tempx[j]; }); + + for (index_t i@DIM=0; i@DIM < xshape@DIM; i@DIM++){ + auto xid = @for(d, 0, XDIM, + i@d * xstride@d); + yp[xid] = tempy[i@DIM]; + y_keyp[xid] = tempx[tempy[i@DIM]]; + } + } +} +#endif // JIT + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/argsort_op.h b/python/jittor/src/ops/argsort_op.h new file mode 100644 index 00000000..c82ecad9 --- /dev/null +++ b/python/jittor/src/ops/argsort_op.h @@ -0,0 +1,71 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + + +namespace jittor { + +struct ArgsortOp : Op { + Var* x, * y, * y_key; + string cmp; + int dim; + bool descending; + /** + Argsort Operator Perform an indirect sort by given key or compare function. + + x is input, y is output index, satisfy: + + x[y[0]] <= x[y[1]] <= x[y[2]] <= ... <= x[y[n]] + + or + + key(y[0]) <= key(y[1]) <= key(y[2]) <= ... <= key(y[n]) + + or + + compare(y[0], y[1]) && compare(y[1], y[2]) && ... + + * [in] x: input var for sort + + * [in] dim: sort alone which dim + + * [in] descending: the elements are sorted in descending order or not(default False). + + * [in] dtype: type of return indexes + + * [out] index: index have the same size with sorted dim + + * [out] value: sorted value + + + Example:: + + index, value = jt.argsort([11,13,12]) + # return [0 2 1], [11 12 13] + index, value = jt.argsort([11,13,12], descending=True) + # return [1 2 0], [13 12 11] + index, value = jt.argsort([[11,13,12], [12,11,13]]) + # return [[0 2 1],[1 0 2]], [[11 12 13],[11 12 13]] + index, value = jt.argsort([[11,13,12], [12,11,13]], dim=0) + # return [[0 1 0],[1 0 1]], [[11 11 12],[12 13 13]] + + */ + // @attrs(multiple_outputs) + ArgsortOp(Var* x, int dim=-1, bool descending=false, NanoString dtype=ns_int32); + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + static VarPtr get_grad(Var* out, Var* dout, Var* v, int v_index, int dim, Var* y); + void infer_shape() override; + + const char* name() const override { return "argsort"; } + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/array_op.cc b/python/jittor/src/ops/array_op.cc new file mode 100644 index 00000000..0841a288 --- /dev/null +++ b/python/jittor/src/ops/array_op.cc @@ -0,0 +1,125 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#ifdef HAS_CUDA +#include +#include "helper_cuda.h" +#include "mem/allocator.h" +#include "mem/allocator/cuda_dual_allocator.h" +#include "event_queue.h" +#endif +#include +#include +#include "var.h" +#include "ops/array_op.h" +#include "misc/cuda_flags.h" +#include "mem/allocator.h" +#include "mem/swap.h" + +namespace jittor { + +#ifdef HAS_CUDA +#pragma GCC visibility push(hidden) +namespace array_local { +cudaStream_t stream; +cudaEvent_t event; + +struct Init { +Init() { + if (!get_device_count()) return; + checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + checkCudaErrors(cudaEventCreate(&event, cudaEventDisableTiming)); +} +~Init() { + if (!get_device_count()) return; + peekCudaErrors(cudaDeviceSynchronize()); + peekCudaErrors(cudaStreamDestroy(stream)); + peekCudaErrors(cudaEventDestroy(event)); +} +} init; + +} +using namespace array_local; + +#endif + +ArrayOp::ArrayOp(const void* ptr, NanoVector shape, NanoString dtype) + : ArrayOp(ArrayArgs{ptr, shape, dtype}) {} + +DECLARE_FLAG(int, use_cuda_host_allocator); + +ArrayOp::ArrayOp(ArrayArgs&& args) { + output = create_output(args.shape, args.dtype); + NanoVector shape = output->shape; + if (shape.size() == 1 && shape[0] == 1) { + output->flags.set(NodeFlags::_force_fuse); + output->flags.set(NodeFlags::_is_scalar); + set_type(OpType::element); + } + #ifdef HAS_CUDA + if (use_cuda && !save_mem && !use_cuda_host_allocator) { + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_cuda, 1); + if (!output->flags.get(NodeFlags::_force_fuse)) { + // free prev allocation first + event_queue.flush(); + // alloc new allocation + auto size = output->size; + new (&allocation) Allocation(&cuda_dual_allocator, size); + auto host_ptr = cuda_dual_allocator.get_dual_allocation(allocation.allocation).host_ptr; + std::memcpy(host_ptr, args.ptr, output->size); + return; + } + } + #endif + // TODO: args.buffer too many copy + new (&allocation) Allocation(cpu_allocator, output->size); + std::memcpy(allocation.ptr, args.ptr, output->size); +} + +void ArrayOp::jit_prepare(JK& jk) { + if (output->flags.get(NodeFlags::_force_fuse)) { + jk << "«T:" << output->dtype(); + + // fill or find cbuffer for const var pass + if (output->dtype().dsize() == 4) { + auto x = std::abs(ptr()[0]); + auto y = std::abs(ptr()[0]); + auto z = ptr()[0]; + if ((x<=2) || (y==1.0f || y==2.0f)) + jk << "«o:" << z; + } + // end of fill cbuffer + } +} + +void ArrayOp::run() { + #ifdef HAS_CUDA + if (allocation.allocator == &cuda_dual_allocator) { + auto host_ptr = cuda_dual_allocator.get_dual_allocation(allocation.allocation).host_ptr; + checkCudaErrors(cudaMemcpyAsync( + allocation.ptr, host_ptr, allocation.size, cudaMemcpyHostToDevice, stream)); + checkCudaErrors(cudaEventRecord(event, stream)); + checkCudaErrors(cudaStreamWaitEvent(0, event, 0)); + // delay free this allocation + allocation.allocator = &delay_free; + } + #endif + // free prev allocation and move into it + auto o = output; + if (save_mem) + free_with_swap(o); + else + o->allocator->free(o->mem_ptr, o->size, o->allocation); + + o->mem_ptr = allocation.ptr; + allocation.ptr = nullptr; + o->allocator = allocation.allocator; + o->allocation = allocation.allocation; + if (save_mem) registe_swap(o); +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/array_op.h b/python/jittor/src/ops/array_op.h new file mode 100644 index 00000000..385de4db --- /dev/null +++ b/python/jittor/src/ops/array_op.h @@ -0,0 +1,40 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" +#include "mem/allocator.h" + +typedef struct _object PyObject; + +namespace jittor { + +struct ArrayArgs { + const void* ptr; + NanoVector shape; + NanoString dtype; + unique_ptr buffer; +}; + +struct ArrayOp : Op { + Var* output; + Allocation allocation; + // @pybind(None) + ArrayOp(const void* ptr, NanoVector shape, NanoString dtype=ns_float32); + + // @pybind(array_) + ArrayOp(ArrayArgs&& args); + + ArrayOp(PyObject* obj); + template + inline T* ptr() { return (T*)allocation.ptr; } + + const char* name() const override { return "array"; } + void run() override; + void jit_prepare(JK& jk) override; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/binary_op.cc b/python/jittor/src/ops/binary_op.cc new file mode 100644 index 00000000..848d40a4 --- /dev/null +++ b/python/jittor/src/ops/binary_op.cc @@ -0,0 +1,573 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/op_register.h" + +namespace jittor { + +#ifndef JIT +static auto make_array = get_op_info("array") + .get_constructor(); +static auto make_broadcast_to = get_op_info("broadcast_to") + .get_constructor(); +static auto make_binary = get_op_info("binary") + .get_constructor(); +static auto make_unary = get_op_info("unary") + .get_constructor(); +static auto make_ternary = get_op_info("ternary") + .get_constructor(); +static auto make_number = get_op_info("number") + .get_constructor(); + +unordered_set binary_ops = { + /** + Computes ``x^y``, element-wise. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var. + + */ + // @pybind(pow, __pow__) + "pow", + + /** + Returns the element-wise maximum of ``x`` and ``y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var. + + */ + "maximum", + + /** + Returns the element-wise minimum of ``x`` and ``y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var. + + */ + "minimum", + + /** + Element-wise adds ``x`` and ``y`` and returns a new Var. + + This operation is equivalent to ``x + y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var. + + */ + // @pybind(add, __add__) + "add", + + /** + Element-wise subtract ``y`` from ``x`` and returns a new Var. + + This operation is equivalent to ``x - y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var. + + */ + // @pybind(subtract, __sub__, sub) + "subtract", + + /** + Element-wise muliplies ``x`` with ``y`` and returns a new Var. + + This operation is equivalent to ``x * y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var. + + */ + // @pybind(multiply, __mul__, mul) + "multiply", + + /** + Element-wise divide ``x`` by ``y`` and returns a new Var. + + This operation is equivalent to ``x / y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var. + + ---------------- + + Example-1:: + >>> a = jt.empty((3,), dtype=jt.int32) + >>> a + jt.Var([707406378 707406378 707406378], dtype=int32) + >>> b = jt.empty((3,), dtype=jt.int32) + >>> b + jt.Var([674510453 171649398 538976288], dtype=int32) + >>> jt.divide(a, b) + jt.Var([1.0487701 4.1212287 1.3125001], dtype=float32) + >>> a / b + jt.Var([1.0487701 4.1212287 1.3125001], dtype=float32) + + .. note :: + returns float value even if the dtype of input Vars are both integers. + @see jt.ops.floor_divide() for floor division. + */ + // @pybind(divide, __truediv__, div) + "divide", + + /** + Element-wise divide ``x`` by ``y`` and returns the floor of the result. + + This operation is equivalent to ``x // y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var. + + ---------------- + + Example-1:: + >>> a = jt.randint(1, 10, (3,), dtype=jt.int32) + >>> a + jt.Var([9 2 7], dtype=int32) + >>> b = jt.randint(1, 10, (3,), dtype=jt.int32) + >>> b + jt.Var([6 4 6], dtype=int32) + >>> jt.floor_divide(a, b) + jt.Var([1 0 1], dtype=int32) + >>> a // b + jt.Var([1 0 1], dtype=int32) + */ + // @pybind(floor_divide, __floordiv__) + "floor_divide", + + /** + Returns the element-wise remainder of division. + + This operation is equivalent to ``x % y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var. + + ---------------- + + Example-1:: + >>> a = jt.rand(3) + >>> a + jt.Var([0.3989529 0.20159635 0.22973768], dtype=float32) + >>> b = jt.rand(3) + >>> b + jt.Var([0.20121202 0.7704864 0.5654395 ], dtype=float32) + >>> jt.mod(a, b) + jt.Var([0.19774088 0.20159635 0.22973768], dtype=float32) + >>> a % b + jt.Var([0.19774088 0.20159635 0.22973768], dtype=float32) + */ + // @pybind(mod, __mod__) + "mod", + + /** + Returns ``x < y`` element-wise. + + This operation is equivalent to ``x < y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var. + + */ + // @pybind(less, __lt__) + "less", + + /** + Returns ``x <= y`` element-wise. + + This operation is equivalent to ``x <= y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var. + + */ + // @pybind(less_equal, __le__) + "less_equal", + + /** + Returns ``x > y`` element-wise. + + This operation is equivalent to ``x > y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var. + + */ + // @pybind(greater, __gt__) + "greater", + + /** + Returns ``x >= y`` element-wise. + + This operation is equivalent to ``x >= y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var. + + */ + // @pybind(greater_equal, __ge__) + "greater_equal", + + /** + Returns ``x == y`` element-wise. + + This operation is equivalent to ``x == y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var. + + */ + // @pybind(equal, __eq__) + "equal", + + /** + Returns ``x != y`` element-wise. + + This operation is equivalent to ``x != y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var. + + * [in] y: the second input, a python number or jt.Var. + + */ + // @pybind(not_equal, __ne__) + "not_equal", + + /** + Shifts the bits of ``x`` to the left by ``y``. + + Bits are shifted to the left by appending ``y`` 0s at the right of ``x``. + This operation is equivalent to ``x << y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var (int32 or int64). + + * [in] y: the second input, a python number or jt.Var (int32 or int64). + + ---------------- + + Example-1:: + >>> a = jt.randint(0, 10, shape=(3,)) + >>> a + jt.Var([7 6 7], dtype=int32) + >>> b = jt.randint(0, 10, shape=(3,)) + >>> b + jt.Var([3 9 8], dtype=int32) + >>> jt.left_shift(a, b) + jt.Var([ 56 3072 1792], dtype=int32) + >>> a << b + jt.Var([ 56 3072 1792], dtype=int32) + */ + // @pybind(left_shift, __lshift__) + "left_shift", + + /** + Shifts the bits of ``x`` to the right by ``y``. + + This operation is equivalent to ``x >> y``. + + ---------------- + + * [in] x: the first input, a python number or jt.Var (int32 or int64). + + * [in] y: the second input, a python number or jt.Var (int32 or int64). + + ---------------- + + Example-1:: + >>> a = jt.randint(0, 1024, shape=(3,)) + >>> a + jt.Var([439 113 92], dtype=int32) + >>> b = jt.randint(0, 10, shape=(3,)) + >>> b + jt.Var([6 8 4], dtype=int32) + >>> jt.right_shift(a, b) + jt.Var([6 0 5], dtype=int32) + */ + // @pybind(right_shift, __rshift__) + "right_shift", + + /** + Returns the element-wise logical AND of the inputs. + + ---------------- + + * [in] x: the first input, jt.Var. + + * [in] y: the second input, jt.Var. + + */ + "logical_and", + + /** + Returns the element-wise logical OR of the inputs. + + ---------------- + + * [in] x: the first input, jt.Var. + + * [in] y: the second input, jt.Var. + + */ + "logical_or", + + /** + Returns the element-wise logical XOR of the inputs. + + ---------------- + + * [in] x: the first input, jt.Var. + + * [in] y: the second input, jt.Var. + + */ + "logical_xor", + + /** + Computes the bitwise AND of x and y. + + ---------------- + + * [in] x: the first input, jt.Var (integal or boolean). + + * [in] y: the second input, jt.Var (integal or boolean). + + */ + // @pybind(bitwise_and, __and__) + "bitwise_and", + + /** + Computes the bitwise OR of x and y. + + ---------------- + + * [in] x: the first input, jt.Var (integal or boolean). + + * [in] y: the second input, jt.Var (integal or boolean). + + */ + // @pybind(bitwise_or, __or__) + "bitwise_or", + + /** + Computes the bitwise XOR of x and y. + + ---------------- + + * [in] x: the first input, jt.Var (integal or boolean). + + * [in] y: the second input, jt.Var (integal or boolean). + + */ + // @pybind(bitwise_xor, __xor__) + "bitwise_xor", +}; + +BinaryOp::BinaryOp(Var* x, Var* y, NanoString op) : x(x), y(y) { + auto xdim = x->shape.size(); + auto ydim = y->shape.size(); + bool need_broadcast = xdim != ydim; + for (size_t i=0; ishape[xdim-i-1]; + auto yshape = y->shape[ydim-i-1]; + if ((xshape == 1 || yshape == 1) && (xshape != yshape)) { + need_broadcast = true; + continue; + } + CHECKop(xshape,==,yshape) << "Shape not match, x:" >> x->to_string() + << " y:" >> y->to_string(); + } + if (need_broadcast) { + auto xp = make_broadcast_to(x, y, {}); + auto yp = make_broadcast_to(y, x, {}); + auto zp = make_binary(xp, yp, op); + forward(zp); + return; + } + + #ifdef IS_ACL + if (x->dtype() != y->dtype()) { + auto dtype = binary_dtype_infer(ns_add, x->ns, y->ns, 0, 0); + auto xp = make_unary(x, dtype); + auto yp = make_unary(y, dtype); + auto zp = make_binary(xp, yp, op); + forward(zp); + return; + } + #endif + + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + set_type(OpType::element); + ns = op; + ASSERT(ns.is_binary()); + z = create_output(x->shape, binary_dtype_infer(op, x->ns, y->ns, x->flags.get(NodeFlags::_is_scalar), y->flags.get(NodeFlags::_is_scalar))); + bool bin = ns.get(NanoString::_no_need_back_in); + bool bout = ns.get(NanoString::_no_need_back_out); + if (bin || bout) { + flags.set(NodeFlags::_manual_set_vnbb); + if (!bin) { + if (!(y->is_stop_grad() && (op==ns_multiply || op==ns_divide))) + x->flags.set(NodeFlags::_needed_by_backward); + if (!(x->is_stop_grad() && (op==ns_multiply))) + y->flags.set(NodeFlags::_needed_by_backward); + } + if (!bout) { + z->flags.set(NodeFlags::_needed_by_backward); + } + } +} + +VarPtr dirty_clone_broadcast(Var* v) { + Op* op = v->input(); + // dirty fix conv duplicated + if (op && !v->is_finished() && v->shape.size() > 4 && op->type() == OpType::broadcast) { + auto vp = op->duplicate(); + if (vp) { + // TODO: loop options should be set to op, rather than var + if (v->loop_options) + vp->loop_options = v->loop_options; + return vp; + } + } + return v; +} + +VarPtr BinaryOp::grad(Var* out, Var* dout, Var* v, int v_index) { + if (ns == ns_add) return dout; + if (ns == ns_subtract) { + if (v_index == 0) + return dout; + else + return make_unary(dout, ns_negative); + } + if (ns == ns_multiply) { + if (v_index == 0) + return make_binary(dirty_clone_broadcast(y), dirty_clone_broadcast(dout), ns_multiply); + else + return make_binary(dirty_clone_broadcast(x), dirty_clone_broadcast(dout), ns_multiply); + } + if (ns == ns_divide) { + if (v_index == 0) + return make_binary(dout, y, ns_divide); + else { + // dy = -dz*x / y^2 + auto ndz = make_unary(dout, ns_negative); + auto ndzx = make_binary(ndz, x, ns_multiply); + auto y2 = make_binary(y, y, ns_multiply); + return make_binary(ndzx, y2, ns_divide); + } + } + if (ns == ns_mod) { + if (v_index == 0) + return dout; + else { + auto a = make_unary(make_binary(x, y, ns_divide), ns_floor); + return make_unary(a, ns_negative); + } + } + if (ns == ns_maximum || ns == ns_minimum) { + auto zeros = make_number(0, dout); + auto cond = make_binary(y, z, ns_equal); + if (v_index==1) + return make_ternary(cond, dout, zeros); + else + return make_ternary(cond, zeros, dout); + } + if (ns == ns_pow) { + if (v_index == 0) { + // dout * y * x^(y-1) + auto d = make_binary(dout, y, ns_multiply); + // auto ones = make_number(1, dout); + int number = 1; + auto ones = make_array(&number, 1, ns_int32); + auto y_1 = make_binary(y, ones, ns_subtract); + auto x_y_1 = make_binary(x, y_1, ns_pow); + return make_binary(d, x_y_1, ns_multiply); + } else { + // dout * x^y * log(x) + auto log_x = make_unary(x, ns_log); + auto x_y_log_x = make_binary(z, log_x, ns_multiply); + return make_binary(dout, x_y_log_x, ns_multiply); + } + } + return nullptr; +} + +void BinaryOp::infer_shape() { +} + +void BinaryOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype() + << "«Ty:" << y->dtype() + << "«Tz:" << z->dtype() + << "«OP:" << ns; +} + +#else // JIT +void BinaryOp::jit_run() { + auto* __restrict__ xp = x->ptr(); + auto* __restrict__ yp = y->ptr(); + auto* __restrict__ zp = z->ptr(); + index_t num = z->num; + for (index_t i=0; i. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct BinaryOp : Op { + Var* x, * y, * z; + BinaryOp(Var* x, Var* y, NanoString p); + + const char* name() const override { return "binary"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/broadcast_to_op.cc b/python/jittor/src/ops/broadcast_to_op.cc new file mode 100644 index 00000000..a9b0184a --- /dev/null +++ b/python/jittor/src/ops/broadcast_to_op.cc @@ -0,0 +1,197 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#include "var.h" +#include "ops/broadcast_to_op.h" +#include "ops/op_register.h" + +namespace jittor { + +#ifndef JIT +static auto make_reduce = get_op_info("reduce") + .get_constructor(); +static auto make_broadcast = get_op_info("broadcast_to") + .get_constructor(); +static auto make_broadcast2 = get_op_info("broadcast_to") + .get_constructor(); + +BroadcastToOp::BroadcastToOp(Var* x, Var* y, NanoVector dims) : x(x), y(y) { + auto count = dims.size(); + // forward x if don't need broadcast + if (y->num>=0 && !count && !need_broadcast(x, y->shape)) { + forward(x); + return; + } + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + flags.set(NodeFlags::_manual_set_vnbb); + set_type(OpType::broadcast); + z = create_output(NanoVector(), x->dtype()); + bcast_mask = 0; + keepdims_mask = 0; + auto ydim = std::max(x->shape.size(), y->shape.size()-count)+count; + for (auto dim : dims) { + if (dim<0) dim += ydim; + CHECK(dim>=0 && dimnum>=0 && !count && !need_broadcast(x, y->shape)) { + forward(x); + return; + } + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + set_type(OpType::broadcast); + z = create_output(NanoVector(), x->dtype()); + bcast_mask = dims_mask; + this->keepdims_mask = keepdims_mask; +} + +BroadcastToOp::BroadcastToOp(Var* x, NanoVector shape, uint dims_mask, uint keepdims_mask) : x(x), y(nullptr), shape(shape) { + auto count = __builtin_popcount(dims_mask); + if (!count) { + forward(x); + return; + } + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + set_type(OpType::broadcast); + z = create_output(NanoVector(), x->dtype()); + bcast_mask = dims_mask; + this->keepdims_mask = keepdims_mask; +} + +BroadcastToOp::BroadcastToOp(Var* x, NanoVector shape, NanoVector dims) : x(x), y(nullptr), shape(shape) { + auto count = dims.size(); + // forward x if don't need broadcast + if (!count && !need_broadcast(x, shape)) { + forward(x); + return; + } + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + set_type(OpType::broadcast); + CHECKop(shape.size(),>,0u) << "Number of shape should greater than 0."; + for (auto v : shape) + CHECKop(v,>=,0u) << "Shape should greater than 0."; + z = create_output(nullptr, x->dtype()); + bcast_mask = 0; + keepdims_mask = 0; + auto ydim = std::max(x->shape.size(), shape.size()-count)+count; + for (auto dim : dims) { + if (dim<0) dim += ydim; + CHECK(dim>=0 && dimshape.size() < shape.size()) return true; + for (uint i=shape.size()-1, j=x->shape.size()-1; ishape[j] != shape[i] && shape[i] != 1)) return true; + return false; +} + +VarPtr BroadcastToOp::duplicate() { + if (y) + return make_broadcast(x, y, bcast_mask, keepdims_mask); + else + return make_broadcast2(x, shape, bcast_mask, keepdims_mask); +} + +VarPtr BroadcastToOp::grad(Var* out, Var* dout, Var* v, int v_index) { + if (v_index==1) return nullptr; + if (bcast_mask==0) return dout; + VarPtr dv = make_reduce(dout, ns_add, bcast_mask, keepdims_mask); + return dv; +} + +void BroadcastToOp::infer_shape() { + if (y && y->num>=0) { + // shape of y is already solved, we can remove deps + LOGvvvv << "Remove broadcast y deps" << y; + shape = y->shape; + set_inputs({x}); + y = nullptr; + } + auto yshapes = y ? y->shape : shape; + auto xdim = x->shape.size(); + auto ydim = yshapes.size(); + auto count = __builtin_popcount(bcast_mask&~keepdims_mask); + auto zdim = std::max(uint64(xdim), uint64(ydim-count)) + count; + + #ifdef _WIN32 + int64 zz[10]; + #else + int64 zz[zdim]; + #endif + for (int i=zdim-1, xi = xdim-1, yi = ydim-1; i>=0; i--) { + bool bx = xi>=0; + bool by = yi>=0; + auto xshape = bx ? x->shape[xi] : 1; + auto yshape = by ? yshapes[yi] : 1; + if (bcast_mask>>i&1) { + yi--; + if (keepdims_mask>>i&1) xi--; + zz[i] = yshape; + continue; + } + auto mask = ((xshape==1 && (yshape!=1 || !bx))&1) << i; + bcast_mask |= mask; + if (bx) keepdims_mask |= mask; + int64 zs; + if ((xshape == 1 || yshape == 1) && (xshape != yshape)) { + zs = xshape * yshape; + } else { + CHECKop(xshape,==,yshape) << "Shape not match" << x->shape << yshapes << bcast_mask; + zs = xshape; + } + zz[i] = zs; + xi--, yi--; + } + + NanoVector zshape; + for (int i=0; iset_shape(zshape); + z->flags.set(NodeFlags::_is_scalar, x->flags.get(NodeFlags::_is_scalar)); + LOGvvv << "Broadcast x(" >> x >> ") shape" << yshapes << "-> z(" >> z >> ")"; +} + +void BroadcastToOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype() + << "«DIM=" << JK::hex1(z->shape.size()) + << "«BCAST=" << JK::hex(bcast_mask); +} + +#else // JIT +void BroadcastToOp::jit_run() { + auto* __restrict__ xp = x->ptr(); + auto* __restrict__ zp = z->ptr(); + // define z shape + @for(i, 0, DIM, index_t zshape@i = z->shape[@i];) + // define z stride + index_t zstride@{DIM-1} = 1; + @for(i, DIM-2, -1, -1, auto zstride@i = zstride@{i+1} * zshape@{i+1};) + // define x stride + index_t xstride@{DIM-1} = 1; + @for(i, DIM-2, -1, -1, auto xstride@i = xstride@{i+1} * @if(BCAST>>(i+1)&1,1,zshape@{i+1});) + // generate d-for loop + @for(d, 0, DIM, for (index_t i@d=0; i@d < zshape@d; i@d++)) { + auto zid = @for(d, 0, DIM, + i@d * zstride@d); + auto xid = @for(d, 0, DIM, + @if(BCAST>>d&1,0,i@d) * xstride@d); + zp[zid] = xp[xid]; + } +} +#endif // JIT + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/broadcast_to_op.h b/python/jittor/src/ops/broadcast_to_op.h new file mode 100644 index 00000000..4f7c617b --- /dev/null +++ b/python/jittor/src/ops/broadcast_to_op.h @@ -0,0 +1,101 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + + +namespace jittor { + +struct BroadcastToOp : Op { + Var* x, * y, * z; + NanoVector shape; + uint16 bcast_mask; + uint16 keepdims_mask; + + /** + Broadcast ``x`` to a given shape. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] shape: the output shape. + + * [in] dims: specifies the new dimension in the output shape, an integer array. + + ---------------- + + Example-1:: + >>> x = jt.randint(0, 10, shape=(2, 2)) + >>> x + jt.Var([[8 1] + [7 6]], dtype=int32) + >>> jt.broadcast(x, shape=(2, 3, 2), dims=[1]) + jt.Var([[[8 1] + [8 1] + [8 1]], + [[7 6] + [7 6] + [7 6]]], dtype=int32) + */ + // @pybind(broadcast) + BroadcastToOp(Var* x, NanoVector shape, NanoVector dims=NanoVector()); + + /** + Broadcast ``x`` to the same shape as ``y``. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] y: the reference jt.Var. + + * [in] dims: specifies the new dimension in the output shape, an integer array. + + ---------------- + + .. note:: + jt.broadcast_var(x, y, dims) is an alias of jt.broadcast(x, y, dims) + + Example-1:: + >>> x = jt.randint(0, 10, shape=(2, 2)) + >>> x + jt.Var([[8 1] + [7 6]], dtype=int32) + >>> y = jt.randint(0, 10, shape=(2, 3, 2)) + >>> jt.broadcast(x, y, dims=[1]) + jt.Var([[[8 1] + [8 1] + [8 1]], + [[7 6] + [7 6] + [7 6]]], dtype=int32) + >>> jt.broadcast_var(x, y, dims=[1]) + jt.Var([[[8 1] + [8 1] + [8 1]], + [[7 6] + [7 6] + [7 6]]], dtype=int32) + */ + // @pybind(broadcast,broadcast_var) + BroadcastToOp(Var* x, Var* y, NanoVector dims=NanoVector()); + // @pybind(None) + BroadcastToOp(Var* x, Var* y, uint dims_mask, uint keepdims_mask); + // @pybind(None) + BroadcastToOp(Var* x, NanoVector shape, uint dims_mask, uint keepdims_mask); + + bool need_broadcast(const Var* x, const NanoVector& shape); + + const char* name() const override { return "broadcast_to"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; + VarPtr duplicate() override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/candidate_op.cc b/python/jittor/src/ops/candidate_op.cc new file mode 100644 index 00000000..e4ade706 --- /dev/null +++ b/python/jittor/src/ops/candidate_op.cc @@ -0,0 +1,133 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "ops/candidate_op.h" +#ifdef JIT_cuda +#include "executor.h" +#endif + +namespace jittor { + +#ifndef JIT +CandidateOp::CandidateOp(Var* x, string&& fail_cond, NanoString dtype) : x(x), fail_cond(move(fail_cond)) { + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + flags.set(NodeFlags::_manual_set_vnbb); + y = create_output(nullptr, dtype); +} + +void CandidateOp::infer_shape() { + y->set_shape({-x->shape[0]}); +} + +void CandidateOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype(); + jk << "«Ty:" << y->dtype(); + jk << "«FUNC:" << fail_cond; + jk << "«XDIM=" << JK::hex1(x->shape.size()); +} + +#else // JIT + +#ifdef JIT_cuda + +__global__ static void candidate_kernel( + @for(i, 0, XDIM, 1, index_t xshape@i, ) + Tx* __restrict__ xp, + Ty* __restrict__ yp, + bool* __restrict__ maskp, + int* __restrict__ np +) { + int n=0; + int tid = threadIdx.x; + int tnum = blockDim.x; + + // define cond stride + index_t xstride@{XDIM-1} = 1; + @for(i, XDIM-2, -1, -1, auto xstride@i = xstride@{i+1} * xshape@{i+1};) + + // generate d-for loop + for (index_t i=0; i < xshape0; i++) { + __syncthreads(); + if (!maskp[i]) continue; + if (tid == 0) { + yp[n] = i; + n++; + } + for (index_t j=i+1+tid; j < xshape0; j+=tnum) { + if (@FUNC) maskp[j] = 0; + } + } + if (tid == 0) { + np[0] = n; + } +} + + +void CandidateOp::jit_run() { + auto* __restrict__ xp = x->ptr(); + // define cond shape + @for(i, 0, XDIM, index_t xshape@i = x->shape[@i];) + + // define ys + auto* __restrict__ yp = y->ptr(); + size_t n_allocation; + int* np = (int*)exe.temp_allocator->alloc(4, n_allocation); + size_t mask_allocation; + bool* maskp = (bool*)exe.temp_allocator->alloc(xshape0, mask_allocation); + checkCudaErrors(cudaMemsetAsync(maskp, 1, xshape0)); + + candidate_kernel<<<1, std::max(1, std::min(1024, xshape0)) >>>( + @for(i, 0, XDIM, 1, xshape@i, ) + xp, + yp, + maskp, + np + ); + + int n=0; + // checkCudaErrors(cudaDeviceSynchronize()); + checkCudaErrors(cudaMemcpy(&n, np, 4, cudaMemcpyDeviceToHost)); + y->set_shape({n}); + exe.temp_allocator->free(np, 4, n_allocation); + exe.temp_allocator->free(maskp, xshape0, mask_allocation); +} +#else +void CandidateOp::jit_run() { + using namespace std; + auto* __restrict__ xp = x->ptr(); + // define cond shape + @for(i, 0, XDIM, index_t xshape@i = x->shape[@i];) + // define cond stride + index_t xstride@{XDIM-1} = 1; + @for(i, XDIM-2, -1, -1, auto xstride@i = xstride@{i+1} * xshape@{i+1};) + + // define ys + auto* __restrict__ yp = y->ptr(); + int64 n=0; + + // generate d-for loop + for (index_t i=0; i < xshape0; i++) { + bool pass = true; + for (index_t j_=0; j_ < n; j_++) { + index_t j = yp[j_]; + if (@FUNC) { + pass = false; + break; + } + } + if (pass) { + yp[n] = i; + n++; + } + } + y->set_shape({n}); +} +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/src/ops/candidate_op.h b/python/jittor/src/ops/candidate_op.h new file mode 100644 index 00000000..b6a59351 --- /dev/null +++ b/python/jittor/src/ops/candidate_op.h @@ -0,0 +1,66 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CandidateOp : Op { + Var* x; + Var* y; + string fail_cond; + /** + Candidate Operator Perform an indirect candidate filter by given a fail condition. + + x is input, y is output index, satisfy:: + + not fail_cond(y[0], y[1]) and + not fail_cond(y[0], y[2]) and not fail_cond(y[1], y[2]) and + ... + ... and not fail_cond(y[m-2], y[m-1]) + + Where m is number of selected candidates. + + Pseudo code:: + + y = [] + for i in range(n): + pass = True + for j in y: + if (@fail_cond): + pass = false + break + if (pass): + y.append(i) + return y + + * [in] x: input var for filter + + * [in] fail_cond: code for fail condition + + * [in] dtype: type of return indexes + + * [out] index: . + + Example:: + + jt.candidate(jt.random(100,2), '(@x(j,0)>@x(i,0))or(@x(j,1)>@x(i,1))') + # return y satisfy: + # x[y[0], 0] <= x[y[1], 0] and x[y[1], 0] <= x[y[2], 0] and ... and x[y[m-2], 0] <= x[y[m-1], 0] and + # x[y[0], 1] <= x[y[1], 1] and x[y[1], 1] <= x[y[2], 1] and ... and x[y[m-2], 1] <= x[y[m-1], 1] + */ + CandidateOp(Var* x, string&& fail_cond, NanoString dtype=ns_int32); + void infer_shape() override; + + const char* name() const override { return "candidate"; } + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/clone_op.cc b/python/jittor/src/ops/clone_op.cc new file mode 100644 index 00000000..47866c6c --- /dev/null +++ b/python/jittor/src/ops/clone_op.cc @@ -0,0 +1,47 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "ops/array_op.h" +#include "ops/op_register.h" +#include "ops/clone_op.h" + +namespace jittor { + +static auto make_clone = get_op_info("clone") + .get_constructor(); + +CloneOp::CloneOp(Var* x) : x(x) { + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + flags.set(NodeFlags::_manual_set_vnbb); + y = create_output(nullptr, x->dtype()); + if (x->name.ptr) + y->name = x->name; +} + +VarPtr CloneOp::grad(Var* out, Var* dout, Var* v, int v_index) { + return make_clone(dout); +} + +void CloneOp::infer_shape() { + y->set_shape(x->shape); + y->share_with(x); +} + +VarPtr detach(Var* x) { + auto y = make_clone(x); + y->input()->set_stop_grad(); + return y; +} + +VarPtr clone(Var* x) { + return make_clone(x); +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/clone_op.h b/python/jittor/src/ops/clone_op.h new file mode 100644 index 00000000..c42a5c66 --- /dev/null +++ b/python/jittor/src/ops/clone_op.h @@ -0,0 +1,25 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CloneOp : Op { + Var* x, * y; + CloneOp(Var* x); + + const char* name() const override { return "clone"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; +}; + +VarPtr detach(Var* x); + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/code_op.cc b/python/jittor/src/ops/code_op.cc new file mode 100644 index 00000000..021b615e --- /dev/null +++ b/python/jittor/src/ops/code_op.cc @@ -0,0 +1,255 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "ops/code_op.h" +#include "ops/op_register.h" +#include "misc/cuda_flags.h" + +#define __inline_static__ inline static + +#ifndef JIT + +namespace jittor { + +static auto make_code = get_op_info("code") + .get_constructor&&, string&&, vector&&, string&&, string&&, vector&&, string&&, DataMap&&>(); + +static inline void check_vary_shape(NanoVector v) { + ASSERT(v.size()) << "Vary shape should not be zero dimension"; + for (int i=0; i= 0)) + << "Vary shape should only occur in the first dimension:" << v; +} + +CodeOp::CodeOp(NanoVector shape, NanoString dtype, vector&& inputs, + string&& cpu_src, vector&& cpu_grad_src, string&& cpu_header, + string&& cuda_src, vector&& cuda_grad_src, string&& cuda_header, + DataMap&& data) + : _inputs(inputs), cpu_src(move(cpu_src)), cpu_grad_src(move(cpu_grad_src)), cpu_header(move(cpu_header)), + cuda_src(move(cuda_src)), cuda_grad_src(move(cuda_grad_src)), cuda_header(move(cuda_header)), + data(move(data)) +{ + flags.set(NodeFlags::_cpu, !!this->cpu_src.size()); + flags.set(NodeFlags::_cuda, !!this->cuda_src.size()); + _outputs.push_back(create_output(shape, dtype)); + + if (_outputs[0]->num < 0) { + check_vary_shape(_outputs[0]->shape); + } + if (this->cuda_grad_src.size() == 0 && this->cpu_grad_src.size() == 0) { + flags.set(NodeFlags::_manual_set_vnbb); + } +} + + +CodeOp::CodeOp( + vector&& shapes, vector&& dtypes, vector&& inputs, + string&& cpu_src, vector&& cpu_grad_src, string&& cpu_header, + string&& cuda_src, vector&& cuda_grad_src, string&& cuda_header, + DataMap&& data) + : _inputs(inputs), cpu_src(move(cpu_src)), cpu_grad_src(move(cpu_grad_src)), cpu_header(move(cpu_header)), + cuda_src(move(cuda_src)), cuda_grad_src(move(cuda_grad_src)), cuda_header(move(cuda_header)), + data(move(data)) +{ + flags.set(NodeFlags::_cpu, !!this->cpu_src.size()); + flags.set(NodeFlags::_cuda, !!this->cuda_src.size()); + CHECKop(shapes.size(),==,dtypes.size()) << "Number of outputs' shapes and dtypes should be the same"; + _outputs.resize(shapes.size()); + CHECKop(_outputs.size(),>,0); + for (int i=0; inum < 0) { + check_vary_shape(_outputs[i]->shape); + } + } + if (this->cuda_grad_src.size() == 0 && this->cpu_grad_src.size() == 0) + flags.set(NodeFlags::_manual_set_vnbb); +} + +CodeOp::CodeOp( + vector&& inputs, vector&& outputs, + string&& cpu_src, vector&& cpu_grad_src, string&& cpu_header, + string&& cuda_src, vector&& cuda_grad_src, string&& cuda_header, + DataMap&& data) + : _inputs(inputs), cpu_src(move(cpu_src)), cpu_grad_src(move(cpu_grad_src)), cpu_header(move(cpu_header)), + cuda_src(move(cuda_src)), cuda_grad_src(move(cuda_grad_src)), cuda_header(move(cuda_header)), + data(move(data)) +{ + flags.set(NodeFlags::_cpu, !!this->cpu_src.size()); + flags.set(NodeFlags::_cuda, !!this->cuda_src.size()); + _outputs.resize(outputs.size()); + CHECKop(_outputs.size(),>,0); + for (int i=0; ishape, o->dtype()); + _outputs[i]->share_with(o); + /* + TODO: vary shape not allowed in direct output + */ + } + if (this->cuda_grad_src.size() == 0 && this->cpu_grad_src.size() == 0) + flags.set(NodeFlags::_manual_set_vnbb); +} + + +VarPtr CodeOp::grad(Var* out, Var* dout, Var* v, int v_index) { + // Do not have grad to extras input + string cpu_src = v_index < cpu_grad_src.size() ? cpu_grad_src[v_index] : ""; + string cuda_src = v_index < cuda_grad_src.size() ? cuda_grad_src[v_index] : ""; + if (!cuda_src.size() && !cpu_src.size()) return nullptr; + auto inputs = clone(_inputs); + // TODO: remove unused deps + // dout -> dout + std::stringstream new_alias; + new_alias << "\n@alias(dout,in" << JK::dec2(inputs.size()) << ")\n"; + inputs.push_back(dout); + // _outputs[i] -> poutj + for (int i=0; i<_outputs.size(); i++) { + new_alias << "\n@alias(pout" << JK::dec2(i) << ",in" << JK::dec2(inputs.size()) << ")\n"; + if (_outputs[i] == out) + new_alias << "\n@alias(pout,in" << JK::dec2(inputs.size()) << ")\n"; + inputs.push_back(_outputs[i]); + } + auto alias = new_alias.str(); + return make_code( + _inputs[v_index]->shape, + _inputs[v_index]->dtype(), + move(inputs), + move(cpu_src), {}, alias+cpu_header, + move(cuda_src), {}, alias+cuda_header, + DataMap(data) + ); +} + +void CodeOp::jit_prepare(JK& jk) { + + // forward: in0 in1 in2 -> out0 out1 + // backward: in0 in1 in2 in3(pout0) in4(pout1) + jk << "«IN_SIZE:" << JK::dec2(_inputs.size()); + for (uint i=0; i<_inputs.size(); i++) { + jk << "«in" << JK::dec2(i) << "_dim:" + << JK::hex1(_inputs[i]->shape.size()); + jk << "«in" << JK::dec2(i) << "_type:" + << _inputs[i]->dtype(); + } + jk << "«OUT_SIZE:" << JK::dec2(_outputs.size()); + for (uint i=0; i<_outputs.size(); i++) { + jk << "«out" << JK::dec2(i) << "_dim:" + << JK::hex1(_outputs[i]->shape.size()); + jk << "«out" << JK::dec2(i) << "_type:" + << _outputs[i]->dtype(); + } + string& header = flags.get(NodeFlags::_cuda) ? + cuda_header : cpu_header; + string& src = flags.get(NodeFlags::_cuda) ? + cuda_src : cpu_src; + + jk << "«HEADER:" << header; + CHECK(src.size()); + jk << "\nnamespace jittor {\n"; + int i=0; + // move cuda kernel function into header + for (; iptr(); + @for(j, 0, in@i@@_dim, index_t in@i@@_shape@j = _inputs[@i]->shape[@j];) + ) + // define outputs + @for(i, 0, OUT_SIZE, + auto out@i = _outputs[@i]; + auto* __restrict__ out@i@@_p = _outputs[@i]->ptr(); + @for(j, 0, out@i@@_dim, index_t out@i@@_shape@j = _outputs[@i]->shape[@j];) + ) + + @PRECALC + + @CODE +} + +} // jittor + +#endif // JIT diff --git a/python/jittor/src/ops/code_op.h b/python/jittor/src/ops/code_op.h new file mode 100644 index 00000000..7c846ddd --- /dev/null +++ b/python/jittor/src/ops/code_op.h @@ -0,0 +1,279 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +typedef unordered_map DataMap; + +struct CodeOp : Op { + vector _inputs; + vector _outputs; + string cpu_src; + vector cpu_grad_src; + string cpu_header; + string cuda_src; + vector cuda_grad_src; + string cuda_header; + DataMap data; + /** + Code Operator for easily customized op. + + ---------------- + + * [in] shape: the output shape, a integer array + + * [in] dtype: the output data type + + * [in] inputs: A list of input jittor Vars + + * [in] cpu_src: cpu source code string, buildin value: + + * in{x}, in{x}_shape{y}, in{x}_stride{y}, in{x}_type, in{x}_p, @in0(...) + * out{x}, out{x}_shape{y}, out{x}_stride{y}, out{x}_type, out{x}_p, @out0(...) + * out, out_shape{y}, out_stride{y}, out_type, out_p, @out(...) + + * [in] cpu_header: cpu header code string. + + * [in] cuda_src: cuda source code string. + + * [in] cuda_header: cuda header code string. + + ---------------- + + Example-1:: + + from jittor import Function + import jittor as jt + + class Func(Function): + def execute(self, x): + self.save_vars = x + return jt.code(x.shape, x.dtype, [x], + cpu_src=''' + for (int i=0; i + @alias(a, in0) + @alias(b, out) + """, + cpu_src=""" + for (int i=0; i + using namespace std; + """, + cpu_src=""" + @alias(a, in0) + @alias(b, out0) + @alias(c, out1) + @b(0) = @c(0) = @a(0); + for (int i=0; i0) + @b(num_b++) = @a(i); + else + @c(num_c++) = @a(i); + } + b->set_shape({num_b}); + c->set_shape({num_c}); + """ + ) + assert (b.data == [5,3,1]).all() + assert (c.data == [-4,-2]).all() + + Example-5:: + + # This example shows how to customize code op + # compilation flags, such as add include search + # path, add definitions, or any command line options + + a = jt.random([10]) + b = jt.code(a.shape, a.dtype, [a], + cpu_src=''' + @out0(0) = HAHAHA; + ''') + # HAHAHA is defined in flags below + # /any/include/path can be change to any path you want to include + b.compile_options = {"FLAGS: -DHAHAHA=233 -I/any/include/path ": 1} + print(b[0]) + # will output 233 + + Example-6:: + + + # This example shows how to pass custom data + # into code op kernel without kernel recompiling. + # In this example, the data {"x":123} canbe vary + # and kernel will not recompile. + # NOTE: the data type pass into kernel is float64 + # cast to int if you want + + a = jt.code([1], "float32", inputs=[], + data = {"x":123}, + cpu_src=''' + @out0(0) = data["x"]; + ''').sync() + assert a.item() == 123 + + CUDA Example-1:: + + #This example shows how to use CUDA in code op. + import jittor as jt + from jittor import Function + jt.flags.use_cuda = 1 + + class Func(Function): + def execute(self, a, b): + self.save_vars = a, b + return jt.code(a.shape, a.dtype, [a,b], + cuda_src=''' + __global__ static void kernel1(@ARGS_DEF) { + @PRECALC + int i = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (; i>>(@ARGS); + ''') + + def grad(self, grad): + a, b = self.save_vars + return jt.code([a.shape, b.shape], [a.dtype, b.dtype], [a, b, grad], + cuda_src=''' + __global__ static void kernel2(@ARGS_DEF) { + @PRECALC + int i = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (; i>>(@ARGS); + ''') + + a = jt.random([100000]) + b = jt.random([100000]) + func = Func() + c = func(a,b) + print(c) + print(jt.grad(c, [a, b])) + + CUDA Example-2:: + + #This example shows how to use multi dimension data with CUDA. + import jittor as jt + from jittor import Function + jt.flags.use_cuda = 1 + + class Func(Function): + def execute(self, a, b): + self.save_vars = a, b + return jt.code(a.shape, a.dtype, [a,b], + cuda_src=''' + __global__ static void kernel1(@ARGS_DEF) { + @PRECALC + for (int i=blockIdx.x; i>>(@ARGS); + ''') + + def grad(self, grad): + a, b = self.save_vars + return jt.code([a.shape, b.shape], [a.dtype, b.dtype], [a, b, grad], + cuda_src=''' + __global__ static void kernel2(@ARGS_DEF) { + @PRECALC + for (int i=blockIdx.x; i>>(@ARGS); + ''') + + a = jt.random((100,100)) + b = jt.random((100,100)) + func = Func() + c = func(a,b) + print(c) + print(jt.grad(c, [a, b])) + */ + CodeOp(NanoVector shape, NanoString dtype, vector&& inputs={}, string&& cpu_src="", vector&& cpu_grad_src={}, string&& cpu_header="", string&& cuda_src="", vector&& cuda_grad_src={}, string&& cuda_header="", DataMap&& data={}); + + // @attrs(multiple_outputs) + CodeOp(vector&& shapes, vector&& dtypes, vector&& inputs={}, string&& cpu_src="", vector&& cpu_grad_src={}, string&& cpu_header="", string&& cuda_src="", vector&& cuda_grad_src={}, string&& cuda_header="", DataMap&& data={}); + + // @attrs(multiple_outputs,replace_outputs) + CodeOp(vector&& inputs, vector&& outputs, string&& cpu_src="", vector&& cpu_grad_src={}, string&& cpu_header="", string&& cuda_src="", vector&& cuda_grad_src={}, string&& cuda_header="", DataMap&& data={}); + + + const char* name() const override { return "code"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/copy_op.cc b/python/jittor/src/ops/copy_op.cc new file mode 100644 index 00000000..ebc748cf --- /dev/null +++ b/python/jittor/src/ops/copy_op.cc @@ -0,0 +1,53 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "ops/op_register.h" +#include "ops/copy_op.h" +#ifdef HAS_CUDA +#include +#include "helper_cuda.h" +#include "misc/cuda_flags.h" +#endif + +namespace jittor { + +CopyOp::CopyOp(Var* x) { + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + flags.set(NodeFlags::_manual_set_vnbb); + auto y = create_output(nullptr, x->dtype()); + if (x->name.ptr) + y->name = x->name; +} + +VarPtr CopyOp::grad(Var* out, Var* dout, Var* v, int v_index) { + return dout; +} + +void CopyOp::infer_shape() { + outputs().front()->set_shape(inputs().front()->shape); +} + +void CopyOp::run() { + auto x = inputs().front(); + auto size = x->size; + auto x_ptr = x->mem_ptr; + auto y_ptr = outputs().front()->mem_ptr; + #ifdef HAS_CUDA + if (flags.get(NodeFlags::_cuda)) { + checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDeviceToDevice, 0)); + } else + #endif + { + std::memcpy(y_ptr, x_ptr, size); + } +} + + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/copy_op.h b/python/jittor/src/ops/copy_op.h new file mode 100644 index 00000000..56fae6c3 --- /dev/null +++ b/python/jittor/src/ops/copy_op.h @@ -0,0 +1,23 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CopyOp : Op { + CopyOp(Var* x); + + const char* name() const override { return "copy"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; + void run() override; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/empty_op.cc b/python/jittor/src/ops/empty_op.cc new file mode 100644 index 00000000..8b84c4dd --- /dev/null +++ b/python/jittor/src/ops/empty_op.cc @@ -0,0 +1,22 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "ops/array_op.h" +#include "ops/op_register.h" +#include "ops/empty_op.h" + +namespace jittor { + +EmptyOp::EmptyOp(NanoVector shape, NanoString dtype) { + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + create_output(shape, dtype); +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/empty_op.h b/python/jittor/src/ops/empty_op.h new file mode 100644 index 00000000..ecb418d0 --- /dev/null +++ b/python/jittor/src/ops/empty_op.h @@ -0,0 +1,20 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct EmptyOp : Op { + EmptyOp(NanoVector shape, NanoString dtype=ns_float32); + + const char* name() const override { return "empty"; } +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/fetch_op.cc b/python/jittor/src/ops/fetch_op.cc new file mode 100644 index 00000000..634f4f07 --- /dev/null +++ b/python/jittor/src/ops/fetch_op.cc @@ -0,0 +1,163 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. +// All Rights Reserved. +// Maintainers: Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#ifdef HAS_CUDA +#include +#include "helper_cuda.h" +#include +#include "misc/cuda_flags.h" +#include "mem/allocator/sfrl_allocator.h" +#include "mem/allocator/cuda_dual_allocator.h" +#include "event_queue.h" +#endif +#include "ops/fetch_op.h" +#include "mem/allocator.h" +#include "executor.h" + +namespace jittor { + +#ifdef HAS_CUDA + +#pragma GCC visibility push(hidden) +namespace fetcher_local { + +cudaStream_t stream; +cudaEvent_t event; + +volatile int64 n_to_fetch; +std::mutex m; +list fetch_tasks; + +static void fetch_caller() { + fetch_tasks.front().call(); + fetch_tasks.pop_front(); +} + +static void to_fetch(CUDA_HOST_FUNC_ARGS) { + event_queue.push(fetch_caller); +} + +struct Init { +Init() { + if (!get_device_count()) return; + checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + checkCudaErrors(cudaEventCreate(&event, cudaEventDisableTiming)); +} +~Init() { + if (!get_device_count()) return; + // do not call deleter on exit + for (auto& f : fetch_tasks) + f.func.deleter = nullptr; + peekCudaErrors(cudaDeviceSynchronize()); + peekCudaErrors(cudaStreamDestroy(stream)); + peekCudaErrors(cudaEventDestroy(event)); +} +} ; + +} +using namespace fetcher_local; + +#endif + +list fetcher; +// this list will be free at each execution +list fetcher_to_free; + +FetchOp::FetchOp(vector&& inputs, FetchFunc&& func) +: fetch_vars(inputs), func(move(func)) { + #ifdef HAS_CUDA + // stream needs to be created after nccl plugin + static Init init_fetch; + #endif + VarPtr vp(0, ns_int32); + outputs_holder.emplace_back(vp); + fetcher.emplace_front(move(vp)); + fetcher_iter = fetcher.begin(); + bool all_finished = true; + for (auto v : fetch_vars) + if (!v->is_finished()) { + all_finished = false; + v->flags.set(NodeFlags::_stop_fuse); + v->flags.set(NodeFlags::_fetch); + } + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + flags.set(NodeFlags::_fetch); + flags.set(NodeFlags::_stop_grad); + fetcher_iter->ptr->flags.set(NodeFlags::_fetch); + // fetcher_to_free.clear(); + if (all_finished) { + // if all finished, run immediately + run(); + } + // if too many fetchers are bufferd, force flush + while (fetcher.size() > 20) { + LOGvvvv << "too many fetchers(">>fetcher.size() >> + ") are bufferd, force flush"; + exe.run_sync({fetcher.back().ptr}, false); + } +} + +void FetchOp::run() { + vector allocations(fetch_vars.size()); + vector arrays(fetch_vars.size()); + #ifdef HAS_CUDA + bool has_cuda_memcpy = false; + event_queue.flush(); + #endif + LOGvvvv << "fetch" << fetch_vars.size() << "vars" << fetch_vars; + int i = 0; + for (auto v : fetch_vars) { + auto& allocation = allocations[i]; + + #ifdef HAS_CUDA + if (v->allocator->is_cuda()) { + checkCudaErrors(cudaEventRecord(event, 0)); + checkCudaErrors(cudaStreamWaitEvent(stream, event, 0)); + new (&allocation) Allocation(&cuda_dual_allocator, v->size); + // mostly device to device + #if IS_CUDA + checkCudaErrors(cudaMemcpyAsync( + allocation.ptr, v->mem_ptr, v->size, cudaMemcpyDefault, stream)); + #else + checkCudaErrors(cudaMemcpyAsync( + allocation.ptr, v->mem_ptr, v->size, cudaMemcpyDeviceToDevice, stream)); + #endif + auto host_ptr = cuda_dual_allocator.get_dual_allocation( + allocation.allocation).host_ptr; + // device to host + checkCudaErrors(cudaMemcpyAsync( + host_ptr, allocation.ptr, v->size, cudaMemcpyDeviceToHost, stream)); + allocation.ptr = host_ptr; + has_cuda_memcpy = true; + } else + #endif + { + new (&allocation) Allocation(cpu_allocator, v->size); + std::memcpy(allocation.ptr, v->mem_ptr, v->size); + } + arrays[i].ptr = allocation.ptr; + arrays[i].shape = v->shape; + arrays[i].dtype = v->dtype(); + i++; + } + #ifdef HAS_CUDA + if (has_cuda_memcpy) { + fetch_tasks.push_back({move(func), move(allocations), move(arrays)}); + checkCudaErrors(_cudaLaunchHostFunc(stream, &to_fetch, 0)); + } else + #endif + { + FetchResult fr{move(func), move(allocations), move(arrays)}; + fr.call(); + } + fetcher_to_free.emplace_front(move(*fetcher_iter)); + fetcher.erase(fetcher_iter); +} + +} // jittor diff --git a/python/jittor/src/ops/fetch_op.h b/python/jittor/src/ops/fetch_op.h new file mode 100644 index 00000000..365e04da --- /dev/null +++ b/python/jittor/src/ops/fetch_op.h @@ -0,0 +1,79 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include "op.h" +#include "var.h" +#include "mem/allocator.h" +#include "ops/array_op.h" + +namespace jittor { + +struct FetchResult; + +struct FetchFunc { + typedef FetchResult R; + std::function callback; + std::function deleter; + FetchFunc() = default; + FetchFunc(FetchFunc&& other) : callback(other.callback), deleter(other.deleter) { + other.callback = nullptr; + other.deleter = nullptr; + }; + FetchFunc(const FetchFunc&) = delete; + FetchFunc(std::function&& callback) : callback(move(callback)) {} + FetchFunc(std::function&& callback, std::function&& deleter) + : callback(move(callback)), deleter(move(deleter)) {}; + ~FetchFunc() { + if (deleter) { + deleter(); + } + } + void operator =(FetchFunc&& other) { this->~FetchFunc(); new (this) FetchFunc(move(other)); } +}; + + +struct SimpleFunc { + std::function callback; + std::function deleter; + SimpleFunc() = default; + SimpleFunc(SimpleFunc&& other) : callback(other.callback), deleter(other.deleter) { + other.callback = nullptr; + other.deleter = nullptr; + }; + SimpleFunc(const SimpleFunc&) = delete; + SimpleFunc(std::function&& callback) : callback(move(callback)) {} + SimpleFunc(std::function&& callback, std::function&& deleter) + : callback(move(callback)), deleter(move(deleter)) {}; + ~SimpleFunc() { + if (deleter) { + deleter(); + } + } + void operator =(SimpleFunc&& other) { this->~SimpleFunc(); new (this) SimpleFunc(move(other)); } +}; + +struct FetchResult { + FetchFunc func; + vector allocations; + vector arrays; + + inline void call() { func.callback(this); } +}; + +struct FetchOp final : Op { + vector fetch_vars; + FetchFunc func; + list::iterator fetcher_iter; + + FetchOp(vector&& inputs, FetchFunc&& func); + + const char* name() const override { return "fetch"; } + void run() override; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/fuse_transpose_op.cc b/python/jittor/src/ops/fuse_transpose_op.cc new file mode 100644 index 00000000..b11a292f --- /dev/null +++ b/python/jittor/src/ops/fuse_transpose_op.cc @@ -0,0 +1,114 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "ops/fuse_transpose_op.h" +#include "var.h" +#include "ops/op_register.h" +#include "misc/cuda_flags.h" + +namespace jittor { + +#ifndef JIT +static auto make_transpose = get_op_info("fuse_transpose") + .get_constructor(); + +static inline NanoVector get_reverse(NanoVector axes) { + NanoVector reverse; + reverse.reserve(axes.size(), axes.size()); + for (uint i=0; iis_finished()) { + auto type = x->input()->type(); + if (type==OpType::broadcast || type==OpType::element) + tp = OpType::reduce; + } + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + set_type(tp); + flags.set(NodeFlags::_manual_set_vnbb); + int i=0; + for (; ishape.size(); + if (!axes.size()) { + for (int i=0; i<(int)xdim; i++) + axes.push_back(xdim-1-i); + } + y = create_output(nullptr, x->dtype()); +} + +void FuseTransposeOp::infer_shape() { + auto xdim = x->shape.size(); + CHECK(xdim); + if (!axes.size()) { + for (int i=0; i<(int)xdim; i++) + axes.push_back(xdim-1-i); + } else { + CHECKop(axes.size(),==,xdim); + int64_t mask=0; + for (auto i : axes) mask |= 1<shape[axes[i]]); + y->set_shape(shape); +} + +VarPtr FuseTransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) { + return make_transpose(dout, get_reverse(axes)); +} + +void FuseTransposeOp::jit_prepare(JK& jk) { + auto bc = type()==OpType::broadcast; + auto ax = bc ? axes : get_reverse(axes); + jk << "«Tx:" << x->dtype(); + jk << "«DIM=" << JK::hex1(axes.size()); + jk << "«BC:" << JK::hex1(bc); + for (uint i=0; i. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "executor.h" +#include "ops/getitem_op.h" +#include "ops/op_register.h" +#ifdef JIT_cuda +#include +#include "helper_cuda.h" +#endif +#ifndef JIT +#include "misc/stack_vector.h" +#include "opt/kernel_ir.h" +#ifdef HAS_CUDA +#include "misc/cuda_flags.h" +#endif +#endif + +namespace jittor { + +#ifndef JIT + + +static auto make_number = get_op_info("number") + .get_constructor(); +static auto make_empty = get_op_info("empty") + .get_constructor(); +static auto make_setitem = get_op_info("setitem") + .get_constructor(); + +GetitemOp::GetitemOp(Var* x, VarSlices&& slices) + : vs(move(slices)) { + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + flags.set(NodeFlags::_has_gopt); + flags.set(NodeFlags::_manual_set_vnbb); + for (int i=0; iflags.set(NodeFlags::_needed_by_backward); + create_output(nullptr, x->dtype()); +} + +GetitemOp::GetitemOp(Var* x, VarSlices&& slices, int _) + : vs(move(slices)) { + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + flags.set(NodeFlags::_has_gopt); + flags.set(NodeFlags::_custom_flag); + flags.set(NodeFlags::_grads); + flags.set(NodeFlags::_manual_set_vnbb); + for (int i=0; iflags.set(NodeFlags::_needed_by_backward); + create_output(nullptr, x->dtype()); + auto out2 = create_output(nullptr, x->dtype()); + out2->share_with(x); + ns.data = _; +} + +void GetitemOp::infer_slices( + StackVector<>& __restrict__ i_to_vs, + StackVector<>& __restrict__ i_to_o, + StackVector<>& __restrict__ out_shape +) { + auto in = inputs().front(); + auto in_shape = in->shape; + auto nin = in_shape.size(); + i_to_vs.n = i_to_o.n = nin; + out_shape.n = 0; + + int vid = 0; + first_oid_of_var = -1; + var_dim = 0; + for (int i=0; i= vs.n) { + // i i i + // | | | + // v v v --> overflow + // s s + i_to_vs[i] = -1; + i_to_o[i] = out_shape.size(); + out_shape.push_back(in_shape[i]); + } else + if (s.is_var()) { + // i --> s ---> o + // + ---> o + // var maybe multiple dims + if (first_oid_of_var == -1) { + for (int i=0; ishape.size()); + first_oid_of_var = out_shape.size(); + for (int j=0; jshape; + auto niv = iv_shape.size(); + for (int j=0; j> out_shape_j >> "!=" + >> iv_shape_j << "data shape:" << in_shape << + "slice shape:" << iv_shape; + if (out_shape_j == 1) + out_shape_j = iv_shape_j; + } + } else + if (s.is_ellipsis()) { + auto remain_slice = vs.n-vid-1; + for (int i=vid+1; i=0) << "NDims not match"; + for (int j=0; j=0 && v>in_shape_i>>")"; + } else + if (s.is_str()) { + i_to_vs[i] = vid++; + i_to_o[i] = -1; + } else { + // slice + auto& slice = s.slice; + auto in_shape_i = in_shape[i]; + auto out_shape_j = in_shape_i; + if (slice.mask == 7) { + // slice is a[::] + // start, stop, step is not filled + vid++; + i_to_vs[i] = -1; + i_to_o[i] = out_shape.size(); + out_shape.push_back(out_shape_j); + } else { + i_to_vs[i] = vid++; + i_to_o[i] = out_shape.size(); + if (in_shape_i > 0) { + slice.fill(in_shape_i); + if (std::abs(slice.step) <= 1) + out_shape_j = (slice.stop - slice.start) * slice.step; + else if (slice.step>0) + out_shape_j = (slice.stop - slice.start - 1) / slice.step + 1; + else + out_shape_j = (slice.start - slice.stop - 1) / -slice.step + 1; + out_shape_j = std::max((int64)0, out_shape_j); + } + out_shape.push_back(out_shape_j); + } + } + } + while (vid < vs.n) { + auto& s = vs.slices[vid++]; + if (s.is_none()) { + out_shape.push_back(1); + } else + CHECK(s.is_ellipsis()) << "Too many slices" << vs << "shape:" << in->shape; + } +} + +void cuda_loop_schedule(NanoVector o_shape, int* masks, int* tdims) { + // bz by bx tz ty tx + // 5 4 3 2 1 0 + // LOi: bitmask of used dims of loop i + // LOi bit 6: need for + // if need for, keep for range: for (int i@i=tid; tid tnum, for -> int i@i = tid + int rtnum = 1024; + // int max_tnum = {1024, 1024, 64, (1u<<31)-1, 65535, 65535}; + int loop_id = (int)o_shape.size()-1; + int tid = 0; + int64 block_size = 1; + int thread_size = 1; + for (int i=0; i<6; i++) tdims[i] = 1; + for (; tid<3 && loop_id>=0 && rtnum>1; tid++) { + int64 si = o_shape[loop_id]; + int mask = 1<rtnum*4) { + // need for, use tid(1<rtnum) { + mask |= (1<<6); + thread_size *= rtnum; + tdims[tid] = rtnum; + rtnum = 0; + } else { + rtnum = rtnum / std::max(si, (int64)1); + thread_size *= si; + tdims[tid] = si; + if (si == 0) mask |= 1<<7; + } + masks[loop_id] = mask; + loop_id --; + } + int64 total_size = (int64)block_size*thread_size; + if (tid<3) tid=3; + for (; tid<6 && loop_id>=0 && total_size<(256*1024); tid++) { + int64 si = o_shape[loop_id]; + int mask = 1<=4 ? 65535 : (1u<<31)-1; + if (si > max_thread) { + si = max_thread; + mask |= 1<<6; + } + total_size *= si; + tdims[tid] = si; + masks[loop_id] = mask; + loop_id --; + } + while (loop_id>=0) { + masks[loop_id--] = 0; + } +} + +void GetitemOp::compile_optimize(string& src) { + _compile_optimize(src); +} + +void GetitemOp::_compile_optimize(string& src) { + if (!flags.get(NodeFlags::_cuda)) + return; + + auto jd = get_jit_define(); + map jd_map(jd.begin(), jd.end()); + + KernelIR main(src); + auto& func = main.children.back()->children.back(); + // auto& loop = func->children.back(); + + func->push_back("void slice_func() {}", &func->before); + + auto& new_func = func->before.back(); + // auto new_func = func->before.back()->move_out(); + + new_func->attrs["dtype"] = "static __global__ void"; + // LOGir << main.to_string(); + src = main.to_string(); + string arg_call = ""; + const char* tname[] = {"threadIdx.x", "threadIdx.y", "threadIdx.z", "blockIdx.x", "blockIdx.y", "blockIdx.z"}; + const char* tname2[] = {"blockDim.x", "blockDim.y", "blockDim.z", "gridDim.x", "gridDim.y", "gridDim.z"}; + for (auto& ir : func->children) { + if (ir->type == "define") { + string& rvalue = ir->attrs.at("rvalue"); + string& lvalue = ir->attrs.at("lvalue"); + string& dtype = ir->attrs.at("dtype"); + if (startswith(rvalue, "input") + || startswith(rvalue, "output") + || startswith(rvalue, "vs.") + || rvalue.back() == ')' + || rvalue.back() == ']') + { + if (dtype == "auto") + LOGvvvv << "keep" << rvalue; + else { + LOGvvvv << "args" << rvalue; + if (arg_call.size()) arg_call += ", "; + arg_call += lvalue; + LOGvvvv << dtype+" "+lvalue; + new_func->push_back(dtype+" "+lvalue+";", &new_func->inner); + } + } else { + LOGvvvv << "move" <push_back(ir->clone()); + } + } + } + new_func->push_back(func->children.back()->move_out()); + auto& loop = new_func->children.back(); + int no = o_shape.size(); + STACK_ALLOC(KernelIR*, loops, no); + if (!no) { + func->push_back("slice_func<<<1,1>>>("+arg_call+");"); + } else { + bool has_zero = 0; + loops[0] = loop.get(); + for (int i=1; ichildren.back().get(); + for (int i=0; iinner.size() == 3); + auto lo = l->find_define("LO"+S(i)); + ASSERT(lo); + auto loi = std::stoi(lo->attrs.at("rvalue")); + if (loi>>7) has_zero = 1; + string tid = ""; + string tnum = ""; + for (int j=0; j<6; j++) { + if ((loi>>j)&1) { + if (tid.size()) { + tid += string("+")+tnum+"*"+tname[j]; + tnum += string("*")+tname2[j]; + } else { + tid = tname[j]; + tnum = tname2[j]; + } + } + } + if (!tid.size()) { + continue; + } + if (loi&(1<<6)) { + l->inner.at(0)->attrs.at("rvalue") = tid; + l->inner.at(2)->attrs.at("code") = "i"+S(i)+"+="+tnum+";"; + } else { + // no need for + while (l->inner.size()) + l->inner.at(0)->erase(); + l->push_front("index_t i"+S(i)+" = "+tid+";"); + } + } + if (!has_zero) { + func->push_back("int no = o_shape.size();"); + func->push_back("STACK_ALLOC(int,masks,no);"); + func->push_back("int tdims[6];"); + func->push_back("cuda_loop_schedule(o_shape, masks, tdims);"); + func->push_back("dim3 grid_dim(tdims[3],tdims[4],tdims[5]);"); + func->push_back("dim3 block_dim(tdims[0],tdims[1],tdims[2]);"); + func->push_back("slice_func<<>>("+arg_call+");"); + } + } + src = main.to_string(); +} + +void GetitemOp::infer_shape() { + auto in = inputs().front(); + auto out = outputs().front(); + auto in_shape = in->shape; + auto nin = in_shape.size(); + + StackVector<> i_to_vs(nin); + StackVector<> i_to_o(nin); + // shape return to use + StackVector<> out_shape; + infer_slices(i_to_vs, i_to_o, out_shape); + + // this will cause save checkpoint failed. + // if (out_shape.n == 0) + // out->flags.set(NodeFlags::_is_scalar); + // optimized shape (each dim is a loop var) + StackVector<> o_shape; + int fov = -1; + for (int i=0; i=0) { + if (vid==-1 && i && i_to_vs[i-1]<0) { + vid = -2; + o_shape.back() *= os; + } else + o_shape.push_back(os); + oid = o_shape.size()-1; + } else { + auto& s = vs.slices[vid]; + if (s.is_var() && fov == -1) { + fov = o_shape.size(); + for (int i=0; iset_shape(out_shape.to_nano_vector()); + + this->i_to_vs = i_to_vs.to_nano_vector(); + this->i_to_o = i_to_o.to_nano_vector(); + this->o_shape = o_shape.to_nano_vector(); + if (outputs().size() > 1) { + auto out2 = output(1); + out2->set_shape(in->shape); + } + + LOGV(999) << "\ni_to_vs:" << i_to_vs + << "\ni_to_o:" << i_to_o + << "\no_shape:" << o_shape; +} + +VarPtr GetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) { + if (v_index) + return nullptr; + auto zeros = make_number(0, v); + // TODO: maybe add here? + // need analysis the overlap attr os var slices + for (int i=0; ishape, in->dtype()); + else + x = make_number(0, in); + } + if (!y) { + y = make_number(0, outputs().front()); + } + dins[0] = make_setitem(x, VarSlices(vs, true), y, ns_void); +} + +void GetitemOp::jit_prepare(JK& jk) { + auto in = inputs().front(); + int idim = i_to_vs.size(); + jk << "«Ti:" << in->dtype(); + jk << "«IDIM=" << JK::hex1(i_to_vs.size()); + jk << "«ODIM=" << JK::hex1(o_shape.size()); + if (first_oid_of_var>=0) { + jk << "«FOV=" << JK::hex1(first_oid_of_var); + jk << "«VD=" << JK::hex1(var_dim); + } + for (int i=0; i=0 && io==-1) { + if (v.is_int()) { + jk << "«VS" << JK::hex1(i) << ":-1"; + } else + if (v.is_str()) { + jk << "«VS" << JK::hex1(i) << ":-5"; + jk << "«VSS" << JK::hex1(i) << ":" << v.get_str(); + } else { + ASSERT(v.is_var()); + auto var = v.var; + auto vshape = var->shape; + auto vdim = vshape.size(); + int vsmask = 0; + for (int j=0; jdtype(); + } + } else + if (iv>=0 && io>=0) { + ASSERT(v.is_slice()); + jk << "«VS" << JK::hex1(i) << ':'; + if (std::abs(v.slice.step) <= 1) + jk << JK::shex1(v.slice.step); + else + jk << '0'; + } + } + #ifdef HAS_CUDA + if (use_cuda) { + int no = o_shape.size(); + STACK_ALLOC(int, masks, no); + int tdims[6]; + cuda_loop_schedule(o_shape, masks, tdims); + for (int i=0; i& __restrict__ i_to_vs, + StackVector<>& __restrict__ i_to_o, + StackVector<>& __restrict__ out_shape + ); + void _compile_optimize(string& src); +}; + +void cuda_loop_schedule(NanoVector o_shape, int* masks, int* tdims); + +} // jittor diff --git a/python/jittor/src/ops/index_op.cc b/python/jittor/src/ops/index_op.cc new file mode 100644 index 00000000..1f27d788 --- /dev/null +++ b/python/jittor/src/ops/index_op.cc @@ -0,0 +1,86 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "ops/index_op.h" + +namespace jittor { + +#ifndef JIT +IndexOp::IndexOp(NanoVector shape, int64 dim, NanoString dtype) : dim(dim) { + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + set_type(OpType::element); + x.reset(new Var*[1]); + x[0] = create_output(shape, dtype); +} + +IndexOp::IndexOp(NanoVector shape, NanoString dtype) : dim(shape.size()) { + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + set_type(OpType::element); + x.reset(new Var*[dim]); + for (int i=0; ishape.size()) { + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + flags.set(NodeFlags::_manual_set_vnbb); + set_type(OpType::element); + x.reset(new Var*[dim]); + for (int i=0; iset_shape(a->shape); +} + +void IndexOp::jit_prepare(JK& jk) { + add_jit_define(jk, "T", x[0]->dtype()); + add_jit_define(jk, "DIM", JK::hex1(dim)); + add_jit_define(jk, "XDIM", JK::hex1(x[0]->shape.size())); +} + +#else // JIT +void IndexOp::jit_run() { + @if(DIM==XDIM, + @for(i,0,XDIM, auto* __restrict__ x@i@@p = x[@i]->ptr();) + , + auto* __restrict__ x0p = x[0]->ptr(); + ) + // define x shape + @for(i, 0, XDIM, index_t x0shape@i = x[0]->shape[@i];) + // define x stride + index_t x0stride@{XDIM-1} = 1; + @for(i, XDIM-2, -1, -1, auto x0stride@i = x0stride@{i+1} * x0shape@{i+1};) + + @for(d, 0, XDIM, for (index_t i@d=0; i@d < x0shape@d; i@d++)) { + auto xid = @for(d, 0, XDIM, + i@d * x0stride@d); + @if(DIM==XDIM, + @for(i,0,XDIM, T x@i@@id = i@i; x@i@@p[xid] = x@i@@id;) + , + T x@DIM@@id = i@DIM; x0p[xid] = x@DIM@@id; + ) + } +} +#endif // JIT + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/index_op.h b/python/jittor/src/ops/index_op.h new file mode 100644 index 00000000..840dd79e --- /dev/null +++ b/python/jittor/src/ops/index_op.h @@ -0,0 +1,60 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + + +namespace jittor { + +struct IndexOp : Op { + unique_ptr x; + int64 dim; + /** + Index Operator generate index of shape. + + It performs equivalent Python-pseudo implementation below:: + + n = len(shape)-1 + x = np.zeros(shape, dtype) + for i0 in range(shape[0]): # 1-st loop + for i1 in range(shape[1]): # 2-nd loop + ...... # many loops + for in in range(shape[n]) # n+1 -th loop + x[i0,i1,...,in] = i@dim + + * [in] shape: the output shape, a integer array + * [in] dim: the dim of the index. + * [in] dtype: the data type string, default int32 + + Example:: + + print(jt.index([2,2], 0)) + # output: [[0,0],[1,1]] + print(jt.index([2,2], 1)) + # output: [[0,1],[0,1]] + */ + IndexOp(NanoVector shape, int64 dim, NanoString dtype=ns_int32); + // @attrs(multiple_outputs) + IndexOp(NanoVector shape, NanoString dtype=ns_int32); + /** shape dependency version of index op + jt.index_var(a, 1) similar with jt.index(a.shape, 1) + */ + // @pybind(index,index_var) + IndexOp(Var* a, int64 dim, NanoString dtype=ns_int32); + /** shape dependency version of index op + jt.index_var(a) similar with jt.index(a.shape) + */ + // @pybind(index,index_var) + // @attrs(multiple_outputs) + IndexOp(Var* a, NanoString dtype=ns_int32); + + const char* name() const override { return "index"; } + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/numpy_code_op.cc b/python/jittor/src/ops/numpy_code_op.cc new file mode 100644 index 00000000..9f218440 --- /dev/null +++ b/python/jittor/src/ops/numpy_code_op.cc @@ -0,0 +1,158 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guowei Yang <471184555@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "ops/numpy_code_op.h" +#include "ops/op_register.h" + +#ifndef JIT + +namespace jittor { + +static auto make_numpy_code = get_op_info("numpy_code") + .get_constructor&&, NumpyFunc, NumpyResult&&>(); + +NumpyCodeOp::NumpyCodeOp(NanoVector shape, NanoString dtype, vector&& inputs, NumpyFunc&& forward, vector&& sbackward) + : _inputs(inputs), forward(move(forward)) +{ + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + _outputs.push_back(create_output(shape, dtype)); + CHECKop(_inputs.size(),<=,10); + ASSERT(_outputs[0]->num >= 0); + for (int i=0; i&& shapes, vector&& dtypes, vector&& inputs, NumpyFunc&& forward, vector&& sbackward) + : _inputs(inputs), forward(move(forward)) +{ + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + CHECKop(shapes.size(),==,dtypes.size()) << "Number of outputs' shapes and dtypes should be the same"; + _outputs.resize(shapes.size()); + CHECKop(_inputs.size(),<=,10); + CHECKop(_outputs.size(),<=,10); + CHECKop(_outputs.size(),>,0); + for (int i=0; inum >= 0); + } + for (int i=0; i&& inputs, NumpyFunc&& forward) + : _inputs(inputs), forward(move(forward)) +{ + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + _outputs.push_back(create_output(shape, dtype)); + CHECKop(_inputs.size(),<=,10); + ASSERT(_outputs[0]->num >= 0); +} + +NumpyCodeOp::NumpyCodeOp(vector&& shapes, vector&& dtypes, vector&& inputs, NumpyFunc&& forward) + : _inputs(inputs), forward(move(forward)) +{ + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + CHECKop(shapes.size(),==,dtypes.size()) << "Number of outputs' shapes and dtypes should be the same"; + _outputs.resize(shapes.size()); + CHECKop(_inputs.size(),<=,10); + CHECKop(_outputs.size(),<=,10); + CHECKop(_outputs.size(),>,0); + for (int i=0; inum >= 0); + } +} + +NumpyCodeOp::NumpyCodeOp(NanoVector shape, NanoString dtype, vector&& inputs, NumpyFunc forward, NumpyResult&& results) + : _inputs(inputs), forward(forward), _results(move(results)) +{ + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + _outputs.push_back(create_output(shape, dtype)); + CHECKop(_inputs.size(),<=,10); + ASSERT(_outputs[0]->num >= 0); +} + +VarPtr NumpyCodeOp::grad(Var* out, Var* dout, Var* v, int v_index) { + NumpyResult result; + + int out_index=-1; + for (int i=0; i<_outputs.size(); i++) { + if (_outputs[i] == out) { + out_index = i; + break; + } + } + ASSERT(out_index!=-1); + result.ints["out_index"] = out_index; + result.arrays["dout"].ptr=dout; + result.arrays["dout"].shape=dout->shape; + result.arrays["dout"].dtype=dout->dtype(); + vector outputs(_outputs.size()); + auto inputs = clone(_inputs); + inputs.push_back(dout); + for (int i=0; ishape; + outputs[i].dtype=_outputs[i]->dtype(); + inputs.push_back(_outputs[i]); + } + result.varrays["f_outputs"] = move(outputs); + + return make_numpy_code( + _inputs[v_index]->shape, + _inputs[v_index]->dtype(), + move(inputs), + backward[v_index], + move(result)); +} + +void NumpyCodeOp::run() { + NumpyResult result; + result.varrays = _results.varrays; + result.ints = _results.ints; + result.arrays = _results.arrays; + + if (result.arrays.count("dout") > 0) { + auto &ptr = result.arrays["dout"].ptr; + ptr = ((Var*)ptr)->mem_ptr; + } + if (result.varrays.count("f_outputs") > 0) { + for (auto& dv : result.varrays["f_outputs"]) { + dv.ptr = ((Var*)dv.ptr)->mem_ptr; + } + } + vector inputs(_inputs.size()); + vector outputs(_outputs.size()); + for (int i=0; iptr(); + inputs[i].shape=_inputs[i]->shape; + inputs[i].dtype=_inputs[i]->dtype(); + } + for (int i=0; iptr(); + outputs[i].shape=_outputs[i]->shape; + outputs[i].dtype=_outputs[i]->dtype(); + } + result.varrays["inputs"] = move(inputs); + result.varrays["outputs"] = move(outputs); + forward.callback(&result); +} + +} // jittor + +#endif // JIT \ No newline at end of file diff --git a/python/jittor/src/ops/numpy_code_op.h b/python/jittor/src/ops/numpy_code_op.h new file mode 100644 index 00000000..17282119 --- /dev/null +++ b/python/jittor/src/ops/numpy_code_op.h @@ -0,0 +1,113 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guowei Yang <471184555@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" +#include "numpy_func.h" + +namespace jittor { + +struct NumpyCodeOp : Op { + vector _inputs; + vector _outputs; + NumpyFunc forward; + vector backward; + NumpyResult _results; + + /** + Numpy Code Operator for easily customized op. + + ---------------- + + * [in] shape: the output shape, a integer array + + * [in] dtype: the output data type + + * [in] inputs: A list of input jittor Vars + + * [in] forward: function, represents forward python function + + * [in] backward: A list of function, represents gradiant for each input + + ---------------- + + Example-1:: + + def forward_code(np, data): + a = data["inputs"][0] + b = data["outputs"][0] + np.add(a,a,out=b) + + def backward_code(np, data): + dout = data["dout"] + out = data["outputs"][0] + np.copyto(out, dout*2.0) + + a = jt.random((5,1)) + b = jt.numpy_code( + a.shape, + a.dtype, + [a], + forward_code, + [backward_code], + ) + + Example-2:: + + def forward_code(np, data): + a,b = data["inputs"] + c,d = data["outputs"] + np.add(a,b,out=c) + np.subtract(a,b,out=d) + + def backward_code1(np, data): + dout = data["dout"] + out = data["outputs"][0] + np.copyto(out, dout) + + def backward_code2(np, data): + dout = data["dout"] + out_index = data["out_index"] + out = data["outputs"][0] + if out_index==0: + np.copyto(out, dout) + else: + np.negative(dout, out) + + a = jt.random((5,1)) + b = jt.random((5,1)) + c, d = jt.numpy_code( + [a.shape, a.shape], + [a.dtype, a.dtype], + [a, b], + forward_code, + [backward_code1,backward_code2], + ) + + */ + NumpyCodeOp(NanoVector shape, NanoString dtype, vector&& inputs, NumpyFunc&& forward, vector&& backward); + + // @attrs(multiple_outputs) + NumpyCodeOp(vector&& shapes, vector&& dtypes, vector&& inputs, NumpyFunc&& forward, vector&& backward); + + NumpyCodeOp(NanoVector shape, NanoString dtype, vector&& inputs, NumpyFunc&& forward); + + // @attrs(multiple_outputs) + NumpyCodeOp(vector&& shapes, vector&& dtypes, vector&& inputs, NumpyFunc&& forward); + + // @pybind(None) + NumpyCodeOp(NanoVector shape, NanoString dtype, vector&& inputs, NumpyFunc forward, NumpyResult&& results); + + const char* name() const override { return "numpy_code"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + + void run() override; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/op_register.cc b/python/jittor/src/ops/op_register.cc new file mode 100644 index 00000000..81d867b4 --- /dev/null +++ b/python/jittor/src/ops/op_register.cc @@ -0,0 +1,44 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "op.h" +#include "ops/op_register.h" + +namespace jittor { + +unordered_map op_info_map; + +void op_registe(const OpInfo& op_info) { + ASSERT(!has_op(op_info.name)) << "Op" << op_info.name << "is already registed, " + << "source_path:" << op_info.source_path << "extra_flags" << op_info.extra_flags; + LOGvv << "registe op" << op_info.name + << "\nsource_path:" << op_info.source_path + << "\nextra_flags:" << op_info.extra_flags + << "\nconstructors:" << op_info.constructors + << "\nvar_members:" << op_info.var_members; + op_info_map[op_info.name] = op_info; +} + +bool has_op(const string& name) { + string op_file_name = Op::op_name_to_file_name(name); + return op_info_map.count(op_file_name); +} + +OpInfo get_op_info(const string& name) { + string op_file_name = Op::op_name_to_file_name(name); + ASSERT(has_op(op_file_name)) << "Op" << name << "not found."; + return op_info_map.at(op_file_name); +} + +vector op_types; + +int registe_op_type(OpByType* op_type) { + op_types.push_back(op_type); + return 0; +} + + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/op_register.h b/python/jittor/src/ops/op_register.h new file mode 100644 index 00000000..d7412215 --- /dev/null +++ b/python/jittor/src/ops/op_register.h @@ -0,0 +1,45 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include +#include "common.h" + +namespace jittor { + +struct OpInfo { + string name, source_path, extra_flags; + vector> constructors; + // string: var member name, uint64: var member offset + vector> var_members; + + template auto get_constructor() { + typedef To (*func_t)(Ts...); + const auto& tid = typeid(func_t); + for (uint i=0; i types; + virtual string expand_op(const vector& args) = 0; + virtual void post_pass(OpCompiler*) = 0; +}; + +extern vector op_types; +int registe_op_type(OpByType*); + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/op_utils.cc b/python/jittor/src/ops/op_utils.cc new file mode 100644 index 00000000..1cdeba8d --- /dev/null +++ b/python/jittor/src/ops/op_utils.cc @@ -0,0 +1,43 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "ops/op_register.h" +#include "var.h" + +namespace jittor { + +static auto make_array = get_op_info("array") + .get_constructor(); +static auto make_unary = get_op_info("unary") + .get_constructor(); +static auto make_broadcast_to = get_op_info("broadcast_to") + .get_constructor(); + +VarPtr make_number(float number, Var* x) { + union Number { + float32 f32; + float64 f64; + int32 i32; + int64 i64; + } v; + if (x->dtype() == ns_float32) v.f32 = number; else + if (x->dtype() == ns_float64) v.f64 = number; else + if (x->dtype() == ns_int32) v.i32 = number; else + if (x->dtype() == ns_int64) v.i64 = number; else { + VarPtr nums = make_array(&number, 1, ns_float32); + nums = make_broadcast_to(nums, x, {}); + return make_unary(nums, x->dtype()); + } + VarPtr nums = make_array(&v, 1, x->dtype()); + return make_broadcast_to(nums, x, {}); +} + +static void init() { + op_registe({"number", "", "", {{&typeid(&make_number), (void*)&make_number}}}); +} +static int caller = (init(), 0); + +} // jittor diff --git a/python/jittor/src/ops/random_op.cc b/python/jittor/src/ops/random_op.cc new file mode 100644 index 00000000..a7ad79cd --- /dev/null +++ b/python/jittor/src/ops/random_op.cc @@ -0,0 +1,66 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include + +#include "var.h" +#include "init.h" +#include "ops/random_op.h" +#include "misc/cuda_flags.h" +#include "ops/op_register.h" + +namespace jittor { + +#ifndef JIT +RandomOp::RandomOp(NanoVector shape, NanoString dtype, NanoString type) { + // auto curand_random = get_op_info("curand_random") + // .get_constructor(); + // output = curand_random(shape, dtype); + #ifdef HAS_CUDA + if (use_cuda) { + static VarPtr(*curand_random)(NanoVector, NanoString, NanoString) = nullptr; + if (!curand_random && has_op("curand_random")) { + curand_random = get_op_info("curand_random") + .get_constructor(); + } + if (curand_random) { + auto var = curand_random(shape, dtype, type); + forward(var); + return; + } + } + #endif + output = create_output(shape, dtype); + this->type = type; + ASSERT(type == ns_normal || type == ns_uniform); +} + +void RandomOp::jit_prepare(JK& jk) { + jk << "«T:" << output->dtype(); + jk << "«R:" << type; +} + +#else // JIT +#ifdef JIT_cpu +void RandomOp::jit_run() { + auto* generator = get_random_engine(); + @if(@strcmp(@R,uniform)==0, + std::uniform_real_distribution distribution(0.0,1.0);, + std::normal_distribution distribution(0.0,1.0); + ) + auto* __restrict__ x = output->ptr(); + index_t num = output->num; + for (index_t i=0; i. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct RandomOp : Op { + Var* output; + NanoString type; + RandomOp(NanoVector shape, NanoString dtype=ns_float32, NanoString type=ns_uniform); + + const char* name() const override { return "random"; } + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/reduce_op.cc b/python/jittor/src/ops/reduce_op.cc new file mode 100644 index 00000000..88267b58 --- /dev/null +++ b/python/jittor/src/ops/reduce_op.cc @@ -0,0 +1,402 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#include "var.h" +#include "ops/reduce_op.h" +#include "ops/op_register.h" +#include "executor.h" + +namespace jittor { + +#ifndef JIT +static auto make_broadcast_to = get_op_info("broadcast_to") + .get_constructor(); +static auto make_binary = get_op_info("binary") + .get_constructor(); +static auto make_unary = get_op_info("unary") + .get_constructor(); +static auto make_reduce = get_op_info("reduce") + .get_constructor(); +static auto make_reduce2 = get_op_info("reduce") + .get_constructor(); +static auto make_ternary = get_op_info("ternary") + .get_constructor(); +static auto make_number = get_op_info("number") + .get_constructor(); + +unordered_set reduce_ops = { + /** + Returns the maximum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.max(x) + jt.Var([4], dtype=int32) + >>> x.max() + jt.Var([4], dtype=int32) + >>> x.max(dim=1) + jt.Var([4 4], dtype=int32) + >>> x.max(dim=1, keepdims=True) + jt.Var([[4] + [4]], dtype=int32) + */ + // @pybind(max, reduce_maximum) + "maximum", + + /** + Returns the minimum elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.min(x) + jt.Var([0], dtype=int32) + >>> x.min() + jt.Var([0], dtype=int32) + >>> x.min(dim=1) + jt.Var([1 0], dtype=int32) + >>> x.min(dim=1, keepdims=True) + jt.Var([[1] + [0]], dtype=int32) + */ + // @pybind(min, reduce_minimum) + "minimum", + + /** + Returns the sum of the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[4 1 2] + [0 2 4]], dtype=int32) + >>> jt.sum(x) + jt.Var([13], dtype=int32) + >>> x.sum() + jt.Var([13], dtype=int32) + >>> x.sum(dim=1) + jt.Var([7 6], dtype=int32) + >>> x.sum(dim=1, keepdims=True) + jt.Var([[7] + [6]], dtype=int32) + */ + // @pybind(sum, reduce_add) + "add", + + /** + Returns the product of all the elements in the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[7 5 5] + [5 7 5]], dtype=int32) + >>> jt.prod(x) + jt.Var([30625], dtype=int32) + >>> x.prod() + jt.Var([30625], dtype=int32) + >>> x.prod(dim=1) + jt.Var([175 175], dtype=int32) + >>> x.prod(dim=1, keepdims=True) + jt.Var([[175] + [175]], dtype=int32) + */ + // @pybind(prod, product, reduce_multiply) + "multiply", + + /** + Tests if all elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 1 1] + [0 1 0]], dtype=int32) + >>> jt.all_(x) + jt.Var([False], dtype=int32) + >>> x.all_() + jt.Var([False], dtype=int32) + >>> x.all_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.all_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32) + */ + // @pybind(reduce_logical_and, all_) + "logical_and", + + /** + Tests if any elements in input evaluate to True. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(2, shape=(2, 3)) + >>> x + jt.Var([[1 0 1] + [0 0 0]], dtype=int32) + >>> jt.any_(x) + jt.Var([True], dtype=int32) + >>> x.any_() + jt.Var([True], dtype=int32) + >>> x.any_(dim=1) + jt.Var([True False], dtype=int32) + >>> x.any_(dim=1, keepdims=True) + jt.Var([[True] + [False]], dtype=int32) + */ + // @pybind(reduce_logical_or, any_) + "logical_or", + "logical_xor", + "bitwise_and", + "bitwise_or", + "bitwise_xor", + + /** + Returns the mean value of the input. + + ---------------- + + * [in] x: the input jt.Var. + + * [in] dim or dims: int or tuples of ints (optional). If specified, reduce along the given the dimension(s). + + * [in] keepdims: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False. + + ---------------- + + Example-1:: + >>> x = jt.randint(10, shape=(2, 3)) + >>> x + jt.Var([[9 4 4] + [1 9 6]], dtype=int32) + >>> jt.mean(x) + jt.Var([5.5000005], dtype=float32) + >>> x.mean() + jt.Var([5.5000005], dtype=float32) + >>> x.mean(dim=1) + jt.Var([5.666667 5.3333335], dtype=float32) + >>> x.mean(dim=1, keepdims=True) + jt.Var([[5.666667 ] + [5.3333335]], dtype=float32) + */ + // @pybind(mean) + "mean", +}; + +EXTERN_LIB int amp_reg; + +ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims) + : x(x) { + // improve float16 mean precision + if (!(amp_reg & 32) && (x->dtype() == ns_float16 || x->dtype() == ns_bfloat16) && (op == ns_mean || op == ns_add)) { + auto x_float32 = make_unary(x, ns_float32); + auto mean = make_reduce(x_float32, op, dims, keepdims); + mean = make_unary(mean, x->dtype()); + forward(mean); + return; + } + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + set_type(OpType::reduce); + if (op.get(NanoString::_no_need_back_in)) + flags.set(NodeFlags::_manual_set_vnbb); + ns = op; + ASSERT(ns.is_binary()); + auto xdim = x->shape.size(); + keepdims_mask = keepdims ? (int)-1 : (int)0; + if (!dims.size()) { + reduce_mask = (1<=0 && dimdtype() == ns_bool && ns == ns_add) + if (x->dtype() == ns_bool) + y = create_output(nullptr, ns_int32); + else + y = create_output(nullptr, reduce_dtype_infer(ns, x->ns)); +} + +ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, uint keepdims_mask) + : x(x) { + // improve float16 mean precision + if (!(amp_reg & 32) && (x->dtype() == ns_float16 || x->dtype() == ns_bfloat16) && (op == ns_mean || op == ns_add)) { + auto x_float32 = make_unary(x, ns_float32); + auto mean = make_reduce2(x_float32, op, dims_mask, keepdims_mask); + mean = make_unary(mean, x->dtype()); + forward(mean); + return; + } + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + set_type(OpType::reduce); + if (op.get(NanoString::_no_need_back_in)) + flags.set(NodeFlags::_manual_set_vnbb); + ns = op; + ASSERT(ns.is_binary()); + reduce_mask = dims_mask; + this->keepdims_mask = keepdims_mask; + y = create_output(nullptr, reduce_dtype_infer(ns, x->ns)); +} + +ReduceOp::ReduceOp(Var* x, NanoString op, int dim, bool keepdims) + : ReduceOp(x, op, NanoVector(dim), keepdims) {} + +void ReduceOp::infer_shape() { + auto xdim = x->shape.size(); + NanoVector yshape; + yshape.clear(); + for (int i=0; i>i&1) { + if (keepdims_mask>>i&1) + yshape.push_back(1); + } else + yshape.push_back(x->shape[i]); + } + if (!yshape.size()) { + yshape.push_back(1); + // change last bit to 1, last dim should keep dim + keepdims_mask |= 1; + } + y->set_shape(yshape); + if (yshape.size() == 1 && y->num == 1) + y->flags.set(NodeFlags::_is_scalar); +} + +VarPtr ReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) { + if (ns == ns_add) { + auto ret = make_broadcast_to(dout, v, reduce_mask, keepdims_mask); + return ret; + } + if (ns == ns_multiply) { + VarPtr a = make_binary(dout, out, ns_multiply); + VarPtr b = make_broadcast_to(a, v, reduce_mask, keepdims_mask); + return make_binary(b, v, ns_divide); + } + if (ns == ns_mean) { + VarPtr a = make_broadcast_to(dout, v, reduce_mask, keepdims_mask); + VarPtr n = make_number(1.0f*out->num / v->num, a); + return make_binary(a, n, ns_multiply); + } + if (ns == ns_maximum || ns == ns_minimum) { + VarPtr zeros = make_number(0, v); + VarPtr a = make_broadcast_to(out, v, reduce_mask, keepdims_mask); + VarPtr cond = make_binary(v, a, ns_equal); + VarPtr dv = make_broadcast_to(dout, v, reduce_mask, keepdims_mask); + return make_ternary(cond, dv, zeros); + } + return nullptr; +} + +void ReduceOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype() + << "«Ty:" << y->dtype() + << "«Tz:" << y->dtype() + << "«OP:" << ns + << "«DIM=" << JK::hex1(x->shape.size()) + << "«REDUCE=" << JK::hex(reduce_mask); +} + +#else // JIT +void ReduceOp::jit_run() { + auto* __restrict__ xp = x->ptr(); + auto* __restrict__ yp = y->ptr(); + + @for(i, 0, DIM, index_t xshape@i = x->shape[@i];) + @for(i, 0, DIM, index_t yshape@i = @if(REDUCE>>i&1,1,xshape@i);) + index_t ystride@{DIM-1} = 1; + @for(i, DIM-2, -1, -1, auto ystride@i = ystride@{i+1} * yshape@{i+1};) + index_t xstride@{DIM-1} = 1; + @for(i, DIM-2, -1, -1, auto xstride@i = xstride@{i+1} * xshape@{i+1};) + Ty count = x->num*1.0 / y->num; + Ty rcount = y->num*1.0 / x->num; + @for(d, 0, DIM,@if(REDUCE>>d&1,, for (index_t xi@d=0; xi@d < xshape@d; xi@d++))) { + auto yid = 0 @for(d, 0, DIM,@if(REDUCE>>d&1,, + xi@d * ystride@d)); + yp[yid] = @expand_op(init_@OP, @Ty); + } + + @for(d, 0, DIM,@if(REDUCE>>d&1,, for (index_t xi@d=0; xi@d < xshape@d; xi@d++))) { + @for(d, 0, DIM,@if(REDUCE>>d&1, for (index_t xi@d=0; xi@d < xshape@d; xi@d++),)) { + auto yid = 0 @for(d, 0, DIM,@if(REDUCE>>d&1,, + xi@d * ystride@d)); + auto xid = 0 @for(d, 0, DIM, + xi@d * xstride@d); + yp[yid] = @expand_op(@OP, @Ty, yp[yid], @Ty, xp[xid], @Tx); + } + } + (void)count, (void)rcount, (void)yshape0, (void)ystride0; +} +#endif // JIT + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/reduce_op.h b/python/jittor/src/ops/reduce_op.h new file mode 100644 index 00000000..2259370b --- /dev/null +++ b/python/jittor/src/ops/reduce_op.h @@ -0,0 +1,27 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct ReduceOp : Op { + Var* x, * y; + uint16 reduce_mask; // i-th bit is 1 of dim-i is reduced + uint16 keepdims_mask; + ReduceOp(Var* x, NanoString op, int dim, bool keepdims=false); + ReduceOp(Var* x, NanoString op, NanoVector dims=NanoVector(), bool keepdims=false); + // @pybind(None) + ReduceOp(Var* x, NanoString op, uint dims_mask, uint keepdims_mask); + + const char* name() const override { return "reduce"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/reindex_op.cc b/python/jittor/src/ops/reindex_op.cc new file mode 100644 index 00000000..5ac1e0f1 --- /dev/null +++ b/python/jittor/src/ops/reindex_op.cc @@ -0,0 +1,142 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "ops/reindex_op.h" +#include "ops/op_register.h" + +namespace jittor { + +#ifndef JIT +static auto make_reindex_reduce = get_op_info("reindex_reduce") + .get_constructor&&, vector&&, vector&&>(); +static auto make_reindex = get_op_info("reindex") + .get_constructor&&, float64, vector&&, vector&&>(); + +ReindexOp::ReindexOp(Var* x, NanoVector shape, vector&& indexes, float64 overflow_value, vector&& overflow_conditions, vector&& extras) + : x(x), + shape(shape), + indexes(move(indexes)), + overflow_conditions(move(overflow_conditions)), + overflow_value(overflow_value), + extras(extras) { + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + set_type(OpType::broadcast); + flags.set(NodeFlags::_manual_set_vnbb); + for (auto& v : extras) v->flags.set(NodeFlags::_needed_by_backward); + y = create_output(nullptr, x->dtype()); +} + +ReindexOp::ReindexOp(Var* x, vector&& indexes, float64 overflow_value, vector&& overflow_conditions) + : x(x), overflow_conditions(move(overflow_conditions)), overflow_value(overflow_value) { + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + set_type(OpType::broadcast); + y = create_output(nullptr, x->dtype()); + ASSERTop(indexes.size(),==,x->shape.size()); + auto& shape = indexes[0]->shape; + ASSERT(indexes.size()<=10 && shape.size()<=10); + + string temp; + temp.reserve(6+3*shape.size()); // @e0(i0,i1) + temp += "@e0("; + for (uint i=0; iindexes.reserve(indexes.size()); + for (uint i=0; ishape; + ASSERTop(ns.size(),==,shape.size()); + for (uint j=0; jindexes.emplace_back(temp); + } + // TODO: fix it, we can't move indexes now, + // because we need it to add_inputs outside + // extras = move(indexes); + extras = indexes; + for (uint i = 0; i < indexes.size(); ++i) { + indexes[i]->flags.set(NodeFlags::_force_fuse); + indexes[i]->flags.set(NodeFlags::_needed_by_backward); + } +} + +VarPtr ReindexOp::duplicate() { + return make_reindex(x, shape, clone(indexes), overflow_value, clone(overflow_conditions), clone(extras)); +} + +VarPtr ReindexOp::grad(Var* out, Var* dout, Var* v, int v_index) { + // Do not have grad to extras input + if (v_index) return nullptr; + return make_reindex_reduce(dout, ns_add, x->shape, clone(indexes), clone(overflow_conditions), move(extras)); +} + +void ReindexOp::infer_shape() { + CHECKop(x->shape.size(),==,indexes.size()) << "Number of x's shape and indexes should be the same."; + if (shape.size()) + y->set_shape(shape); + else { + ASSERT(extras.size()); + y->set_shape(extras[0]->shape); + } + CHECK(y->shape.size()) << "Number of shape should greater than 0."; +} + +void ReindexOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype() + << "«XDIM=" << JK::hex1(x->shape.size()) + << "«YDIM=" << JK::hex1(y->shape.size()) + << "«OVERFLOW:" << overflow_value; + for (uint i=0; idtype(); + } +} + +#else // JIT +void ReindexOp::jit_run() { + auto* __restrict__ xp = x->ptr(); + // define extra + @for(i, 0, ESIZE, + auto* __restrict__ extras@i@@p = extras[@i]->ptr(); + @for(j, 0, EDIM@i, index_t extras@i@@shape@j = extras[@i]->shape[@j];) + index_t extras@i@@stride@{EDIM@i-1} = 1; + @for(j, EDIM@i-2, -1, -1, auto extras@i@@stride@j = extras@i@@stride@{j+1} * extras@i@@shape@{j+1};) + ) + auto* __restrict__ yp = y->ptr(); + // define y shape + @for(i, 0, YDIM, index_t yshape@i = y->shape[@i];) + // define y stride + index_t ystride@{YDIM-1} = 1; + @for(i, YDIM-2, -1, -1, auto ystride@i = ystride@{i+1} * yshape@{i+1};) + // define x shape + @for(i, 0, XDIM, index_t xshape@i = x->shape[@i];) + // define x stride + index_t xstride@{XDIM-1} = 1; + @for(i, XDIM-2, -1, -1, auto xstride@i = xstride@{i+1} * xshape@{i+1};) + // generate d-for loop + @for(d, 0, YDIM, for (index_t i@d=0; i@d < yshape@d; i@d++)) { + auto yid = @for(d, 0, YDIM, + i@d * ystride@d); + @for(d, 0, XDIM, index_t xid@d = @expand_macro(INDEX@d);) + auto xid = @for(d, 0, XDIM, + xid@d * xstride@d); + bool check_overflow = 0 @for(d, 0, XDIM, || xid@d<0 || xid@d>=xshape@d) @for(d, 0, OSIZE, || (@expand_macro(OFD@d))); + yp[yid] = check_overflow ? Tx(@OVERFLOW) : xp[xid]; + } +} +#endif // JIT + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/reindex_op.h b/python/jittor/src/ops/reindex_op.h new file mode 100644 index 00000000..9b67ee8b --- /dev/null +++ b/python/jittor/src/ops/reindex_op.h @@ -0,0 +1,106 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + + +namespace jittor { + +struct ReindexOp : Op { + Var* x, * y; + NanoVector shape; + vector indexes; + vector overflow_conditions; + float64 overflow_value; + vector extras; + /** + Reindex Operator is a one-to-many map operator. + It performs equivalent Python-pseudo implementation below:: + + # input is x, output is y + n = len(shape)-1 + m = len(x.shape)-1 + k = len(overflow_conditions)-1 + y = np.zeros(shape, x.dtype) + for i0 in range(shape[0]): # 1-st loop + for i1 in range(shape[1]): # 2-nd loop + ...... # many loops + for in in range(shape[n]) # n+1 -th loop + if is_overflow(i0,i1,...,in): + y[i0,i1,...,in] = overflow_value + else: + # indexes[i] is a c++ style integer expression consisting of i0,i1,...,in + y[i0,i1,...,in] = x[indexes[0],indexes[1],...,indexes[m]] + + # is_overflow is defined as following + def is_overflow(i0,i1,...,in): + return ( + indexes[0] < 0 || indexes[0] >= x.shape[0] || + indexes[1] < 0 || indexes[1] >= x.shape[1] || + ...... + indexes[m] < 0 || indexes[m] >= x.shape[m] || + + # overflow_conditions[i] is a c++ style boolean expression consisting of i0,i1,...,in + overflow_conditions[0] || + overflow_conditions[1] || + ...... + overflow_conditions[k] + ) + ---------------- + * [in] x: A input jittor Var + + * [in] shape: the output shape, a integer array + + * [in] indexes: array of c++ style integer expression, its length should be the same with the number of dimension of x, some buildin variables it can use are:: + + XDIM, xshape0, ..., xshapen, xstride0, ..., xstriden + YDIM, yshape0, ..., yshapem, ystride0, ..., ystridem + i0, i1, ..., in + @e0(...), @e1(...) for extras input index + e0p, e1p , ... for extras input pointer + + * [in] overflow_value: overflow value + + * [in] overflow_conditions: array of c++ style boolean expression, it length can be vary. the buildin variables it can use are the same with indexes + + * [in] extras: extra var used for index + + ---------------- + Example + Convolution implemented by reindex operation:: + + def conv(x, w): + N,H,W,C = x.shape + Kh, Kw, _C, Kc = w.shape + assert C==_C + xx = x.reindex([N,H-Kh+1,W-Kw+1,Kh,Kw,C,Kc], [ + 'i0', # Nid + 'i1+i3', # Hid+Khid + 'i2+i4', # Wid+KWid + 'i5', # Cid + ]) + ww = w.broadcast_var(xx) + yy = xx*ww + y = yy.sum([3,4,5]) # Kh, Kw, C + return y, yy + */ + ReindexOp(Var* x, NanoVector shape, vector&& indexes, float64 overflow_value=0, vector&& overflow_conditions={}, vector&& extras={}); + /** Alias x.reindex([i,j,k]) -> + x.reindex(i.shape, ['@e0(...)','@e1(...)','@e2(...)',], extras=[i,j,k]) + */ + // @pybind(reindex,reindex_var) + ReindexOp(Var* x, vector&& indexes, float64 overflow_value=0, vector&& overflow_conditions={}); + + + const char* name() const override { return "reindex"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; + VarPtr duplicate() override; + DECLARE_jit_run; +}; + +} // jittor diff --git a/python/jittor/src/ops/reindex_reduce_op.cc b/python/jittor/src/ops/reindex_reduce_op.cc new file mode 100644 index 00000000..89f4d2c4 --- /dev/null +++ b/python/jittor/src/ops/reindex_reduce_op.cc @@ -0,0 +1,134 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#include "var.h" +#include "ops/reindex_reduce_op.h" +#include "ops/op_register.h" + +namespace jittor { + +#ifndef JIT +static auto make_reindex = get_op_info("reindex") + .get_constructor&&, float64, vector&&, vector&&>(); +static auto make_binary = get_op_info("binary") + .get_constructor(); +static auto make_ternary = get_op_info("ternary") + .get_constructor(); +static auto make_number = get_op_info("number") + .get_constructor(); + + +ReindexReduceOp::ReindexReduceOp(Var* y, NanoString op, NanoVector shape, vector&& indexes, vector&& overflow_conditions, vector&& extras) + : y(y), shape(shape), indexes(move(indexes)), overflow_conditions(move(overflow_conditions)), extras(extras) { + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + set_type(OpType::reduce); + if (op.get(NanoString::_no_need_back_in)) + flags.set(NodeFlags::_manual_set_vnbb); + ns = op; + ASSERT((ns.is_binary() && ns!=ns_mean) || ns == ns_void); + x = create_output(nullptr, y->dtype()); + for (auto e : extras) { + if (e->shape != y->shape) { + e->flags.set(NodeFlags::_stop_fuse); + } + if (op.get(NanoString::_no_need_back_in)) + e->flags.set(NodeFlags::_needed_by_backward); + } +} + +VarPtr ReindexReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) { + // Do not have grad to extras input + if (v_index) return nullptr; + if (ns == ns_add) + return make_reindex(dout, v->shape, clone(indexes), 0, clone(overflow_conditions), move(extras)); + if (ns == ns_multiply) { + VarPtr a = make_binary(dout, out, ns_multiply); + VarPtr b = make_reindex(a, v->shape, clone(indexes), 0, clone(overflow_conditions), move(extras)); + return make_binary(b, v, ns_divide); + } + if (ns == ns_maximum || ns == ns_minimum) { + VarPtr zeros = make_number(0, v); + VarPtr a = make_reindex(out, v->shape, clone(indexes), 0, clone(overflow_conditions), move(extras)); + VarPtr cond = make_binary(v, a, ns_equal); + VarPtr dv = make_reindex(dout, v->shape, clone(indexes), 0, clone(overflow_conditions), move(extras)); + return make_ternary(cond, dv, zeros); + } + return nullptr; +} + +void ReindexReduceOp::infer_shape() { + CHECKop(shape.size(),==,indexes.size()) << "Number of shape and indexes should be the same."; + CHECK(shape.size()) << "Number of shape should greater than 0."; + for (auto v : shape) + CHECKop(v,>=,0u) << "Shape should greater than 0."; + x->set_shape(shape); + CHECKop(x->size,>=,0u); + CHECKop(y->size,>=,0u); +} + +void ReindexReduceOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype() + << "«OP:" << ns + << "«YDIM=" << JK::hex1(y->shape.size()) + << "«XDIM=" << JK::hex1(x->shape.size()); + for (uint i=0; idtype(); + } +} + +#else // JIT +void ReindexReduceOp::jit_run() { + auto* __restrict__ yp = y->ptr(); + // define extra + @for(i, 0, ESIZE, + auto* __restrict__ extras@i@@p = extras[@i]->ptr(); + @for(j, 0, EDIM@i, index_t extras@i@@shape@j = extras[@i]->shape[@j];) + index_t extras@i@@stride@{EDIM@i-1} = 1; + @for(j, EDIM@i-2, -1, -1, auto extras@i@@stride@j = extras@i@@stride@{j+1} * extras@i@@shape@{j+1};) + ) + auto* __restrict__ xp = x->ptr(); + // define x shape + @for(i, 0, XDIM, index_t xshape@i = x->shape[@i];) + // define x stride + index_t xstride@{XDIM-1} = 1; + @for(i, XDIM-2, -1, -1, auto xstride@i = xstride@{i+1} * xshape@{i+1};) + // define y shape + @for(i, 0, YDIM, index_t yshape@i = y->shape[@i];) + // define y stride + index_t ystride@{YDIM-1} = 1; + @for(i, YDIM-2, -1, -1, auto ystride@i = ystride@{i+1} * yshape@{i+1};) + // init + + @if(@strcmp(@OP, void)==0,, + @for(d, 0, XDIM, for (index_t i@d=0; i@d < xshape@d; i@d++)) { + auto xid = @for(d, 0, XDIM, + i@d * xstride@d); + xp[xid] = @expand_op(init_@OP, @Tx); + } + ) // end @if + + // generate d-for loop + @for(d, 0, YDIM, for (index_t i@d=0; i@d < yshape@d; i@d++)) { + auto yid = @for(d, 0, YDIM, + i@d * ystride@d); + @for(d, 0, XDIM, index_t xid@d = @expand_macro(INDEX@d);) + auto xid = @for(d, 0, XDIM, + xid@d * xstride@d); + bool check_overflow = 0 @for(d, 0, XDIM, || xid@d<0 || xid@d>=xshape@d) @for(d, 0, OSIZE, || (@expand_macro(OFD@d))); + if (!check_overflow) + xp[xid] = @expand_op(@OP, @Tx, xp[xid], @Tx, yp[yid], @Tx); + } +} +#endif // JIT + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/reindex_reduce_op.h b/python/jittor/src/ops/reindex_reduce_op.h new file mode 100644 index 00000000..70ec79c0 --- /dev/null +++ b/python/jittor/src/ops/reindex_reduce_op.h @@ -0,0 +1,94 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + + +namespace jittor { + +struct ReindexReduceOp : Op { + Var* y, * x; + NanoVector shape; + vector indexes; + vector overflow_conditions; + vector extras; + /** + Reindex Reduce Operator is a many-to-one map operator. + It performs equivalent Python-pseudo implementation below:: + + # input is y, output is x + n = len(y.shape)-1 + m = len(shape)-1 + k = len(overflow_conditions)-1 + x = np.zeros(shape, y.dtype) + x[:] = initial_value(op) + for i0 in range(y.shape[0]): # 1-st loop + for i1 in range(y.shape[1]): # 2-nd loop + ...... # many loops + for in in range(y.shape[n]) # n+1 -th loop + # indexes[i] is a c++ style integer expression consisting of i0,i1,...,in + xi0,xi1,...,xim = indexes[0],indexes[1],...,indexes[m] + if not is_overflow(xi0,xi1,...,xim): + x[xi0,xi1,...,xim] = op(x[xi0,xi1,...,xim], y[i0,i1,...,in]) + + # is_overflow is defined as following + def is_overflow(xi0,xi1,...,xim): + return ( + xi0 < 0 || xi0 >= shape[0] || + xi1 < 0 || xi1 >= shape[1] || + ...... + xim < 0 || xim >= shape[m] || + + # overflow_conditions[i] is a c++ style boolean expression consisting of i0,i1,...,in + overflow_conditions[0] || + overflow_conditions[1] || + ...... + overflow_conditions[k] + ) + + * [in] y: A input jittor Var + + * [in] op: a string represent the reduce operation type + + * [in] shape: the output shape, a integer array + + * [in] indexes: array of c++ style integer expression, its length should be the same with length of output shape, some buildin variables it can use are:: + + XDIM, xshape0, ..., xshapem, xstride0, ..., xstridem + YDIM, yshape0, ..., yshapen, ystride0, ..., ystriden + i0, i1, ..., in + @e0(...), @e1(...) for extras input index + e0p, e1p , ... for extras input pointer + + * [in] overflow_conditions: array of c++ style boolean expression, it length can be vary. the buildin variables it can use are the same with indexes. + + * [in] extras: extra var used for index + + Example + + Pooling implemented by reindex operation:: + + def pool(x, size, op): + N,H,W,C = x.shape + h = (H+size-1)//size + w = (W+size-1)//size + return x.reindex_reduce(op, [N,h,w,C], [ + "i0", # Nid + f"i1/{size}", # Hid + f"i2/{size}", # Wid + "i3", # Cid + ]) + */ + ReindexReduceOp(Var* y, NanoString op, NanoVector shape, vector&& indexes, vector&& overflow_conditions={}, vector&& extras={}); + + const char* name() const override { return "reindex_reduce"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/reshape_op.cc b/python/jittor/src/ops/reshape_op.cc new file mode 100644 index 00000000..dd35f865 --- /dev/null +++ b/python/jittor/src/ops/reshape_op.cc @@ -0,0 +1,60 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "ops/array_op.h" +#include "ops/op_register.h" +#include "ops/reshape_op.h" + +namespace jittor { + +static auto make_reshape = get_op_info("reshape") + .get_constructor(); + +ReshapeOp::ReshapeOp(Var* x, NanoVector shape) : x(x), shape(shape) { + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + flags.set(NodeFlags::_manual_set_vnbb); + y = create_output(nullptr, x->dtype()); + ASSERT(shape.size() > 0) << "input target shape of reshape can't be empty."; +} + +VarPtr ReshapeOp::grad(Var* out, Var* dout, Var* v, int v_index) { + return make_reshape(dout, x->shape); +} + +void ReshapeOp::infer_shape() { + size_t uncertain_dim = 0; + int64_t y_items = 1; + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] < 0) { + ++uncertain_dim; + } else + y_items *= shape[i]; + } + CHECK(uncertain_dim <= 1) << "max number of -1 is 1, but get" << uncertain_dim << "."; + int64_t x_items = x->num; + auto yshape = shape; + if (uncertain_dim == 0) { + CHECKop(x_items,==,y_items) << "reshape shape is invalid for input of size"; + } else { + if (x_items == 0) { + uncertain_dim = 0; + } else { + CHECK(y_items != 0 && x_items % y_items == 0) << "reshape shape is invalid for input of size " << x_items; + uncertain_dim = x_items / y_items; + } + yshape.clear(); + for (auto a : shape) + yshape.push_back(a<0 ? uncertain_dim : a); + } + y->set_shape(yshape); + y->share_with(x); +} +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/reshape_op.h b/python/jittor/src/ops/reshape_op.h new file mode 100644 index 00000000..eb9181e5 --- /dev/null +++ b/python/jittor/src/ops/reshape_op.h @@ -0,0 +1,50 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct ReshapeOp : Op { + Var* x, * y; + NanoVector shape; + + /** + Returns a tensor with the same data and number of elements as input, but with the specified shape. + + A single dimension may be -1, in which case it's inferred from the remaining dimensions and the number of elements in input. + + ---------------- + + * [in] x: the input jt.Var + + * [in] shape: the output shape, an integer array + + ---------------- + + Example-1:: + >>> a = jt.randint(0, 10, shape=(12,)) + >>> a + jt.Var([4 0 8 4 6 3 1 8 1 1 2 2], dtype=int32) + >>> jt.reshape(a, (3, 4)) + jt.Var([[4 0 8 4] + [6 3 1 8] + [1 1 2 2]], dtype=int32) + >>> jt.reshape(a, (-1, 6)) + jt.Var([[4 0 8 4 6 3] + [1 8 1 1 2 2]], dtype=int32) + */ + ReshapeOp(Var* x, NanoVector shape); + + const char* name() const override { return "reshape"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; +}; +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/safe_clip_op.cc b/python/jittor/src/ops/safe_clip_op.cc new file mode 100644 index 00000000..33086068 --- /dev/null +++ b/python/jittor/src/ops/safe_clip_op.cc @@ -0,0 +1,54 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "ops/safe_clip_op.h" +#include "ops/op_register.h" + +namespace jittor { + +#ifndef JIT + +SafeClipOp::SafeClipOp(Var* x, float64 left, float64 right) : x(x), left(left), right(right) { + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + flags.set(NodeFlags::_manual_set_vnbb); + set_type(OpType::element); + y = create_output(nullptr, x->dtype()); +} + +VarPtr SafeClipOp::grad(Var* out, Var* dout, Var* v, int v_index) { + return dout; +} + +void SafeClipOp::infer_shape() { + y->set_shape(x->shape); +} + +void SafeClipOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype() <<"«"; +} + +#else // JIT +void SafeClipOp::jit_run() { + auto* __restrict__ xp = x->ptr(); + Tx left_value = (Tx)std::max((float64) + @if(@strcmp(@Tx,float16)==0,-65500, + @if(@strcmp(@Tx,bfloat16)==0,-1e38, + std::numeric_limits::lowest())), left); + Tx right_value = (Tx)std::min((float64) + @if(@strcmp(@Tx,float16)==0,65500, + @if(@strcmp(@Tx,bfloat16)==0,1e38, + std::numeric_limits::max())), right); + auto* __restrict__ yp = y->ptr(); + index_t num = y->num; + for (index_t i=0; i right_value ? right_value : xp[i]); +} +#endif // JIT + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/safe_clip_op.h b/python/jittor/src/ops/safe_clip_op.h new file mode 100644 index 00000000..4deba0f7 --- /dev/null +++ b/python/jittor/src/ops/safe_clip_op.h @@ -0,0 +1,33 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + + +namespace jittor { + +struct SafeClipOp : Op { + Var* x, * y; + float64 left, right; + /** Safe clip value to a range, and keep + the gradient pass thought. + + * [in] x: input value + * [in] left: float64 clip min value. + * [in] right: float64 clip max value. + + */ + // @pybind(safe_clip) + SafeClipOp(Var* x, float64 left=-1e300, float64 right=1e300); + + const char* name() const override { return "safe_clip"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/setitem_op.cc b/python/jittor/src/ops/setitem_op.cc new file mode 100644 index 00000000..b8d8d8e2 --- /dev/null +++ b/python/jittor/src/ops/setitem_op.cc @@ -0,0 +1,372 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "ops/setitem_op.h" +#include "ops/getitem_op.h" +#ifdef JIT +#ifdef JIT_cuda +#include +#include "helper_cuda.h" +#endif +#else +#include "ops/op_register.h" +#ifdef HAS_CUDA +#include "misc/cuda_flags.h" +#endif +#endif + +namespace jittor { + +#ifndef JIT + +static auto make_array = get_op_info("array") + .get_constructor(); +static auto make_getitem = get_op_info("getitem") + .get_constructor(); +static auto make_getitem2 = get_op_info("getitem") + .get_constructor, Var*, VarSlices&&, int>(); +static auto make_setitem = get_op_info("setitem") + .get_constructor(); +static auto make_binary = get_op_info("binary") + .get_constructor(); +static auto make_unary = get_op_info("unary") + .get_constructor(); + +SetitemOp::SetitemOp(Var* x, VarSlices&& slices, Var* y, NanoString op) + : vs(move(slices)), op(op) { + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + flags.set(NodeFlags::_has_gopt); + if (op.get(NanoString::_no_need_back_in)) { + flags.set(NodeFlags::_manual_set_vnbb); + for (int i=0; iflags.set(NodeFlags::_needed_by_backward); + } + ASSERT(op == ns_void || op.is_binary()); + create_output(nullptr, x->dtype()); + if (flags.get(NodeFlags::_custom_flag)) { + flags.set(NodeFlags::_grads); + } +} + +void SetitemOp::infer_shape() { + auto in = inputs().front(); + auto data = input(1); + auto out = outputs().front(); + auto in_shape = in->shape; + auto nin = in_shape.size(); + + StackVector<> i_to_vs(nin); + StackVector<> i_to_o(nin); + // shape return to use + StackVector<> out_shape; + ((GetitemOp*)this)->infer_slices(i_to_vs, i_to_o, out_shape); + if (!out_shape.size()) out_shape.push_back(1); + + // get broadcast mask of set value + auto data_shape = data->shape; + auto data_dim = data_shape.size(); + int bmask = 0; + int bmask2 = 0; + + CHECKop(data_dim,<=,out_shape.size()) << "Data dimension not match"; + for (int i=0; i o_shape; + int fov = -1; + for (int i=0; i=0) { + if (vid==-1 && i && i_to_vs[i-1]<0 + && ((bmask>>oid)&1) == ((bmask>>(oid-1))&1)) + // same broadcast condition with prev dim + { + vid = -2; + o_shape.back() *= os; + } else { + o_shape.push_back(os); + // fix bmask2 offset + bmask2 |= ((bmask>>oid)&1) << (o_shape.size()-1); + } + oid = o_shape.size()-1; + } else { + auto& s = vs.slices[vid]; + if (s.is_var() && fov == -1) { + fov = o_shape.size(); + for (int i=0; i>(first_oid_of_var+i))&1) << (o_shape.size()-1); + } + } + } + } + first_oid_of_var = fov; + this->bmask = bmask2; + + out->set_shape(in_shape); + + this->i_to_vs = i_to_vs.to_nano_vector(); + this->i_to_o = i_to_o.to_nano_vector(); + this->o_shape = o_shape.to_nano_vector(); + + LOGvvvv << "\ni_to_vs:" << i_to_vs + << "\ni_to_o:" << i_to_o + << "\no_shape:" << o_shape; +} + +void SetitemOp::grads(Var** dout, VarPtr* dins) { + if (!dout[0]) return; + auto outs = make_getitem2(dout[0], VarSlices(vs, true), 0); + dins[0] = move(outs[1]); + dins[1] = move(outs[0]); +} + +VarPtr SetitemOp::grad(Var* out, Var* dout, Var* v, int v_index) { + if (v_index >= 2) + return nullptr; + if (op == ns_void) { + if (v_index == 0) { + float32 number = 0; + VarPtr zero = make_array(&number, 1, ns_float32); + return make_setitem(dout, VarSlices(vs, true), zero, ns_void); + } else { + return make_getitem(dout, VarSlices(vs, true)); + } + } + if (op == ns_add) { + if (v_index == 0) { + return dout; + } else { + return make_getitem(dout, VarSlices(vs, true)); + } + } + if (op == ns_subtract) { + if (v_index == 0) { + return dout; + } else { + return make_unary(make_getitem(dout, VarSlices(vs, true)), ns_negative); + } + } + if (op == ns_multiply) { + if (v_index == 0) { + return make_setitem(dout, VarSlices(vs, true), input(1), ns_multiply); + } else { + return make_binary( + make_getitem(inputs().front(), VarSlices(vs, true)), + make_getitem(dout, VarSlices(vs, true)), ns_multiply); + } + } + if (op == ns_divide) { + if (v_index == 0) { + return make_setitem(dout, VarSlices(vs, true), input(1), ns_divide); + } else { + // dy = -dz*x / y^2 + auto dout2 = make_getitem(dout, VarSlices(vs, true)); + auto x = make_getitem(inputs().front(), VarSlices(vs, true)); + auto y = v; + auto ndz = make_unary(dout2, ns_negative); + auto ndzx = make_binary(ndz, x, ns_multiply); + auto y2 = make_binary(y, y, ns_multiply); + return make_binary(ndzx, y2, ns_divide); + } + } + LOGf << "Setitem grad of op" << op << "is not supported yet"; + return nullptr; +} + +void SetitemOp::jit_prepare(JK& jk) { + for (int i=0; idtype() + << "«BMASK=" << JK::hex(bmask); + // TODO: merge code + auto in = inputs().front(); + int idim = i_to_vs.size(); + jk << "«Ti:" << in->dtype(); + jk << "«IDIM=" << JK::hex1(i_to_vs.size()); + jk << "«ODIM=" << JK::hex1(o_shape.size()); + if (first_oid_of_var>=0) { + jk << "«FOV=" << JK::hex1(first_oid_of_var); + jk << "«VD=" << JK::hex1(var_dim); + } + for (int i=0; i=0 && io==-1) { + if (v.is_int()) { + jk << "«VS" << JK::hex1(i) << ":-1"; + } else + if (v.is_str()) { + jk << "«VS" << JK::hex1(i) << ":-5"; + jk << "«VSS" << JK::hex1(i) << ":" << v.get_str(); + } else { + ASSERT(v.is_var()); + auto var = v.var; + auto vshape = var->shape; + auto vdim = vshape.size(); + int vsmask = 0; + for (int j=0; jdtype(); + } + } else + if (iv>=0 && io>=0) { + ASSERT(v.is_slice()); + jk << "«VS" << JK::hex1(i) << ':'; + if (std::abs(v.slice.step) <= 1) + jk << JK::shex1(v.slice.step); + else + jk << '0'; + } + } + #ifdef HAS_CUDA + if (use_cuda) { + int no = o_shape.size(); + STACK_ALLOC(int, masks, no); + int tdims[6]; + cuda_loop_schedule(o_shape, masks, tdims); + for (int i=0; i. +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "ops/array_op.h" +#include "ops/op_register.h" +#include "ops/tape_op.h" + +namespace jittor { + +TapeOp::TapeOp(Var* x) { + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + flags.set(NodeFlags::_manual_set_vnbb); + create_output(nullptr, x->dtype()); +} + +VarPtr TapeOp::grad(Var* out, Var* dout, Var* v, int v_index) { + return dout; +} + +void TapeOp::infer_shape() { + auto x = inputs().front(); + auto y = outputs().front(); + y->set_shape(x->shape); + y->share_with(x); +} + +void Tapes::grads(Var** douts, VarPtr* dins) { + CHECK(callback.deleter); + try { + callback.func(_outputs.size(), douts, _inputs.size(), dins); + } catch (...) { + // if error occur in callback, we need to + // free it to prevent memory leak, but this is still + // not enough, error may occur outside. please + // find a better solution + callback.deleter(); + callback.deleter = nullptr; + throw; + } +} + +Tapes::Tapes( + const vector& taped_inputs, + const vector& taped_outputs, + GradCallback&& grad_callback +) { + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + flags.set(NodeFlags::_grads); + flags.set(NodeFlags::_manual_set_vnbb); + callback = move(grad_callback); + + + /* + stop grad stop grad + i --> tape --> t_i ---> .... ---> o --> tape --> t_o + | ^ + +---> tapes ------------------------------+ + */ + // set tape output + for (int i=0; ivar->dtype()); + out->add_inputs({this}); + auto v = taped_outputs[i]->var; + auto op = v->input(); + op->add_inputs(vector{out.ptr}); + } + // set tapes input + vector tin(taped_inputs.size()); + for (int i=0; ivar->input()->inputs().front(); + } + add_inputs(tin); + // stop grad for input and output + for (int i=0; ivar->set_stop_grad(); + } + for (int i=0; ivar->input()->inputs().front()->set_stop_grad(); + } +} + +void tape_together( + const vector& taped_inputs, + const vector& taped_outputs, + GradCallback&& grad_callback +) { + new Tapes(taped_inputs, taped_outputs, move(grad_callback)); +} + +} // jittor diff --git a/python/jittor/src/ops/tape_op.h b/python/jittor/src/ops/tape_op.h new file mode 100644 index 00000000..299e410c --- /dev/null +++ b/python/jittor/src/ops/tape_op.h @@ -0,0 +1,59 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include "op.h" +#include "var_holder.h" + +namespace jittor { + +struct Tapes; + +struct GradCallback { + typedef jittor::VarHolder VarHolder; + typedef VarHolder* VarHolderPtr; + typedef jittor::Var Var; + typedef jittor::VarPtr VarPtr; + std::function func; + std::function deleter; + inline ~GradCallback() { if (deleter) deleter(); } + GradCallback(const GradCallback&) = delete; + GradCallback() = default; + GradCallback(GradCallback&& other) : func(other.func), deleter(other.deleter) { + other.func = nullptr; + other.deleter = nullptr; + }; + GradCallback(std::function && func, std::function&& deleter) + : func(move(func)), deleter(move(deleter)) {}; + + void operator =(GradCallback&& other) { this->~GradCallback(); new (this) GradCallback(move(other)); } +}; + +struct TapeOp final : Op { + TapeOp(Var* x); + + const char* name() const override { return "tape"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; +}; + + +struct Tapes final : Op { + GradCallback callback; + Tapes( + const vector& taped_inputs, + const vector& taped_outputs, + GradCallback&& grad_callback + ); + const char* name() const override { return "tapes"; } + void grads(Var** douts, VarPtr* dins) override; +}; + + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/ternary_op.cc b/python/jittor/src/ops/ternary_op.cc new file mode 100644 index 00000000..428281aa --- /dev/null +++ b/python/jittor/src/ops/ternary_op.cc @@ -0,0 +1,98 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "ops/ternary_op.h" +#include "ops/op_register.h" + +namespace jittor { + +#ifndef JIT +static auto make_ternary = get_op_info("ternary") + .get_constructor(); +static auto make_broadcast = get_op_info("broadcast_to") + .get_constructor(); +static auto make_number = get_op_info("number") + .get_constructor(); + +TernaryOp::TernaryOp(Var* cond, Var* x, Var* y) : cond(cond), x(x), y(y) { + bool bx = cond->shape.size() > x->shape.size() || cond->num > x->num; + bool by = cond->shape.size() > y->shape.size() || cond->num > y->num; + bool bx2 = cond->shape.size() < x->shape.size() || cond->num < x->num; + bool by2 = cond->shape.size() < y->shape.size() || cond->num < y->num; + if (bx || by || bx2 || by2) { + VarPtr xx, yy, cc; + if (bx2) cc = make_broadcast(cond, x, NanoVector()), cond=cc; + if (by2) cc = make_broadcast(cond, y, NanoVector()), cond=cc; + bx = cond->shape.size() > x->shape.size() || cond->num > x->num; + by = cond->shape.size() > y->shape.size() || cond->num > y->num; + if (bx) xx = make_broadcast(x, cond, NanoVector()), x = xx; + if (by) yy = make_broadcast(y, cond, NanoVector()), y = yy; + forward(make_ternary(cond, x, y)); + return; + } + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + set_type(OpType::element); + flags.set(NodeFlags::_manual_set_vnbb); + cond->flags.set(NodeFlags::_needed_by_backward); + if (x->dtype() == y->dtype()) { + z = create_output(nullptr, x->dtype()); + } else { + z = create_output(nullptr, dtype_infer(x->ns, y->ns, x->flags.get(NodeFlags::_is_scalar), y->flags.get(NodeFlags::_is_scalar))); + } +} + +VarPtr TernaryOp::grad(Var* out, Var* dout, Var* v, int v_index) { + if (v_index==0) return nullptr; + auto zeros = make_number(0, dout); + if (v_index==1) + return make_ternary(cond, dout, zeros); + else + return make_ternary(cond, zeros, dout); +} + +void TernaryOp::infer_shape() { + auto xdim = x->shape.size(); + auto ydim = y->shape.size(); + auto cdim = cond->shape.size(); + CHECK(xdim==ydim && cdim==ydim) << "Number of dims should be the same."; + NanoVector zshape; + for (size_t i=0; ishape[i]; + auto yshape = y->shape[i]; + auto cshape = cond->shape[i]; + auto shape = std::min(xshape, std::min(yshape, cshape)); + auto shape2 = std::max(xshape, std::max(yshape, cshape)); + zshape.push_back(shape2); + CHECK(shape==shape2) << "Shape not match" << x->shape << y->shape << cond->shape; + } + z->set_shape(zshape); +} + +void TernaryOp::jit_prepare(JK& jk) { + jk << "«Tc:" << cond->dtype(); + jk << "«Tx:" << x->dtype(); + jk << "«Ty:" << y->dtype(); + jk << "«Tz:" << z->dtype(); +} + +#else // JIT +void TernaryOp::jit_run() { + auto* __restrict__ condp = cond->ptr(); + auto* __restrict__ xp = x->ptr(); + auto* __restrict__ yp = y->ptr(); + auto* __restrict__ zp = z->ptr(); + index_t num = z->num; + for (index_t i=0; i. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct TernaryOp : Op { + Var* cond, * x, * y, * z; + TernaryOp(Var* cond, Var* x, Var* y); + + const char* name() const override { return "ternary"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/transpose_op.cc b/python/jittor/src/ops/transpose_op.cc new file mode 100644 index 00000000..8a1fcb22 --- /dev/null +++ b/python/jittor/src/ops/transpose_op.cc @@ -0,0 +1,116 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "ops/transpose_op.h" +#include "var.h" +#include "ops/op_register.h" +#include "misc/cuda_flags.h" + +namespace jittor { + +#ifndef JIT +static auto make_transpose = get_op_info("transpose") + .get_constructor(); + +TransposeOp::TransposeOp(Var* x, NanoVector axes_) : x(x), axes(axes_) { + int i=0; + for (; ishape.size(); + if (!axes.size()) { + for (int i=0; i<(int)xdim; i++) + axes.push_back(xdim-1-i); + } + if (axes.size() < xdim || (axes.size() == xdim && axes[xdim-1]==xdim-1)) { + static VarPtr(*fuse_transpose)(Var*, NanoVector) = get_op_info("fuse_transpose").get_constructor(); + auto var = fuse_transpose(x, axes); + forward(var); + return; + } + #ifdef HAS_CUDA + if (use_cuda) { + static VarPtr(*cutt_transpose)(Var*, NanoVector) = nullptr; + if (!cutt_transpose && has_op("cutt_transpose")) { + cutt_transpose = get_op_info("cutt_transpose") + .get_constructor(); + } + if (cutt_transpose) { + auto var = cutt_transpose(x, axes); + forward(var); + return; + } + } + #endif + y = create_output(nullptr, x->dtype()); + flags.set(NodeFlags::_manual_set_vnbb); +} + +void TransposeOp::infer_shape() { + auto xdim = x->shape.size(); + CHECK(xdim); + if (!axes.size()) { + for (int i=0; i<(int)xdim; i++) + axes.push_back(xdim-1-i); + } else { + CHECKop(axes.size(),==,xdim); + int64_t mask=0; + for (auto i : axes) mask |= 1<shape[axes[i]]); + y->set_shape(shape); +} + +VarPtr TransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) { + NanoVector reverse; + reverse.reserve(axes.size(), axes.size()); + for (uint i=0; idtype(); + jk << "«DIM=" << JK::hex1(axes.size()); + for (uint i=0; i. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "misc/cpu_math.h" +#include "var.h" +#include "ops/unary_op.h" +#include "ops/op_register.h" + +namespace jittor { + +#ifndef JIT +static auto make_binary = get_op_info("binary") + .get_constructor(); +static auto make_unary = get_op_info("unary") + .get_constructor(); +static auto make_ternary = get_op_info("ternary") + .get_constructor(); +static auto make_number = get_op_info("number") + .get_constructor(); + +static unordered_set unary_ops = { + /** + Returns a copy of the input var, casted to boolean. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.arange(3) + >>> x + jt.Var([0 1 2], dtype=int32) + >>> x.bool() + jt.Var([False True True], dtype=bool) + >>> jt.bool(x) + jt.Var([False True True], dtype=bool) + */ + "bool", + + /** + Returns a copy of the input var, casted to int8. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.int8() + jt.Var([4 2 8], dtype=int8) + >>> jt.int8(x) + jt.Var([4 2 8], dtype=int8) + */ + "int8", + + /** + Returns a copy of the input var, casted to int16. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.int16() + jt.Var([4 2 8], dtype=int16) + >>> jt.int16(x) + jt.Var([4 2 8], dtype=int16) + */ + "int16", + + /** + Returns a copy of the input var, casted to int32. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.int() + jt.Var([4 2 8], dtype=int32) + >>> jt.int(x) + jt.Var([4 2 8], dtype=int32) + >>> x.int32() + jt.Var([4 2 8], dtype=int32) + >>> jt.int32(x) + jt.Var([4 2 8], dtype=int32) + >>> x.long() + jt.Var([4 2 8], dtype=int32) + >>> jt.long(x) + jt.Var([4 2 8], dtype=int32) + */ + "int32", + + /** + Returns a copy of the input var, casted to int64. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.int64() + jt.Var([4 2 8], dtype=int64) + >>> jt.int64(x) + jt.Var([4 2 8], dtype=int64) + */ + "int64", + + /** + Returns a copy of the input var, casted to unsigned int8. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.uint8() + jt.Var([4 2 8], dtype=uint8) + >>> jt.uint8(x) + jt.Var([4 2 8], dtype=uint8) + */ + "uint8", + + /** + Returns a copy of the input var, casted to unsigned int16. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.uint16() + jt.Var([4 2 8], dtype=uint16) + >>> jt.uint16(x) + jt.Var([4 2 8], dtype=uint16) + */ + "uint16", + + /** + Returns a copy of the input var, casted to unsigned int32. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.uint32() + jt.Var([4 2 8], dtype=uint32) + >>> jt.uint32(x) + jt.Var([4 2 8], dtype=uint32) + */ + "uint32", + + /** + Returns a copy of the input var, casted to unsigned int64. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.uint64() + jt.Var([4 2 8], dtype=uint64) + >>> jt.uint64(x) + jt.Var([4 2 8], dtype=uint64) + */ + "uint64", + + /** + Returns a copy of the input var, casted to float16 (half-precision float). + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.half() + jt.Var([4.094 2.008 8.48 ], dtype=float16) + >>> jt.half(x) + jt.Var([4.094 2.008 8.48 ], dtype=float16) + >>> x.float16() + jt.Var([4.094 2.008 8.48 ], dtype=float16) + >>> jt.float16(x) + jt.Var([4.094 2.008 8.48 ], dtype=float16) + */ + "float16", + + /** + Returns a copy of the input var, casted to bfloat16 (brain half-precision float). + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.rand(3) * 10 + >>> x + jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32) + >>> x.bfloat16() + jt.Var([4.094 2.008 8.48 ], dtype=bfloat16) + >>> jt.bfloat16(x) + jt.Var([4.094 2.008 8.48 ], dtype=bfloat16) + */ + "bfloat16", + + /** + Returns a copy of the input var, casted to float32. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.arange(3) + >>> x + jt.Var([0 1 2], dtype=int32) + >>> x.float() + jt.Var([0. 1. 2.], dtype=float32) + >>> jt.float(x) + jt.Var([0. 1. 2.], dtype=float32) + >>> x.float32() + jt.Var([0. 1. 2.], dtype=float32) + >>> jt.float32(x) + jt.Var([0. 1. 2.], dtype=float32) + */ + "float32", + + /** + Returns a copy of the input var, casted to float64 (double-precision float). + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> x = jt.arange(3) + >>> x + jt.Var([0 1 2], dtype=int32) + >>> x.double() + jt.Var([0. 1. 2.], dtype=float64) + >>> jt.double(x) + jt.Var([0. 1. 2.], dtype=float64) + >>> x.float64() + jt.Var([0. 1. 2.], dtype=float64) + >>> jt.float64(x) + jt.Var([0. 1. 2.], dtype=float64) + */ + "float64", + // please keep float64 the last type + + /** + Returns the absolute value of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var + + ---------------- + + Example-1:: + >>> jt.abs(jt.float32([-1, 0, 1])) + jt.Var([1. 0. 1.], dtype=float32) + */ + // @pybind(abs, __abs__) + "abs", + + /** + Returns the negative value of the input ``x``. + + This operator is equavilant to ``-x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> jt.negative(jt.float32([-1, 0, 1])) + jt.Var([ 1. -0. -1.], dtype=float32) + */ + // @pybind(negative, __neg__) + "negative", + + /** + Returns the logical NOT of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var, integal or boolean. + + ---------------- + + Example-1:: + >>> jt.logical_not(jt.int32([-1, 0, 1])) + jt.Var([False True False], dtype=bool) + */ + "logical_not", + + /** + Returns the bitwise NOT of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var, integal or boolean. + + ---------------- + + Example-1:: + >>> jt.bitwise_not(jt.int32([1, 2, -3])) + jt.Var([-2 -3 2], dtype=int32) + */ + "bitwise_not", + + /** + Returns the natural logarithm of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) * 2 + >>> x + jt.Var([0.02863695 1.30122 1.6048753 1.140261 ], dtype=float32) + >>> jt.log(x) + jt.Var([-3.5530574 0.26330233 0.47304606 0.13125724], dtype=float32) + >>> x.log() + jt.Var([-3.5530574 0.26330233 0.47304606 0.13125724], dtype=float32) + */ + "log", + + /** + Returns the exponential of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) * 2 + >>> x + jt.Var([1.9841381 1.4103996 0.5855549 1.4212812], dtype=float32) + >>> jt.exp(x) + jt.Var([7.2727766 4.0975924 1.7959872 4.1424246], dtype=float32) + >>> x.exp() + jt.Var([7.2727766 4.0975924 1.7959872 4.1424246], dtype=float32) + */ + "exp", + + /** + Returns the square root of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) * 2 + >>> x + jt.Var([0.81957287 0.5609612 0.07435933 1.7571875 ], dtype=float32) + >>> jt.sqrt(x) + jt.Var([0.90530264 0.7489734 0.27268907 1.3255895 ], dtype=float32) + >>> x.sqrt() + jt.Var([0.90530264 0.7489734 0.27268907 1.3255895 ], dtype=float32) + */ + "sqrt", + + /** + Returns the closest integer of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 2.101595 0.33055413 -0.44147047 -0.7720668 ], dtype=float32) + >>> jt.round(x) + jt.Var([ 2.0 0.0 0.0 -1.0], dtype=float32) + >>> x.round() + jt.Var([ 2.0 0.0 0.0 -1.0], dtype=float32) + */ + "round", + + /** + Returns the largest integer less than or equal to the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-1.0339162 -0.7259972 -0.9220003 -0.8449701], dtype=float32) + >>> jt.floor(x) + jt.Var([-2.0 -1.0 -1.0 -1.0], dtype=float32) + >>> x.floor + jt.Var([-2.0 -1.0 -1.0 -1.0], dtype=float32) + */ + "floor", + + /** + Returns the smallest integer greater than or equal to the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-1.0339162 -0.7259972 -0.9220003 -0.8449701], dtype=float32) + >>> jt.ceil(x) + jt.Var([-1.0 0.0 0.0 0.0], dtype=float32) + >>> x.ceil() + jt.Var([-1.0 0.0 0.0 0.0], dtype=float32) + */ + "ceil", + + /** + Returns the closest integer of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 2.101595 0.33055413 -0.44147047 -0.7720668 ], dtype=float32) + >>> jt.round_int(x) + jt.Var([ 2 0 0 -1], dtype=int32) + >>> x.round_int + jt.Var([ 2 0 0 -1], dtype=int32) + */ + "round_int", + + /** + Returns the largest integer less than or equal to the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-1.0339162 -0.7259972 -0.9220003 -0.8449701], dtype=float32) + >>> jt.floor_int(x) + jt.Var([-2 -1 -1 -1], dtype=int32) + >>> x.floor_int + jt.Var([-2 -1 -1 -1], dtype=int32) + */ + "floor_int", + + /** + Returns the smallest integer greater than or equal to the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-1.0339162 -0.7259972 -0.9220003 -0.8449701], dtype=float32) + >>> jt.ceil_int(x) + jt.Var([-1 0 0 0], dtype=int32) + >>> x.ceil_int() + jt.Var([-1 0 0 0], dtype=int32) + */ + "ceil_int", + + /** + Returns the sine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.32893723 -0.7112559 -0.872391 1.8001337 ], dtype=float32) + >>> jt.sin(x) + jt.Var([ 0.32303742 -0.6527857 -0.76586854 0.9738172 ], dtype=float32) + >>> x.sin() + jt.Var([ 0.32303742 -0.6527857 -0.76586854 0.9738172 ], dtype=float32) + */ + "sin", + + /** + Returns the arcsine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.09342023 -0.42522037 0.9264933 -0.785264 ], dtype=float32) + >>> jt.asin(x) + jt.Var([ 0.09355665 -0.43920535 1.1849847 -0.9031224 ], dtype=float32) + >>> x.asin() + jt.Var([ 0.09355665 -0.43920535 1.1849847 -0.9031224 ], dtype=float32) + */ + // @pybind(asin, arcsin) + "asin", + + /** + Returns the hyperbolic sine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.32893723 -0.7112559 -0.872391 1.8001337 ], dtype=float32) + >>> jt.sinh(x) + jt.Var([ 0.3349012 -0.77276015 -0.9873369 2.9425898 ], dtype=float32) + >>> x.sinh + jt.Var([ 0.3349012 -0.77276015 -0.9873369 2.9425898 ], dtype=float32) + */ + "sinh", + + /** + Returns the inverse hyperbolic sine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-1.9749726 -0.52341473 0.8906148 1.0338128 ], dtype=float32) + >>> jt.asinh(x) + jt.Var([-1.4323865 -0.5020559 0.8018747 0.90508187], dtype=float32) + >>> x.asinh() + jt.Var([-1.4323865 -0.5020559 0.8018747 0.90508187], dtype=float32) + */ + // @pybind(asinh, arcsinh) + "asinh", + + /** + Returns the tangent of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.32893723 -0.7112559 -0.872391 1.8001337 ], dtype=float32) + >>> jt.tan(x) + jt.Var([ 0.34133783 -0.8617148 -1.1910915 -4.283673 ], dtype=float32) + >>> x.tan() + jt.Var([ 0.34133783 -0.8617148 -1.1910915 -4.283673 ], dtype=float32) + */ + "tan", + + /** + Returns the inverse tangent of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-0.85885596 1.187804 0.47249675 0.95933187], dtype=float32) + >>> jt.atan(x) + jt.Var([-0.70961297 0.87102956 0.44140393 0.76464504], dtype=float32) + >>> x.atan() + jt.Var([-0.70961297 0.87102956 0.44140393 0.76464504], dtype=float32) + */ + // @pybind(atan, arctan) + "atan", + + /** + Returns the hyperbolic tangent of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([-0.85885596 1.187804 0.47249675 0.95933187], dtype=float32) + >>> jt.tanh(x) + jt.Var([-0.6956678 0.82989657 0.4402144 0.7439787 ], dtype=float32) + >>> x.tanh() + jt.Var([-0.6956678 0.82989657 0.4402144 0.7439787 ], dtype=float32) + */ + "tanh", + + /** + Returns the inverse hyperbolic tangent of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) * 2 - 1 + >>> x + jt.Var([ 0.9062414 -0.799802 -0.27219176 -0.7274077 ], dtype=float32) + >>> jt.atanh(x) + jt.Var([ 1.5060828 -1.0980625 -0.27922946 -0.9231999 ], dtype=float32) + >>> x.atanh() + jt.Var([ 1.5060828 -1.0980625 -0.27922946 -0.9231999 ], dtype=float32) + */ + // @pybind(atanh, arctanh) + "atanh", + + /** + Returns the cosine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.32893723 -0.7112559 -0.872391 1.8001337 ], dtype=float32) + >>> jt.cos(x) + jt.Var([ 0.9463862 0.7575426 0.6429972 -0.2273323], dtype=float32) + >>> x.cos() + jt.Var([ 0.9463862 0.7575426 0.6429972 -0.2273323], dtype=float32) + */ + "cos", + + /** + Returns the inverse cosine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) * 2 - 1 + >>> x + jt.Var([ 0.5876564 0.740723 -0.667666 0.5371753], dtype=float32) + >>> jt.acos(x) + jt.Var([0.9426371 0.7366504 2.3018656 1.0037117], dtype=float32) + >>> x.acos() + jt.Var([0.9426371 0.7366504 2.3018656 1.0037117], dtype=float32) + */ + // @pybind(acos, arccos) + "acos", + + /** + Returns the hyperbolic cosine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.32893723 -0.7112559 -0.872391 1.8001337 ], dtype=float32) + >>> jt.cosh(x) + jt.Var([1.0545894 1.2637873 1.405288 3.1078668], dtype=float32) + >>> x.cosh() + jt.Var([1.0545894 1.2637873 1.405288 3.1078668], dtype=float32) + */ + "cosh", + + /** + Returns the inverse hyperbolic cosine of the input ``x``. + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) + 1 + >>> x + jt.Var([1.3609099 1.8137748 1.1146184 1.3911307], dtype=float32) + >>> jt.acosh(x) + jt.Var([0.8259237 1.2020639 0.47432774 0.8579033 ], dtype=float32) + >>> x.acosh() + jt.Var([0.8259237 1.2020639 0.47432774 0.8579033 ], dtype=float32) + */ + // @pybind(acosh, arccosh) + "acosh", + + /** + Returns the sigmoid of the input ``x``. + + .. math:: + out_i = \frac{1}{1 + e^{x_i}} + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.49443012 0.4305426 -1.0364404 -1.2628382 ], dtype=float32) + >>> jt.sigmoid(x) + jt.Var([0.62114954 0.6060032 0.2618374 0.2204857 ], dtype=float32) + >>> x.sigmoid() + jt.Var([0.62114954 0.6060032 0.2618374 0.2204857 ], dtype=float32) + */ + "sigmoid", + + /** + Computes the error function of each element. The error function is defined as follows: + + .. math:: + erf(x) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} dt + + ---------------- + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.randn(4) + >>> x + jt.Var([ 0.49443012 0.4305426 -1.0364404 -1.2628382 ], dtype=float32) + >>> jt.erf(x) + jt.Var([ 0.51559156 0.45739546 -0.85728306 -0.9258883 ], dtype=float32) + >>> x.erf() + jt.Var([ 0.51559156 0.45739546 -0.85728306 -0.9258883 ], dtype=float32) + */ + "erf", + + /** + Computes the inverse error function of each element. + + * [in] x: the input jt.Var. + + ---------------- + + Example-1:: + >>> x = jt.rand(4) * 2 - 1 + >>> x + jt.Var([ 0.00277209 -0.26642472 0.7869792 0.5415418 ], dtype=float32) + >>> jt.erfinv(x) + jt.Var([ 0.00245671 -0.24068035 0.8805613 0.5242405 ], dtype=float32) + >>> x.erfinv() + jt.Var([ 0.00245671 -0.24068035 0.8805613 0.5242405 ], dtype=float32) + */ + "erfinv", +}; + +UnaryOp::UnaryOp(Var* x, NanoString op) : x(x) { + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + set_type(OpType::element); + ns = op; + ASSERT(ns.is_unary() | ns.is_dtype()); + NanoString dtype; + if (ns == x->dtype()) { + forward(x); + return; + } + if (ns.is_dtype()) { + dtype = ns; + ns = ns_cast; + } else + dtype = unary_dtype_infer(ns, x->ns); + y = create_output(nullptr, dtype); + y->flags.set(NodeFlags::_is_scalar, x->flags.get(NodeFlags::_is_scalar)); + bool bin = ns.get(NanoString::_no_need_back_in); + bool bout = ns.get(NanoString::_no_need_back_out); + if (bin || bout) { + flags.set(NodeFlags::_manual_set_vnbb); + if (!bin) { + x->flags.set(NodeFlags::_needed_by_backward); + } + if (!bout) { + y->flags.set(NodeFlags::_needed_by_backward); + } + } +} + +VarPtr UnaryOp::grad(Var* out, Var* dout, Var* v, int v_index) { + if (!x->is_float()) return nullptr; + if (ns == ns_cast) return make_unary(dout, x->dtype()); + if (ns == ns_negative) return make_unary(dout, ns); + if (ns == ns_abs) { + auto neg = make_unary(dout, ns_negative); + auto zeros = make_number(0, x); + auto cond = make_binary(x, zeros, ns_greater_equal); + return make_ternary(cond, dout, neg); + } + if (ns == ns_log) + return make_binary(dout, x, ns_divide); + if (ns == ns_exp) + return make_binary(dout, y, ns_multiply); + if (ns == ns_sqrt){ + auto two = make_number(2, x); + auto twoy = make_binary(two, y, ns_multiply); + return make_binary(dout, twoy, ns_divide); + } + // dsin(x) = cos(x) + if (ns == ns_sin) + return make_binary(dout, make_unary(x, ns_cos), ns_multiply); + // dcos(x) = -sin(x) + if (ns == ns_cos) + return make_binary(dout, make_unary(make_unary(x, ns_sin), ns_negative), ns_multiply); + // dtan(x) = 1/cos^2(x) + if (ns == ns_tan) { + auto one = make_number(1, x); + auto cosx = make_unary(x, ns_cos); + auto cos2x = make_binary(cosx, cosx, ns_multiply); + return make_binary(dout, cos2x, ns_divide); + } + // dasin(x) = 1/sqrt(1-x^2) + if (ns == ns_asin) { + auto one = make_number(1, x); + auto x2 = make_binary(x, x, ns_multiply); + x2 = make_binary(one, x2, ns_subtract); + x2 = make_unary(x2, ns_sqrt); + return make_binary(dout, x2, ns_divide); + } + // dacos(x) = -1/sqrt(1-x^2) + if (ns == ns_acos) { + auto one = make_number(1, x); + auto x2 = make_binary(x, x, ns_multiply); + x2 = make_binary(one, x2, ns_subtract); + x2 = make_unary(x2, ns_sqrt); + return make_unary(make_binary(dout, x2, ns_divide), ns_negative); + } + // datan(x) = 1/(x^2+1) + if (ns == ns_atan) { + auto one = make_number(1, x); + auto x2 = make_binary(x, x, ns_multiply); + x2 = make_binary(one, x2, ns_add); + return make_binary(dout, x2, ns_divide); + } + + // dsinh(x) = cosh(x) + if (ns == ns_sinh) + return make_binary(dout, make_unary(x, ns_cosh), ns_multiply); + // dcosh(x) = sinh(x) + if (ns == ns_cosh) + return make_binary(dout, make_unary(x, ns_sinh), ns_multiply); + // dtanh(x) = 1/cosh^2(x) + if (ns == ns_tanh) { + auto cosx = make_unary(x, ns_cosh); + auto cos2x = make_binary(cosx, cosx, ns_multiply); + return make_binary(dout, cos2x, ns_divide); + } + + // dasinh(x) = 1/sqrt(x^2+1) + if (ns == ns_asinh) { + auto one = make_number(1, x); + auto x2 = make_binary(x, x, ns_multiply); + x2 = make_binary(x2, one, ns_add); + x2 = make_unary(x2, ns_sqrt); + return make_binary(dout, x2, ns_divide); + } + // dacosh(x) = 1/sqrt(x^2-1) + if (ns == ns_acosh) { + auto one = make_number(1, x); + auto x2 = make_binary(x, x, ns_multiply); + x2 = make_binary(x2, one, ns_subtract); + x2 = make_unary(x2, ns_sqrt); + return make_binary(dout, x2, ns_divide); + } + // datanh(x) = 1/(1-x^2) + if (ns == ns_atanh) { + auto one = make_number(1, x); + auto x2 = make_binary(x, x, ns_multiply); + x2 = make_binary(one, x2, ns_subtract); + return make_binary(dout, x2, ns_divide); + } + // dsigmoid(x) = sigmoid(x) - sigmoid(x)^2 + if (ns == ns_sigmoid) { + auto r = make_binary(out, out, ns_multiply); + r = make_binary(out, r, ns_subtract); + return make_binary(dout, r, ns_multiply); + } + // derf(x) = e^(-x^2)*2/sqrt(pi) + if (ns == ns_erf) { + auto two_div_sqrt_pi = make_number(2/1.7724538509055159, x); + auto two = make_number(2, x); + auto x2 = make_binary(x, x, ns_multiply); + x2 = make_unary(x2, ns_negative); + auto r = make_unary(x2, ns_exp); + r = make_binary(r, two_div_sqrt_pi, ns_multiply); + return make_binary(dout, r, ns_multiply); + } + // derfinv(x) = sqrt(pi) / 2 * exp(erfinv(x)^2) + if (ns == ns_erfinv) { + auto sqrt_pi_div_two = make_number(1.7724538509055159/2, x); + auto y2 = make_binary(y, y, ns_multiply); + auto r = make_unary(y2, ns_exp); + r = make_binary(r, sqrt_pi_div_two, ns_multiply); + return make_binary(dout, r, ns_multiply); + } + return nullptr; +} + +void UnaryOp::infer_shape() { + y->set_shape(x->shape); +} + +void UnaryOp::jit_prepare(JK& jk) { + jk << "«Tx:" << x->dtype() + << "«Ty:" << y->dtype() + << "«OP:" << ns; +} + +#else // JIT +void UnaryOp::jit_run() { + auto* __restrict__ xp = x->ptr(); + auto* __restrict__ yp = y->ptr(); + index_t num = y->num; + for (index_t i=0; i. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + + +namespace jittor { + +struct UnaryOp : Op { + Var* x, * y; + // @pybind(unary,cast) + UnaryOp(Var* x, NanoString op); + + const char* name() const override { return "unary"; } + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/ops/where_op.cc b/python/jittor/src/ops/where_op.cc new file mode 100644 index 00000000..fceeeef0 --- /dev/null +++ b/python/jittor/src/ops/where_op.cc @@ -0,0 +1,269 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "ops/where_op.h" +#include "misc/cuda_flags.h" +#include "ops/op_register.h" +#ifdef JIT_cuda +#include "executor.h" +#include +#include +#include "helper_cuda.h" +#endif + +namespace jittor { + +#ifndef JIT +WhereOp::WhereOp(Var* cond, NanoString dtype) : cond(cond) { + flags.set(NodeFlags::_cpu); + flags.set(NodeFlags::_cuda); + flags.set(NodeFlags::_manual_set_vnbb); + auto ndim = cond->shape.size(); + #ifdef HAS_CUDA + if (use_cuda) { + static auto cub_where = has_op("cub_where") ? get_op_info("cub_where") + .get_constructor, Var*, NanoString>() : nullptr; + if (cub_where && (ndim>1 || std::abs(cond->num)>4096)) { + auto var = cub_where(cond, dtype); + for(uint i=0;i(); +WhereOp::WhereOp(Var* cond, Var* x, Var* y) { + forward(make_ternary(cond, x, y)); + return; +} + +void WhereOp::infer_shape() { + auto ndim = cond->shape.size(); + auto num = -cond->num; + for (uint i=0; iset_shape({num}); +} + +void WhereOp::jit_prepare(JK& jk) { + jk << "«Ti:" << cond->dtype(); + jk << "«To:" << outs[0]->dtype(); + jk << "«NDIM=" << JK::hex1(cond->shape.size()); +} + +#else // JIT +#ifdef JIT_cuda + +__global__ static void where_kernel( + @for(i, 0, NDIM, 1, index_t condshape@i, ) + Ti* __restrict__ condp, + @for(i, 0, NDIM, 1, To* __restrict__ outs@i@@p, ) + int* __restrict__ np +) { + __shared__ uint n; + int tid = threadIdx.x; + int tnum = blockDim.x; + if (tid == 0) + n = 0; + // define cond stride + index_t condstride@{NDIM-1} = 1; + @for(i, NDIM-2, -1, -1, auto condstride@i = condstride@{i+1} * condshape@{i+1};) + __syncthreads(); + + // generate d-for loop + @for(d, 0, NDIM-1, for (index_t i@d=0; i@d < condshape@d; i@d++)) + for (index_t i@{NDIM-1}=tid; i@{NDIM-1} 0; offset /= 2) { + uint x = __shfl_up_sync(FULL_MASK, val, offset); + val += lane_id>=offset? x : 0; + } + return val; +} + +__device__ inline uint bc(uint val, uint lane_id) { + return __shfl_sync(FULL_MASK, val, lane_id); +} + +__global__ static void where_kernel_one_warp( + @for(i, 0, NDIM, 1, index_t condshape@i, ) + Ti* __restrict__ condp, + @for(i, 0, NDIM, 1, To* __restrict__ outs@i@@p, ) + int* __restrict__ np +) { + uint n = 0; + int tid = threadIdx.x; + int tnum = 32; + // define cond stride + index_t condstride@{NDIM-1} = 1; + @for(i, NDIM-2, -1, -1, auto condstride@i = condstride@{i+1} * condshape@{i+1};) + + // generate d-for loop + @for(d, 0, NDIM-1, for (index_t i@d=0; i@d < condshape@d; i@d++)) + for (index_t i=0; iptr(); + // define cond shape + @for(i, 0, NDIM, index_t condshape@i = cond->shape[@i];) + + // define outs + @for(i, 0, NDIM, auto* __restrict__ outs@i@@p = outs[@i]->ptr();) + + size_t n_allocation; + int* np = (int*)exe.temp_allocator->alloc(4, n_allocation); + + // one block kernel, result maybe unstable + // int tnum = condshape@{NDIM-1}; + // tnum = std::max(1, std::min(1024, tnum)); + // where_kernel<<<1,tnum>>>( + // @for(i, 0, NDIM, 1, condshape@i, ) + // condp, + // @for(i, 0, NDIM, 1, outs@i@@p, ) + // np + // ); + + + int tnum = condshape@{NDIM-1}; + if (tnum < 100) { + // one warp kernel, result is stable + where_kernel_one_warp<<<1,32>>>( + @for(i, 0, NDIM, 1, condshape@i, ) + condp, + @for(i, 0, NDIM, 1, outs@i@@p, ) + np + ); + } else { + // one block kernel, result is stable + where_kernel_one_block<<<1,WTN>>>( + @for(i, 0, NDIM, 1, condshape@i, ) + condp, + @for(i, 0, NDIM, 1, outs@i@@p, ) + np + ); + } + + int n=0; + // checkCudaErrors(cudaDeviceSynchronize()); + checkCudaErrors(cudaMemcpy(&n, np, 4, cudaMemcpyDeviceToHost)); + @for(i, 0, NDIM, outs[@i]->set_shape({n});) + exe.temp_allocator->free(np, 4, n_allocation); +} +#else + +void WhereOp::jit_run() { + auto* __restrict__ condp = cond->ptr(); + // define cond shape + @for(i, 0, NDIM, index_t condshape@i = cond->shape[@i];) + // define cond stride + index_t condstride@{NDIM-1} = 1; + @for(i, NDIM-2, -1, -1, auto condstride@i = condstride@{i+1} * condshape@{i+1};) + + // define outs + @for(i, 0, NDIM, auto* __restrict__ outs@i@@p = outs[@i]->ptr();) + int64 n=0; + + // generate d-for loop + @for(d, 0, NDIM, for (index_t i@d=0; i@d < condshape@d; i@d++)) { + auto condid = @for(d, 0, NDIM, + i@d * condstride@d); + if (condp[condid]) { + @for(i, 0, NDIM, outs@i@@p[n] = i@i;) + n++; + } + } + @for(i, 0, NDIM, outs[@i]->set_shape({n});) +} + +#endif +#endif // JIT + +} // jittor diff --git a/python/jittor/src/ops/where_op.h b/python/jittor/src/ops/where_op.h new file mode 100644 index 00000000..10b250e7 --- /dev/null +++ b/python/jittor/src/ops/where_op.h @@ -0,0 +1,42 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + + +namespace jittor { + +struct WhereOp : Op { + Var* cond; + unique_ptr outs; + /** + Where Operator generate index of true condition. + + * [in] cond: condition for index generation + + * [in] dtype: type of return indexes + + * [out] out: return an array of indexes, same length with number of dims of cond + + Example:: + + jt.where([[0,0,1],[1,0,0]]) + # return [jt.Var([0 1], dtype=int32), jt.Var([2 0], dtype=int32)] + */ + // @attrs(multiple_outputs) + WhereOp(Var* cond, NanoString dtype=ns_int32); + /** + * Condition operator, perform cond ? x : y + * */ + WhereOp(Var* cond, Var* x, Var* y); + void infer_shape() override; + + const char* name() const override { return "where"; } + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/expr.cc b/python/jittor/src/opt/expr.cc new file mode 100644 index 00000000..5f5f953f --- /dev/null +++ b/python/jittor/src/opt/expr.cc @@ -0,0 +1,1180 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "opt/expr.h" +#include "utils/str_utils.h" + +namespace jittor { +namespace expr { +// operator precedence and associativity +// equivalence: https://en.cppreference.com/w/cpp/language/operator_precedence +// different from c++, precedence 17(,;) is right associativity + +static const unordered_set is_left_associativity({ + 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, /* 17 */ +}); + +static const unordered_map precedence({ + {"::", 1}, + + {"++", 2}, {"--", 2}, + {"(", 2}, {")", 2}, {"()", 2}, + {"[", 2}, {"]", 2}, {"[]", 2}, + {"{", 2}, {"}", 2}, {"{}", 2}, + {".", 2}, {"->", 2}, {"<>", 2}, + + {"!", 3}, {"~", 3}, + {"*", 5}, {"/", 5}, {"%", 5}, + {"+", 6}, {"-", 6}, + {"<<", 7}, {">>", 7}, + {"<=", 9}, {"<", 9}, {">=", 9}, {">", 9}, + {"!=", 10}, {"==", 10}, + {"&", 11}, + {"^", 12}, + {"|", 13}, + {"&&", 14}, + {"||", 15}, + + // a @> b = !a || b + // a @< b = a || !b + {"@", 16}, {"@>", 16}, {"@<", 16}, + + {"?", 26}, {":", 26}, {"?:", 26}, + {"=", 26}, {"+=", 26}, {"-=", 26}, + {"*=", 26}, {"/=", 26}, {"*=", 26}, + {"<<=", 26}, {">>=", 26}, {"&=", 26}, {"^=", 26}, {"|=", 26}, + + // precedence 27 used for little higher than "," + {",", 28}, {";", 28} +}); + +static const unordered_set is_unary_op({ + "++", "--", "!", "~", "+", "-", "&", "*", "::" +}); + +static const unordered_set is_left_unary_op({ + "!", "~", "+", "-", "&", "*", "::" +}); + + +static const unordered_set is_associative_op({ + "+", "*", "&", "|", "&&", "||" +}); +static const unordered_set is_commutative_op({ + "+", "*", "&", "|", "&&", "||" +}); + +static bool isvar(char x) { return isalnum(x) || x == '_'; } +static bool isempty(char x) { return x==' ' || x=='\t' || x=='\n';} + +static inline int64 ex_stoll(const string& str) { + if (startswith(str, "0x") || startswith(str, "0X")) + return std::stoll(str,0,16); + else if (startswith(str, "0b") || startswith(str, "0B")) + return std::stoll(str.substr(2),0,2); + return std::stoll(str,0,10); +} + +Expr::Expr(size_t flags, const string& str, vector>&& children) + : flags(flags), str(str), father(0), fid(0), children(move(children)) { + for (uint i=0; ichildren.size(); i++) { + this->children[i]->father = this; + this->children[i]->fid = i; + } + if (is(_float)) set_data(std::stof(str)); else + if (is(_int)) set_data(ex_stoll(str)); + maintain(); +} + +unique_ptr make(const string& str, vector>&& children) { + size_t flags = 0; + if (is_associative_op.count(str)) + flags |= (_binary_op | _asso_op); + if (is_commutative_op.count(str)) + flags |= (_binary_op | _comm_op); + if (!flags) { + if (children.size()==1) flags |= (_unary_op); else + if (children.size()==2) flags |= (_binary_op); else + if (children.size()==3) flags |= (_ternary_op); else + LOGf << str << children.size(); + } + if ((flags&_unary_op) && is_left_unary_op.count(str)) + flags |= _left_op; + auto e = std::make_unique(flags, str, move(children)); + return e; +} + +Expr::Expr(const string& src) : flags(0) { + vector> values; + vector> nodes; + vector> ops; + vector op_flags; + + auto comsume_op_and_var = [&](uint op_num, uint var_num) -> bool { + if (!(ops.size() >= op_num)) return false; + if (!(values.size() >= var_num)) return false; + int op_pos = ops.size()>op_num ? ops[ops.size()-op_num-1].first : -1; + if (!(op_pos < values[values.size()-var_num].first)) return false; + int flag = op_flags.back(); + auto expression = make(flag, ""); + // is left op: ++a, --a, !a, &a + if (ops[ops.size()-op_num].first < values[values.size()-var_num].first) + expression->set_is(_left_op); + for (uint i=0; istr += src.substr(p.first, p.second-p.first); + } + for (uint i=0; iadd_child(move(nodes[nodes.size()-var_num+i])); + } + auto l = values[values.size()-var_num].first; + auto r = values[values.size()-1].second; + for (uint i=0; iset_is_not(_op); + if (var_num==1) expression->set_is(_unary_op); + if (var_num==2) { + expression->set_is(_binary_op); + if (is_associative_op.count(expression->str)) + expression->set_is(_asso_op); + if (is_commutative_op.count(expression->str)) + expression->set_is(_comm_op); + } + if (var_num==3) expression->set_is(_ternary_op); + expression->maintain(); + nodes.emplace_back(move(expression)); + values.push_back({l, r}); + + return true; + }; + + auto execute_back = [&](const string& op) { + if (op==":") { + // a?b:c + ASSERT(ops.size()>=2); + ASSERT(src[ops[ops.size()-2].first]=='?'); + ASSERT(values.size()>=3); + comsume_op_and_var(2, 3); + return; + } + if (comsume_op_and_var(1, 2)) return; + ASSERT(is_unary_op.count(op)) << op << "is not unary op"; + ASSERT(comsume_op_and_var(1, 1)) << ops.size() << values.size(); + }; + + auto substr = [&](const pair& p) -> string { + return src.substr(p.first, p.second-p.first); + }; + vector> tokens; + vector flags; + get_tokens(src, tokens, flags); + + for (uint x=0; x") { + // parse template a(); + auto i = tokens.at(x+1).first; + if (src.at(i)=='(' && src.at(i+1)==')') + target = "<"; + } + if (target.size()) { + int tid = ops.size()-1; + while (tid>=0 && target != substr(ops[tid])) + tid--; + ASSERT(tid>=0) << "braces not match" << src; + // a(...) + // ^ tpos + // ^ bpos + int tpos = ops[tid].first; + int bpos = values.size()-1; + while (bpos>=0 && values[bpos].first>=tpos) bpos--; + bpos = bpos >= 0 ? values[bpos].first : -1; + // find first outside braces op pos + // +a(...) or a+(...) + // ^ ^ opos + int opos = tid>0 ? ops[tid-1].first : -1; + if (bpos > opos) { + // +a(args) + // ^ bpos + // ^ opos + vector> args; + while (1) { + auto prev_op = substr(ops.back()); + if (prev_op == "," || prev_op == target) { + if (ops.back().first < values.back().first) { + args.push_back(move(nodes.back())); + nodes.pop_back(); + values.pop_back(); + } else + break; + if (prev_op == ",") { + ops.pop_back(); + op_flags.pop_back(); + } else + break; + } else + execute_back(prev_op); + } + ops.push_back(cp); + op_flags.push_back(flag); + ASSERT(comsume_op_and_var(2, 1)) << ops << op << values << nodes; + for (uint i=0; iadd_child(move(args.rbegin()[i])); + // is a function call: a(b), a[b], a{b} + nodes.back()->set_is(_call); + } else { + // not a function call + // a+(...) + while ((int)ops.size()>tid+1) { + auto prev_op = substr(ops.back()); + execute_back(prev_op); + } + // pop left braces + ops.pop_back(); + op_flags.pop_back(); + } + continue; + } + int pd = precedence.at(op); + bool is_left = is_left_associativity.count(pd); + while (ops.size()) { + auto prev_op = src.substr(ops.back().first, ops.back().second - ops.back().first); + if (prev_op == "(" || prev_op == "[" || prev_op == "{") break; + auto ppd = precedence.at(prev_op); + if (ppd < pd || (ppd==pd && is_left)) { + execute_back(prev_op); + } else { + break; + } + } + ops.push_back(cp); + op_flags.push_back(flag); + } + auto check_cast_func = [&]() { + while (nodes.size() > 1) { + int op_pos = ops.size() ? ops.back().first : -1; + if (op_pos>=0 && src[op_pos] == '?') return; + auto vpos = values[values.size()-2]; + if (op_pos >= vpos.first) return; + + // something like + // xxx = (float) yyy + if (vpos.first && src[vpos.first-1] == '(' && src[vpos.second] == ')') { + auto v1 = move(nodes[nodes.size()-2]); + auto v2 = move(nodes[nodes.size()-1]); + auto r = values.back().second; + nodes.pop_back();nodes.pop_back(); + values.pop_back();values.pop_back(); + auto v3 = expr::make(expr::Flags::_call | expr::Flags::_binary_op, "()"); + v3->add_child(move(v1)); + v3->add_child(move(v2)); + nodes.emplace_back(move(v3)); + values.emplace_back(std::make_pair(vpos.first, r)); + } else + return; + } + }; + while (ops.size()) { + check_cast_func(); + auto prev_op = substr(ops.back()); + execute_back(prev_op); + } + check_cast_func(); + + ASSERT(nodes.size() == 1) << "Left multiple nodes:" << nodes; + move_from(nodes[0]); +} + +string Expr::to_string(int try_reduce_braces, int debug) const { + std::stringstream ss; + int pd = try_reduce_braces?100:-1; + to_string(ss, pd, pd, debug); + return ss.str(); +} + +int64 Expr::as_int() const { + if (is(_float)) + return int64(as_float()); + ASSERT(is(_int)); + return data.i; +} + +float64 Expr::as_float() const { + if (is(_int)) + return float64(as_int()); + ASSERT(is(_float)); + return data.f; +} + +void get_tokens( + const string& src, + vector>& tokens, + vector& flags +) { + int end = src.size(); + while (end && (src[end-1]==' ' || src[end-1]=='\n' || src[end-1]==';')) + end--; + for (uint i=0; i=src.size()) break; + size_t flag=0; + uint j = i+1; + if (src[i]=='\'' || src[i]=='\"') { + while (j') j--; + if (src[i]=='[' && src[j-1]==']') j--; + if (src[i]=='{' && src[j-1]=='}') j--; + if (src[j-1]=='@' && j>) m(<=) m(<) m(>=) + m(>) m(!=) m(==) + m(&) m(^) m(|) m(&&) m(||) +#undef m + if (op==",") return b; + LOGf << "Op" << op << a << b << "not support"; + return 0; +} + +static float64 eval_binary_float(const string& op, float64 a, float64 b) { +#define m(o) if (op == #o) return float64(a o b); + m(+) m(-) m(*) m(/) + m(<=) m(<) m(>=) + m(>) m(!=) m(==) + m(&&) m(||) +#undef m + if (op==",") return b; + LOGf << "Op" << op << a << b << "not support"; + return 0; +} + +static int64 eval_unary_left_int(const string& op, int64 a) { +#define m(o) if (op == #o) return int64(o a); + m(+) m(-) m(!) m(~) +#undef m + LOGf << "Op" << op << a << "not support"; + return 0; +} + +static float64 eval_unary_left_float(const string& op, float64 a) { +#define m(o) if (op == #o) return float64(o a); + m(+) m(-) m(!) +#undef m + LOGf << "Op" << op << a << "not support"; + return 0; +} + +static void _eval(Expr* e) { + auto& c = e->children; + string op = move(e->str); + if (e->is(_binary_op)) { + ASSERT(c.size()==2); + if (c[0]->is(_float) | c[1]->is(_float)) { + e->set_is_only(_float); + e->set_data(eval_binary_float(op, c[0]->as_float(), c[1]->as_float())); + } else { + e->set_is_only(_int); + e->set_data(eval_binary_int(op, c[0]->as_int(), c[1]->as_int())); + } + } else + if (e->is(_ternary_op)) { + ASSERTop(op,==,"?:"); + if (c[1]->is(_float) | c[2]->is(_float)) { + e->set_is_only(_float); + e->set_data(c[0]->as_int() ? c[1]->as_float() : c[2]->as_float()); + } else { + e->set_is_only(_int); + e->set_data(c[0]->as_int() ? c[1]->as_int() : c[2]->as_int()); + } + } else { + ASSERTop(c.size(),==,1); + ASSERT(e->is(_left_op)); + if (c[0]->is(_float)) + e->set_data(eval_unary_left_float(op, c[0]->as_float())); + else + e->set_data(eval_unary_left_int(op, c[0]->as_int())); + e->set_is_only(c[0]->flags); + } +} + +static void eval_asso_binary(Expr* e) { + vector> nc; + nc.reserve(e->children.size()); + for (uint i=0; ichildren.size(); i++) { + auto& c = e->children[i]; + if (!nc.size() || nc.back()->is_not(_number) || c->is_not(_number)) { + nc.push_back(move(c)); + continue; + } + auto& b = nc.back(); + if (b->is(_float) | c->is(_float)) { + b->set_data(eval_binary_float(e->str, b->as_float(), c->as_float())); + b->set_is_only(_float); + } else { + b->set_data(eval_binary_int(e->str, b->as_int(), c->as_int())); + b->set_is_only(_int); + } + } + e->children.clear(); + if (nc.size()==1) { + e->move_from(nc.back()); + return; + } + e->insert(0, nc); + + // eval x*0 -> 0 + if (e->str=="*") { + for (auto& c : e->children) { + if (c->is(_number) && c->data.i==0) { + e->swap(c->clone().get()); + return; + } + } + } +} + +pair get_zero_elem(const string& op) { + if (op=="+") return {1,0}; + if (op=="-") return {1,0}; + if (op=="*") return {1,1}; + if (op=="/") return {1,1}; + return {0,0}; +} + +unique_ptr Expr::eval() { + auto a = make(flags, str); + if (is(_op)) { + a->children.reserve(children.size()); + auto p = get_zero_elem(str); + bool can_eval = true; + for (auto& c : children) { + a->children.push_back(c->eval()); + auto& x = a->children.back(); + if (!x->is(_number)) + can_eval = false; + if (x->is(_int)) { + if (p.first && p.second == x->as_int()) { + if (a->children.size()>1 || a->is(_asso_op)) + a->children.pop_back(); + } + } + } + if (a->is(_asso_op)) { + if (a->children.size()==0) { + a->children.push_back(make(S(p.second))); + } + eval_asso_binary(a.get()); + if (a->children.size()==1) { + return move(a->children.back()); + } + return a; + } + if (can_eval) { + _eval(a.get()); + a->children.clear(); + return a; + } + if (a->children.size()==1 && p.first) { + return move(a->children.back()); + } + } else { + a->data.i = data.i; + } + return a; +} + + +unique_ptr Expr::assign_symbol(const unordered_map& symbols) { + auto a = clone(); + a->dfs([&](Expr* e) { + if (!e->is_sym()) return; + auto iter = symbols.find(e->str); + if (iter == symbols.end()) return; + e->swap(make(iter->second).get()); + }); + return a; +} + +unique_ptr Expr::simplify() { + auto e = eval(); + return e; +} + + +std::ostream& operator<<(std::ostream& os, const Flags& f) { + #define m(x) if (f & x) os << "is" #x << ","; + m(_unary_op); + m(_binary_op); + m(_ternary_op); + m(_op); + m(_call); + m(_left_op); + m(_char); + m(_string); + m(_int); + m(_float); + m(_number); + #undef m + return os; +} + +void Expr::move_from(unique_ptr& e) { + flags=e->flags; + str=move(e->str); + children = move(e->children); + for (uint i=0; ifather = this; + data.i = e->data.i; + e = nullptr; +} + +void Expr::to_string(std::ostream& os, int olp, int orp, int debug) const { + if (is_not(_op)) { + // TODO: negtive value need braces to protect - + bool need_bc = is(_number) && as_float()<0; + if (need_bc) os << '('; + if (is(_int)) os << as_int(); else + if (is(_float)) os << as_float(); else + os << str; + if (need_bc) os << ')'; + return; + } + string s; + bool need_bc = 1; + int pd = olp; + if (olp>=0) { + pd = precedence.at(str); + bool is_left = is_left_associativity.count(pd); + bool check_left = pd < olp || (pd==olp && !is_left); + bool check_right = pd < orp || (pd==orp && is_left); + need_bc = !(check_left && check_right); + } + if (need_bc) + os << "("; + if (debug) { + os << "/*f:"; + os << (Flags)flags; + os << ";s:"; + os << str; + os << ";c:"; + os << children.size(); + os << "*/"; + } + if (is(_ternary_op)) { + // a?b:c + ASSERT(children.size()==3 && str=="?:"); + children[0]->to_string(os, olp, pd, debug); + os << "?"; + children[1]->to_string(os, pd, pd, debug); + os << ":"; + children[2]->to_string(os, pd, orp, debug); + } else if (is(_call)) { + // a(b,c,d) + ASSERT(children.size() && str.size()==2); + children[0]->to_string(os, olp, pd, debug); + os << str[0]; + for (uint i=1; ito_string(os, npd, npd, debug); + if (i+1 == children.size()) + os << str[1]; + else + os << ","; + } + if (children.size()==1) os << str[1]; + } else if (is(_left_op)) { + // ++a, --a + os << str; + if (children.size() != 1) { + os << " !!ERR "; + for (auto &c : children) + os << c->str.size() << " " << *c; + os << " ERR!! "; + } else + children[0]->to_string(os, pd, orp, debug); + } else { + // a--, a+b + ASSERT(children.size()); + ASSERT(children.size()>=2 || is(_unary_op) || is(_asso_op)) << str << children; + children[0]->to_string(os, olp, pd, debug); + if (is(_unary_op)) os << str; + else { + for (uint i=1; ito_string(os, pd, pd, debug); + } + if (children.size()>1) { + os << str; + children.back()->to_string(os, pd, orp, debug); + } + } + } + if (need_bc) + os << ")"; +} + +std::ostream& operator<<(std::ostream& os, const Expr& expression) { + return os << expression.to_string(); +} + +unique_ptr make(size_t flags, const string& str, vector>&& children) { + unique_ptr e(new Expr(flags, str, move(children))); + return e; +} + +void Expr::add_child(unique_ptr&& c) { + c->father = this; + c->fid = children.size(); + children.push_back(move(c)); +} + +unique_ptr Expr::move_out() { + ASSERT(father); + auto& fc = father->children; + unique_ptr e = move(fc[fid]); + fc.erase(fc.begin()+fid); + for (uint i=fid; ifid = i; + father = nullptr; + fid = 0; + return e; +} + +void Expr::swap(Expr* e) { + std::swap(flags, e->flags); + std::swap(str, e->str); + std::swap(father, e->father); + std::swap(fid, e->fid); + std::swap(data, e->data); + std::swap(children, e->children); +} + +void Expr::erase() { + move_out(); +} + +unique_ptr Expr::clone() { + auto e = make(flags, str); + e->data.i = data.i; + e->children.reserve(children.size()); + for (auto& c : children) + e->add_child(c->clone()); + return e; +} + +void Expr::insert(int pos, vector>& v) { + children.insert( + children.begin()+pos, + make_move_iterator(v.begin()), + make_move_iterator(v.end()) + ); + for (uint i=pos; ifather = this; + children[i]->fid = i; + } +} + +vector> Expr::move_out(int start, int end) { + if (end<=0) end += children.size(); + vector> v; + v.reserve(end-start); + for (int i=start; ifather = nullptr; + v.back()->fid = 0; + } + children.erase(children.begin()+start, children.begin()+end); + for (uint i=end; ifid = i; + return v; +} + +void Expr::collapse_children(uint& cid) { + auto c = children[cid].get(); + auto v = c->move_out(0); + auto ncid = cid + v.size() - 1; + children.erase(children.begin()+cid); + insert(cid, v); + cid = ncid; +} + +void Expr::maintain() { + if (is(_asso_op)) { + // a+(b+c) -> a+b+c + for (uint i=0; iis(_asso_op) && children[i]->str==str) { + collapse_children(i); + } + } + } +} + +static void rule_minus(Expr* e) { + if (e->is(_unary_op)) { + // -a -> (-1)*a + e->move_from(make_op("*", + make(_int, "-1"), + e->children[0] + )); + } else { + // a-b -> a+(-1)*b + auto c = e->move_out(0); + e->move_from(make_op("+", + c[0], + make_op("*", + make(_int, "-1"), + c[1] + ) + )); + } +} + +static bool rule_not(Expr* e) { + ASSERT(e->children.size()==1); + auto& c = e->children[0]; + // !var not change + if (c->is_var()) return false; + if (c->str == "&&" || c->str=="||") { + // !(a&&b) -> !a || !b + // !(a||b) -> !a && !b + vector> cc(c->children.size()); + for (uint i=0; ichildren.size(); i++) { + cc[i] = make_op("!", move(c->children[i])); + } + e->move_from(make(c->str[0]=='&'?"||":"&&", move(cc))); + return true; + } + if (c->str == "!") { + // !!a -> a + ASSERT(c->children.size()==1); + e->move_from(c->children[0]->move_out()); + if (e->str == "!") + rule_not(e); + return true; + } + static const unordered_map nmap = { + {"<",">="}, {"<=",">"}, {">","<="}, {">=","<"}, + {"==","!="}, {"!=","=="} + }; + auto iter = nmap.find(c->str); + if (iter != nmap.end()) { + // !(a a>=b + e->move_from(c->move_out()); + e->str = iter->second; + return true; + } + return false; +} + +static void rule_mul(Expr* e, const string& add="+") { + string mul = e->str; + if (e->is(_binary_op)) { + // (a+b)*(c+d) -> a*c + a*d + b*c + b*d + vector add_index, add_range, add_cid; + for (uint i=0; ichildren.size(); i++) { + auto c = e->children[i].get(); + if (c->str==add && c->is(_binary_op)) { + add_cid.push_back(add_range.size()); + add_range.push_back(c->children.size()); + add_index.push_back(0); + } else + add_cid.push_back(-1); + } + if (!add_range.size()) return; + vector> nc; + int n = add_index.size(); + while (1) { + vector> nm; + for (uint i=0; ichildren.size(); i++) { + auto c = e->children[i].get(); + if (add_cid[i] == -1) + nm.emplace_back(c->clone()); + else + nm.emplace_back(c->children[add_index[add_cid[i]]]->clone()); + } + nc.emplace_back(make(mul, move(nm))); + int p = n-1; + add_index[p]++; + while (add_index[p] >= add_range[p]) { + add_index[p] = 0; + p--; + if (p<0) break; + add_index[p]++; + } + if (p<0) break; + } + e->move_from(make(add, move(nc))); + } +} + +static void rule_at(Expr* e) { + if (e->str == "@>") { + // a @> b = !a || b + e->move_from(make_op("||", + make_op("!", e->children[0]), + e->children[1] + )); + } else + if (e->str == "@<") { + // a @< b = !a || b + e->move_from(make_op("||", + make_op("!", e->children[1]), + e->children[0] + )); + } +} + +static void rule_cmp(Expr* e) { + auto b = e->children[1]->move_out(); + auto a = e->children[0]->move_out(); + auto a2 = a->clone(); + auto b2 = b->clone(); + if (e->str == "==") { + // a==b -> a>=b&&a<=b + e->move_from(make_op("&&", + make_op(">=", a, b), + make_op("<=", a2, b2) + )); + } else { + // a!=b -> ab + e->move_from(make_op("||", + make_op("<", a, b), + make_op(">", a2, b2) + )); + } +} + +unique_ptr expand(Expr* e) { + auto h = e->clone(); + e = h.get(); + uint cid = 0; + // while loop dfs + while (1) { + if (cid==0) { + // first enter + if (e->str == "-") { + rule_minus(e); + } else + if (e->str == "!") { + if (rule_not(e)) continue; + } else + if (e->str.size() && e->str[0] == '@') { + rule_at(e); + } else + if (e->str=="==" || e->str=="!=") { + rule_cmp(e); + } + } + if (cid>=e->children.size()) { + // before return + e->maintain(); + if (e->str == "*") { + rule_mul(e); + } else + if (e->str == "&&") { + rule_mul(e, "||"); + } + // return to father + cid = e->fid; + // auto c = e; + e = e->father; + if (!e) break; + // back from child + cid ++; + continue; + } + // recursive to child + e = e->children[cid].get(); + cid = 0; + } + return h; +} + +bool match(Expr* src, Expr* target) { + vector> results; + return match(src, target, {}, {}, results); +} + +bool match( + Expr* src, Expr* target, + const vector& solve_symbols, + const vector& exclude_symbols, + vector>& results +) { + auto s = src->expand()->simplify(); + auto t = target->expand()->simplify(); + int n = solve_symbols.size(); + unordered_map ss; + for (int i=0; i es(exclude_symbols.begin(), exclude_symbols.end()); + + auto solve_id = [&](Expr* e) -> int { + if (!e->is_sym()) return -1; + auto iter = ss.find(e->str); + if (iter == ss.end()) return -1; + return iter->second; + }; + + std::function has_exclude = [&](Expr* e) -> bool { + if (e->is_sym()) return es.count(e->str); + if (e->is(_op)) { + for (auto& c : e->children) + if (has_exclude(c.get())) + return true; + } + return false; + }; + + std::function>&)> log_do_match; + + std::function>&)> do_match = + [&](Expr* s, Expr* t, vector>& results) -> bool { + if (t->is_not(_op)) { + int tid = solve_id(t); + // if is a symbol need to solve + if (tid>=0) { + if (has_exclude(s)) + return false; + results[tid] = s->clone(); + return true; + } else { + return s->flags == t->flags && s->to_string()==t->to_string(); + } + } else { + auto ze = get_zero_elem(t->str); + if (s->is_not(_op)) { + // if op don't have zero element + if (!ze.first) + return false; + } else + if (s->flags != t->flags || s->str != t->str) + return false; + if (!ze.first && s->children.size() != t->children.size()) + return false; + int n = s->is(_op) ? s->children.size() : 1; + int m = t->children.size(); + unique_ptr zep; + if (ze.first) { + zep = make(S(ze.second)); + n++; + } + std::function check_match = + [&](int is, int it) -> bool { + vector> ns(results.size()); + Expr* sp = zep.get(); + if (s->is_not(_op) && is==0) { + sp = s; + } else + if (is<(int)s->children.size()) + sp = s->children[is].get(); + if (!log_do_match(sp, t->children[it].get(), ns)) { + return false; + } + for (uint j=0; jis_not(_op) && t->str == "*" && s->str == "0") { + // 0 match 0*a*b + for (int i=0; iis(_asso_op)) { + // wildcard assosiative id + // a*b+c <---- + for (int i=m-1; i>=0; i--) { + auto tid = solve_id(t->children.at(i).get()); + asso_wildcard_tid = tid; + if (tid>=0) { + // matched by other expr + if (results[tid]) + continue; + LOGvvvv << "check asso wild" << t->children.at(i) << *t; + bool in_other_target = 0; + for (int j=0; jchildren.at(j)->dfs([&](Expr* e) { + if (!e->is_sym()) return; + if (e->str == t->children.at(i)->str) + in_other_target = 1; + }); + if (in_other_target) break; + } + if (!in_other_target) { + asso_wildcard_id = i; + break; + } + } + } + LOGvvvv << "asso_wildcard_id" << asso_wildcard_id; + // asso_wildcard_id = -1; + } + if (t->is(_comm_op)) { + // is commutative op, children can be matched in any order + vector is_matched(m); + for (int i=0; i> solve_symbols[_] >> "]=" >> + (results[_]?results[_]->to_string(1):"null"); + if (!matched && asso_wildcard_id>=0) { + auto j = asso_wildcard_id; + auto& res = results[asso_wildcard_tid]; + // if zero elem and results already matched + if (i==(int)s->children.size() && res) + continue; + // match a+b -> c + auto bk = move(res); + if (!check_match(i, j)) { + return false; + } + is_matched[j] = true; + if (bk) + res = make_op(s->str, move(bk), move(res)); + continue; + } + // if not matched and not zero elem + if (!matched && i<(int)s->children.size()) { + return false; + } + } + for (int j=0; j>& results) -> bool { + LOGvvvv >> string(depth*4, ' ') >> + "match" << s->to_string(1) << t->to_string(1); + depth++; + auto res = do_match(s, t, results); + depth--; + LOGvvvv >> string(depth*4, ' ') >> + "return=" >> res << s->to_string(1) << t->to_string(1); + return res; + }; + + results.clear(); + results.resize(n); + if (!log_do_match(s.get(), t.get(), results)) + return false; + for (int i=0; i. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +namespace jittor { +namespace expr { +// Expression +enum Flags { + // is op or is a symbol + _unary_op = 1<<7, + _binary_op = 1, + _ternary_op = 1<<8, + _op = _unary_op | _binary_op | _ternary_op, + // is a function call: a(b), a[b], a{b} + _call = 1<<1, + // is left op: ++a, --a, !a, &a + _left_op = 1<<2, + // is associative op: (a+b)+c -> a+(b+c) + _asso_op = 1<<9, + // is commutative op: a*b -> b*a + _comm_op = 1<<10, + // 'a' + _char = 1<<3, + // "string" + _string = 1<<4, + // int: 1, 0x1a, 0b1, 1u, 1ull + _int = 1<<5, + // float: 1. 1.0f 1e3 + _float = 1<<6, + _number = _int | _float, +}; + +struct Expr { + size_t flags; + string str; + Expr* father; + // index in father's children + size_t fid; + // data for number + union Data { + int64 i; + float64 f; + } data; + vector> children; + + Expr(const string& src); + Expr(size_t flags, const string& str, vector>&& children); + + void add_child(unique_ptr&& c); + unique_ptr move_out(); + vector> move_out(int start, int end=0); + void insert(int pos, vector>& v); + void erase(); + void swap(Expr* e); + + template + void dfs(Func&& func) { + func(this); + for (auto& c : children) + c->dfs(func); + } + + int64 as_int() const; + float64 as_float() const; + inline void set_data(int64 x) { data.i = x; str=S(x);} + inline void set_data(float64 x) { data.f = x; str=S(x); } + + unique_ptr assign_symbol(const unordered_map& symbols); + + // to_string: return expression string of this expression + // args: + // try_reduce_braces: try to reduce brances if precedence correct + // for example: + // a+(b*c) -> a+b*c + // (a,(b,c)) -> a,b,c + // debug: output debug info in comment + // example: /*f:{flags};s:{str};c:{children.size()}*/ + // return: expression string of this expression + string to_string(int try_reduce_braces=false, int debug=false) const; + // args: + // olp: outside left precedence, -1 for force add braces + // orp: outside right precedence, -1 for force add braces + void to_string(std::ostream& os, int olp, int orp, int debug=false) const; + + // collapse children of cid-th child into father's children + // a+(b+c) -> a+b+c + void collapse_children(uint& cid); + + inline unique_ptr expand(); + unique_ptr eval(); + unique_ptr simplify(); + unique_ptr clone(); + void move_from(unique_ptr& e); + inline void move_from(unique_ptr&& e) { move_from(e); }; + + void maintain(); + + inline bool is(size_t f) const { return flags & f;} + inline bool is_not(size_t f) const { return !(flags & f);} + inline void set_is(size_t f) { flags |= f;} + inline void set_is_not(size_t f) { flags &= ~f;} + inline void set_is_only(size_t f) { flags = f;} + inline bool is_sym() const { return is_not(_op | _char | _string | _number); } + inline bool is_var() const { return is_not(_op); } +}; + +std::ostream& operator<<(std::ostream& os, const Expr& expression); +std::ostream& operator<<(std::ostream& os, const Flags& f); + + +inline unique_ptr make(const string& str) { return std::make_unique(str); }; +unique_ptr make(size_t flags, const string& str, vector>&& children={}); +unique_ptr make(const string& str, vector>&& children); +template +unique_ptr make_op(const string& str, Args&&... args); + +/* Match between source expression and target expression, try to solve symbols +arguments: + src: source expression + target: target expression + solve_symbols: symbols in target expression which need to be solved + exclude_symbols: symbols that should not occur in results + results: same length with solve_symbols, return the solved symbols + return: return true if solved success +example: + auto src = make("3*i+j-1"); + auto target = make("i*stride+pad+j"); + vector> results; + match(src.get(), target.get(), {"stride", "pad"}, {"i", "j"}, results); + LOGv << results; + // print [3,-1] + */ +bool match( + Expr* src, Expr* target, + const vector& solve_symbols, + const vector& exclude_symbols, + vector>& results +); + +bool match(Expr* src, Expr* target); + +void get_tokens( + const string& src, + vector>& tokens, + vector& flags +); + +unique_ptr expand(Expr* e); +inline unique_ptr Expr::expand() { return expr::expand(this); } + +template +unique_ptr make_op(const string& str, Args&&... args) { + vector> children; + children.reserve(sizeof...(args)); + auto f = [&](unique_ptr& c) { children.emplace_back(move(c)); }; + // Brace-enclosed initializers + int dummy[] = {(f(args), 0)...}; + (void)dummy; + return make(str, move(children)); +} + +template +void dfs(Expr* e, Func&& f) { + f(e); + for (auto& c : e->children) + dfs(c.get(), f); +} + +} // expr +} // jittor + diff --git a/python/jittor/src/opt/gopt/setitem_gopt.cc b/python/jittor/src/opt/gopt/setitem_gopt.cc new file mode 100644 index 00000000..2fc68bda --- /dev/null +++ b/python/jittor/src/opt/gopt/setitem_gopt.cc @@ -0,0 +1,175 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "ops/setitem_op.h" +#include "ops/getitem_op.h" + +namespace jittor { + +inline static bool fast_strcmp(const char* a, const char* b) { + return ((const uint64*)a)[0] == ((const uint64*)b)[0]; + // while (*b && *a == *b) a++, b++; + // return !*b; +} + +// add dependency b -> a +static inline void add_dependency(Node* a, Node* b) { + // check dependency is not exist + for (auto na : a->inputs()) { + if (na == b) return; + } + a->add_inputs({b}); + auto edge = a->_inputs.end(); + edge = std::prev(edge); + // set -1 mean this is a control dependency edge + edge->back->index = -1; +} + +static void setitem_inplace(SetitemOp* op) { + // LOGir << "in setitem_inplace"; + auto input = op->inputs().front(); + if (!(input->outputs().size() == 1 && + input->forward_liveness<=1 && + (op->op == ns_void || op->op == ns_add || op->op == ns_subtract))) { + return; + } + auto input_op = input->input(); + if (input_op) { + // make sure input op will not use input + auto input_name = input_op->name(); + if (!(input_op->type() == OpType::broadcast || + input_op->inputs().size() == 0 || + fast_strcmp(input_name, "setitem") || + fast_strcmp(input_name, "getitem"))) + // TODO: inplace getitem maybe risky, getitem maybe inplace too + return; + } + auto output = op->outputs().front(); + // return if output is all ready shared + if (output->allocator) return; + output->share_with(input); + + auto data = op->input(1); + // if setitem requires type conversion, don't inplace + if (data->dtype() != input->dtype()) + return; + + input_op = input->input(); + + if (input_op && input_op->inputs().size() == 1) { + input_op = input_op->inputs().front()->input(); + } + if (input_op && input_op->inputs().size() == 1) { + input_op = input_op->inputs().front()->input(); + } + + VarSlices vs = op->vs; + if (!(data->is_finished() == 0 && + (data->outputs().size() == 1 || + (!input_op + || input_op->inputs().size() == 0)))) + return; + if (data->allocator) + return; + auto data_op = data->input(); + if (data_op->flags.get(NodeFlags::_custom_flag)) + return; + + auto in_shape = input->shape; + int64 inplace_size = 1; + for (int i = vs.n - 1; i > 0; --i) { + VarSlice s = vs.slices[i]; + if (!(s.is_slice())) return; + Slice ss = s.slice; + if (!(ss.start == 0 && (ss.mask&2) && ss.step == 1)) + return; + inplace_size *= in_shape[i]; + } + + VarSlice s = vs.slices[0]; + if (s.is_var() || s.is_str()) return; + + int64 size = 0; + if (s.is_int()) + size = in_shape[0] == 0 ? 0 : s.i * (input->size / in_shape[0]); + else if (s.is_slice()) { + Slice ss = s.slice; + // we also need to check the first dim is continuous + if (ss.step != 1) + return; + size = in_shape[0] == 0 ? 0 : ss.start * (input->size / in_shape[0]); + inplace_size *= ss.stop - ss.start; + } + if (inplace_size > data->num) { + // if data has been broadcast into input, don't + // inplace data, because their shapes are not match + // This would lead partial setitem + return; + } + add_dependency(data->input(), input->node()); + data->share_with(input, size); + op->ns.set(GetitemOp::_inplace); + // LOGir << input->shape << input->dtype() << data->shape << data->dtype() << vs << data->input(); + // LOGir << output; +} + +static void getitem_inplace(GetitemOp* op) { + // LOGir << "in getitem_inplace"; + + auto in = op->inputs().front(); + auto ou = op->outputs().front(); + + // return if out is all ready inplaced + if (ou->allocator) + return; + + VarSlices vs = op->vs; + auto in_shape = in->shape; + + for (int i = vs.n - 1; i > 0; --i) { + VarSlice s = vs.slices[i]; + if (!(s.is_slice())) return; + Slice ss = s.slice; + if (!(ss.start == 0 && (ss.mask&2) && ss.step == 1)) + return; + } + + VarSlice s = vs.slices[0]; + if (s.is_var() || s.is_str()) return; + + int64 size = 0; + if (s.is_int()) + size = in_shape[0] == 0 ? 0 : s.i * (in->size / in_shape[0]); + else if (s.is_slice()) { + size = in_shape[0] == 0 ? 0 : s.slice.start * (in->size / in_shape[0]); + if (s.slice.step != 1) return; + } + ASSERT(size>=0 && size<=in->size); + ou->share_with(in, size); + op->ns.set(GetitemOp::_inplace); + // LOGir << "pass getitem_inplace"; + // LOGir << "inplace getitem" << vs << in->shape << ou->shape; +} + +void SetitemOp::graph_optimize() { + // LOGir << "hello graph_optimize"; + setitem_inplace(this); + (void*)setitem_inplace; +} + +void GetitemOp::graph_optimize() { + // This optimize is still WIP + // LOGir << "hello getitem graph_optimize"; + // setitem_grad_opt(this); + // (void)getitem_inplace; + getitem_inplace(this); + (void*)getitem_inplace; +} + +} + diff --git a/python/jittor/src/opt/jit_searcher.cc b/python/jittor/src/opt/jit_searcher.cc new file mode 100644 index 00000000..84320e4b --- /dev/null +++ b/python/jittor/src/opt/jit_searcher.cc @@ -0,0 +1,92 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#include +#include "opt/jit_searcher.h" +#include "opt/pass_manager.h" +#include "jit_compiler.h" +#include "fused_op.h" + +namespace jittor { + +DEFINE_FLAG(int, jit_search_kernel, 0, "Jit search for the fastest kernel."); +DEFINE_FLAG(int, jit_search_warmup, 2, ""); +DEFINE_FLAG(int, jit_search_rerun, 10, ""); + +Searcher::Searcher(OpCompiler* oc) : oc(oc) { + reset(); +} + +int64_t Searcher::get_time_of_current_choices() { + JK& jk = get_jk(); + auto* op = oc->op; + // generate jit_key + op->update_jit_key(); + string jit_key = jk.to_cstring(); + // generate src + PassManager pm(oc); + pm.run_passes(); + string src = pm.all.to_string(); + // compile + auto jit_entry = oc->compile(jit_key, src); + for (int i=0; i(finish-start).count(); + // 25ns function call overhead + total_ns -= jit_search_rerun * 25ll; + return std::max((int64_t)1, total_ns); +} + +void Searcher::reset() { + // TODO: setup timeout + timeout = 1ll<<62; + best_time = 1ll<<62; +} + +void Searcher::search(const loop_option_candidates_t& candidates) { + FusedOp* op = oc->op; + auto& choices = op->get_loop_options_tuned(); + + LOGvv << "Available candidates:" << candidates; + + // search best choices + vector names; + for (auto& kv : candidates) { + if (op->loop_options_origin->count(kv.first)) continue; + names.push_back(kv.first); + } + std::sort(names.begin(), names.end()); + std::function dfs = [&](int i) { + if (i == (int)names.size()) { + auto time = get_time_of_current_choices(); + if (time < best_time) { + best_time = time; + best_choices = choices; + } + LOGvvv << "Choices(">> time/1.0e6/jit_search_rerun >> "ms, best " >> best_time/1.0e6/jit_search_rerun >> ")" << choices; + return; + } + for (int j : candidates.at(names[i])) { + choices[names[i]] = j; + dfs(i+1); + } + }; + if (names.size()) { + LOGvv << "DFS search names:" << names; + dfs(0); + } + + if (best_time == (1ll<<62)) return; + LOGvv << "Best choices(" >> best_time/1.0e6/jit_search_rerun >> "ms" >>"):" << best_choices; + choices = best_choices; + op->update_jit_key(); +} + +} \ No newline at end of file diff --git a/python/jittor/src/opt/jit_searcher.h b/python/jittor/src/opt/jit_searcher.h new file mode 100644 index 00000000..db092ac6 --- /dev/null +++ b/python/jittor/src/opt/jit_searcher.h @@ -0,0 +1,25 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +namespace jittor { + +DECLARE_FLAG(int, jit_search_kernel); + +struct Searcher { + OpCompiler* oc; + int64_t timeout, best_time; + loop_options_t best_choices; + + Searcher(OpCompiler* oc); + void reset(); + int64_t get_time_of_current_choices(); + void search(const loop_option_candidates_t& candidates); +}; + +} \ No newline at end of file diff --git a/python/jittor/src/opt/kernel_ir.cc b/python/jittor/src/opt/kernel_ir.cc new file mode 100644 index 00000000..a76d849e --- /dev/null +++ b/python/jittor/src/opt/kernel_ir.cc @@ -0,0 +1,1072 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "opt/kernel_ir.h" + +namespace jittor { + +template +vector::iterator> sort(unordered_map& m) { + vector::iterator> v; + v.reserve(m.size()); + for (auto i=m.begin(); i!=m.end(); ++i) + v.push_back(i); + auto cmp = [](const auto& a, const auto& b) -> bool { + return a->first < b->first; + }; + sort(v.begin(), v.end(), cmp); + return v; +} + +bool isvar(char x) { return isalnum(x) || x == '_' || x == ':'; } + +std::ostream& operator<<(std::ostream& os, KernelIR& ir) { + return os << ir.to_string(); +} + +void KernelIR::del_scope() { + if (father && (type=="define" || type=="func" || type=="macro")) { + father->scope[attrs["lvalue"]].remove(this); + } +} + +void KernelIR::add_scope() { + if (father && (type=="define" || type=="func" || type=="macro")) + father->scope[get_attr("lvalue")].push_back(this); +} + +void KernelIR::clear() { + del_scope(); + type.clear(); + attrs.clear(); + for (int i=(int)inner.size()-1; i>=0; i--) + inner[i]->erase(); +} + +string& KernelIR::get_attr(const string& s) { + return attrs[s]; +} + +bool KernelIR::has_attr(const string& s) { + auto iter = attrs.find(s); + if (iter == attrs.end() || iter->second.size()==0) + return false; + return true; +} + +void KernelIR::try_parse_define(const string& s) { + // dtype lvalue = rvalue; + clear(); + string& dtype = get_attr("dtype"); + string& lvalue = get_attr("lvalue"); + string& rvalue = get_attr("rvalue"); + int count=0; + uint end=s.size(); + bool find_eq = 0; + while (end && (s[end-1]==' ' || s[end-1]=='\n' || s[end-1]==';')) end--; + for (uint i=0; i=0 && s[r]!=')') r--; + ASSERT(l>* ls, bool raw) { + if (!ls) ls = &children; + ASSERT(ls>=&before && ls<=&after); + ls->emplace_back(std::make_unique(src, raw)); + auto& ir = *ls->back(); + ir.father = this; + ir.flist = ls; + ir.add_scope(); +} + +void KernelIR::push_front(const string& src, vector>* ls, bool raw) { + if (!ls) ls = &children; + ASSERT(ls>=&before && ls<=&after); + ls->insert(ls->begin(), std::make_unique(src, raw)); + auto& ir = *ls->front(); + ir.father = this; + ir.flist = ls; + ir.add_scope(); +} + +void KernelIR::push_back(unique_ptr&& irp, vector>* ls) { + ASSERT(irp->father==nullptr); + if (!ls) ls = &children; + ASSERT(ls>=&before && ls<=&after); + ls->emplace_back(move(irp)); + auto& ir = *ls->back(); + ir.father = this; + ir.flist = ls; + ir.add_scope(); +} + +void KernelIR::push_front(unique_ptr&& irp, vector>* ls) { + ASSERT(irp->father==nullptr); + if (!ls) ls = &children; + ASSERT(ls>=&before && ls<=&after); + ls->insert(ls->begin(), move(irp)); + auto& ir = *ls->front(); + ir.father = this; + ir.flist = ls; + ir.add_scope(); +} + +void remove_func_call_arg(string& src, int arg_i) { + int presum=0, aid=-1, prev=0; + for (int i=0; i<(int)src.size(); i++) { + if (src[i]=='(') presum++; + if (presum==1 && (src[i]=='(' || src[i]==',' || src[i]==')')) { + if (arg_i == aid) { + if (src[i]==',' && arg_i==0) i++; + src.erase(prev, i-prev); + return; + } + aid++; + prev = i+(src[i]=='('); + } + if (src[i]==')') presum--; + } + LOGf << "Func call do not have enough argument" << arg_i << src; +} + +void KernelIR::erase() { + ASSERT(father && flist); + // if is a function argument + if (father->type=="func" && flist==&father->inner) { + string& func_name = father->get_attr("lvalue"); + uint i=0; + while (isize() && flist->at(i).get() != this) i++; + ASSERT(i < flist->size()); + auto used = father->find_used(); + for (auto c : used) { + string& code = c->get_attr("code"); + if (c->type=="" && startswith(code, func_name)) + remove_func_call_arg(code, i); + } + } + del_scope(); + for (uint i=0; isize(); i++) + if ((*flist)[i].get() == this) { + flist->erase(flist->begin()+i); + return; + } + ASSERT(0); +} + +template +void KernelIR::for_each_rev(Func&& func) { + vector>* ls[] = {&before, &inner, &children, &after}; + for (auto& l : ls) { + for (int i=(int)l->size()-1; i>=0; i--) + func((*l)[i]); + } +} + +KernelIR* KernelIR::find_define(const string& name) { + auto iter = scope.find(name); + if (iter == scope.end() || iter->second.size()==0) { + if (father) + return father->find_define(name); + return nullptr; + } + ASSERT(iter->second.size()==1) << + "Name" << name << (iter->second.size()?"duplicate":"not found") + << this->to_string(0,1) << scope; + return iter->second.back(); +} + +unique_ptr KernelIR::clone(bool with_children) { + auto ir = std::make_unique(); + ir->type = type; + ir->attrs = attrs; + for (auto& c : before) + ir->push_back(c->clone(), &ir->before); + for (auto& c : inner) + ir->push_back(c->clone(), &ir->inner); + if (with_children) + for (auto& c : children) + ir->push_back(c->clone(), &ir->children); + for (auto& c : after) + ir->push_back(c->clone(), &ir->after); + return ir; +} + +void KernelIR::rebuild_scope() { + scope.clear(); + for_each([&](unique_ptr& c) { + c->add_scope(); + }); +} + +void KernelIR::update_father() { + auto update = [&](vector>* flist) { + for (auto& c : *flist) + c->flist = flist, c->father = this; + }; + update(&before); + update(&inner); + update(&children); + update(&after); +} + +void KernelIR::swap(KernelIR& other, bool with_children) { + std::swap(type, other.type); + std::swap(attrs, other.attrs); + std::swap(before, other.before); + std::swap(inner, other.inner); + if (with_children) std::swap(children, other.children); + std::swap(after, other.after); + update_father(); + other.update_father(); + rebuild_scope(); + other.rebuild_scope(); +} + +unique_ptr KernelIR::move_out() { + ASSERT(father && flist); + del_scope(); + int i=(int)flist->size()-1; + for (; i>=0; i--) + if ((*flist)[i].get() == this) + break; + ASSERT(i < (int)flist->size()); + unique_ptr ir = move((*flist)[i]); + flist->erase(flist->begin()+i); + flist = nullptr; + father = nullptr; + return ir; +} + +vector> KernelIR::move_out_children() { + vector> cs(children.size()); + int i=(int)children.size()-1; + for (; i>=0; i--) cs[i] = children[i]->move_out(); + return cs; +} + +bool KernelIR::check_attr(const string& k, const string& v) { + auto iter = attrs.find(k); + return iter!= attrs.end() && iter->second==v; +} + +vector KernelIR::find_loops(string lid) { + vector q({this}), loops; + for (uint i=0; icheck_attr("loop_id", lid)) { + loops.push_back(ir); + } + ir->for_each([&](unique_ptr& c) { + q.push_back(c.get()); + }); + } + return loops; +} + +string KernelIR::to_string(int level, bool debug) { + if (level==0 && debug) { + check_father(); + } + if (level==0 && type=="" && children.size()) { + level--; + } + std::stringstream s; + //TODO: no level up for before & after + //bool level_up = (before.size() || after.size()) && level>0; + bool level_up = (before.size() || after.size()) && level>0 && (type != "define" && type != ""); + if (level_up) { + for (int i=0; ito_string(level, debug); + if (debug) { + for (int i=0; ichildren) s << "C"; + if (flist == &father->before) s << "B"; + if (flist == &father->inner) s << "I"; + if (flist == &father->after) s << "A"; + s << " "; + } + s << type; + for (auto kv : sort(attrs)) + if (kv->second.size()) s << " " << kv->first << ":\"" << kv->second << '"'; + s << "\n"; + if (scope.size()) { + for (int i=0; ifirst << '(' << kv->second.size() << "), "; + s << "\n"; + } + } + for (int i=0; i=3); + s << "for ("; + for (int i=0; i<3; i++) { + auto c = inner[i]->to_string(); + c = c.substr(0, c.size()-2); // remove ;\n + s << c << (i==2?"":"; "); + } + s << ") "; + inner_left = 3; + has_bc = true; + } else { + // empty loop + has_bc = true; + } + } else if (type == "if") { + ASSERT(inner.size()>=1); + auto src = inner[0]->to_string(); + s << "if (" << src.substr(0, src.size()-2) << ") "; + inner_left = 1; + has_bc = true; + } else if (type == "define") { + s << attrs["dtype"] << " " << attrs["lvalue"]; + if (has_attr("rvalue")) + s << " = " << attrs["rvalue"]; + s << ";\n"; + } else if (type == "func") { + s << attrs["dtype"] << ' ' << attrs["lvalue"] << '('; + for (uint i=0; ito_string(); + s << arg.substr(0, arg.size()-2); + inner_left = inner.size(); + } + s << ") "; + has_bc = true; + } else if (father) { + auto iter = attrs.find("code"); + ASSERT(iter != attrs.end()) << attrs << type << father; + s << iter->second << "\n"; + has_bc = attrs.count("has_bc"); + } else { + s << "\n"; + } + if (has_bc) s << "{\n"; + for (uint i=inner_left; ito_string(level+1, debug); + for (auto& c : children) + s << c->to_string(level+1, debug); + if (has_bc) { + for (int i=0; ito_string(level, debug); + if (level_up) { + for (int i=0; i(src, raw)); + auto& ir = *children[pos]; + ir.father = this; + ir.flist = &children; + ir.add_scope(); +} + +void KernelIR::insert(uint pos, vector>& irs) { + vector> irs2(irs.size()); + for (int i=(int)irs.size()-1; i>=0; i--) { + if (irs[i]->father) + irs2[i] = irs[i]->move_out(); + else + irs2[i] = move(irs[i]); + } + children.insert( + children.begin()+pos, + make_move_iterator(irs2.begin()), + make_move_iterator(irs2.end()) + ); + for (uint i=0; ifather = this; + c->flist = &children; + c->add_scope(); + } +} + +void KernelIR::check_father() { + for_each([&](unique_ptr& c) { + ASSERTop(c->father,==,this) << "father attrs:" << attrs << "attrs:" << c->attrs; + c->check_father(); + }); +} + +bool KernelIR::get_number(const string& name, int& num) { + auto iter = scope.find(name); + if (iter == scope.end()) { + if (father) + return father->get_number(name, num); + num = -1; + return false; + } + ASSERT(iter->second.size()==1); + auto snum = iter->second.back()->attrs["rvalue"]; + if (snum.size() && isdigit(snum[0])) { + num = std::stoi(snum); + return true; + } + num = -2; + return false; +} + +KernelIR::KernelIR(const string& src, bool raw) { + uint end = src.size(); + uint start = 0; + while (end && (src[end-1] == ' ' || src[end-1] == '\n')) end--; + while (start1); + attrs["lvalue"] = v.at(1); + attrs["rvalue"] = v.size()>2 ? v.at(2) : ""; + return; + } else { + push_back(src.substr(j, k-j), nullptr, raw); + i = k; + continue; + } + } + if (j==end) return; + uint k=j; + while (k=2 && s[0]=='{' && s[s.size()-1]=='}') { + // empty loop + type = "loop"; + end--; + continue; + } + // func define + if (s.size()>=2 && s.back()=='}') { + int l = s.find("{"); + ASSERT(l != string::npos); + if (startswith(s, "namespace ")) { + // namespace xxx {...} + // l + attrs["code"] = s.substr(0, l); + attrs["has_bc"] = "1"; + type = ""; + i = j + l; + end--; + continue; + } + int ll = s.rfind("(", l); + int rr = s.rfind(")", l); + // if () not found, maybe src like this: + // vector a = {...}; + if (ll<0 && rr<0) { + type = ""; + attrs["code"] = src + ";"; + return; + } + ASSERT(l>=0 && ll>=0 && rr>=0 && ll0 && s[x-1]!=' ') x--; + int y = x-1; + while (y>0 && s[y]==' ') y--; + ASSERT(0 cid(children.size()); + uint num=0; + for (uint i=0; itype != "loop" || children[i]->has_attr("raw")) + num++; + for (uint i=0,j=0,k=0; itype != "loop" || children[i]->has_attr("raw")) + cid[i] = j++; + else + cid[i] = num + k++; + } + vector> cb(children.size()); + for (uint i=0; i>& replace_vars, bool equal, bool remove_define) { + string& lvalue = get_attr("lvalue"); + string& code = get_attr("code"); + string& rvalue = get_attr("rvalue"); + + int replace_time = 0; + int max_time = 1; + while (replace_timetype=="loop" || flist==&father->inner); + for (auto& p : replace_vars) + if (p.first != p.second) + if (startswith(lvalue, p.first, 0, equal)) { + // remove this define if matched + if (remove_define && !inside_loop) { + erase(); + return; + } else { + del_scope(); + lvalue = p.second + lvalue.substr(p.first.size()); + add_scope(); + replaced = true; + } + break; + } + } + } + string* ss[2] = {&code, &rvalue}; + for (int p=0; p<2; p++) { + auto& code = *ss[p]; + for (uint i=0; i= code.size()) break; + uint j=i+1; + while (j=p.first.size() && startswith(code, p.first, i, equal, j)) { + code.erase(i, p.first.size()); + code.insert(i, p.second); + j = j-p.first.size()+p.second.size(); + replaced = true; + break; + } + } + i = j; + } + } + + if (!replaced) break; + } + + ASSERT(max_time==1 || replace_time& c) { + c->replace(replace_vars, equal); + }); +} + + +void KernelIR::rename_loop_index() { + vector irs(1, this); + for (uint i=0; iget_attr("rvalue"); + auto& lvalue = ir->get_attr("lvalue"); + if (ir->type == "loop" && rvalue.size()) { + if (startswith(rvalue, "range")) { + auto& loop_id = ir->get_attr("loop_id"); + loop_id = rvalue.substr(5); + ir->replace({{lvalue, "id"+rvalue.substr(5)}}, true); + } else { + // TODO + LOGvvvv << "Unhandled loop var" << rvalue; + } + } + for (auto& c : ir->children) + if (c->type == "loop") + irs.push_back(c.get()); + } +} + + +void KernelIR::merge_loop() { + unordered_map loops; + for (int i=(int)children.size()-1; i>=0; i--) { + auto& loop = children[i]; + if (loop->type != "loop") + continue; + auto& loop_id = loop->get_attr("loop_id"); + if (!loop_id.size()) continue; + auto iter = loops.find(loop_id); + if (iter == loops.end()) { + loops[loop_id] = loop.get(); + continue; + } + auto* mloop = iter->second; + ASSERT(mloop->check_attr("loop_id", loop_id)); + mloop->insert(0, loop->children); + children.erase(children.begin()+i); + } + for (auto& kv : loops) + kv.second->merge_loop(); +} + +void KernelIR::solve_conflict_define() { + unordered_set defs; + for (size_t i=0; itype == "define") { + auto lvalue = c->get_attr("lvalue"); + if (lvalue.size()==0) + continue; + if (defs.count(lvalue)) { + // add _ to conflict definitions + string new_def = lvalue + '_'; + while (defs.count(new_def)) + new_def += '_'; + LOGvvv << "Conflict define" << c->to_string() << "change to" << new_def; + for (size_t j=i; jreplace({{lvalue, new_def}}, true, false); + defs.insert(new_def); + } else + defs.insert(lvalue); + } else + if (c->type == "loop") + c->solve_conflict_define(); + } +} + +void KernelIR::expand_empty_block() { + for (uint i=0; itype != "loop") continue; + if (loop->has_attr("loop_id")) { + loop->expand_empty_block(); + continue; + } + if (loop->has_attr("rvalue")) + continue; + insert(i+1, loop->children); + // use children[i] instead of loop + children[i]->erase(); + i--; + } +} + +void KernelIR::check_unused() { + if (has_attr("raw")) return; + attrs["used"] = ""; + const char* ss[] = {"code", "rvalue", "rvalue2"}; + for (const char* s : ss) { + auto& code = get_attr(s); + for (uint i=0; i= code.size()) break; + uint j=i+1; + while (jattrs["used"] = "1"; + } + i = j; + } + } + for_each([&](unique_ptr& c) { + c->check_unused(); + }); +} + +void KernelIR::find_used(KernelIR* def, vector& used) { + if (has_attr("raw")) return; + const char* ss[] = {"code", "rvalue", "rvalue2"}; + for (const char* s : ss) { + auto& code = get_attr(s); + for (uint i=0; i= code.size()) break; + uint j=i+1; + while (j& c) { + c->find_used(def, used); + }); +} + +vector KernelIR::find_used() { + vector used; + if (father) father->find_used(this, used); + return used; +} + + +bool KernelIR::remove_unused() { + bool has_unused = false; + for_each_rev([&](unique_ptr& c) { + has_unused |= c->remove_unused(); + }); + if (type=="define" && check_attr("used", "")) { + LOGvvvv << "Remove unused value:" << attrs["lvalue"]; + erase(); + return true; + } + return has_unused; +} + +void KernelIR::remove_all_unused() { + while (1) { + check_unused(); + if (!remove_unused()) + break; + } +} + +void KernelIR::remove_intermediate(const unordered_set& names) { + const char* ss[] = {"code", "lvalue", "rvalue"}; + bool need_re_parse = false; + for (const char* s : ss) { + auto& code = get_attr(s); + for (uint i=0; i= code.size()) break; + uint j=i+1; + while (j xxxd + if (k>=i && code[k]=='p' && names.count(code.substr(i,k-i))) { + code[k] = 'd'; + for (uint l=k+1; l auto xxxd = xxx + code = "auto " + code; + need_re_parse = true; + j += 5; + } + } + i = k+1; + continue; + } else + if (code[k] == 'p' && string(s)=="lvalue" && type=="define") { + if (names.count(code.substr(i,k-i))) { + erase(); + return; + } + } else + if (code[k] == 'p' && string(s)=="code" && type=="") { + if (names.count(code.substr(i,k-i))) { + // xxxp -> 0 + for (uint l=i; l& c) { + c->remove_intermediate(names); + }); +} + + +void KernelIR::split_loop(int i, int j) { + if (type=="loop" && check_attr("loop_id", S(i))) { + auto sj = S(j); + auto si = S(i); + auto& dtype = get_attr("dtype"); + auto& rvalue2 = get_attr("rvalue2"); + auto& lvalue = get_attr("lvalue"); + auto c = move_out_children(); + // set stride of loop i + rvalue2 = "stride" + si; + // change lvalue++ -> lvalue+=rvalue2 + ASSERT(inner.size()>=3); + ASSERT(lvalue == "id"+si); + inner[2]->attrs["code"] = lvalue+"+="+rvalue2+";"; + push_back("for ("+dtype+" id"+sj+"=0; id"+sj+"attrs["loop_id"] = sj; + sloop->attrs["split_id"] = si; + sloop->insert(0, c); + sloop->replace({{"id"+si, "(id"+si+"+id"+sj+")"}}, true); + return; + } + for_each([&](unique_ptr& c) { + c->split_loop(i,j); + }); +} + +void KernelIR::resplit() { + if (has_attr("resplited")) return; + ASSERT(type=="loop"); + attrs["resplited"] = "1"; + auto& rvalue2 = get_attr("rvalue2"); + auto& lvalue = get_attr("lvalue"); + auto& rvalue = get_attr("rvalue"); + auto& dtype = get_attr("dtype"); + ASSERT(rvalue2.size()); + ASSERT(inner.size()>3 && startswith(inner[3]->get_attr("lvalue"), "range")); + ASSERT(startswith(inner[3]->get_attr("rvalue"), "::min")) << + "No need to resplit"; + + // delete prev inner code(init and condition) + inner[0]->erase(); + inner[0]->erase(); + // condition + push_front(lvalue+"+"+rvalue2+"<="+rvalue+";", &inner); + // init + push_front(lvalue+"=0;", &inner); + // define + push_front(dtype+" "+lvalue+" = 0;", &before); + + int num=0; + if (get_number(rvalue2, num)) { + // range = number; + inner[3]->attrs["rvalue"] = S(num); + } else { + ASSERT(num == -2); + // range = rvalue2; + inner[3]->attrs["rvalue"] = rvalue2; + } + // add if and clone children + push_back("if ("+lvalue+"<"+rvalue+") {}", &after); + after.back()->push_back(dtype+" "+inner[3]->get_attr("lvalue")+" = "+rvalue+"-"+lvalue+";"); + for (auto &c : children) + after.back()->push_back(c->clone()); +} + +} // jittor diff --git a/python/jittor/src/opt/kernel_ir.h b/python/jittor/src/opt/kernel_ir.h new file mode 100644 index 00000000..ea5a7fe5 --- /dev/null +++ b/python/jittor/src/opt/kernel_ir.h @@ -0,0 +1,232 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include "utils/str_utils.h" + +namespace jittor { + +struct KernelIR { + // if type == define + // src is dtype lvalue = rvalue; + // if type == loop + // src is for (inner[0]; inner[1]; inner[2]) + // if type == if + // src is if (inner[0]) + // if type == func + // src is dtype lvalue(*inner) + // else + // src is code + string type; + KernelIR* father=nullptr; + vector>* flist; + // available attrs: + // * code: used in macro, comment type + // * dtype: used in define, func type + // * lvalue: used in define, func type + // * rvalue: used in define type + // * loop_id: generate by rename_loop_index + // * split_id: generate by split_loop + // * used: generate by check_unused + // * rvalue2: generate by split_loop, used in loop type, represent stride + unordered_map attrs; + // before... + // src { + // rest of inner... + // children... + // } + // after... + vector> before, inner, children, after; + unordered_map> scope; + + KernelIR() {} + KernelIR(const string& src, bool raw=false); + string& get_attr(const string& s); + bool has_attr(const string& s); + bool check_attr(const string& k, const string& v); + + // src: source for kernel ir + // irp: kernel ir + // raw: raw code or not + // ls: can be [before, inner, children, after], default is children + void push_back(const string& src, vector>* ls=nullptr, bool raw=false); + void push_front(const string& src, vector>* ls=nullptr, bool raw=false); + void push_back(unique_ptr&& irp, vector>* ls=nullptr); + void push_front(unique_ptr&& irp, vector>* ls=nullptr); + + // insert kernel ir (src) to pos + void insert(uint pos, const string& src, bool raw=false); + // insert kernel irs to pos + void insert(uint pos, vector>& irs); + // recursive clone kernel ir(type, attrs, before, inner, after) + // with_children: recursively clone children or not + unique_ptr clone(bool with_children=true); + // swap two loop + // with_children: swap children or not + void swap(KernelIR& other, bool with_children=false); + + // add into parent scope + void add_scope(); + // delete in parent scope + void del_scope(); + // clear and reconstruct scope + void rebuild_scope(); + // self destroy + void erase(); + // clear self(attr, inner), preserve children, before, after + void clear(); + // move out self from father + unique_ptr move_out(); + // move out all children + vector> move_out_children(); + + // try pase define statement from string s + // if failed, fail back to a normal code + void try_parse_define(const string& s); + + // parse syntax "for (dtype lvalue = 0; lvalue find_loops(string lid); + // find definition from ancestors + KernelIR* find_define(const string& name); + // for each sub ir, include before, inner, children, after + template + void for_each(Func&& func) { + vector>* ls[] = {&before, &inner, &children, &after}; + for (auto& l : ls) { + for (auto& c : (*l)) + func(c); + } + } + template + void dfs(Func&& func) { + vector>* ls[] = {&before, &inner, &children, &after}; + for (auto& l : ls) { + for (auto& c : (*l)) { + func(c); + c->dfs(func); + } + } + } + // for each sub ir backward, include before, inner, children, after + template void for_each_rev(Func&& func); + // update sub irs' father to itself + void update_father(); + + // recursively to_string + // level: indent level + // debug: output type, attrs, scope in comments, check father + string to_string(int level=0, bool debug=false); + + // move all loop back(exclude raw loop) + void move_loop_back(); + + // replace vars + // replace_vars: pairs of string, e.g. [(a,b), (x,y)] replace a->b, x->y + // equal: if true, replace vars need to match completely, if false, replace vars can match the prefix + // remove_define: if a definition statement matched, remove or not + // would't remove if inside a loop or is in inner list + void replace(const vector>& replace_vars, bool equal=false, bool remove_define=true); + + // recursively rename loop var by loop id, loop_id is parsed from rvalue + // for (dtype i=lvalue; i + // for (dtype id{loop_id}; id{loop_id} + // for (...) { s1; s2; } + void merge_loop(); + + // recursively expand block if no attr[loop_id] and attr[rvalue] + // { s1; s2; } + // --> + // s1; s2; + void expand_empty_block(); + + // recursively resolve conlict definitions + // T a = 1; a++; T a = 2; a++; + // --> + // T a = 1; a++; T _a = 2; _a++; + void solve_conflict_define(); + + // TODO: move to pass + // remove intermediate variables in names + // xxxp[...] -> xxxd + // xxxd = xxx -> auto xxxd = xxx + // xxxp -> 0 + void remove_intermediate(const unordered_set& names); + + // remove definitions which attr[used]==1, return remove or not + bool remove_unused(); + + // remove all unused definitions, until no unused definition occurs. + void remove_all_unused(); + + // recursively generate attr[used] + void check_unused(); + // recursively find used + void find_used(KernelIR* def, vector& used); + vector find_used(); + + // split loop(loop_id=i) into two loop + // for (T id{i}; id{i} + // for (T id{i}; id{i}father == self + void check_father(); +}; + +std::ostream& operator<<(std::ostream& os, KernelIR& ir); + +// match aaa::bbb +bool isvar(char x); + +// match x[y] +bool isvarp(char x); + +// remove arg_i-th arguments of func_call +// src: func_call source +// arg_i: arguments id +void remove_func_call_arg(string& src, int arg_i); + +} // jittor diff --git a/python/jittor/src/opt/pass/assume_aligned_pass.cc b/python/jittor/src/opt/pass/assume_aligned_pass.cc new file mode 100644 index 00000000..5cde114d --- /dev/null +++ b/python/jittor/src/opt/pass/assume_aligned_pass.cc @@ -0,0 +1,52 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "op_compiler.h" +#include "mem/allocator.h" +#include "opt/pass_manager.h" +#include "opt/pass/assume_aligned_pass.h" +#include "executor.h" + +namespace jittor { + +void AssumeAlignedPass::run() { + if (!op->get_loop_option("compile_shapes")) return; + ir->push_front("#define assume_aligned(ptr) (void)(__builtin_assume_aligned(ptr, alignment))", &ir->before); + auto check = [&](KernelIR* func) { + if (func->type != "func") + return; + vector>* ls[] = {&func->inner, &func->children}; + for (auto& l : ls) + for (auto& c : (*l)) { + if (c->type != "define") continue; + auto& lvalue = c->get_attr("lvalue"); + // if is a var pointer + if (startswith(lvalue, "op") && endswith(lvalue, "p")) { + string name = lvalue.substr(0, lvalue.size()-1); + uint op_id, opvar_id; + Op* op; + Var* var; + pm->oc->get_op_var_by_name(name, op_id, opvar_id, op, var); + // add assume_aligned if is aligned_allocator + if (exe.allocator->is_aligned()) { + // if is a function arguments + if (l == ls[0]) + func->push_front("assume_aligned("+lvalue+");"); + else + c->push_back("assume_aligned("+lvalue+");", &c->after); + } + } + } + + }; + check(ir); + for (auto& c : ir->before) + check(c.get()); +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/assume_aligned_pass.h b/python/jittor/src/opt/pass/assume_aligned_pass.h new file mode 100644 index 00000000..9794460b --- /dev/null +++ b/python/jittor/src/opt/pass/assume_aligned_pass.h @@ -0,0 +1,17 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct AssumeAlignedPass : Pass { + AssumeAlignedPass() : Pass("assume_aligned") {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/atomic_tuner_pass.h b/python/jittor/src/opt/pass/atomic_tuner_pass.h new file mode 100644 index 00000000..456a8389 --- /dev/null +++ b/python/jittor/src/opt/pass/atomic_tuner_pass.h @@ -0,0 +1,20 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guowei Yang <471184555@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct AtomicTunerPass : Pass { + AtomicTunerPass() : Pass("atomic") {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/check_cache_pass.cc b/python/jittor/src/opt/pass/check_cache_pass.cc new file mode 100644 index 00000000..80181575 --- /dev/null +++ b/python/jittor/src/opt/pass/check_cache_pass.cc @@ -0,0 +1,152 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "opt/pass_manager.h" +#include "opt/pass/check_cache_pass.h" +#include +#include +#include "profiler/memory_checker.h" +using namespace std; + +namespace jittor { + +void CheckCachePass::run() { + auto choice = op->get_loop_option("check_cache"); + + if (!choice) return; + + /* + input: + single simple assignment like: + a[x] = b[y] + c[z]; + or read only like: + f(a[x]); + output: (read_addr_list, write_addr_list) + */ + auto get_read_write_address = [&](string code) -> pair, vector> { + vector assignment_list({"[^><=!]=[^=]", "\\+=", "-=", "\\*=", "/=", "%=", ">>=", "<<=", "&=", "\\|=", "\\^="}); + vector read_addr, write_addr; + string pa; + for (int i = 0; i < (int)assignment_list.size(); ++i) { + if (i > 0) { + pa += "|"; + } + pa += assignment_list[i]; + } + regex pattern(pa); + int assignment_cnt = 0, assignment_pos = -1; + string temp_code = code; + smatch m; + while (regex_search(temp_code, m, pattern)) { + assignment_pos = m.position(0); + ++assignment_cnt; + temp_code = m.suffix().str(); + } + ASSERT(assignment_cnt <= 1); // for simple assignment only + vector address_pos; + for (int i = 0; i < (int)code.length(); ++i) { + if (code[i] == '[') { + address_pos.push_back(i); + } + if (code[i] == ']') { + int sp = address_pos.back() - 1; + // don't check cache of shape[...] + if (sp>=4 && code.substr(sp-4, 5) == "shape") { + address_pos.pop_back(); + continue; + } + while (sp >= 0 && ((code[sp] >= 'A' && code[sp] <= 'Z') || (code[sp] >= 'a' && code[sp] <= 'z') || + (code[sp] >= '0' && code[sp] <= '9') || code[sp] == '_' || code[sp] == '.' || (sp > 0 && code[sp] == '>' && code[sp - 1] == '-'))) { + if (sp > 0 && code[sp] == '>' && code[sp - 1] == '-') + sp -= 2; + else + --sp; + } + ++sp; + string s = "(size_t)&(" + code.substr(sp, i - sp + 1) + ")"; + if (i <= assignment_pos) + write_addr.push_back(s); + else + read_addr.push_back(s); + address_pos.pop_back(); + } + } + return make_pair(read_addr, write_addr); + }; + size_t page_size = op->get_loop_option("page_size"), vtop = op->get_loop_option("vtop"), + tlb_size = op->get_loop_option("tlb_size"), tlb_ways = op->get_loop_option("tlb_ways"), tlb_line_size = op->get_loop_option("tlb_line_size"), + L1_size = op->get_loop_option("L1_size"), L1_ways = op->get_loop_option("L1_ways"), L1_line_size = op->get_loop_option("L1_line_size"), + L2_size = op->get_loop_option("L2_size"), L2_ways = op->get_loop_option("L2_ways"), L2_line_size = op->get_loop_option("L2_line_size"), + L3_size = op->get_loop_option("L3_size"), L3_ways = op->get_loop_option("L3_ways"), L3_line_size = op->get_loop_option("L3_line_size"); + + ir->push_back("#include \"profiler/memory_checker.h\"", &ir->before); + ir->push_back("using namespace jittor;", &ir->before); + // declaration + ir->push_back("EXTERN_LIB \"C\" std::unique_ptr memory_checker;", &ir->before); + // definition + ir->push_back("std::unique_ptr memory_checker;", &ir->before); + vector commands; + stringstream command; + string replace_strategy = MemoryChecker::get_replace_strategy(op->get_loop_option("replace_strategy")); + + command << "Cache* tlb = new " << replace_strategy << "(CacheConfig(" << tlb_size << "," << tlb_ways << "," << tlb_line_size << "));"; + commands.push_back(command.str()); + command.str(""); + command << "Cache* L1 = new " << replace_strategy << "(CacheConfig(" << L1_size << "," << L1_ways << "," << L1_line_size << "));"; + commands.push_back(command.str()); + command.str(""); + command << "Cache* L2 = new " << replace_strategy << "(CacheConfig(" << L2_size << "," << L2_ways << "," << L2_line_size << "));"; + commands.push_back(command.str()); + command.str(""); + command << "Cache* L3 = new " << replace_strategy << "(CacheConfig(" << L3_size << "," << L3_ways << "," << L3_line_size << "));"; + commands.push_back(command.str()); + command.str(""); + commands.push_back("vector caches({L1, L2, L3});"); + commands.push_back("memory_checker.reset(new MemoryChecker(tlb, caches, "+S(page_size)+","+S(vtop)+"));"); + + while (commands.size()) { + ir->push_front(commands.back(), &ir->children, true); + commands.pop_back(); + } + vector q({ir}); + vector attrs_to_check{"code", "rvalue"}; + for (uint i=0; ifor_each([&](unique_ptr& c) { + q.push_back(c.get()); + }); + + vector codes_to_check; + for (auto& attr : attrs_to_check) { + if (!ir->has_attr(attr)) continue; + auto& code = ir->attrs[attr]; + codes_to_check.push_back(code); + } + for (int j = 0; j < (int)ir->inner.size(); ++j) { + codes_to_check.push_back(ir->inner[j]->to_string()); + } + for (int j = 0; j < (int)codes_to_check.size(); ++j) { + string code = codes_to_check[j]; + pair, vector> rw_list = get_read_write_address(code); + for (int k = 0; k < (int)rw_list.first.size(); ++k) { + string addr = rw_list.first[k]; + ir->push_back("memory_checker->check_hit(" + addr + ");", &ir->before); + } + for (int k = 0; k < (int)rw_list.second.size(); ++k) { + string addr = rw_list.second[k]; + ir->push_back("memory_checker->check_hit(" + addr + ");", &ir->before); + } + } + } + //ir->push_back("memory_checker->print_miss();", &ir->children); +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/check_cache_pass.h b/python/jittor/src/opt/pass/check_cache_pass.h new file mode 100644 index 00000000..f3fc364d --- /dev/null +++ b/python/jittor/src/opt/pass/check_cache_pass.h @@ -0,0 +1,20 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct CheckCachePass : Pass { + CheckCachePass() : Pass("check_cache") {}; + void run() override; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/compile_shapes_pass.cc b/python/jittor/src/opt/pass/compile_shapes_pass.cc new file mode 100644 index 00000000..82634d36 --- /dev/null +++ b/python/jittor/src/opt/pass/compile_shapes_pass.cc @@ -0,0 +1,40 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "op_compiler.h" +#include "opt/pass_manager.h" +#include "opt/pass/compile_shapes_pass.h" + +namespace jittor { + +void CompileShapesPass::run() { + if (!op->get_loop_option("compile_shapes")) return; + for (auto& c : ir->children) { + if (c->type != "define") continue; + auto& rvalue = c->get_attr("rvalue"); + // T range = op{i}_{vnamr}->shape[j]; + // j i + if (!startswith(rvalue, "op") || rvalue.back() != ']') + continue; + uint i=rvalue.size()-2; + while (i && isdigit(rvalue[i])) i--; + ASSERT(rvalue[i] == '[' && i>7); + uint j = i-7; + ASSERT(startswith(rvalue, "->shape[", j)); + string name = rvalue.substr(0, j); + uint op_id, opvar_id; + Op* op; + Var* var; + pm->oc->get_op_var_by_name(name, op_id, opvar_id, op, var); + int shapeid = std::stoi(rvalue.substr(i+1, rvalue.size()-i-2)); + ASSERT(shapeid < (int)var->shape.size()); + rvalue = S(var->shape[shapeid]); + } +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/compile_shapes_pass.h b/python/jittor/src/opt/pass/compile_shapes_pass.h new file mode 100644 index 00000000..974902b1 --- /dev/null +++ b/python/jittor/src/opt/pass/compile_shapes_pass.h @@ -0,0 +1,17 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct CompileShapesPass : Pass { + CompileShapesPass() : Pass("compile_shapes") {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/const_var_pass.cc b/python/jittor/src/opt/pass/const_var_pass.cc new file mode 100644 index 00000000..11bb2d52 --- /dev/null +++ b/python/jittor/src/opt/pass/const_var_pass.cc @@ -0,0 +1,54 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "opt/expr.h" +#include "var.h" +#include "opt/pass_manager.h" +#include "opt/pass/const_var_pass.h" +#include "ops/array_op.h" +#include "jit_key.h" + +namespace jittor { + +using namespace expr; + +void ConstVarPass::run() { + JK& jk = get_jk(); + int changed = 0; + for (int i=0; iops.size(); i++) { + auto opi = op->ops[i]; + if (opi->name() != string("array")) + continue; + string s; + auto* v = opi->output(0); + if (v->num != 1) + continue; + auto array_op = (ArrayOp*)opi; + jk.clear(); + array_op->jit_prepare(jk); + if (jk.to_string().find("[o:") == string::npos) + continue; + if (v->dtype() == ns_int32) { + s = S(array_op->ptr()[0]); + } else + if (v->dtype() == ns_float32) { + s = S(array_op->ptr()[0]); + } else + continue; + auto def = ir->find_define("op"+S(i)+"_outputd"); + ASSERT(def); + def->attrs["dtype"] = v->dtype().to_cstring(); + def->attrs["rvalue"] = s; + changed ++; + LOGvvvv << def->to_string(); + } + if (changed) { + ir->remove_all_unused(); + } +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/const_var_pass.h b/python/jittor/src/opt/pass/const_var_pass.h new file mode 100644 index 00000000..536787aa --- /dev/null +++ b/python/jittor/src/opt/pass/const_var_pass.h @@ -0,0 +1,17 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct ConstVarPass : Pass { + ConstVarPass() : Pass("const_var_pass") {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/expand_empty_block_pass.cc b/python/jittor/src/opt/pass/expand_empty_block_pass.cc new file mode 100644 index 00000000..3070a879 --- /dev/null +++ b/python/jittor/src/opt/pass/expand_empty_block_pass.cc @@ -0,0 +1,44 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "opt/pass_manager.h" +#include "opt/pass/expand_empty_block_pass.h" + +namespace jittor { + +void check_empty_block(KernelIR* ir) { + for (uint i=0; ichildren.size(); i++) { + auto loop = ir->children[i].get(); + if (loop->type != "loop") continue; + if (loop->has_attr("loop_id")) { + continue; + } + if (loop->has_attr("rvalue")) + continue; + ir->insert(i+1, "for (int _=0; _<1; _++) {}"); + ir->children[i+1]->insert(0, loop->children); + // use children[i] instead of loop + ir->children[i]->erase(); + i--; + } +} + +void ExpandEmptyBlockPass::run() { + check_empty_block(ir); + ir->expand_empty_block(); +} + +JIT_TEST(check_empty_block) { + KernelIR ir("x=1;{a=1;}y=1;"); + check_empty_block(&ir); + ASSERT(ir.children[1]->attrs.at("lvalue")=="_"); + ir.move_loop_back(); + ASSERT(ir.children[2]->attrs.at("lvalue")=="_"); +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/expand_empty_block_pass.h b/python/jittor/src/opt/pass/expand_empty_block_pass.h new file mode 100644 index 00000000..76c9a499 --- /dev/null +++ b/python/jittor/src/opt/pass/expand_empty_block_pass.h @@ -0,0 +1,17 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct ExpandEmptyBlockPass : Pass { + ExpandEmptyBlockPass() : Pass("expand_empty_block") {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/fake_main_pass.cc b/python/jittor/src/opt/pass/fake_main_pass.cc new file mode 100644 index 00000000..847ccd69 --- /dev/null +++ b/python/jittor/src/opt/pass/fake_main_pass.cc @@ -0,0 +1,165 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "opt/pass_manager.h" +#include "opt/pass/fake_main_pass.h" + +namespace jittor { + +void FakeMainPass::run() { + // if this op is relayed, we don't run fake main pass + for (auto& o : op->ops) + if (!pm->oc->op_exist(o)) + return; + + // TODO: fake_main only supported when compile_shapes + if (!op->get_loop_option("jtune")) + return; + all->push_back("#include "); + all->push_back("#include "); + all->push_back("using namespace std;"); + all->push_back("using namespace jittor;"); + if (op->flags.get(NodeFlags::_cpu)) { + all->push_back("void* _fake_alloc(size_t size) {\n" + "return aligned_alloc(alignment, size);\n" + "}", nullptr, true); + } else { + all->push_back("void* _fake_alloc(size_t size) {\n" + "char* ptr;\n" + "checkCudaErrors(cudaMallocManaged((void**)&ptr, sizeof(char)*size));\n" + "return (void*) ptr;\n" + "}", nullptr, true); + } + all->push_back("int64_t _getenv(const char* name, int64_t _default) {\n" + "auto* v = getenv(name);\n" + "return v?stoll(v):_default;\n" + "}", nullptr, true); + all->push_back("void output_float(const string& scale, int base, const string& suffix, double k) {\n" + "uint w=10, p=3;\n" + "cout << ' ' << std::setw(w-2-suffix.size());\n" + "cout << std::setprecision(p);\n" + "uint i=0;\n" + "for (; i+1push_back("extern \"C\" void fake_main() {\n" + "cout << \"Enter fake_main entry.\" << endl;\n" + "#define fake_new(T) ((T*)(new char[sizeof(T)]()))\n" + "auto* op = fake_new(FusedOp);\n" + "auto& ops = op->ops;\n" + "Var* var;" + "}", nullptr, true); + auto& main = all->children.back(); + // fake ops + for (uint i=0; iops.size(); i++) { + auto* opi = op->ops[i]; + string name = opi->name(); + string name2 = Op::op_name_to_file_name(name); + string name3 = Op::file_name_to_class_name(name2); + main->push_back( + "ops.push_back(fake_new("+name3+"Op));" + ); + if (name3=="Array") { + main->push_back( + "{\n" + "auto ptr = new double[1];\n" + "((ArrayOp*)(ops["+S(i)+"]))->allocation.ptr = (void*)ptr;\n" + "ptr[0] = 0;\n" + "}\n" + ); + main->push_back("ArrayOp* op"+S(i)+"=((ArrayOp*)(ops["+S(i)+"]));"); + } + } + // fake vars + map var_map; + for (auto& c : ir->children) { + if (c->type != "define") continue; + auto& name = c->attrs["lvalue"]; + auto& rvalue = c->attrs["rvalue"]; + uint op_id, var_id; + Op* op; + Var* var; + try { + pm->oc->get_op_var_by_name(name, op_id, var_id, op, var); + } catch (...) { + continue; + } + // build fake var + auto vec_to_str = [](const NanoVector& v) -> string { + std::stringstream ss; + ss << '{'; + for (uint i=0; ipush_back(rvalue+" = "+var_map[(size_t)var]+";", nullptr, true); + continue; + } + var_map[(size_t)var] = rvalue; + main->push_back("{\n" + +rvalue+"= var = fake_new(Var);\n" + "var->flags.flags = "+S(var->flags.flags)+";\n" + "var->shape = "+vec_to_str(var->shape)+";\n" + "var->size = "+S(var->size)+";\n" + "var->num = "+S(var->num)+";\n" + "var->mem_ptr = _fake_alloc(var->size);\n" + "}", nullptr, true); + } + uint64_t in, out, compute; + op->statistics(in, out, compute); + string need_sync = op->flags.get(NodeFlags::_cuda) ? "checkCudaErrors(cudaDeviceSynchronize());\n" : ""; + main->push_back("{\n" + "auto warmup = _getenv(\"warmup\", 2);\n" + "auto rerun = _getenv(\"rerun\", 10);\n" + "int loop = "+S(op->get_loop_option("insert_profile_loop")?10:0)+";\n" + "warmup = warmup ? std::max(warmup>>loop, (int64_t)1) : 0;\n" + "rerun = std::max((rerun+1)>>loop, (int64_t)1);\n" + "int64_t in = "+S(in)+";\n" + "int64_t out = "+S(out)+";\n" + "int64_t compute = "+S(compute)+";\n" + "int64_t count=0, time_max=0, time_min=1ll<<62, time_total=0;\n" + "int64_t in_total=0, out_total=0, compute_total=0;\n" + "for (int64_t i=0; ijit_run();\n" + +need_sync+ + "for (int64_t i=0; ijit_run();\n" + +need_sync+ + "auto finish = std::chrono::high_resolution_clock::now();\n" + "auto total_ns = (int64_t)std::chrono::duration_cast(finish-start).count();\n" + "// 24ns function call overhead\n" + "total_ns = std::max((int64_t)1, total_ns-24);\n" + "count += 1<>loop);\n" + "time_min = std::min(time_min, total_ns>>loop);\n" + "time_total += total_ns;\n" + "in_total += in<. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct FakeMainPass : Pass { + FakeMainPass() : Pass("fake_main") {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/float_atomic_fix_pass.cc b/python/jittor/src/opt/pass/float_atomic_fix_pass.cc new file mode 100644 index 00000000..fa823225 --- /dev/null +++ b/python/jittor/src/opt/pass/float_atomic_fix_pass.cc @@ -0,0 +1,90 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "opt/expr.h" +#include "opt/pass_manager.h" +#include "opt/pass/float_atomic_fix_pass.h" +#include "utils/str_utils.h" + +namespace jittor { + +void FloatAtomicFixPass::run() { + auto choice = op->get_loop_option("parallel"); + bool is_cuda = op->flags.get(NodeFlags::_cuda); + if (is_cuda) choice=1; + if (!choice) return; + + unordered_map fixed; + auto fix_float_atomic = [&](string name, Var* v) { + if (fixed.count(name)) return; + fixed[name] = 1; + string namep = name+"p"; + ir->dfs([&](unique_ptr& i) { + if (!i->has_attr("code")) return; + auto& code = i->attrs["code"]; + if (!startswith(code, namep)) return; + LOGvvvv << "find code" << code; + auto src = expr::make(code); + auto target = expr::make(namep+"[b]=c"); + vector> results; + if (!expr::match(src.get(), target.get(), {"b","c"}, {}, results)) + return; + // fix code a[b] = c --> + // a[b] = __int_as_float(floatToOrderedInt(c)) + string new_code; + if (v->dtype() == ns_float32) + new_code = namep+'['+results.at(0)->to_string(true)+ + "] = __int_as_float(floatToOrderedInt(" + + results.at(1)->to_string(true) + "));"; + else + new_code = namep+'['+results.at(0)->to_string(true)+ + "] = __longlong_as_double(floatToOrderedInt(" + + results.at(1)->to_string(true) + "));"; + LOGvvvv << "prev code" << code >> "\nreplace:" << new_code; + code = new_code; + }); + ir->push_back("fix_float("+namep+", "+name+"->num);"); + }; + + ir->dfs([&](unique_ptr& i) { + if (!i->has_attr("code")) return; + auto& code = i->attrs["code"]; + const char* m = nullptr; + if (startswith(code, "cuda_atomic_min")) + m = "cuda_atomic_min"; + else if (startswith(code, "cuda_atomic_max")) + m = "cuda_atomic_max"; + if (!m) return; + LOGvvvv << "find match" << m << i; + vector> results; + auto target = expr::make(string(m)+"(&x[y], z)"); + auto src = expr::make(code); + if (!expr::match(src.get(), target.get(), {"x","y","z"}, {}, results)) + return; + LOGvvvv << "match results" << results; + uint op_id; uint opvar_id; Op* op; Var* var; + string s = results.at(0)->to_string(); + if (s.rbegin()[0] != 'p') return; + s = s.substr(0, s.size()-1); + try { + pm->oc->get_op_var_by_name(s, op_id, opvar_id, op, var); + } catch (...) { + return; + } + if (!var->dtype().is_float()) return; + if (var->dtype() == ns_float16 || var->dtype() == ns_bfloat16) + // float16 use atomicCAS, because no float16 atomicMax + return; + LOGvvvv << "find var" << var << "op" << op; + fix_float_atomic(s, var); + }); +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/float_atomic_fix_pass.h b/python/jittor/src/opt/pass/float_atomic_fix_pass.h new file mode 100644 index 00000000..b82419bf --- /dev/null +++ b/python/jittor/src/opt/pass/float_atomic_fix_pass.h @@ -0,0 +1,19 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct FloatAtomicFixPass : Pass { + FloatAtomicFixPass() : Pass("float_atomic_fix") {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/insert_profile_loop_pass.cc b/python/jittor/src/opt/pass/insert_profile_loop_pass.cc new file mode 100644 index 00000000..bfa50896 --- /dev/null +++ b/python/jittor/src/opt/pass/insert_profile_loop_pass.cc @@ -0,0 +1,39 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "opt/pass_manager.h" +#include "opt/pass/insert_profile_loop_pass.h" + +namespace jittor { + +void InsertProfileLoopPass::run() { + if (!op->get_loop_option("insert_profile_loop")) return; + int loopend = ir->children.size()-1; + auto check_loop = [](unique_ptr& c) -> bool { + return c->type == "loop" || c->has_attr("loop_func"); + }; + while (loopend>=0 && !check_loop(ir->children[loopend])) + loopend--; + if (loopend<0) { + LOGw << "Loop body not found, profile loop cannot insert."; + return; + } + int loopid = loopend; + while (loopid>0 && check_loop(ir->children[loopid-1])) + loopid--; + vector> loops(loopend-loopid+1); + for (int i=loopend, j=loops.size()-1; i>=loopid; i--, j--) + loops[j] = ir->children[i]->move_out(); + + ir->insert(loopid, "for (int _=0; _<1024; _++) {}"); + auto& loop = ir->children[loopid]; + loop->push_back("__asm__ __volatile__ (\"\": : : \"memory\");"); + loop->insert(1, loops); +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/insert_profile_loop_pass.h b/python/jittor/src/opt/pass/insert_profile_loop_pass.h new file mode 100644 index 00000000..70e7694c --- /dev/null +++ b/python/jittor/src/opt/pass/insert_profile_loop_pass.h @@ -0,0 +1,17 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct InsertProfileLoopPass : Pass { + InsertProfileLoopPass() : Pass("insert_profile_loop") {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/loop_to_func_pass.cc b/python/jittor/src/opt/pass/loop_to_func_pass.cc new file mode 100644 index 00000000..3cf6d263 --- /dev/null +++ b/python/jittor/src/opt/pass/loop_to_func_pass.cc @@ -0,0 +1,112 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "opt/pass_manager.h" +#include "opt/pass/loop_to_func_pass.h" + +namespace jittor { + +DECLARE_FLAG(string, cc_type); + +void LoopToFuncPass::run() { + auto choice = op->get_loop_option("parallel"); + bool is_cuda = op->flags.get(NodeFlags::_cuda); + if (is_cuda) choice=1; + if (cc_type=="clang") choice=1; + if (!choice) return; + int func_num=0; + string hash_name = op->get_hash_name(); + + ir->push_back("using namespace jittor;", &ir->before); + if ((cc_type=="icc" || cc_type=="g++") && choice) + // icc will failed if not inline when parallel + ir->push_back("#define INLINE_FUNC inline static void ", &ir->before); + else + ir->push_back("#define INLINE_FUNC __attribute__((always_inline)) static void ", &ir->before); + for (uint i=0; ichildren.size(); i++) { + auto& c = ir->children[i]; + if (c->type != "loop") continue; + if (c->has_attr("vectorized") || c->has_attr("unrolled") || c->has_attr("resplited")) + continue; + if (c->before.size()) + continue; + if (c->inner.size() < 3) + continue; + if (!c->has_attr("lvalue")) + continue; + if (c->has_attr("raw")) + continue; + + // func definition + ir->push_back("INLINE_FUNC func_"+hash_name+"_"+S(func_num++)+"() {}", &ir->before); + auto& func = ir->before.back(); + + // generate function arguments + vector args; + for (auto& d : ir->children) { + if (d->has_attr("raw")) continue; + if (d->type == "loop") break; + if (d->has_attr("code") && startswith(d->attrs["code"], "func")) break; + if (d->type == "define") { + if (d->has_attr("rvalue")) { + auto& rvalue = d->attrs["rvalue"]; + auto& dtype = d->attrs["dtype"]; + if (endswith(d->attrs["lvalue"], "_value") || + endswith(d->attrs["lvalue"], "_outputv")) { + args.push_back(d.get()); + continue; + } + if (rvalue.find("ops") != string::npos) + continue; + if (dtype=="Var*") + continue; + if (dtype=="Op*") + continue; + if (rvalue.find("->") != string::npos || + dtype.find("*") != string::npos) { + args.push_back(d.get()); + continue; + } + } + } + func->push_back(d->clone()); + } + func->push_back(c->clone()); + string func_call = func->attrs["lvalue"]+"("; + for (auto arg : args) { + if (arg != args.front()) + func_call += ','; + auto dtype = arg->attrs["dtype"]; + auto& lvalue = arg->attrs["lvalue"]; + auto& rvalue = arg->attrs["rvalue"]; + if (startswith(dtype, "auto")) { + // resolve auto + if (rvalue.find("<") == -1 || rvalue.find(">") == -1) { + //resolve auto xxx = ((T*)xxx)[0]; + std::vector temp = split(split(rvalue, "*)", 2).at(0), "(", 0); + dtype = temp[temp.size() - 1] + dtype.substr(4); + } else { + dtype = split(split(rvalue, "<", 2).at(1), ">", 2).at(0) + dtype.substr(4); + } + } + func_call += arg->attrs["lvalue"]; + func->push_back(dtype+" "+lvalue+";", &func->inner); + } + func_call += ");"; + c->erase(); + ir->insert(i, func_call); + + auto& fc = ir->children[i]; + fc->attrs["loop_func"] = func->attrs["lvalue"]; + } + #ifdef __APPLE__ + ir->remove_all_unused(); + #endif +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/loop_to_func_pass.h b/python/jittor/src/opt/pass/loop_to_func_pass.h new file mode 100644 index 00000000..2200f997 --- /dev/null +++ b/python/jittor/src/opt/pass/loop_to_func_pass.h @@ -0,0 +1,17 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct LoopToFuncPass : Pass { + LoopToFuncPass() : Pass("loop_to_func") {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/loop_var_analyze_pass.cc b/python/jittor/src/opt/pass/loop_var_analyze_pass.cc new file mode 100644 index 00000000..ad9b0e8e --- /dev/null +++ b/python/jittor/src/opt/pass/loop_var_analyze_pass.cc @@ -0,0 +1,313 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "opt/pass_manager.h" +#include "opt/pass/loop_var_analyze_pass.h" +#include "ops/reduce_op.h" +#include "ops/broadcast_to_op.h" +#include "ops/reindex_op.h" + +namespace jittor { + +DEFINE_FLAG(int, para_opt_level, 3, "para_opt_level"); + +void LoopVarAnalyzePass::run() { + // loop_vars: opi_xx->shape[j] + vector loop_vars; + // we use input var of reduce op as the loop var + // TODO: consider reshape op + auto& vars = op->vars; + bool has_reduce = false, has_element = false; + bool has_op = false; + for (Op* op : this->op->ops) { + auto& op_members = this->pm->oc->op_members; + // TODO: fix it + // ugly temp fix for index_var + auto opid = this->op->get_node_id(op); + if (op->name()==string("index") && + op->inputs().size()+op->outputs().size() != op_members[opid].size()) { + op_members[opid].insert(op_members[opid].begin(), "wtf"); + } + } + // LoopVarAnalyzePass has three steps: + // 1. Find the appropriate variable and use its shape as loop variables. + // Those loop vars are store in "vector loop_vars", + // e.g. op0_x->shape[0], op0_x->shape[1], op0_x->shape[2], ... + // 2. Replace the loop variable of the fused op with those loop variable. + // + // For example, previous code is : + // + // index_t op0_xshape0 = op0_x->shape[0]; + // index_t op0_xshape1 = op0_x->shape[1]; + // index_t op0_xshape2 = op0_x->shape[2]; + // for (index_t op0_i0=0; op0_i0shape[0]; + // index_t op1_yshape1 = op1_y->shape[1]; + // index_t op1_yshape2 = op1_y->shape[2]; + // for (index_t op1_y0=0; op1_y0shape[0]; + // index_t range1 = op0_x->shape[1]; + // index_t range2 = op0_x->shape[2]; + // for (index_t op0_i0=0; op0_i0 op2, op2's input is an alias of op1's output, + // Suppose the input of op2 is op2_x, the output of op1 is op1_y + // we replace op2_x with op1_y + + // TODO: find loop range in better way + // we pick loop var from below priority: + // 1. reduce input + // 2. element input + // 3. broadcast output + + // ugly fix multi different dim element input + // (caused by force fused array op) + int max_elm_dim = 0; + int64 max_elm_size = 0; + for (uint i=0; iinput(); + if (!pm->oc->op_exist(op)) + continue; + has_op = true; + if (op->type() == OpType::reduce) + has_reduce = true; + if (op->type() == OpType::element) { + has_element = true; + max_elm_dim = std::max(max_elm_dim, op->outputs().front()->shape.size()); + if (max_elm_dim == op->outputs().front()->shape.size()) + max_elm_size = std::max(max_elm_size, std::abs(op->outputs().front()->num)); + } + } + } + for (uint i=0; iinput(); + // input var as loop var + // TODO: consider only broadcast + var = op->inputs().front(); + if (!pm->oc->op_exist(op)) + continue; + if (has_reduce && op->type() != OpType::reduce) + continue; + if (has_element && !has_reduce && op->type() != OpType::element) + continue; + if (op->type() == OpType::element + && (op->outputs().front()->shape.size() != max_elm_dim || + std::abs(op->outputs().front()->num) != max_elm_size)) + continue; + if (op->name_ex() == "array") + // array op should not be loop var + continue; + Var* loop_var; + if (op->type() == OpType::broadcast || op->name_ex() == "index") { + loop_var = op->output(0); + } else { + loop_var = op->inputs().front(); + } + loop_vars.reserve(loop_var->shape.size()); + string vname = pm->oc->get_name_by_op_var(op, loop_var); + ASSERT(vname!="__fill__"); + for (uint j=0; jshape.size(); j++) + loop_vars.emplace_back(vname+"->shape["+S(j)+"]"); + break; + } + } + ASSERT(!has_op || loop_vars.size()) << "Loop var not found." << op->ops; + // if (loop_vars.size()==0) { + // LOGw << "TODO: loop var not found."; + // // return; + // } + vector> loop_var_defines; + vector loop_var_names; + vector unused; + for (uint k=0; kshape[{j}] + loop_var_define << opi << "_index_t range" << k << + " = " << loop_vars[k] << ";"; + loop_var_defines.emplace_back( + std::make_unique(loop_var_define.str())); + loop_var_names.emplace_back(string("range")+S(k)); + unused.emplace_back(loop_var_names.back()); + } + number_of_ranges = loop_var_names.size(); + int member_count=pm->oc->total_member_count(); + + ir->insert(member_count, loop_var_defines); + // replace loop var + vector> replace_vars; + for (uint i=0; iops.size(); i++) { + Op* opi = op->ops[i]; + uint ndim=0; + uint64_t mask=0; + vector vnames; + // loop var may not exist(relayed) + if (!pm->oc->op_exist(opi)) + continue; + if (opi->name()==string("array")) + continue; + if (opi->type() == OpType::reduce) { + ndim = ((ReduceOp*)opi)->inputs().front()->shape.size(); + for (uint i=0; iinputs().size(); i++) + vnames.push_back(pm->oc->get_name_by_op_input(opi, i)); + } else + if (opi->type() == OpType::broadcast) { + ndim = ((BroadcastToOp*)opi)->outputs().front()->shape.size(); + for (uint o=0; ooutputs().size(); o++) + vnames.push_back(pm->oc->get_name_by_op_output(opi, o)); + } else { + ndim = opi->outputs().front()->shape.size(); + for (uint o=0; ooutputs().size(); o++) + vnames.push_back(pm->oc->get_name_by_op_output(opi, o)); + } + for (uint j=0; j>j&1) && j {loop_var_names[j]} + std::stringstream name1; + name1 << vname<<"shape"< same_inputs; + for (auto o : op->ops) { + if (!pm->oc->op_exist(o)) + continue; + int i_id = 0; + for (auto i : o->inputs()) { + i_id ++; + auto fi_id = op->get_node_id(i); + if (op->vars.at(fi_id).type != 0) + continue; + if (same_inputs.count(i)) { + auto j = same_inputs[i]; + auto name1 = pm->oc->get_name_by_op_input(o, i_id-1); + auto name2 = pm->oc->get_name_by_op_var(j, i); + if (name1[0] == '_' || name2[0] == '_') + continue; + // replace name1 -> name2 + replace_vars.emplace_back(name1+'p', name2+'p'); + } else { + auto name2 = pm->oc->get_name_by_op_var(o, i); + if (name2[0] == '_') + continue; + same_inputs[i] = o; + } + } + } + } + + for (auto& t : op->edges) { + uint i,j,k,l; + std::tie(i,j,k,l) = t; + // virtual op holds all inputs + if (i>=op->ops.size()) + continue; + // loop var may not exist(relayed) + auto opa = op->ops.at(i); + auto opb = op->ops.at(k); + if (!pm->oc->op_exist(opa) || !pm->oc->op_exist(opb)) + continue; + // replace op{j}_{kname}* -> op{i}_{oname}* + auto name1 = pm->oc->get_name_by_op_input(opb, l); + auto name2 = pm->oc->get_name_by_op_output(opa, j); + replace_vars.emplace_back(name1, name2); + } + + // dirty fix wrong array fuse + if (max_elm_size>1) + for (int i=0; iop->ops.size(); i++) { + auto op = this->op->ops[i]; + if (op->type() == OpType::element && + op->name() != string("array") && + op->outputs().front()->num == 1) { + replace_vars.emplace_back("op"+S(i)+"_xstride0", "0"); + replace_vars.emplace_back("op"+S(i)+"_ystride0", "0"); + replace_vars.emplace_back("op"+S(i)+"_zstride0", "0"); + } + } + + + LOGvvv << "replace_vars" << replace_vars; + ir->replace(replace_vars); + + for (int i=0; iop->ops.size(); i++) { + auto op = this->op->ops[i]; + if (op->type() == OpType::element && + op->name() == string("array") && + op->outputs().front()->num == 1) { + ir->replace({{"op"+S(i)+"_outputshape0", "1"}}); + } + } + + // fix index op stride not found + replace_vars.clear(); + for (int i=0; iop->ops.size(); i++) { + auto op = this->op->ops[i]; + if (op->type() == OpType::element && + op->name() == string("index")) { + for (int j=1; joutputs().size(); j++) + replace_vars.push_back({"op"+S(i)+"_x"+S(j)+"stride", "op"+S(i)+"_x0stride"}); + } + } + if (replace_vars.size()) + ir->replace(replace_vars); + LOGvvvv << "KernelIR after replace\n" >> ir->to_string(0, true); + // move define + ir->move_loop_back(); + LOGvvvv << "KernelIR after move_loop_back\n" >> ir->to_string(0, true); + + // check reindex run arguments op + for (Op* op : this->op->ops) { + string op_name = op->name(); + if (op_name == "reindex" || op_name == "reindex_reduce") { + ReindexOp* rop = (ReindexOp*)op; + vector ss = rop->indexes; + for (auto& s : rop->overflow_conditions) ss.push_back(s); + for (auto& s : ss) { + if (s.find("//") != string::npos) { + LOGf << "Arguments of reindex op should not contain '//' operation, please replace 'a//b' to 'int(a/b)', Arguments of reindex op: " << s << ss; + } + } + } + } +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/loop_var_analyze_pass.h b/python/jittor/src/opt/pass/loop_var_analyze_pass.h new file mode 100644 index 00000000..d418c9d7 --- /dev/null +++ b/python/jittor/src/opt/pass/loop_var_analyze_pass.h @@ -0,0 +1,20 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct LoopVarAnalyzePass : Pass { + // total number of loop ranges + int number_of_ranges; + + LoopVarAnalyzePass() : Pass("loop_var_analyze"), number_of_ranges(0) {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/mark_raw_pass.cc b/python/jittor/src/opt/pass/mark_raw_pass.cc new file mode 100644 index 00000000..704acbba --- /dev/null +++ b/python/jittor/src/opt/pass/mark_raw_pass.cc @@ -0,0 +1,38 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "opt/pass_manager.h" +#include "opt/pass/mark_raw_pass.h" + +namespace jittor { + +void MarkRawPass::run() { + vector raws = {"relay_groups"}; + for (auto& c : ir->children) { + string* check = nullptr; + bool found = false; + if (c->type == "define") { + check = &c->get_attr("rvalue"); + } else if (c->has_attr("code")) + check = &c->get_attr("code"); + if (check) { + for (auto& s : raws) + if (check->find(s) != string::npos) { + found = true; + break; + } + if (found) { + c->attrs["raw"] = "1"; + if (c->type=="define") + raws.push_back(c->get_attr("lvalue")); + } + } + } +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/mark_raw_pass.h b/python/jittor/src/opt/pass/mark_raw_pass.h new file mode 100644 index 00000000..4a4c2c91 --- /dev/null +++ b/python/jittor/src/opt/pass/mark_raw_pass.h @@ -0,0 +1,17 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct MarkRawPass : Pass { + MarkRawPass() : Pass("mark_raw") {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/merge_loop_pass.cc b/python/jittor/src/opt/pass/merge_loop_pass.cc new file mode 100644 index 00000000..ba7ed592 --- /dev/null +++ b/python/jittor/src/opt/pass/merge_loop_pass.cc @@ -0,0 +1,62 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "opt/pass_manager.h" +#include "opt/pass/merge_loop_pass.h" + +namespace jittor { + + +void MergeLoopPass::run() { + auto choice = op->get_loop_option("merge", 1); + if (!choice) return; + bool is_cuda = op->flags.get(NodeFlags::_cuda); + if (is_cuda) { + vector loops; + vector loop_keys; + for (auto& c : ir->children) { + if (c->type != "loop") + continue; + if (!c->has_attr("loop_id")) + continue; + if (c->has_attr("raw")) + continue; + auto* cc = c.get(); + string key = cc->get_attr("loop_id"); + while (cc->children.size()==1 && cc->children[0]->has_attr("loop_id")) { + cc = cc->children[0].get(); + key += cc->get_attr("loop_id"); + } + loops.push_back(c.get()); + loop_keys.push_back(key); + } + LOGvvvv << "loop keys" << loop_keys; + for (int i=(int)loops.size()-1; i>=0; i--) { + if (!loops[i]) continue; + for (int j=i-1; j>=0; j--) { + if (!loops[j]) continue; + int cpx=0; // commen prefix + auto& ki = loop_keys[i]; + auto& kj = loop_keys[j]; + while (cpx < ki.size() && cpx=1 || cpx==0) + continue; + loops[i]->insert(0, loops[j]->children); + loops[i]->merge_loop(); + loops[j]->erase(); + loops[j] = nullptr; + } + } + } else { + ir->merge_loop(); + } +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/merge_loop_pass.h b/python/jittor/src/opt/pass/merge_loop_pass.h new file mode 100644 index 00000000..91130fb3 --- /dev/null +++ b/python/jittor/src/opt/pass/merge_loop_pass.h @@ -0,0 +1,17 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct MergeLoopPass : Pass { + MergeLoopPass() : Pass("merge_loop") {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/merge_loop_var_pass.cc b/python/jittor/src/opt/pass/merge_loop_var_pass.cc new file mode 100644 index 00000000..defd17b6 --- /dev/null +++ b/python/jittor/src/opt/pass/merge_loop_var_pass.cc @@ -0,0 +1,155 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "opt/expr.h" +#include "var.h" +#include "opt/pass_manager.h" +#include "opt/pass/merge_loop_var_pass.h" + +namespace jittor { + +using namespace expr; + +static unique_ptr trace_and_expand(KernelIR* ir, expr::Expr* e) { + auto a = e->clone(); + std::function func = + [&](expr::Expr* c) { + if (!c->is_sym()) return; + if (startswith(c->str, "range") && c->str.size() == 6) + // dont expand range + return; + if (endswith(c->str, "outputd")) + return; + auto def = ir->find_define(c->str); + if (!def) return; + if (def->type!="define") + return; + if (!def->has_attr("rvalue")) return; + auto& rvalue = def->attrs["rvalue"]; + LOGvvvv << *c << "->" << rvalue; + if (def->father && def->flist==&def->father->inner) { + // dont expand loop or func + return; + } + c->swap(expr::make(rvalue).get()); + if (!c->children.size()) func(c); + }; + a->dfs(func); + return a; +} + +void MergeLoopVarPass::run() { + // LOGir << ir->to_string(); + auto choice = op->get_loop_option("merge_loop_var", 1); + if (!choice) return; + for (int ci=0; cichildren.size(); ci++) { + auto& c = ir->children[ci]; + if (c->type != "loop") + continue; + vector to_opt; + c->dfs([&](unique_ptr& i) { + if (i->type == "loop" && i->father && i->father->type == "loop" + && i->father->children.size() == 1 && + i->before.size() == 0 && i->after.size() == 0) { + to_opt.push_back(i.get()); + } + }); + for (int ii=0; iifather; + LOGvvvv << "check opt" << i->attrs["rvalue"] << fa->attrs["rvalue"]; + auto range_b = i->attrs["rvalue"]; + auto id_b = i->attrs["lvalue"]; + auto range_a = fa->attrs["rvalue"]; + auto id_a = fa->attrs["lvalue"]; + if (!(i->type == "loop" && i->father && i->father->type == "loop" + && i->father->children.size() == 1 && i->father->inner.size() == 3 && + i->before.size() == 0 && i->after.size() == 0)) { + continue; + } + if (range_b.size() > 6) { + // range23 -> range2*range3 + string tmp = range_b.substr(0, 6); + for (int i=6; i> results; + vector solve_symbols = {"d", "c"}; + vector exclude_symbols = {id_a, id_b}; + + bool can_opt = true; + i->dfs([&](unique_ptr& c) { + if (!can_opt) return; + if (c->type == "if") { + // don't optimize reindex like op yet + can_opt = false; + return; + } + if (c->type == "define" && c->has_attr("rvalue")) { + auto& s = c->attrs["rvalue"]; + auto& lv = c->attrs["lvalue"]; + if (!(endswith(lv, "id") || endswith(lv, "_i"))) + return; + auto se = expr::make(s); + se = trace_and_expand(c.get(), se.get())->simplify(); + LOGvvvv << "expand" << s << "->" << se; + // LOGir << "expand" << s << "->" << se; + results.clear(); + auto ret = expr::match(se.get(), te.get(), solve_symbols, exclude_symbols, results); + if (ret) { + LOGvvvv << "check rvalue" << se << '\n' << + te << '\n' << + ret << results; + } else { + can_opt = false; + LOGvvvv << "cannot match" << se << '\n' << + te; + } + } + }); + if (!can_opt) + continue; + auto ni = i->clone(); + auto aid = fa->attrs["loop_id"]; + auto bid = i->attrs["loop_id"]; + auto newid = aid+bid; + auto new_range = "range" + newid; + auto x = i->find_define(new_range); + if (!x) { + ir->push_back(i->attrs["dtype"]+" "+new_range+" = "+range_b+" * "+range_a+";"); + } + ni->replace({{"range"+bid, new_range}, {"id"+aid, "0"}}, true, true); + ni->attrs["loop_id"] = newid; + ni->attrs["rvalue"] = new_range; + // simplify 0 * x -> 0 + // ni->dfs([&](unique_ptr& c) { + // if (!can_opt) return; + // if (c->type == "define" && c->has_attr("rvalue")) { + // auto& s = c->attrs["rvalue"]; + // auto se = expr::make(s)->simplify(); + // s = se->to_string(); + // } + // }); + LOGvvvv << "new merged loop" << ni; + ni->swap(*fa, true); + } + } + ir->move_loop_back(); + ir->remove_all_unused(); + // LOGir << ir->to_string(); +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/merge_loop_var_pass.h b/python/jittor/src/opt/pass/merge_loop_var_pass.h new file mode 100644 index 00000000..47227355 --- /dev/null +++ b/python/jittor/src/opt/pass/merge_loop_var_pass.h @@ -0,0 +1,17 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct MergeLoopVarPass : Pass { + MergeLoopVarPass() : Pass("merge_loop_var") {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/parallel_pass.h b/python/jittor/src/opt/pass/parallel_pass.h new file mode 100644 index 00000000..e2152f3b --- /dev/null +++ b/python/jittor/src/opt/pass/parallel_pass.h @@ -0,0 +1,17 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct ParallelPass : Pass { + ParallelPass() : Pass("parallel") {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/pass.cc b/python/jittor/src/opt/pass/pass.cc new file mode 100644 index 00000000..6876e318 --- /dev/null +++ b/python/jittor/src/opt/pass/pass.cc @@ -0,0 +1,23 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "opt/pass/pass.h" +#include "opt/pass_manager.h" + +namespace jittor { + +Pass::Pass(const string& name): name(name) {} +Pass::~Pass() {} + +void Pass::init(PassManager* pm) { + this->pm = pm; + op = pm->oc->op; + all = &pm->all; + ir = pm->main_ir; +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/pass.h b/python/jittor/src/opt/pass/pass.h new file mode 100644 index 00000000..0fba2f13 --- /dev/null +++ b/python/jittor/src/opt/pass/pass.h @@ -0,0 +1,28 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include "fused_op.h" +#include "opt/kernel_ir.h" + +namespace jittor { + +struct Pass { + FusedOp* op; + KernelIR* all; + KernelIR* ir; + PassManager* pm; + string name; + + Pass(const string& name); + virtual ~Pass(); + + void init(PassManager* pm); + virtual void run() = 0; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/remove_intermediate_pass.cc b/python/jittor/src/opt/pass/remove_intermediate_pass.cc new file mode 100644 index 00000000..c1a84e94 --- /dev/null +++ b/python/jittor/src/opt/pass/remove_intermediate_pass.cc @@ -0,0 +1,49 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "opt/pass_manager.h" +#include "opt/pass/remove_intermediate_pass.h" + +namespace jittor { + +static bool remove_empty_loop(KernelIR* i) { + for (int j=0; jchildren.size(); j++) { + if (remove_empty_loop(i->children[j].get())) + j--; + } + if (i->type == "loop" && i->children.size() == 0) { + i->erase(); + return true; + } + return false; +} + +void RemoveIntermediatePass::run() { + unordered_set names; + for (auto& vi : op->vars) { + // intermediate + if (vi.type != 1) continue; + Op* op = vi.var->input(); + if (!pm->oc->op_exist(op)) continue; + for (uint i=0; ioutputs().size(); i++) + if (op->output(i)==vi.var) + names.insert(pm->oc->get_name_by_op_output(op, i)); + } + LOGvvvv << "Remove intermediate:" << names; + ir->remove_intermediate(names); + ir->remove_all_unused(); + ir->solve_conflict_define(); + + // remove empty loop + remove_empty_loop(ir); + + ir->remove_all_unused(); + +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/remove_intermediate_pass.h b/python/jittor/src/opt/pass/remove_intermediate_pass.h new file mode 100644 index 00000000..8cc19f9f --- /dev/null +++ b/python/jittor/src/opt/pass/remove_intermediate_pass.h @@ -0,0 +1,17 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct RemoveIntermediatePass : Pass { + RemoveIntermediatePass() : Pass("remove_intermediate") {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/remove_loop_pass.cc b/python/jittor/src/opt/pass/remove_loop_pass.cc new file mode 100644 index 00000000..5e8d3430 --- /dev/null +++ b/python/jittor/src/opt/pass/remove_loop_pass.cc @@ -0,0 +1,29 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "opt/pass_manager.h" +#include "opt/pass/remove_loop_pass.h" + +namespace jittor { + +void RemoveLoopPass::run() { + int loop_id=0; + for (size_t i=0; ichildren.size(); i++) { + auto& c = ir->children[i]; + if (c->type == "loop") { + auto choice = op->get_loop_option("remove"+S(loop_id)); + if (choice) { + c->erase(); + i--; + } + loop_id++; + } + } +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/remove_loop_pass.h b/python/jittor/src/opt/pass/remove_loop_pass.h new file mode 100644 index 00000000..f4577251 --- /dev/null +++ b/python/jittor/src/opt/pass/remove_loop_pass.h @@ -0,0 +1,18 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +// this is a debug pass, remove i-th loop, key: removei +struct RemoveLoopPass : Pass { + RemoveLoopPass() : Pass("remove_loop") {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/rename_loop_index_pass.cc b/python/jittor/src/opt/pass/rename_loop_index_pass.cc new file mode 100644 index 00000000..830697d7 --- /dev/null +++ b/python/jittor/src/opt/pass/rename_loop_index_pass.cc @@ -0,0 +1,19 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "opt/pass_manager.h" +#include "opt/pass/rename_loop_index_pass.h" + +namespace jittor { + +void RenameLoopIndexPass::run() { + // TODO: move out rename_loop_index + ir->rename_loop_index(); +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/rename_loop_index_pass.h b/python/jittor/src/opt/pass/rename_loop_index_pass.h new file mode 100644 index 00000000..2ba1d41c --- /dev/null +++ b/python/jittor/src/opt/pass/rename_loop_index_pass.h @@ -0,0 +1,17 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct RenameLoopIndexPass : Pass { + RenameLoopIndexPass() : Pass("rename_loop_index") {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/reorder_loop_pass.cc b/python/jittor/src/opt/pass/reorder_loop_pass.cc new file mode 100644 index 00000000..1d740f82 --- /dev/null +++ b/python/jittor/src/opt/pass/reorder_loop_pass.cc @@ -0,0 +1,66 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "opt/pass_manager.h" +#include "opt/pass/reorder_loop_pass.h" +#include "opt/pass/loop_var_analyze_pass.h" +#include "opt/pass/split_loop_pass.h" + +namespace jittor { + +vector ReorderLoopPass::search_parse_loop_order() { + vector order; + auto* sl_pass = pm->get_pass("split_loop"); + ASSERT(sl_pass); + auto number_of_ranges_after_split = sl_pass->number_of_ranges_after_split; + if (!number_of_ranges_after_split) return order; + for (int i=0; iget_loop_option("order"+S(i)); + ASSERT(choice<=i); + order.insert(order.end()-choice, i); + } + ASSERT(order.size() == (uint)number_of_ranges_after_split); + return order; +} + +void ReorderLoopPass::run() { + vector order = search_parse_loop_order(); + vector loops; + for (uint i=0; ichildren.size(); i++) { + KernelIR* loop = ir->children[i].get(); + if (loop->type != "loop") + continue; + loops.clear(); + loops.push_back(loop); + while (1) { + loop = loops.back(); + KernelIR* loop2 = nullptr; + for (auto& c : loop->children) { + if (c->type != "loop") + continue; + ASSERT(loop2 == nullptr); + loop2 = c.get(); + } + if (loop2 == nullptr) break; + ASSERT(loop->children.size()==1); + loops.push_back(loop2); + } + // sort loop with order + int count=0; + for (auto j : order) { + uint k; + for (k=count; kcheck_attr("loop_id", S(j))) + break; + if (kswap(*loops[count++]); + } + } +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/reorder_loop_pass.h b/python/jittor/src/opt/pass/reorder_loop_pass.h new file mode 100644 index 00000000..d3a81532 --- /dev/null +++ b/python/jittor/src/opt/pass/reorder_loop_pass.h @@ -0,0 +1,18 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct ReorderLoopPass : Pass { + ReorderLoopPass() : Pass("reorder_loop") {}; + void run() override; + vector search_parse_loop_order(); +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/replace_for_num_pass.cc b/python/jittor/src/opt/pass/replace_for_num_pass.cc new file mode 100644 index 00000000..2df14cea --- /dev/null +++ b/python/jittor/src/opt/pass/replace_for_num_pass.cc @@ -0,0 +1,84 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "op_compiler.h" +#include "opt/pass_manager.h" +#include "opt/pass/replace_for_num_pass.h" + +namespace jittor { + +void ReplaceForNumPass::run() { + for (uint fid=0; fidchildren.size(); fid++) { + auto& loop_ir = ir->children[fid]; + if (loop_ir->type != "loop") + continue; + auto& rvalue = loop_ir->get_attr("rvalue"); + auto j=rvalue.find("num"); + if (j == string::npos) continue; + auto& loop_num = rvalue; + auto& loop_index = loop_ir->get_attr("lvalue"); + LOGvvvv << "Find for_num" << loop_num << loop_index; + uint sid=fid-1; + bool found = false; + // find definition of loop range + for (;sid>0; sid--) { + if (ir->children[sid]->type != "define") + continue; + if (!ir->children[sid]->check_attr("lvalue", loop_num)) + continue; + found = true; + break; + } + // T xx_num = xxx->num + // def = xxx + ASSERT(found); + auto& code2 = ir->children[sid]->get_attr("rvalue"); + ASSERT(endswith(code2, "->num")) << ir->children[sid]->attrs; + string def = code2.substr(0, code2.size()-5); + uint op_id, opvar_id; + Op* op; + Var* var; + pm->oc->get_op_var_by_name(def, op_id, opvar_id, op, var); + auto new_code = OpCompiler::precompile( + { + {"DIM", S(var->shape.size())}, + {"op_id", S(op_id)}, + {"def", def}, + {"loop_index", loop_index}, + } , + "@for(di,0,DIM, op@op_id@@_index_t @def@@shape@di = @def->shape[@di];)\n" + "op@op_id@@_index_t @def@@stride@{DIM-1} = 1;\n" + "@for(di,DIM-2,-1,-1, auto @def@@stride@di = @def@@stride@{di+1} * @def@@shape@{di+1};)\n" + "@for(di,0,DIM, for (op@op_id@@_index_t @loop_index@di=0; @loop_index@di<@def@@shape@di; @loop_index@di++))\n" + "{ op@op_id@@_index_t @loop_index = @for(di,0,DIM, + @loop_index@di * @def@@stride@di); }" + ); + KernelIR new_ir(new_code); + ASSERT(new_ir.children.size()>=2 && + new_ir.children.back()->type == "loop" && + new_ir.children.front()->type == "define"); + auto& new_for = new_ir.children.back(); + auto* inner_for = new_for.get(); + for (uint di=0; di+1shape.size(); di++) { + ASSERT(inner_for->children.size()==1); + inner_for = inner_for->children[0].get(); + } + auto& prev_for = ir->children[fid]; + LOGvvvv << "new_ir\n" >> new_ir.to_string(); + LOGvvvv << "prev_for\n" >> prev_for->to_string(); + inner_for->insert( + inner_for->children.size(), + prev_for->children + ); + LOGvvvv << "new_ir\n" >> new_ir.to_string(); + prev_for->erase(); + fid += new_ir.children.size()-1; + ir->insert(sid, new_ir.children); + } +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/replace_for_num_pass.h b/python/jittor/src/opt/pass/replace_for_num_pass.h new file mode 100644 index 00000000..88fd9d16 --- /dev/null +++ b/python/jittor/src/opt/pass/replace_for_num_pass.h @@ -0,0 +1,26 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +// replace_for_num pass +// T num=opi_x->num; +// for (T i=0; i +// T opi_xshapej = opi_x->shape[j]; ... +// T opi_xstride{DIM-1} = 1; +// T opi_xstride{j} = opi_xstride{j+1} * opi_xshape{j+1} +// for (T i{d}=0; i{d}. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "op_compiler.h" +#include "opt/pass_manager.h" +#include "opt/pass/restride_pass.h" + +namespace jittor { + +// find t{number} in s +int findn(const string& s, const string&t) { + for (uint i=0; i+t.size()<=s.size(); i++) { + bool found = true; + for (uint j=0; jget_loop_option("restride"); + auto pf = op->get_loop_option("restride_profile"); + if (!choice) return; + vector q({ir}); + unordered_map replaces; + unordered_map> rloops; + unordered_map origin_defs; + vector defs; + for (uint i=0; itype == "define") { + vector loops, splits; + KernelIR* fa = ir->father; + // find all loop index affect this define + while (fa && fa->type=="loop" && fa->has_attr("loop_id")) { + auto idname = "id"+fa->attrs["loop_id"]; + if (findn(ir->attrs["rvalue"], idname) != -1) + loops.push_back(fa); + fa = fa->father; + } + if (loops.size()) { + string newid; + // create new id which is continuous + for (uint i=0; iget_attr("lvalue"); + bool found = false; + for (auto& split : splits) { + if (split->get_attr("split_id") != loops[i]->get_attr("loop_id")) + newid += "*"+split->get_attr("rvalue"); + else { + split = loops[i]; + found = true; + } + } + if (!found) splits.push_back(loops[i]); + } + auto& lvalue = ir->get_attr("lvalue"); + auto& rvalue = ir->get_attr("rvalue"); + if (replaces.count(lvalue) && (replaces[lvalue] != newid || origin_defs[lvalue] != rvalue)) { + // conflict stride, pass + replaces[lvalue] = ""; + } else { + replaces[lvalue] = newid; + rloops[lvalue] = loops; + origin_defs[lvalue] = rvalue; + } + defs.push_back(ir); + } + } + for (auto& c : ir->children) + q.push_back(c.get()); + } + string total_size = "0"; + string prev_name; + vector newdefs; + for (auto& kv : replaces) { + if (pf) break; + if (kv.second.size() == 0) continue; + string name = kv.first.substr(0, kv.first.size()-2); + uint op_id, opvar_id; + Op* op; + Var* var; + pm->oc->get_op_var_by_name(name, op_id, opvar_id, op, var); + std::stringstream ss; + ss << var->dtype() << "* __restrict__ " << name << "_new = (" << var->dtype() << "*)"; + if (prev_name.size()) + ss << "(((char*)" << prev_name << "_new)+" + prev_name + "->size);"; + else + ss << "&buffer[0];"; + prev_name = name; + total_size += "+" + name + "->size"; + newdefs.push_back(ss.str()); + + KernelIR* cir = ir; + std::stringstream s2; + auto& loops = rloops[kv.first]; + bool is_output = opvar_id >= op->inputs().size(); + for (int i=(int)loops.size()-1; i>=0; i--) { + if (!is_output && i==(int)loops.size()-1) { + cir->push_front(loops[i]->clone(false)); + cir = cir->children.front().get(); + continue; + } + cir->push_back(loops[i]->clone(false)); + cir = cir->children.back().get(); + } + auto org_id = origin_defs[kv.first]; + if (is_output) { + // this var is output + s2 << name << "p[" << org_id << "] = " << name << "_new[" << kv.second << "];"; + cir->push_back(s2.str()); + } else { + // this var is input + s2 << name << "_new[" << kv.second << "] = "<< name << "p[" << org_id << "];"; + cir->push_front(s2.str()); + } + } + if (total_size != "0") { + ir->push_back("auto total_size = "+total_size+";", nullptr, true); + ir->push_back("char* __restrict__ buffer = (char*)aligned_alloc(alignment, total_size);", nullptr, true); + for (auto& def : newdefs) + ir->push_back(def); + ir->move_loop_back(); + ir->push_back("::free(buffer);"); + } + // replace prev id with new id + for (auto ir : defs) { + auto& lvalue = ir->attrs["lvalue"]; + auto& rvalue = ir->attrs["rvalue"]; + if (replaces.count(lvalue) && replaces[lvalue] != "") { + string name = lvalue.substr(0, lvalue.size()-2); + rvalue = replaces[lvalue]; + ASSERT(ir->father); + if (!pf) + ir->father->replace({{name+"p", name+"_new"}}); + } + } +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/restride_pass.h b/python/jittor/src/opt/pass/restride_pass.h new file mode 100644 index 00000000..63058294 --- /dev/null +++ b/python/jittor/src/opt/pass/restride_pass.h @@ -0,0 +1,17 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct RestridePass : Pass { + RestridePass() : Pass("restride") {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/shared_reduce_pass.h b/python/jittor/src/opt/pass/shared_reduce_pass.h new file mode 100644 index 00000000..8a393d9b --- /dev/null +++ b/python/jittor/src/opt/pass/shared_reduce_pass.h @@ -0,0 +1,17 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Zheng-Ning Liu . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct SharedReducePass : Pass { + SharedReducePass() : Pass("shared_reduce") {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/solve_conflict_define_pass.cc b/python/jittor/src/opt/pass/solve_conflict_define_pass.cc new file mode 100644 index 00000000..bd068ce6 --- /dev/null +++ b/python/jittor/src/opt/pass/solve_conflict_define_pass.cc @@ -0,0 +1,18 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "opt/pass_manager.h" +#include "opt/pass/solve_conflict_define_pass.h" + +namespace jittor { + +void SolveConflictDefinePass::run() { + ir->solve_conflict_define(); +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/solve_conflict_define_pass.h b/python/jittor/src/opt/pass/solve_conflict_define_pass.h new file mode 100644 index 00000000..6b1c2877 --- /dev/null +++ b/python/jittor/src/opt/pass/solve_conflict_define_pass.h @@ -0,0 +1,17 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct SolveConflictDefinePass : Pass { + SolveConflictDefinePass() : Pass("solve_conflict_define") {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/split_loop_pass.cc b/python/jittor/src/opt/pass/split_loop_pass.cc new file mode 100644 index 00000000..ea7d6f96 --- /dev/null +++ b/python/jittor/src/opt/pass/split_loop_pass.cc @@ -0,0 +1,35 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "opt/pass_manager.h" +#include "opt/pass/split_loop_pass.h" +#include "opt/pass/loop_var_analyze_pass.h" + +namespace jittor { + +void SplitLoopPass::run() { + auto* lva_pass = pm->get_pass("loop_var_analyze"); + ASSERT(lva_pass); + if (op->flags.get(NodeFlags::_cpu)) + ir->push_back("using namespace std;", &ir->before); + number_of_ranges_after_split = lva_pass->number_of_ranges; + for (int i=0; iget_loop_option("split"+S(i)); + if (choice > 1) { + int j = number_of_ranges_after_split++; + int split_size = std::max(1, choice); + auto loops = ir->find_loops(S(i)); + ASSERT(loops.size()); + ir->push_back(loops[0]->attrs["dtype"]+" stride"+S(i)+" = "+S(split_size)+";"); + ir->split_loop(i, j); + } + } + ir->move_loop_back(); +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/split_loop_pass.h b/python/jittor/src/opt/pass/split_loop_pass.h new file mode 100644 index 00000000..be6f82f7 --- /dev/null +++ b/python/jittor/src/opt/pass/split_loop_pass.h @@ -0,0 +1,19 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct SplitLoopPass : Pass { + int number_of_ranges_after_split; + + SplitLoopPass() : Pass("split_loop"), number_of_ranges_after_split(0) {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/unroll_pass.cc b/python/jittor/src/opt/pass/unroll_pass.cc new file mode 100644 index 00000000..554089f2 --- /dev/null +++ b/python/jittor/src/opt/pass/unroll_pass.cc @@ -0,0 +1,90 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "opt/pass_manager.h" +#include "opt/pass/unroll_pass.h" + +namespace jittor { + +void UnrollPass::run() { + auto choice = op->get_loop_option("unroll"); + if (!choice) return; + vector q({ir}); + vector loops; + for (uint i=0; ihas_attr("rvalue2")) + dont_unroll = true; + for (auto& c : ir->children) { + // non vectorized loop + if (c->type == "if") + dont_unroll = true; + if (c->type == "loop" && !c->has_attr("vectorized") && !c->has_attr("unrolled")) + dont_unroll = true; + q.push_back(c.get()); + } + ASSERT(!(ir->type=="loop" && !dont_unroll && !ir->has_attr("loop_id"))); + if (!dont_unroll && ir->has_attr("loop_id")) { + loops.push_back(ir); + } + } + for (auto loop : loops) { + if (loop->has_attr("vectorized") || loop->has_attr("unrolled")) + continue; + loop->attrs["unrolled"] = "1"; + if (choice==1) + loop->push_back("#pragma unroll", &loop->before); + else { + int num=0; + auto& split_id = loop->get_attr("split_id"); + auto& loop_id = loop->get_attr("loop_id"); + auto& rvalue = loop->get_attr("rvalue"); + if (!loop->get_number(rvalue, num)) { + if (split_id.size()) { + string& si = split_id; + ASSERT(loop->get_number("stride"+si, num)); + if (num>128) { + loop->push_back("#pragma unroll", &loop->before); + continue; + } + auto floop = loop->father; + while (floop && !floop->check_attr("loop_id", split_id)) + floop = floop->father; + ASSERT(floop) << loop->to_string(); + floop->resplit(); + // fully unrolled loops + auto loops2 = floop->find_loops(loop_id); + ASSERT(loops2.size()); + for (auto loop2 : loops2) { + loop2->before.clear(); + loop2->push_back("#pragma unroll("+S(num)+")", &loop2->before); + loop2->attrs["unrolled"] = "1"; + } + // partial unrolled loops in if + ASSERT(floop->after.size() && floop->after[0]->type == "if"); + auto loops = floop->after[0]->find_loops(loop_id); + ASSERT(loops.size()); + for (auto loop2 : loops) { + loop2->before.clear(); + loop2->push_back("#pragma unroll", &loop2->before); + loop2->attrs["unrolled"] = "1"; + } + continue; + } else { + loop->push_back("#pragma unroll", &loop->before); + continue; + } + } + loop->push_back("#pragma unroll("+S(num)+")", &loop->before); + } + } +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/unroll_pass.h b/python/jittor/src/opt/pass/unroll_pass.h new file mode 100644 index 00000000..ca496cb9 --- /dev/null +++ b/python/jittor/src/opt/pass/unroll_pass.h @@ -0,0 +1,17 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct UnrollPass : Pass { + UnrollPass() : Pass("expand_empty_block") {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/use_movnt_pass.cc b/python/jittor/src/opt/pass/use_movnt_pass.cc new file mode 100644 index 00000000..28f094cd --- /dev/null +++ b/python/jittor/src/opt/pass/use_movnt_pass.cc @@ -0,0 +1,29 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guowei Yang <471184555@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "opt/pass_manager.h" +#include "opt/pass/use_movnt_pass.h" + +namespace jittor { + +void UseMovntPass::run() { + // TODO: need to test this pass + if (!op->get_loop_option("use_movnt")) + return; + + for (auto& c : ir->children) { + if (c->type != "loop") continue; + c->push_front("//@begin replace \"vmova(.*,.*\\(.*\\))\" \"vmovnt\\g<1>\"", &c->children, true); + c->push_back("//@end", &c->children, true); + } +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/use_movnt_pass.h b/python/jittor/src/opt/pass/use_movnt_pass.h new file mode 100644 index 00000000..3e15e737 --- /dev/null +++ b/python/jittor/src/opt/pass/use_movnt_pass.h @@ -0,0 +1,20 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guowei Yang <471184555@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct UseMovntPass : Pass { + UseMovntPass() : Pass("use_movnt") {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass/vectorize_pass.cc b/python/jittor/src/opt/pass/vectorize_pass.cc new file mode 100644 index 00000000..75453301 --- /dev/null +++ b/python/jittor/src/opt/pass/vectorize_pass.cc @@ -0,0 +1,81 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "var.h" +#include "opt/pass_manager.h" +#include "opt/pass/vectorize_pass.h" + +namespace jittor { + +void VectorizePass::run() { + auto choice = op->get_loop_option("vectorize"); + if (!choice) return; + vector q({ir}); + vector inner_loops; + for (uint i=0; ichildren) { + if (c->type == "loop") + has_loop = true; + q.push_back(c.get()); + } + if (!has_loop && ir->has_attr("loop_id")) + inner_loops.push_back(ir); + } + LOGvvvv << "Find" << inner_loops.size() << "inner loops"; + for (auto loop : inner_loops) { + if (choice == 1) { + loop->push_back("#pragma vector", &loop->before); + } else if (choice > 1) { + int num=0; + if (!loop->get_number(loop->get_attr("rvalue"), num)) { + if (loop->has_attr("split_id")) { + string si = loop->attrs["split_id"]; + string loop_id = loop->attrs["loop_id"]; + ASSERT(loop->get_number("stride"+si, num)); + int vectorlength = 64; + while (vectorlength && vectorlength/2 >= num) + vectorlength /= 2; + auto floop = loop->father; + while (floop && !floop->check_attr("loop_id", si)) + floop = floop->father; + ASSERT(floop); + auto loops = floop->find_loops(loop_id); + ASSERT(loops.size()); + for (auto loop2 : loops) { + loop2->before.clear(); + loop2->push_back("#pragma vector", &loop2->before); + loop2->attrs["vectorized"] = "1"; + } + floop->resplit(); + auto loops2 = floop->find_loops(loop_id); + ASSERT(loops2.size()); + for (auto loop2 : loops2) { + loop2->before.clear(); + loop2->push_back("#pragma vector vectorlength("+S(vectorlength)+")", &loop2->before); + loop2->attrs["vectorized"] = "1"; + } + continue; + } + loop->push_back("#pragma vector", &loop->before); + } else { + int vectorlength = 64; + while (vectorlength && vectorlength/2 >= num) + vectorlength /= 2; + if (vectorlength > 1) + loop->push_back("#pragma vector vectorlength("+S(vectorlength)+")", &loop->before); + else + loop->push_back("#pragma vector", &loop->before); + } + } + loop->push_back("#pragma ivdep", &loop->before); + loop->attrs["vectorized"] = "1"; + } +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/pass/vectorize_pass.h b/python/jittor/src/opt/pass/vectorize_pass.h new file mode 100644 index 00000000..87925e20 --- /dev/null +++ b/python/jittor/src/opt/pass/vectorize_pass.h @@ -0,0 +1,17 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "opt/pass/pass.h" + +namespace jittor { + +struct VectorizePass : Pass { + VectorizePass() : Pass("vectorize") {}; + void run() override; +}; + +} // jittor diff --git a/python/jittor/src/opt/pass_manager.cc b/python/jittor/src/opt/pass_manager.cc new file mode 100644 index 00000000..95d2df4b --- /dev/null +++ b/python/jittor/src/opt/pass_manager.cc @@ -0,0 +1,126 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var.h" +#include "opt/pass_manager.h" +#include "opt/pass/replace_for_num_pass.h" +#include "opt/pass/loop_var_analyze_pass.h" +#include "opt/pass/remove_loop_pass.h" +#include "opt/pass/rename_loop_index_pass.h" +#include "opt/pass/compile_shapes_pass.h" +#include "opt/pass/split_loop_pass.h" +#include "opt/pass/reorder_loop_pass.h" +#include "opt/pass/merge_loop_pass.h" +#include "opt/pass/merge_loop_var_pass.h" +#include "opt/pass/const_var_pass.h" +#include "opt/pass/expand_empty_block_pass.h" +#include "opt/pass/solve_conflict_define_pass.h" +#include "opt/pass/remove_intermediate_pass.h" +#include "opt/pass/restride_pass.h" +#include "opt/pass/vectorize_pass.h" +#include "opt/pass/unroll_pass.h" +#include "opt/pass/use_movnt_pass.h" +#include "opt/pass/loop_to_func_pass.h" +#include "opt/pass/assume_aligned_pass.h" +#include "opt/pass/parallel_pass.h" +#include "opt/pass/atomic_tuner_pass.h" +#include "opt/pass/shared_reduce_pass.h" +#include "opt/pass/float_atomic_fix_pass.h" +#include "opt/pass/insert_profile_loop_pass.h" +#include "opt/pass/fake_main_pass.h" +#include "opt/pass/check_cache_pass.h" +#include "opt/pass/mark_raw_pass.h" +#include "utils/str_utils.h" + +namespace jittor { + +DECLARE_FLAG(string, cc_type); +DEFINE_FLAG(string, exclude_pass, "", "Don't run certain pass."); +DEFINE_FLAG(string, log_op_hash, "", "Output compiler pass result of certain hash of op."); + + +PassManager::PassManager(OpCompiler* oc) : oc(oc), all(oc->get_src()) { + main_ir = nullptr; + for (auto& c : all.children) + if (c->type=="func" && c->attrs["lvalue"]=="jittor::FusedOp::jit_run") { + main_ir = c.get(); + break; + } + ASSERT(main_ir); +} + +bool PassManager::check(Pass* pass) { + if (exclude_pass=="*") return false; + if (exclude_pass==pass->name) return false; + if (startswith(exclude_pass, "after:")) { + auto n = (uint)stoi(exclude_pass.substr(6)); + if (finished_passes.size()>=n) + return false; + } + return true; +} + +void PassManager::run_passes() { + auto& ir = *main_ir; + + LOGvvvv << "KernelIR:\n" << ir.to_string(); + if (oc->op->ops.size() == 1 && oc->op->ops[0]->name() == string("array")) { + ir.remove_all_unused(); + if (oc->op->flags.get(NodeFlags::_cuda)) { + ir.children.back()->erase(); + string type = oc->op->ops[0]->outputs().front()->dtype().to_cstring(); + ir.push_back("kernel<<<1,1>>>(op0_outputp, op0_outputv);"); + auto jt_type = type == "bool" ? type : "jittor::" + type; + ir.push_back("__global__ static void kernel("+jt_type+"* xp, "+jt_type+" x) { xp[0] = x; } ", &ir.before, true); + } + return; + } + run_pass(); + run_pass(); + run_pass(); + run_pass(); + run_pass(); + run_pass(); + + run_pass(); + run_pass(); + run_pass(); + run_pass(); + run_pass(); + + run_pass(); + + run_pass(); + run_pass(); + // tmp disable ConstVarPass + // run_pass(); + + run_pass(); + + if (cc_type == "icc") { + // only icc supports pragma + run_pass(); + run_pass(); + run_pass(); + } + run_pass(); + run_pass(); + run_pass(); + run_pass(); + run_pass(); + run_pass(); + run_pass(); + run_pass(); + + run_pass(); + + run_pass(); + + run_pass(); +} + +} // jittor + diff --git a/python/jittor/src/opt/pass_manager.h b/python/jittor/src/opt/pass_manager.h new file mode 100644 index 00000000..a268a473 --- /dev/null +++ b/python/jittor/src/opt/pass_manager.h @@ -0,0 +1,65 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include "fused_op.h" +#include "op_compiler.h" +#include "opt/kernel_ir.h" +#include "opt/pass/pass.h" + +namespace jittor { + +DECLARE_FLAG(string, exclude_pass); +DECLARE_FLAG(string, log_op_hash); + +struct PassManager { + OpCompiler* oc; + KernelIR all; + KernelIR* main_ir; + unordered_map pass_map; + vector> finished_passes; + + PassManager(OpCompiler* oc); + // run and store a pass + template void run_pass(); + // get a pass by pass name, return nullptr if not found + template T* get_pass(const string& name); + + bool check(Pass* pass); + + void run_passes(); + +}; + +template +void PassManager::run_pass() { + auto pass = std::make_unique(); + if (!check(pass.get())) { + LOGvvv << "exclude pass" << pass->name; + return; + } + LOGvvv << "run pass" << pass->name; + pass->init(this); + pass->run(); + LOGvvvv << "Kernel IR after pass" << pass->name << ":\n" + << main_ir->to_string(0, true); + + if (log_op_hash.size() && log_op_hash == oc->op->get_hash_name()) + LOGi << "hash mach:" << log_op_hash << "pass:" << pass->name + << main_ir->to_string(0, true); + pass_map.emplace(pass->name, pass.get()); + finished_passes.push_back(move(pass)); +} + +template +T* PassManager::get_pass(const string& name) { + auto iter = pass_map.find(name); + if (iter == pass_map.end()) return nullptr; + return (T*)iter->second; +} + +} // jittor diff --git a/python/jittor/src/opt/tuner/broadcast_tuner.cc b/python/jittor/src/opt/tuner/broadcast_tuner.cc new file mode 100644 index 00000000..6bac6275 --- /dev/null +++ b/python/jittor/src/opt/tuner/broadcast_tuner.cc @@ -0,0 +1,68 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guowei Yang <471184555@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "common.h" +#include "var.h" +#include "opt/tuner/broadcast_tuner.h" +#include "opt/pass_manager.h" +#include "opt/pass/loop_var_analyze_pass.h" +#include "opt/pass/split_loop_pass.h" + +namespace jittor { + +DEFINE_FLAG(int, l1_cache_size, 32768, "size of level 1 cache (byte)"); + +void BroadcastTuner::run(PassManager* pm, TunerManager* tm) { + confidence = 0; + FusedOp* fo=tm->oc->op; + if (!fo) return; + if (fo->flags.get(NodeFlags::_cuda)) return; + + int bc=0, rd=0; + for (uint i=0; iops.size(); i++) { + Op* op = fo->ops[i]; + if (op->name_ex() == "reindex") return; + if (op->name_ex() == "index") return; + if (op->type() == OpType::reduce) rd = 1; + if (op->type() == OpType::broadcast) bc = 1; + } + if (!bc || rd) return; + + auto* lva_pass = pm->get_pass("loop_var_analyze"); + auto* sl_pass = pm->get_pass("split_loop"); + if (!sl_pass || !lva_pass) return; + auto number_of_ranges = lva_pass->number_of_ranges; + if (number_of_ranges<2) return; + + confidence = 20; + if (number_of_ranges>2) confidence=9; + + int var_size = 0; + map var_map_input; + for (uint i=0; ivars.size(); i++) + if (fo->vars[i].type == 0){ + Var* var = fo->vars[i].var; + if (var_map_input.count((size_t)var)) continue; + var_map_input[(size_t)var] = 1; + var_size += var->dsize(); + } + + int st = -1; + if (var_size==0) var_size=1; + for (int i = l1_cache_size/var_size; i; st++, i>>=1); + + add_candidate("split1", 1< +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include "var.h" +#include "opt/tuner_manager.h" + +namespace jittor { + +struct BroadcastTuner : Tuner { + BroadcastTuner() : Tuner("broadcast") {} + void run(PassManager* pm, TunerManager* tm); +}; + +} \ No newline at end of file diff --git a/python/jittor/src/opt/tuner/conv_tuner.cc b/python/jittor/src/opt/tuner/conv_tuner.cc new file mode 100644 index 00000000..41c32646 --- /dev/null +++ b/python/jittor/src/opt/tuner/conv_tuner.cc @@ -0,0 +1,419 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guowei Yang <471184555@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "common.h" +#include "var.h" +#include "ops/reindex_op.h" +#include "ops/reindex_reduce_op.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "opt/tuner/conv_tuner.h" +#include "opt/pass_manager.h" +#include "opt/expr.h" +#include "ops/op_register.h" + +#include +#include + +namespace jittor { + +using namespace expr; +extern int use_cuda; + +struct OpInspector { + // binary mask for + // m1: exact dimension map + // m2: no relation + // m3: other + uint64 m1=0, m2=0, m3=0; + // which dimension map + vector mm; + Op* op; + bool failed=0; + + void init(ReindexOp* op) { + unordered_map p; + mm.resize(op->y->shape.size(), -1); + for (uint i=0; iy->shape.size(); i++) + p["i"+S(i)] = i; + for (uint i=0; ix->shape.size(); i++) { + if (p.count(op->indexes[i])) { + int j = p.at(op->indexes[i]); + if (mm[j]!=-1) failed=1; + mm[j] = i; + m1 |= 1ll<indexes[i]); + expr::dfs(e.get(), [&](expr::Expr* e) { + if (e->is_sym() && p.count(e->str)) { + int j = p.at(e->str); + if (mm[j]!=-1) failed=1; + m3 |= 1ll << j; + mm[j] = i; + } + }); + } + } + m2 = ((1ll< p; + mm.resize(op->y->shape.size(), -1); + for (uint i=0; iy->shape.size(); i++) + p["i"+S(i)] = i; + for (uint i=0; ix->shape.size(); i++) { + if (p.count(op->indexes[i])) { + int j = p.at(op->indexes[i]); + if (mm[j]!=-1) failed=1; + mm[j] = i; + m1 |= 1ll<indexes[i]); + expr::dfs(e.get(), [&](expr::Expr* e) { + if (e->is_sym() && p.count(e->str)) { + int j = p.at(e->str); + if (mm[j]!=-1) failed=1; + m3 |= 1ll << j; + mm[j] = i; + } + }); + } + } + m2 = ((1ll<z->shape.size(), 0); + m2 = op->bcast_mask; + m1 = ((1ll<z->shape.size(); i++) + if ((m1>>i)&1) mm[i] = j++; + } + + OpInspector(BroadcastToOp* op) : op(op) { init(op); } + + void init(ReduceOp* op) { + mm.resize(op->x->shape.size(), 0); + m2 = op->reduce_mask; + m1 = ((1ll<x->shape.size(); i++) + if ((m1>>i)&1) mm[i] = j++; + } + + OpInspector(ReduceOp* op) : op(op) { init(op); } + + OpInspector(Op* op) : op(op) { + if (strcmp(op->name(), "reduce") == 0) + init((ReduceOp*)op); + else if (strcmp(op->name(), "broadcast_to") == 0) + init((BroadcastToOp*)op); + else if (strcmp(op->name(), "reindex") == 0) + init((ReindexOp*)op); + else if (strcmp(op->name(), "reindex_reduce") == 0) + init((ReindexReduceOp*)op); + else + failed = 1; + } + + // get last one index of binary mask + void get_id(uint64 m, int& i) { + if (m==0) failed=1; + else { + i=0; + while (!(m&1)) i++,m>>=1; + if (m!=1) failed=1; + } + } + // get last two index of binary mask + void get_id(uint64 m, int& i, int& j) { + if (m==0) failed=1; + else { + i=j=0; + while (!(m&1)) i++,m>>=1; + if (m<=1) { + failed=1; + return; + } + j=i+1,m>>=1; + while (!(m&1)) j++,m>>=1; + if (m!=1) failed=1; + } + } + + // get last three index of binary mask + void get_id(uint64 m, int& i, int& j, int& k) { + if (m==0) failed=1; + else { + i=j=0; + while (!(m&1)) i++,m>>=1; + if (m<=1) { + failed=1; + return; + } + j=i+1,m>>=1; + while (!(m&1)) j++,m>>=1; + if (m<=1) { + failed=1; + return; + } + k=j+1, m>>=1; + while (!(m&1)) k++,m>>=1; + if (m!=1) { + failed=1; + return; + } + } + } + + bool check_overlap(const vector& v) { + uint64 sum=0; + for (auto a : v) { + if (sum & (1ll<& order) { + string new_fmt = fmt; + if (order.size() != fmt.size()) { + failed = 1; + return ""; + } + if (check_overlap(order)) + return ""; + for (uint i=0; i=(int)new_fmt.size()) { + failed = 1; + return ""; + } + new_fmt[order[i]] = fmt[i]; + } + return new_fmt; + } +}; + +std::ostream& operator<<(std::ostream& os, const OpInspector& oi) { + if (oi.failed) return os << "inspect failed"; + for (uint i=0; i>i)&1); + os << ','; + for (uint i=0; i>i)&1); + os << ','; + for (uint i=0; i>i)&1); + return os << ',' << oi.mm; +} + +void ConvTuner::forwardTune(FusedOp* fop) { + for (Op* op : fop->ops) + if (op->name_ex()=="reduce.add" || op->name_ex()=="reindex_reduce.add") { + // reduce op and reindex reduce op have the same memory layout + // it is ok to force cast. + auto op_iop = op->input(0)->input(); + if (!(op_iop + && op_iop->name_ex()=="binary.multiply" + && fop->has(op_iop))) + continue; + auto bop = (BinaryOp*)op_iop; + + if (!(bop->y->input() && bop->x->input() && fop->has(bop->x->input()) && fop->has(bop->y->input()))) continue; + if (!(bop->x->input()->type()==OpType::broadcast && bop->y->input()->type()==OpType::broadcast)) return; + + // only support float32,float16 currently + if (use_cuda) { + if (!bop->z->dtype().is_float()) + continue; + } else { + if (bop->z->dtype() != ns_float32) + continue; + } + Op* ops[3] = {op, bop->x->input(), bop->y->input()}; + int ok = 0; + LOGvvvv << "conv like op" << fop << fop->get_jit_key(get_jk()); + for (int y_id=0; y_id<3; y_id++) + for (int x_id=0; x_id<3; x_id++) + for (int w_id=0; w_id<3; w_id++) { + if (ok) break; + if (x_id == y_id || x_id == w_id || y_id == w_id) continue; + LOGvvvv << "try" << x_id << y_id << w_id; + OpInspector xoi(ops[x_id]); + OpInspector yoi(ops[y_id]); + OpInspector woi(ops[w_id]); + vector* xop_indexes; + if (strcmp(xoi.op->name(), "reindex") == 0) { + xop_indexes = &((ReindexOp*)xoi.op)->indexes; + } else + if (strcmp(xoi.op->name(), "reindex_reduce") == 0) { + xop_indexes = &((ReindexReduceOp*)xoi.op)->indexes; + } else + continue; + if (xoi.failed || yoi.failed || woi.failed) continue; + int xn, xc, xh, xw, wh, ww, wci, wco, yn, yc, yh, yw; + int zn, zg, zci, zco, zh, zw, zwh, zww; + zn = zci = zco = zh = zw = zwh = zww = 0; + if (bop->x->shape.size() == 7) { + xoi.get_id(xoi.m1 & woi.m2 & yoi.m1, zn); + xoi.get_id(xoi.m1 & woi.m1 & yoi.m2, zci); + xoi.get_id(xoi.m3 & woi.m2 & yoi.m1, zh, zw); + xoi.get_id(xoi.m2 & woi.m1 & yoi.m1, zco); + xoi.get_id(xoi.m3 & woi.m1 & yoi.m2, zwh, zww); + LOGvvvv << "zn,zci,zco,zh,zw,zwh,zww =" << vector{zn,zci,zco,zh,zw,zwh,zww}; + xoi.check_overlap({zn,zci,zco,zh,zw,zwh,zww}); + zg = -1; + } else { + if (bop->x->shape.size() != 8) + continue; + // group conv + xoi.get_id(xoi.m1 & woi.m2 & yoi.m1, zn); + xoi.get_id(xoi.m3 & woi.m3 & yoi.m3, zg); + xoi.get_id(xoi.m3 & woi.m2 & yoi.m1, zh, zw); + xoi.get_id(xoi.m2 & woi.m3 & yoi.m3, zco); + xoi.get_id(xoi.m3 & woi.m1 & yoi.m2, zci, zwh, zww); + LOGvvvv << "zn,zg,zci,zco,zh,zw,zwh,zww =" << vector{zn,zg,zci,zco,zh,zw,zwh,zww}; + xoi.check_overlap({zn,zg,zci,zco,zh,zw,zwh,zww}); + } + if (xoi.failed) continue; + xn = xoi.mm[zn]; + xc = xoi.mm[zci]; + xh = xoi.mm[zh]; + xw = xoi.mm[zw]; + LOGvvvv << "xnchw =" << vector{xn,xc,xh,xw}; + auto xformat = xoi.format("abcd", {xn, xc, xh, xw}); + LOGvvvv << "xformat =" << xformat; + wci = woi.mm[zci]; + wco = woi.mm[zco]; + wh = woi.mm[zwh]; + ww = woi.mm[zww]; + auto wformat = xoi.format("iohw", {wci, wco, wh, ww}); + LOGvvvv << "wformat =" << wformat; + yn = yoi.mm[zn]; + yc = yoi.mm[zco]; + yh = yoi.mm[zh]; + yw = yoi.mm[zw]; + auto yformat = xoi.format("abcd", {yn, yc, yh, yw}); + LOGvvvv << "yformat =" << yformat; + + // mkl doesn't support "cdab" format + if (yformat == "cdab") continue; + // cuda doesn't support "iohw" format + if (fop->flags.get(NodeFlags::_cuda) && wformat == "iohw") continue; + if (xoi.failed) continue; + std::stringstream ss; + // i@zh*stride+i@zwh+padding + ss << "i" << zh << "*stride+i" << zwh << "*dilation+padding"; + auto expr_h = expr::make(ss.str()); + ss.str(""); + ss << "i" << zw << "*stride+i" << zww << "*dilation+padding"; + auto expr_w = expr::make(ss.str()); + + vector> rh, rw; + auto src_h = expr::make(xop_indexes->at(xh)); + if (!expr::match(src_h.get(), expr_h.get(), {"stride", "padding", "dilation"}, {"i"+S(zh), "i"+S(zwh)}, rh)) { + LOGvvvv << "Expr not match" << src_h << expr_h; + continue; + } + LOGvvvv << "H Expr matched" << src_h << expr_h; + if (!rh[0]->is(expr::_number) || !rh[1]->is(expr::_number) || !rh[2]->is(expr::_number)) return; + auto src_w = expr::make(xop_indexes->at(xw)); + if (!expr::match(src_w.get(), expr_w.get(), {"stride", "padding", "dilation"}, {"i"+S(zw), "i"+S(zww)}, rw)) + continue; + LOGvvvv << "W Expr matched" << src_w << expr_w; + if (!rw[0]->is(expr::_number) || !rw[1]->is(expr::_number) || !rw[2]->is(expr::_number)) return; + int stride_h = rh[0]->as_int(); + int padding_h = -rh[1]->as_int(); + int dilation_h = rh[2]->as_int(); + int stride_w = rw[0]->as_int(); + int padding_w = -rw[1]->as_int(); + int dilation_w = rw[2]->as_int(); + if (dilation_h < 1 || dilation_w < 1) continue; + LOGvvvv << "get stride padding and dilation" << stride_h << padding_h << dilation_h; + if (xformat == "bacd") { + LOGvvvv << "mkl not support bacd, continue"; + continue; + } + Var* x = x_id == 0 ? xoi.op->output(0) : xoi.op->input(0); + Var* w = w_id == 0 ? woi.op->output(0) : woi.op->input(0); + Var* y = y_id == 0 ? yoi.op->output(0) : yoi.op->input(0); + + int oh = (x->shape[xh]-w->shape[wh]*dilation_h+dilation_h-1+padding_h*2)/stride_h+1; + int ow = (x->shape[xw]-w->shape[ww]*dilation_w+dilation_w-1+padding_w*2)/stride_w+1; + if (oh != y->shape[yh] || ow != y->shape[yw]) { + LOGvvvv << "shape not match" << "(" >> oh >> "," >> ow >> ") !=" + << "(" >> y->shape[yh] >> "," >> y->shape[yw] >> ")"; + continue; + } + int groups = zg==-1 ? 1 : x->shape[xc] / w->shape[wci]; + LOGvvvv << "groups: " << groups; + if (groups>1 && wformat != "oihw") + continue; + + VarPtr rvar; + int rid; + string relay_conv_name; + + if (y_id == 0) { + relay_conv_name = fop->flags.get(NodeFlags::_cpu) ? + "mkl_conv" : "cudnn_conv"; + if (!has_op(relay_conv_name)) + continue; + auto make_conv = get_op_info(relay_conv_name) + .get_constructor(); + LOGvvvv << x << w << stride_h << stride_w << padding_h << padding_w << dilation_h << dilation_w << groups << xformat << wformat << yformat; + rvar = make_conv(x, w, stride_h, stride_w, padding_h, padding_w, dilation_h, dilation_w, groups, xformat, wformat, yformat); + } else + if (x_id == 0) { + relay_conv_name = fop->flags.get(NodeFlags::_cpu) ? + "mkl_conv_backward_x" : "cudnn_conv_backward_x"; + if (!has_op(relay_conv_name)) + continue; + auto height = x->shape[xformat.find("c")]; + auto width = x->shape[xformat.find("d")]; + auto make_conv_x = get_op_info(relay_conv_name) + .get_constructor(); + LOGvvvv << w << y << height << width << stride_h << stride_w << padding_h << padding_w << dilation_h << dilation_w << groups << xformat << wformat << yformat; + rvar = make_conv_x(w, y, height, width, stride_h, stride_w, padding_h, padding_w, dilation_h, dilation_w, groups, xformat, wformat, yformat); + } else { + relay_conv_name = fop->flags.get(NodeFlags::_cpu) ? + "mkl_conv_backward_w" : "cudnn_conv_backward_w"; + if (!has_op(relay_conv_name)) + continue; + auto kh = w->shape[wformat.find("h")]; + auto kw = w->shape[wformat.find("w")]; + LOGvvvv << x << y << kh << stride_h << stride_w << padding_h << padding_w << dilation_h << dilation_w << groups << xformat << wformat << yformat; + auto make_conv_w = get_op_info(relay_conv_name) + .get_constructor(); + rvar = make_conv_w(x, y, kh, kw, stride_h, stride_w, padding_h, padding_w, dilation_h, dilation_w, groups, xformat, wformat, yformat); + } + + LOGvvvv << relay_conv_name << "output:" << rvar; + rid = fop->context->vrm.add_relay_group({{rvar, op->output(0)}}); + if (rid>=0) { + auto srid = "relay"+S(rid); + add_candidate(srid, 1); + add_candidate(srid, 0); + confidence = 20; + ok = 1; + LOGvvvv << "ok" << x_id << y_id << w_id; + } + } + } +} + +void ConvTuner::run(PassManager* pm, TunerManager* tm) { + FusedOp* fop=tm->oc->op; + + forwardTune(fop); +} + +} diff --git a/python/jittor/src/opt/tuner/conv_tuner.h b/python/jittor/src/opt/tuner/conv_tuner.h new file mode 100644 index 00000000..bd028871 --- /dev/null +++ b/python/jittor/src/opt/tuner/conv_tuner.h @@ -0,0 +1,23 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guowei Yang <471184555@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include "var.h" +#include "opt/tuner_manager.h" + +namespace jittor { + +struct ConvTuner : Tuner { + ConvTuner() : Tuner("conv") {} + void forwardTune(FusedOp* fop); + void run(PassManager* pm, TunerManager* tm); +}; + +} \ No newline at end of file diff --git a/python/jittor/src/opt/tuner/matmul_tuner.cc b/python/jittor/src/opt/tuner/matmul_tuner.cc new file mode 100644 index 00000000..a54270ca --- /dev/null +++ b/python/jittor/src/opt/tuner/matmul_tuner.cc @@ -0,0 +1,104 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang . +// Guoye Yang <498731903@qq.com> +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "common.h" +#include "var.h" +#include "ops/reduce_op.h" +#include "ops/binary_op.h" +#include "ops/broadcast_to_op.h" +#include "opt/tuner/matmul_tuner.h" +#include "opt/pass_manager.h" +#include "ops/op_register.h" + +namespace jittor { + +void MatmulTuner::run(PassManager* pm, TunerManager* tm) { + FusedOp* fop=tm->oc->op; + for (Op* op : fop->ops) { + if (op->name_ex()!="reduce.add") continue; + auto rop = (ReduceOp*)op; + if (!(rop->x->input() && rop->x->input()->name_ex()=="binary.multiply" && fop->has(rop->x->input()))) + continue; + auto bop = (BinaryOp*)(rop->x->input()); + if (!(bop->x->input() && bop->x->input()->name_ex()=="broadcast_to" && fop->has(bop->x->input()))) + continue; + if (!(bop->y->input() && bop->y->input()->name_ex()=="broadcast_to" && fop->has(bop->y->input()))) + continue; + auto bcop1 = (BroadcastToOp*)(bop->x->input()); + auto bcop2 = (BroadcastToOp*)(bop->y->input()); + if (bcop1->shape.size() != 3) continue; + if (bcop1->x->shape.size() != 2) continue; + if (bcop2->x->shape.size() != 2) continue; + Var* xx = bcop1->x; + Var* yy = bcop2->x; + bool is_matmul = false, t1 = false, t2 = false; + // xx : n m + // yy : m k + // out: (n,k) + if ((rop->reduce_mask == (1u<<1)) && (bcop1->bcast_mask == (1u<<2)) && (bcop2->bcast_mask == (1u<<0))) { + is_matmul = true; + t1 = false; + t2 = false; + } + if ((rop->reduce_mask == (1u<<1)) && (bcop1->bcast_mask == (1u<<0)) && (bcop2->bcast_mask == (1u<<2))) { + is_matmul = true; + t1 = false; + t2 = false; + std::swap(xx, yy); + } + // xx : m n + // yy : m k + // out: (n,k) + if ((rop->reduce_mask == (1u<<0)) && (bcop1->bcast_mask == (1u<<2)) && (bcop2->bcast_mask == (1u<<1))) { + is_matmul = true; + t1 = true; + t2 = false; + } + if ((rop->reduce_mask == (1u<<0)) && (bcop1->bcast_mask == (1u<<1)) && (bcop2->bcast_mask == (1u<<2))) { + is_matmul = true; + t1 = true; + t2 = false; + std::swap(xx, yy); + } + // xx : n m + // yy : k m + // out: (n,k) + if ((rop->reduce_mask == (1u<<2)) && (bcop1->bcast_mask == (1u<<1)) && (bcop2->bcast_mask == (1u<<0))) { + is_matmul = true; + t1 = false; + t2 = true; + } + if ((rop->reduce_mask == (1u<<2)) && (bcop1->bcast_mask == (1u<<0)) && (bcop2->bcast_mask == (1u<<1))) { + is_matmul = true; + t1 = false; + t2 = true; + std::swap(xx, yy); + } + if (!is_matmul) continue; + // TODO: support int8 * int8 + if (!(xx->dtype().is_float() && yy->dtype().is_float())) continue; + if (fop->flags.get(NodeFlags::_cpu)) + if (xx->dtype().dsize() != 4) continue; + + string relay_matmul_name = fop->flags.get(NodeFlags::_cpu) ? + "mkl_matmul" : "cublas_matmul"; + if (!has_op(relay_matmul_name)) + return; + auto make_matmul = get_op_info(relay_matmul_name) + .get_constructor(); + auto rvar = make_matmul(xx, yy, t1, t2); + auto rid = fop->context->vrm.add_relay_group({{rvar, rop->y}}); + auto srid = "relay"+S(rid); + add_candidate(srid, 1); + add_candidate(srid, 0); + confidence = 20; + } +} + +} \ No newline at end of file diff --git a/python/jittor/src/opt/tuner/matmul_tuner.h b/python/jittor/src/opt/tuner/matmul_tuner.h new file mode 100644 index 00000000..93170b1b --- /dev/null +++ b/python/jittor/src/opt/tuner/matmul_tuner.h @@ -0,0 +1,21 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include "var.h" +#include "opt/tuner_manager.h" + +namespace jittor { + +struct MatmulTuner : Tuner { + MatmulTuner() : Tuner("matmul") {} + void run(PassManager* pm, TunerManager* tm); +}; + +} \ No newline at end of file diff --git a/python/jittor/src/opt/tuner/reduce_tuner.cc b/python/jittor/src/opt/tuner/reduce_tuner.cc new file mode 100644 index 00000000..dbb3e0fa --- /dev/null +++ b/python/jittor/src/opt/tuner/reduce_tuner.cc @@ -0,0 +1,76 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guowei Yang <471184555@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "common.h" +#include "var.h" +#include "ops/reduce_op.h" +#include "opt/tuner/reduce_tuner.h" +#include "opt/pass_manager.h" +#include "opt/pass/loop_var_analyze_pass.h" +#include "opt/pass/split_loop_pass.h" + +namespace jittor { + +DECLARE_FLAG(int, l1_cache_size); + +void ReduceTuner::run(PassManager* pm, TunerManager* tm) { + confidence = 0; + FusedOp* fo=tm->oc->op; + if (!fo) return; + if (fo->flags.get(NodeFlags::_cuda)) return; + int rd=0; + map dim_map; + for (uint i=0; iops.size(); i++) { + Op* op = fo->ops[i]; + if (op->name() == string("reindex_reduce")) return; + if (op->type() == OpType::reduce) { + rd = 1; + auto mask = ((ReduceOp*)op)->reduce_mask; + for (uint j=0; (1<>j&1) dim_map[j] = 1; + } + } + if (!rd) return; + + auto* lva_pass = pm->get_pass("loop_var_analyze"); + auto* sl_pass = pm->get_pass("split_loop"); + if (!sl_pass || !lva_pass) return; + auto number_of_ranges = lva_pass->number_of_ranges; + if (number_of_ranges<2) return; + + confidence = 20; + if (number_of_ranges>2) confidence = 9; + for (auto iter = dim_map.begin(); iter != dim_map.end(); iter++) + if (iter->first != 0) confidence = 9; + + int var_size = 0; + map var_map_input, var_map_output; + for (uint i=0; ivars.size(); i++) + if (fo->vars[i].type == 0){ + Var* var = fo->vars[i].var; + if (var_map_input.count((size_t)var)) continue; + var_map_input[(size_t)var] = 1; + var_size += var->dsize(); + } else if (fo->vars[i].type == 2){ + Var* var = fo->vars[i].var; + if (var_map_output.count((size_t)var)) continue; + var_map_output[(size_t)var] = 1; + var_size += var->dsize(); + } + + int st = -1; + for (int i = l1_cache_size/var_size; i; st++, i>>=1); + add_candidate("split1", 1< +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include "var.h" +#include "ops/reduce_op.h" +#include "opt/tuner_manager.h" + +namespace jittor { + +struct ReduceTuner : Tuner { + ReduceTuner() : Tuner("reduce") {} + void run(PassManager* pm, TunerManager* tm); +}; + +} \ No newline at end of file diff --git a/python/jittor/src/opt/tuner/reorder_tuner.cc b/python/jittor/src/opt/tuner/reorder_tuner.cc new file mode 100644 index 00000000..20f86493 --- /dev/null +++ b/python/jittor/src/opt/tuner/reorder_tuner.cc @@ -0,0 +1,27 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "common.h" +#include "opt/tuner/reorder_tuner.h" +#include "opt/pass_manager.h" +#include "opt/pass/loop_var_analyze_pass.h" +#include "opt/pass/split_loop_pass.h" + +namespace jittor { + +void ReorderTuner::run(PassManager* pm, TunerManager* tm) { + auto* lva_pass = pm->get_pass("loop_var_analyze"); + auto* sl_pass = pm->get_pass("split_loop"); + if (!sl_pass || !lva_pass) return; + auto number_of_ranges = lva_pass->number_of_ranges; + auto number_of_ranges_after_split = sl_pass->number_of_ranges_after_split; + for (int i=0; i. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include "opt/tuner_manager.h" + +namespace jittor { + +struct ReorderTuner : Tuner { + ReorderTuner() : Tuner("reorder") {} + void run(PassManager* pm, TunerManager* tm); +}; + +} \ No newline at end of file diff --git a/python/jittor/src/opt/tuner/tuner.cc b/python/jittor/src/opt/tuner/tuner.cc new file mode 100644 index 00000000..0ab8ed8c --- /dev/null +++ b/python/jittor/src/opt/tuner/tuner.cc @@ -0,0 +1,19 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "common.h" +#include "opt/tuner/tuner.h" + +namespace jittor { + +Tuner::Tuner(const string& name) : name(name), confidence(0), candidates({}) {}; +Tuner::~Tuner() {} + +void Tuner::add_candidate(const string& key, int value) { + candidates[key].push_back(value); +} + +} \ No newline at end of file diff --git a/python/jittor/src/opt/tuner/tuner.h b/python/jittor/src/opt/tuner/tuner.h new file mode 100644 index 00000000..84294843 --- /dev/null +++ b/python/jittor/src/opt/tuner/tuner.h @@ -0,0 +1,23 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +namespace jittor { + +struct Tuner { + string name; + int confidence; + loop_option_candidates_t candidates; + + Tuner(const string& name); + void add_candidate(const string& key, int value); + virtual ~Tuner(); + virtual void run(PassManager* pm, TunerManager* tm) = 0; +}; + +} \ No newline at end of file diff --git a/python/jittor/src/opt/tuner_manager.cc b/python/jittor/src/opt/tuner_manager.cc new file mode 100644 index 00000000..a2b3fad4 --- /dev/null +++ b/python/jittor/src/opt/tuner_manager.cc @@ -0,0 +1,66 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "opt/pass_manager.h" +#include "opt/tuner_manager.h" +#include "opt/tuner/reorder_tuner.h" +#include "opt/tuner/broadcast_tuner.h" +#include "opt/tuner/reduce_tuner.h" +#include "opt/tuner/matmul_tuner.h" +#include "opt/tuner/conv_tuner.h" + +namespace jittor { + +DEFINE_FLAG(int, enable_tuner, 1, "Enable tuner."); + +TunerManager::TunerManager(OpCompiler* oc) +: oc(oc), searcher(oc), best_tuner(nullptr) { +} + +template void TunerManager::run_tuner(PassManager* pm) { + auto tuner = std::make_unique(); + tuner->run(pm, this); + LOGvvv << "Run tuner" << tuner->name >> + ": confidence(" >> tuner->confidence >> + ") candidates(" >> tuner->candidates >> ")"; + if (best_tuner==nullptr || best_tuner->confidence < tuner->confidence) + best_tuner = tuner.get(); + tuners.push_back(move(tuner)); +} + +string TunerManager::tune() { + PassManager pm(oc); + string src_after_passes; + pm.run_passes(); + src_after_passes = pm.all.to_string(); + if (!enable_tuner) return src_after_passes; + + run_tuner(&pm); + run_tuner(&pm); + run_tuner(&pm); + run_tuner(&pm); + run_tuner(&pm); + + // use the best tuner if it is confidence enough + if (best_tuner && best_tuner->confidence) { + if (jit_search_kernel) + searcher.search(best_tuner->candidates); + else { + if (best_tuner->confidence >= 10) { + auto& loop_options = oc->op->get_loop_options_tuned(); + for (auto& kv : best_tuner->candidates) + loop_options[kv.first] = kv.second.front(); + oc->op->update_jit_key(); + PassManager pm(oc); + pm.run_passes(); + src_after_passes = pm.all.to_string(); + } + } + } + return src_after_passes; +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/opt/tuner_manager.h b/python/jittor/src/opt/tuner_manager.h new file mode 100644 index 00000000..244ef6b9 --- /dev/null +++ b/python/jittor/src/opt/tuner_manager.h @@ -0,0 +1,28 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include "opt/tuner/tuner.h" +#include "opt/jit_searcher.h" + +namespace jittor { + +struct TunerManager { + OpCompiler* oc; + Searcher searcher; + Tuner* best_tuner; + + vector> tuners; + + TunerManager(OpCompiler* oc); + string tune(); + + // run and store a tuner, return confidence + template void run_tuner(PassManager* pm); +}; + +} \ No newline at end of file diff --git a/python/jittor/src/opt/var_relay.cc b/python/jittor/src/opt/var_relay.cc new file mode 100644 index 00000000..3e67a86e --- /dev/null +++ b/python/jittor/src/opt/var_relay.cc @@ -0,0 +1,200 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#include "ops/op_register.h" +#include "opt/var_relay.h" +#include "fused_op.h" +#include "graph.h" + +namespace jittor { + + +VarRelayGroup::VarRelayGroup() { +} + +VarRelayGroup::VarRelayGroup(VarRelayGroup&& other) { + relayed_pairs = move(other.relayed_pairs); + removed_input_vars = move(other.removed_input_vars); + nodes = move(other.nodes); + oprcs = move(other.oprcs); +} + +VarRelayGroup::~VarRelayGroup() { + for (Node* node : nodes) + if (node->is_var()) + Var::number_of_lived_vars++; + else + Op::number_of_lived_ops++; +} + +int VarRelayManager::add_relay_group(const vector>& group) { + for (auto& g : relay_groups) + for (auto& p : g.relayed_pairs) + for (auto& p2 : group) + if (p.second == (fop->get_node_id(p2.second))) { + LOGvvvv << "Var allready relayed" << p2.second; + return -1; + } + relay_groups.emplace_back(); + auto& relay_group = relay_groups.back(); + relay_group.relayed_pairs.reserve(group.size()); + for (const auto& p : group) { + relay_group.relayed_pairs.push_back({p.first, fop->get_node_id(p.second)}); + ASSERTop(p.first->size,==,p.second->size); + } + + // break the input link between relay and target + std::unordered_set fnodes; + fnodes.reserve(fop->ops.size()+fop->vars.size()); + for (auto& op : fop->ops) fnodes.insert(op); + for (auto& v : fop->vars) fnodes.insert(v.var); + vector& q = relay_group.nodes; + vector s; + for (const auto& p : group) { + s.push_back(p.first->node()); + ASSERT(!fnodes.count(p.first)) << "Relayed source should not in fused_op"; + } + bfs_backward(s, q, [&](Node *node) -> bool { + return !fnodes.count(node); + }); + // currently, we only support single op relay + ASSERT(q.size()==2*group.size()); + for (Node* node : q) { + node->__release(); + if (node->is_var()) + continue; + Op* op = node->op(); + op->do_jit_prepare(get_jk()); + list new_inputs; + int removed = 0; + for (Var* v : op->inputs()) + if (!fnodes.count(v)) + new_inputs.push_back(v->node()); + else { + removed++; + relay_group.removed_input_vars.push_back(v); + } + if (removed) { + op->set_inputs(move(new_inputs)); + LOGvvv << "Remove" << removed << "inputs from" << op; + } + } + // generate OpRelayContext + relay_group.oprcs.resize(relay_group.relayed_pairs.size()); + for (uint i=0; iinput(); + auto op_info = get_op_info(oprc.op->name()); + oprc.relayed_members.resize(op_info.var_members.size()); + for (uint i=0; iget_node_id(v); + } + LOGvvvv << "Relay op" << oprc.op->name() >>".">> + op_info.var_members[i].first << "-->" << + oprc.relayed_members[i]; + } + } + return (int)relay_groups.size()-1; +} + +vector> VarRelayManager::get_op_relay_info(const vector& relay_switches) { + ASSERT(relay_switches.size()==relay_groups.size()); + auto num = fop->ops.size()+fop->vars.size(); + auto node_id = [&](Node* node) -> int { + if (node->is_var()) return fop->get_node_id(node); + return fop->get_node_id(node) + fop->vars.size(); + }; + vector deps(num); + // pair: first: group_id, second: relayed_pair id + vector> relay_source(num, {-1,-1}); + vector is_relayed(num); + for (uint i=0; i nodes(num); + for (auto v : fop->vars) { + auto vid = node_id(v.var); + nodes[vid] = v.var; + // if is input, continue + if (v.type==0) continue; + // add input op dependency + deps[node_id(v.var->input())]++; + // if var is relayed + if (is_relayed[vid]) continue; + // if is output, add dependency + if (v.type==2) + deps[vid]++; + } + for (auto op : fop->ops) { + nodes[node_id(op)] = op; + for (auto var : op->inputs()) { + deps[node_id(var)]++; + } + } + vector q; + q.reserve(num); + for (uint i=0; iis_var() && fop->vars[nid].type==0) + continue; + for (auto i : node->_inputs) { + auto nnid = node_id(i.node); + deps[nnid]--; + if (!deps[nnid]) { + q.push_back(nnid); + relay_source[nnid] = relay_source[nid]; + } + } + } + relay_source.erase(relay_source.begin(), relay_source.begin()+fop->vars.size()); + return relay_source; +} + +string VarRelayManager::get_relay_src(int group_id, int op_id) { + auto& oprc = relay_groups[group_id].oprcs[op_id]; + Op* op = oprc.op; + string name = op->name(); + auto op_info = get_op_info(name); + string name2 = Op::op_name_to_file_name(name); + string name3 = Op::file_name_to_class_name(name2); + std::stringstream ss; + string relay_op_name = "rop_"+S(group_id)+"_"+S(op_id); + ss << "\n // @relay_op\n"; + ss << " Op* "<vrm.relay_groups["<do_run();\n"; + LOGvvv << "get_relay_src\n" << ss.str(); + return ss.str(); +} + +} // jittor diff --git a/python/jittor/src/opt/var_relay.h b/python/jittor/src/opt/var_relay.h new file mode 100644 index 00000000..3d99f400 --- /dev/null +++ b/python/jittor/src/opt/var_relay.h @@ -0,0 +1,51 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include "var.h" + +namespace jittor { + +struct OpRelayContext { + Op* op; + // j=relayed_members[i] represents: op's i-th member is relay to fused_op's j-th vars + vector relayed_members; +}; + +struct VarRelayGroup { + // pair: VarPtr: relay source, uint64: relay target var id in fused_op + vector> relayed_pairs; + vector removed_input_vars; + // nodes of relay source + vector nodes; + vector oprcs; + VarRelayGroup(); + VarRelayGroup(const VarRelayGroup&) = delete; + VarRelayGroup(VarRelayGroup&&); + ~VarRelayGroup(); +}; + +struct VarRelayManager { + FusedOp* fop = nullptr; + vector relay_groups; + + void set_fused_op(FusedOp* fop) {this->fop=fop;} + /* add_relay_group: add relay group into current fused_op + group: list of pair of source and target vars + return: relay group id + */ + int add_relay_group(const vector>& group); + /* get_op_relay_info + relay_switches: switches control the on or off of each relay + return: relay group id and op id + */ + vector> get_op_relay_info(const vector& relay_switches); + + string get_relay_src(int group_id, int op_id); +}; + +} // jittor diff --git a/python/jittor/src/parallel_compiler.cc b/python/jittor/src/parallel_compiler.cc new file mode 100644 index 00000000..e99a9b9c --- /dev/null +++ b/python/jittor/src/parallel_compiler.cc @@ -0,0 +1,351 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#include +#include +#include +#include +#include + +#include "parallel_compiler.h" +#include "op_compiler.h" +#include "executor.h" +#include "lock.h" +#include "opt/jit_searcher.h" +#include "fused_op.h" +#include "mem/mem_info.h" + + +namespace jittor { + +DEFINE_FLAG(int, use_parallel_op_compiler, 16, "Number of threads that parallel op comiler used, default 16, set this value to 0 will disable parallel op compiler."); + +// from log.cc +EXTERN_LIB int segfault_happen; + +// simple thread used for parallel compilation +struct SimpleThread { + int id; + typedef std::function Func; + Func func; + std::mutex mtx; + std::condition_variable cv; + std::thread thread; + void run() { + get_thread_name() = "C"+S(id); + try { + std::unique_lock lck(mtx); + if (func) + func(id); + while (true) { + cv.wait(lck); + if (func) { + func(id); + } else + return; + } + } catch (const std::exception& e) { + LOGe << e.what(); + } + } + void launch_one(Func func) { + std::unique_lock lck(mtx); + this->func = func; + cv.notify_all(); + } + SimpleThread(int id) : id(id), func(nullptr), thread(&SimpleThread::run, this) {} + ~SimpleThread() { + join(); + } + void join() { + if (thread.joinable()) { + launch_one(nullptr); + thread.join(); + } + } +}; + +struct SimpleThreads; +EXTERN_LIB SimpleThreads threads; +EXTERN_LIB vector cleanup_callback; + +struct SimpleThreads { + list threads; + static void stop() { + jittor::threads.threads.clear(); + } + void create_threads(int n) { + if (threads.size()) return; + for (int i=0; i& queue, vector& range, FusedOp& fused_op, vector& fuse_ops, vector& ops, int64 tt, int force_compile) { + // jit_search_kernel require compile at runtime + if (!force_compile) + if (jit_search_kernel || !use_parallel_op_compiler || not_compile_window > 100000) + return; + + // try not use parallel compile if no op needs compile + if (last_compiled_op_num != jit_key_mapper.size()) { + not_compile_window = 0; + last_compiled_op_num = jit_key_mapper.size(); + } else { + not_compile_window += queue.size(); + } + + + vector op_needs_compile; + string_view_map map; + vector> fop_needs_compile; + auto& jkl = get_jk(); + + for (uint rid=0; ridtype() != OpType::other) { + op = &fused_op; + is_fused_op = true; + int ll = (riddo_prepare(jkl); + if (jkl.empty()) continue; + + const char* jit_key = jkl.to_cstring(); + auto iter = jit_key_mapper.find(jit_key); + if (iter != jit_key_mapper.end()) continue; + + auto iter2 = map.find(jit_key); + if (iter2 != map.end()) continue; + + map[jit_key] = 1; + if (is_fused_op) { + op_needs_compile.push_back(-1-(int)fop_needs_compile.size()); + fop_needs_compile.emplace_back(std::make_unique(fused_op)); + } else { + op_needs_compile.push_back(rid); + } + + + LOGvv << "Op needs compile:" << op; + } catch (const std::exception& e) { + // log jit_key and file location + op->do_prepare(jkl); + string jit_src_path = Op::get_filename_from_jit_key(jkl.to_cstring(), ".cc"); + LOGe << "[Error] source file location:" << jit_src_path; + if (is_fused_op) { + LOGf << "Compile fused operator(" >> rid >> '/' >> queue.size() >> ")" + << "failed:" << fused_op.ops << "\n\nReason: " >> e.what(); + } else + LOGf << "Compile operator(" >> rid >> '/' >> queue.size() >> ")" + << "failed:" << op << "\n\nReason: " >> e.what(); + } + } + // if too less op needs compile, don't use parallel compiler + // if (op_needs_compile.size() < 3) return; + if (op_needs_compile.size() == 0) return; + + static int thread_num = std::max(1, std::min(use_parallel_op_compiler, + int(mem_info.total_cpu_ram/(1024ll*1024*1024*3)))); + #ifdef NODE_MEMCHECK + // only use one thread in debug mode + // because global id map has no lock + thread_num = 1; + #endif + static std::atomic ai; + static volatile int has_error; + static string error_msg; + static vector>> op_entrys(thread_num); + // represents: task id, is_fused_op, entry or context, new_jit_key + threads.create_threads(thread_num); + static std::mutex entry_lock; + ai = 0; + has_error = 0; + error_msg = ""; + int n = op_needs_compile.size(); + LOGvv << "Total number of op needs compile" << op_needs_compile.size() + << "thread_num:" << thread_num; + + // backup number + auto bk_var = Var::number_of_lived_vars, bk_op = Op::number_of_lived_ops; + jittor::lock_guard lg; + auto func = [&](int tid) { + auto& entrys = op_entrys.at(tid); + entrys.clear(); + auto& jkl = get_jk(); + while (!has_error && !segfault_happen) { + int i = ai++; + if (i >= n) break; + int rid = op_needs_compile[i]; + Op* op; + bool is_fused_op = rid<0; + try { + if (!is_fused_op) { + int root = queue[rid]; + op = ops[root]; + LOGvv << "Compile Op:" << op; + op->do_prepare(jkl); + auto op_entry = OpCompiler::do_compile(op); + entrys.emplace_back(std::make_tuple(i, 0, (void*)op_entry, op->get_jit_key(jkl))); + } else { + FusedOp& fused_op = *fop_needs_compile[-rid-1]; + op = &fused_op; + LOGvv << "Compile FusedOp:" << op; + LOGV(11) << "FusedOps:" << fused_op.ops; + fused_op.context = new FusedOpContext(); + fused_op.context->setup(&fused_op); + fused_op.do_prepare(jkl); + auto op_entry = OpCompiler::do_compile(op); + fused_op.context->entry = op_entry; + entrys.emplace_back(std::make_tuple(i, 1, (void*)fused_op.context, op->get_jit_key(jkl))); + + // compile relay operators + for (auto& vrg : fused_op.context->vrm.relay_groups) { + for (auto& orc : vrg.oprcs) { + orc.op->do_prepare(jkl); + bool needs_compile; + { + std::lock_guard lock(entry_lock); + auto iter = jit_ops.find(jkl.to_cstring()); + needs_compile = (iter == jit_ops.end()); + if (needs_compile) { + jit_ops[jkl.to_cstring()] = nullptr; + } + } + if (!needs_compile) continue; + string s = jkl.to_string(); + auto op_entry = OpCompiler::do_compile(orc.op); + { + std::lock_guard lock(entry_lock); + jit_ops[s] = op_entry; + } + } + } + } + } catch (const std::exception& e) { + // log jit_key and file location + op->do_prepare(jkl); + string jit_src_path = Op::get_filename_from_jit_key(jkl.to_cstring(), ".cc"); + std::stringstream ss; + ss << "[Error] source file location:" << jit_src_path << '\n'; + + if (is_fused_op) { + ss << "Compile fused operator(" << i << '/' << n << ")" + << "failed:" << ((FusedOp*)op)->ops << "\n\nReason: " << e.what() << '\n'; + } else + ss << "Compile operator(" << i << '/' << n << ")" + << "failed:" << op << "\n\nReason: " << e.what() << '\n'; + error_msg = ss.str(); + has_error = 1; + break; + } + } + }; // end of threads.launch_all + + typedef std::chrono::high_resolution_clock Time; + auto start = Time::now(); + int active_threads = std::min(thread_num, (int)op_needs_compile.size()); + threads.launch_all(active_threads, func); + int prev_i = 0; + bool change_line = false; + int sleep_us = 10; + while (prev_i < n && !has_error && !segfault_happen) { + int i = std::max(std::min(ai-active_threads, n), 0); + if (i == prev_i) { + // std::this_thread::sleep_for(std::chrono::milliseconds(100)); + std::this_thread::sleep_for(std::chrono::microseconds(sleep_us)); + sleep_us = std::min(sleep_us*2, 1000000); // max 0.1s + continue; + } + prev_i = i; + auto diff = (Time::now() - start).count(); + if (diff > 2e9) { + if (!change_line) { + std::cerr << "\n"; + change_line = true; + } + // delay output progress in 2s + float eta = diff / 1e9 / i * (n-i); + std::cerr << "Compiling Operators(" << i << '/' << n << ")" + << " used: " << std::setprecision(3) << std::setw(4) << diff/1e9 << "s eta: " + << std::setprecision(3) << std::setw(4) << eta << "s \r"; + } + } + if (change_line) + std::cerr << std::endl; + Var::number_of_lived_vars = bk_var; Op::number_of_lived_ops = bk_op; + + if (segfault_happen) { + LOGe << "Segfault happen, main thread exit"; + threads.wait_all(); + exit(1); + } + + if (has_error) { + threads.wait_all(); + LOGf << "Error happend during compilation:\n" << error_msg; + } + + // fill all op entry + for (int i=0; i(t)); + int is_fused_op = std::get<1>(t); + auto& new_jit_key = std::get<3>(t); + if (is_fused_op) + jit_fused_ops[new_jit_key] = jit_fused_ops[prev_jit_key] = (FusedOpContext*)std::get<2>(t); + else + jit_ops[new_jit_key] = jit_ops[prev_jit_key] = (jit_op_entry_t)std::get<2>(t); + jit_key_mapper[prev_jit_key] = new_jit_key; + } + } +} + + +} // jittor diff --git a/python/jittor/src/parallel_compiler.h b/python/jittor/src/parallel_compiler.h new file mode 100644 index 00000000..0cf3485d --- /dev/null +++ b/python/jittor/src/parallel_compiler.h @@ -0,0 +1,14 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +namespace jittor { + +void parallel_compile_all_ops(vector& queue, vector& range, FusedOp& fused_op, vector& fuse_ops, vector& ops, int64 tt, int force_compile=0); + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/profiler/cache_info.cc b/python/jittor/src/profiler/cache_info.cc new file mode 100644 index 00000000..c7ca63aa --- /dev/null +++ b/python/jittor/src/profiler/cache_info.cc @@ -0,0 +1,26 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "profiler/cache_info.h" + +namespace jittor { +CacheInfo::CacheInfo(unique_ptr* mm) { + check_times = mm->get()->check_times; + tlb_miss_times = mm->get()->tlb->miss_time; + cache_miss_times.clear(); + for (int i = 0; i < (int)mm->get()->caches.size(); ++i) + cache_miss_times.push_back(mm->get()->caches[i]->miss_time); +} + +CacheInfo::CacheInfo() { + check_times = tlb_miss_times = 0; + cache_miss_times.clear(); +} + +} //jittor \ No newline at end of file diff --git a/python/jittor/src/profiler/cache_info.h b/python/jittor/src/profiler/cache_info.h new file mode 100644 index 00000000..95efd525 --- /dev/null +++ b/python/jittor/src/profiler/cache_info.h @@ -0,0 +1,23 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include +#include "profiler/memory_checker.h" + +namespace jittor { +struct CacheInfo { + int64_t check_times, tlb_miss_times; + vector cache_miss_times; + CacheInfo(unique_ptr* mm); + CacheInfo(); +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/profiler/memory_checker.cc b/python/jittor/src/profiler/memory_checker.cc new file mode 100644 index 00000000..f5ca1a4b --- /dev/null +++ b/python/jittor/src/profiler/memory_checker.cc @@ -0,0 +1,62 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "profiler/memory_checker.h" + +int virt_to_phys_user(uintptr_t* paddr, uintptr_t vaddr); + +namespace jittor { +MemoryChecker::MemoryChecker(Cache* tlb, vector caches, size_t page_size, size_t vtop) +: tlb(tlb), caches(caches), page_size(page_size), vtop(vtop) { + clear(); +} + +MemoryChecker::~MemoryChecker() { + delete tlb; + for (int i = 0; i < (int)caches.size(); ++i) + delete caches[i]; +} + +string MemoryChecker::get_replace_strategy(int id) { + if (id == 0) + return "DefaultReplacementCache"; + if (id == 1) + return "LRUCache"; + return "DefaultReplacementCache"; +} + +void MemoryChecker::clear() { + check_times = 0; + tlb->clear(); + for (int i = 0; i < (int)caches.size(); ++i) + caches[i]->clear(); +} + +void MemoryChecker::print_miss() { + LOGi << "Total:" << check_times; + LOGi << "TLB Misses:" << tlb->miss_time; + for (int i = 0; i < (int)caches.size(); ++i) + LOGi << "L" >> (i+1) << "Cache Misses:" << caches[i]->miss_time; +} + +void MemoryChecker::check_hit(size_t vaddr) { + size_t paddr; + if (vtop) + CHECK(virt_to_phys_user(&paddr, vaddr)==0) + << "FAILED to translate vaddr to paddr"; + else + paddr = vaddr; + ++check_times; + for (int i = 0; i < (int)caches.size(); ++i) + if (caches[i]->check_hit(paddr)) + break; + tlb->check_hit(size_t(vaddr)/page_size); +} + +} //jittor \ No newline at end of file diff --git a/python/jittor/src/profiler/memory_checker.h b/python/jittor/src/profiler/memory_checker.h new file mode 100644 index 00000000..67b8e742 --- /dev/null +++ b/python/jittor/src/profiler/memory_checker.h @@ -0,0 +1,36 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include +#include +#include +#include "common.h" +#include "profiler/replacement.h" + +namespace jittor { +struct MemoryChecker { + Cache* tlb; + vector caches; + size_t page_size; + int64_t check_times; + // translate virtual address to physical address or not + size_t vtop; + + //TODO auto build MemoryChecker + MemoryChecker(Cache* tlb, vector caches, size_t page_size, size_t vtop); + ~MemoryChecker(); + static string get_replace_strategy(int id); + void clear(); + void print_miss(); + void check_hit(size_t vaddr); +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/profiler/profiler.cc b/python/jittor/src/profiler/profiler.cc new file mode 100644 index 00000000..b9fb5363 --- /dev/null +++ b/python/jittor/src/profiler/profiler.cc @@ -0,0 +1,613 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#include +#include +#ifdef _WIN32 +#include +#else +#include +#endif +#ifdef HAS_CUDA +#include +#include "helper_cuda.h" +#endif +#include "misc/cuda_flags.h" +#include "profiler/profiler.h" +#include "op.h" +#include "fused_op.h" +#include "profiler/memory_checker.h" +#include "misc/deleter.h" +#include "executor.h" +#include "utils/str_utils.h" + +namespace jittor { + +Profiler profiler; + +DEFINE_FLAG(int, profiler_warmup, 0, "Profiler warmup."); +DEFINE_FLAG(int, profiler_rerun, 0, "Profiler rerun."); +DEFINE_FLAG(int, profiler_record_peek, 0, "Profiler record peek mem bandwidth."); +DEFINE_FLAG(int, profiler_record_shape, 0, "Profiler record shape for op."); +DEFINE_FLAG(int, profiler_hide_relay, 0, "Profiler hide relayed op."); +DEFINE_FLAG_WITH_SETTER(int, profiler_enable, 0, "Enable profiler."); + +void setter_profiler_enable(int value) { + if (value) + Profiler::start(); + else + Profiler::stop(); +} + +Profiler::~Profiler() { + if (profiler_enable) { + Profiler::stop(); + Profiler::report(); + } +} + +void Profiler::start(int64 warmup, int64 rerun) { + if (warmup==0) warmup = profiler_warmup; + if (rerun==0) rerun = profiler_rerun; + profiler_enable = 1; + profiler.records.clear(); + profiler.marks.clear(); + profiler.warmup = warmup; + profiler.rerun = rerun; + profiler.relay_extra_cost = 0; + profiler.relay_fop = 0; +} + +void Profiler::stop() { + profiler_enable = 0; +} + +unique_ptr* load_memory_checker(string name) { + const char* msg = ""; + LOGvv << "Opening jit lib:" << name; + #ifdef _WIN32 + void* handle = (void*)LoadLibrary(name.c_str()); + #elif defined(__linux__) + void* handle = dlopen(name.c_str(), RTLD_LAZY | RTLD_DEEPBIND | RTLD_LOCAL); + msg = dlerror(); + #else + void* handle = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); + msg = dlerror(); + #endif + + CHECK(handle) << "Cannot open library" << name << ":" << msg; + + #ifdef _WIN32 + auto mm = (unique_ptr*)GetProcAddress((HINSTANCE)handle, "memory_checker"); + #else + //dlerror(); + auto mm = (unique_ptr*)dlsym(handle, "memory_checker"); + msg = dlerror(); + #endif + CHECK(!msg) << "Loading symbol memory_checker from" << name << "failed:" << msg; + + return mm; +} + +EXTERN_LIB string _get_stack_info(Node* node, const char* change_line=""); + +static string get_stack_info(Op* op) { + string stack_info = "stack info:\n"; + if (string("fused") == op->name()) { + auto fop = (FusedOp*)op; + map stacks; + for (Op* op : fop->ops) { + stacks[_get_stack_info(op)] = 1; + } + for (auto& kv : stacks) { + stack_info += kv.first; + stack_info += '\n'; + } + if (trace_py_var >= 3) { + std::stringstream ss; + ss << "input from:\n"; + for (auto& vi : fop->vars) { + if (vi.type == 0) { + auto v = vi.var; + ss << v->shape << ',' << v->dtype() << ',' << v->name << ','; + if (v->input()) + ss << v->input()->name_ex() << ',' << _get_stack_info(v->input()); + else + ss << _get_stack_info(v); + ss << '\n'; + } + } + stack_info += ss.str(); + } + return stack_info; + } else { + stack_info += _get_stack_info(op); + stack_info += '\n'; + if (trace_py_var >= 3) { + std::stringstream ss; + ss << "input from:\n"; + for (auto v : op->inputs()) { + ss << v->shape << ',' << v->dtype() << ',' << v->name << ','; + if (v->input()) + ss << v->input()->name_ex() << ',' << _get_stack_info(v->input()); + else + ss << _get_stack_info(v); + ss << '\n'; + } + stack_info += ss.str(); + } + return stack_info; + } +} + +static void stat_peek_bandwidth(uint64 in, uint64 out, uint64 loop, uint64& peek_time_total) { + auto size = (in+out) / 2; + // memcpy in some not aligned case will drop performance + size &= ~((1 << 12)-1); + // size = 7680000*4; + auto temp1 = exe.alloc_temp(size); + auto temp2 = exe.alloc_temp(size); + loop = 1 << loop; + int warmup = std::max(loop/8, (uint64)1); + for (int i=0; i(finish-start).count(); + peek_time_total += total_ns; +} + +struct RecordExtraCost { + int ck; + std::chrono::high_resolution_clock::time_point start; + + RecordExtraCost(int ck) : ck(ck) { + if (!ck) return; + start = std::chrono::high_resolution_clock::now(); + } + + ~RecordExtraCost() { + if (!ck) return; + auto finish = std::chrono::high_resolution_clock::now(); + auto total_ns = (int64_t)std::chrono::duration_cast(finish-start).count(); + profiler.relay_extra_cost += total_ns; + } +}; + +static string get_marks(Op* op, bool is_fused) { + loop_options_t* options = nullptr; + if (is_fused) { + auto* fop = (FusedOp*)op; + options = fop->loop_options; + } else { + if (op->outputs().size()) { + if (op->outputs().front()->loop_options) + options = &op->outputs().front()->loop_options.data(); + } + } + if (!options) return string(); + for (auto& kv : *options) { + if (startswith(kv.first, "_marks:")) { + return kv.first.substr(7); + } + } + return string(); +} + +static string origin_key(const string& s) { + if (s.size() && s[0]=='[') { + return s.substr(s.find("]")+1); + } + return s; +} + +void Profiler::record_and_run( + jit_op_entry_t jit_entry, + Op* op, + const char* jit_key +) { + if (!profiler_enable) + jit_entry(op); + else { + auto ikey=jit_key_mapper.find(jit_key); + const char* key = ikey==jit_key_mapper.end() ? + jit_key : ikey->second.c_str(); + bool is_fused = op->name() == string("fused"); + string marks = get_marks(op, is_fused); + string new_key; + if (marks.size()) { + // add marks into key, for better report + new_key = string("[marks:"); + new_key += marks; + new_key += "]"; + new_key += key; + key = new_key.c_str(); + } + + auto iter = profiler.records.find(key); + uint64_t in, out, compute; + if (profiler.relay_fop) + profiler.relay_fop->statistics(in, out, compute); + else + op->statistics(in, out, compute); + if (iter == profiler.records.end()) { + profiler.records[key] = Info{ + 0, 0, -1ull, 0, + 0, 0, 0 + }; + iter = profiler.records.find(key); + if (trace_py_var) { + iter->second.stack_info = get_stack_info(op); + } + } + + uint64* shape_time = nullptr; + if (trace_py_var || profiler_record_shape) { + // record shape + NanoVector shape; + int64 num = 0; + Op** ops = &op; + int op_num = 1; + if (is_fused) { + ops = &(((FusedOp*)op)->ops[0]); + op_num = ((FusedOp*)op)->ops.size(); + } + for (int i=0; iinputs()) { + if (v->num > num) { + num = v->num; + shape = v->shape; + } + } + for (auto v : o->outputs()) { + if (v->num > num) { + num = v->num; + shape = v->shape; + } + } + } + iter->second.shapes[shape].second += 1; + shape_time = &iter->second.shapes[shape].first; + } + int64_t warmup = profiler.warmup; + int64_t rerun = profiler.rerun + 1; + rerun = std::max(NanoVector::get_nbits(rerun) - 2, 0); + int loop = 0; + Deleter _d; + if (is_fused) { + auto fop = ((FusedOp*)op); + if (fop->context && fop->context->vrm.relay_groups.size()) { + // relay op + loop = rerun; + profiler.relay_extra_cost = 0; + profiler.relay_fop = fop; + _d.del = [&]() { + profiler.relay_extra_cost = 0; + profiler.relay_fop = 0; + }; + } else + loop = fop->get_loop_option("insert_profile_loop") ? 10 : 0; + } + int64 num = 1<<(rerun - loop); + + { + profiler_enable = 0; + Deleter del([&]() { profiler_enable = 1;}); + RecordExtraCost rec(profiler.relay_fop && profiler.relay_fop != op); + for (int64_t i=0; i(finish-start).count(); + if (profiler.relay_fop == op) { + total_ns -= profiler.relay_extra_cost; + } + // 24ns function call overhead + total_ns = std::max((int64_t)1, total_ns-24); + iter->second.update(rerun, total_ns, in, out, compute); + if (shape_time) shape_time[0] += total_ns; + + // add markers record + if ((profiler.relay_fop == op || profiler.relay_fop == nullptr) + && marks.size()) { + // only record not relay op + auto vs = split(marks, ","); + for (auto& mark : vs) { + if (mark.size()) { + auto& mark_info = profiler.marks[mark]; + mark_info.count += 1; + mark_info.time_total += total_ns; + } + } + } + + RecordExtraCost rec(profiler.relay_fop && profiler.relay_fop != op); + if (profiler_record_peek) + stat_peek_bandwidth(in, out, rerun, iter->second.peek_time_total); + LOGvvvv << "Duration" << total_ns >> "ns running" << op; + if (is_fused && + ((FusedOp*)op)->get_loop_option("check_cache")) { + auto fname = Op::get_filename_from_jit_key(origin_key(key), ".so"); + unique_ptr* mc = load_memory_checker(fname); + iter->second.cache_info.reset(new CacheInfo(mc)); + } + } +} + +vector> Profiler::report(const string& sort_key) { + vector> rep = {{"Name", "FileName", "Count", "TotalTime", "AvgTime", "MinTime", "MaxTime", "Input", "Output", "InOut", "Compute"}}; + if (profiler_record_peek) + rep[0].push_back("Peek"); + vector names, fnames; + vector> info; + vector order; + int sort_key_id = 0; + for (; sort_key_id<(int)rep[0].size(); sort_key_id++) + if (rep[0][sort_key_id] == sort_key) + break; + ASSERT(sort_key_id<(int)rep[0].size()) << "Key not supported:" << sort_key; + double total_time = 0; + double total_mem_access = 0; + for (auto& kv : profiler.records) { + auto& kinfo = kv.second; + names.push_back(kv.first); + fnames.push_back(Op::get_filename_from_jit_key(origin_key(kv.first), ".cc")); + if (kv.second.stack_info.size()) { + fnames.back() += '\n'; + fnames.back() += kv.second.stack_info.c_str(); + } + if (kv.second.shapes.size()) { + // show shapes + vector,NanoVector>> shapes; + shapes.reserve(kv.second.shapes.size()); + for (auto& kv2 : kv.second.shapes) { + shapes.push_back(std::make_pair(kv2.second, kv2.first)); + } + std::sort(shapes.begin(), shapes.end()); + std::stringstream ss; + ss << "shapes:\n"; + for (int i=0; i<10; i++) { + if (i>=shapes.size()) break; + auto& sp = shapes[shapes.size() - i - 1]; + auto rate = sp.first.first * 100.0 / kinfo.time_total; + ss << sp.second << ':' << sp.first.second << + "("<< std::setprecision(3) << rate << "%), "; + } + if (shapes.size()>10) + ss << "... total " << shapes.size() << '\n'; + fnames.back() += ss.str(); + } + order.push_back(order.size()); + // do not count relay op time + if (kv.first.find("relay") == string::npos) { + total_time += kinfo.time_total; + total_mem_access += kinfo.in_total + kinfo.out_total; + } + info.push_back({ + (double)kinfo.count, // Count + (double)kinfo.time_total, // TotalTime + (double)kinfo.time_total*1.0 / kinfo.count, // AvgTime + (double)kinfo.time_min, // MinTime + (double)kinfo.time_max, // MaxTime + (double)kinfo.in_total*1e9 / kinfo.time_total, // Input + (double)kinfo.out_total*1e9 / kinfo.time_total, // Output + (double)(kinfo.in_total+kinfo.out_total)*1e9 / kinfo.time_total, // InOut + (double)kinfo.compute_total*1e9 / kinfo.time_total, // Compute + }); + if (profiler_record_peek) + info.back().push_back( + (double)(kinfo.in_total+kinfo.out_total)*1e9 / kinfo.peek_time_total // Peek + ); + } + if (sort_key_id>=2) + std::sort(order.begin(), order.end(), [&](int i, int j) { + return info[i][sort_key_id-2] > info[j][sort_key_id-2]; + }); + else + std::sort(order.begin(), order.end(), [&](int i, int j) { + return names[i] > names[j]; + }); + std::stringstream ss; + ss << "Profile result, sorted by " << sort_key << "\n" + << "('it/s' represent number of iterations per sec)\n"; + uint w = 10, p=3; + for (auto& s : rep[0]) { + ss << std::setw(w) << s; + if (s == "TotalTime") + ss << std::setw(w) << "%,cum%"; + } + ss << '\n'; + auto output_float = [&](const string& scale, int base, const string& suffix, double k) { + ss << ' ' << std::setw(w-2-suffix.size()); + ss << std::setprecision(p); + uint i=0; + for (; i+1= w-1) + ss << "\n" << std::setw(w) << " "; + ss << std::setw(w) << fname; + if (fname.size() >= w-1) + ss << "\n" << std::setw(w*2) << " "; + for (uint j=0; j> '\n' >> ss.str() >> '\n'; + + //cache rep + // TODO: report_cache sort_key + vector> rep_cache = report_cache("CheckTimes"); + if (rep_cache.size() > 1) + rep.insert(rep.end(), rep_cache.begin(), rep_cache.end()); + return rep; +} + +vector> Profiler::report_cache(const string& sort_key) { + vector> rep = {{"Name", "FileName", "CheckTimes", "TLBMissRate"}}; + vector names, fnames; + vector> info; + vector> int_info; + vector order; + int sort_key_id = 0; + for (; sort_key_id<(int)rep[0].size(); sort_key_id++) + if (rep[0][sort_key_id] == sort_key) + break; + ASSERT(sort_key_id<(int)rep[0].size()) << "Key not supported:" << sort_key; + sort_key_id--; + for (auto& kv : profiler.records) { + if (!kv.second.cache_info) + continue; + names.push_back(kv.first); + fnames.push_back(Op::get_filename_from_jit_key(origin_key(kv.first), ".cc")); + CacheInfo& kinfo = *kv.second.cache_info; + order.push_back(order.size()); + vector one_info = {(double)kinfo.check_times, ((double)kinfo.tlb_miss_times) / kinfo.check_times}; + vector one_int_info = {(int)kinfo.check_times, (int)kinfo.tlb_miss_times}; + for (int i = 0; i < (int)kinfo.cache_miss_times.size(); ++i) { + if ((int)rep[0].size() < 4 + i + 1) { + std::stringstream ss; + ss << "L" << i + 1 << "MissRate"; + rep[0].push_back(ss.str()); + } + one_info.push_back(((double)kinfo.cache_miss_times[i]) / kinfo.check_times); + one_int_info.push_back((int)kinfo.cache_miss_times[i]); + } + info.push_back(one_info); + int_info.push_back(one_int_info); + } + if (sort_key_id>0) + std::sort(order.begin(), order.end(), [&](int i, int j) { + return info[i][sort_key_id-1] > info[j][sort_key_id-1]; + }); + else + std::sort(order.begin(), order.end(), [&](int i, int j) { + return names[i] > names[j]; + }); + std::stringstream ss; + ss << "Memory profile result, sorted by " << sort_key << "\n"; + uint w = 15; + for (auto& s : rep[0]) + ss << std::setw(w) << s; + ss << '\n'; + for (auto i : order) { + auto& name = names[i]; + auto& fname = fnames[i]; + rep.push_back({name, fname}); + ss << std::setw(w) << name; + if (name.size() >= w-1) + ss << "\n" << std::setw(w) << " "; + ss << std::setw(w) << fname; + if (fname.size() >= w-1) + ss << "\n" << std::setw(w*2) << " "; + for (uint j=0; j> '\n' >> ss.str() >> '\n'; + + return rep; +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/profiler/profiler.h b/python/jittor/src/profiler/profiler.h new file mode 100644 index 00000000..c050f335 --- /dev/null +++ b/python/jittor/src/profiler/profiler.h @@ -0,0 +1,76 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include "profiler/cache_info.h" +#include "op_compiler.h" +#include "misc/cstr.h" +#include "misc/nano_vector.h" + +namespace jittor { + +// @pyjt(profiler) +// @attrs(submodule) +struct Profiler { + struct Info { + uint64_t count; + // time in us + uint64_t time_max, time_min, time_total; + // thoughtput in byte + uint64_t in_total, out_total; + // compute thoughtput in ops + uint64_t compute_total; + // peek time use memcopy + uint64_t peek_time_total; + // cache test info + unique_ptr cache_info; + cstr stack_info; + unordered_map> shapes; + + void update(int c, uint64_t t, uint64_t in, uint64_t out, uint64_t comp) { + count += 1<>c); + time_min = std::min(time_min, t>>c); + time_total += t; + in_total += in<> report(const string& sort_key="TotalTime"); + static vector> report_cache(const string& sort_key="TotalTime"); + + static void record_and_run( + jit_op_entry_t jit_entry, + Op* op, + const char* jit_key + ); + + int64_t warmup=0, rerun=0; + unordered_map records; + int64 relay_extra_cost; + FusedOp* relay_fop; + + struct MarkInfo { + uint64_t count, time_total; + }; + unordered_map marks; + + ~Profiler(); +}; + +EXTERN_LIB Profiler profiler; + +DECLARE_FLAG(int, profiler_enable); + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/profiler/profiler_guard.h b/python/jittor/src/profiler/profiler_guard.h new file mode 100644 index 00000000..b044e0ef --- /dev/null +++ b/python/jittor/src/profiler/profiler_guard.h @@ -0,0 +1,62 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include "profiler/cache_info.h" +#include "profiler/profiler.h" +#include "op_compiler.h" + +namespace jittor { + +struct ProfilerGuard { + const char* key; + bool alive; + std::chrono::high_resolution_clock::time_point start_time; + std::chrono::high_resolution_clock::time_point stop_time; + + inline void start(int64 warmup=0, int64 rerun=0) { + alive = true; + start_time = std::chrono::high_resolution_clock::now(); + } + + inline void stop() { + if (!alive) return; + alive = false; + stop_time = std::chrono::high_resolution_clock::now(); + + auto iter = profiler.records.find(key); + if (iter == profiler.records.end()) { + profiler.records[key] = Profiler::Info{ + 0, 0, -1ull, 0, + 0, 0, 0 + }; + iter = profiler.records.find(key); + } + + auto total_ns = (int64_t)std::chrono::duration_cast(stop_time-start_time).count(); + // 24ns function call overhead + total_ns = std::max((int64_t)1, total_ns-24); + iter->second.update(1, total_ns, 0, 0, 0); + } + + inline ProfilerGuard(const char* _key) { + key = _key; + if (profiler_enable) { + ProfilerGuard::start(); + } + } + + inline ~ProfilerGuard() { + if (profiler_enable) { + ProfilerGuard::stop(); + } + } +}; + +DECLARE_FLAG(int, profiler_enable); + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/profiler/replacement.cc b/python/jittor/src/profiler/replacement.cc new file mode 100644 index 00000000..01ccca6d --- /dev/null +++ b/python/jittor/src/profiler/replacement.cc @@ -0,0 +1,74 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "profiler/replacement.h" + +namespace jittor { +CacheConfig::CacheConfig(size_t size, size_t ways, size_t line_size) : size(size), ways(ways), line_size(line_size) {} + + +Cache::Cache(const CacheConfig config) : config(config), miss_time(0) { +} + +Cache::~Cache() { +} + +void Cache::clear() { + miss_time = 0; + clear_(); +} + +bool Cache::check_hit(size_t paddr) { + bool hit = check_hit_(paddr); + if (!hit) ++miss_time; + return hit; +} + +DefaultReplacementCache::DefaultReplacementCache(const CacheConfig config) : Cache(config) {} + +bool DefaultReplacementCache::check_hit_(size_t paddr) { + size_t cache_set = paddr % (config.size/config.ways) / config.line_size; + size_t tag = paddr / (config.size/config.ways); + for (auto t : data[cache_set]) + if (t == tag) return true; + if (data[cache_set].size() >= config.ways) return false; + data[cache_set].push_back(tag); + return false; +} + +void DefaultReplacementCache::clear_() { + data.clear(); +} + +LRUCache::LRUCache(const CacheConfig config) : Cache(config) {} + +bool LRUCache::check_hit_(size_t paddr) { + size_t cache_set = paddr % (config.size/config.ways) / config.line_size; + size_t tag = paddr / (config.size/config.ways); + + for (int i = 0; i < (int)data[cache_set].size(); ++i) { + size_t t = data[cache_set][i]; + if (t == tag) { + data[cache_set].erase(data[cache_set].begin() + i); + data[cache_set].insert(data[cache_set].begin(), tag); + return true; + } + } + data[cache_set].insert(data[cache_set].begin(), tag); + if (data[cache_set].size() > config.ways) { + data[cache_set].pop_back(); + } + return false; +} + +void LRUCache::clear_() { + data.clear(); +} + +} //jittor \ No newline at end of file diff --git a/python/jittor/src/profiler/replacement.h b/python/jittor/src/profiler/replacement.h new file mode 100644 index 00000000..df0ea533 --- /dev/null +++ b/python/jittor/src/profiler/replacement.h @@ -0,0 +1,49 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include +#include + +namespace jittor { +struct CacheConfig { + size_t size, ways, line_size; + CacheConfig(size_t size, size_t ways, size_t line_size=64); +}; + +struct Cache { + CacheConfig config; + int miss_time; + + Cache(const CacheConfig config); + virtual ~Cache(); + void clear(); + bool check_hit(size_t paddr); + virtual bool check_hit_(size_t paddr) = 0; + virtual void clear_() = 0; +}; + +struct DefaultReplacementCache : Cache { + std::map> data; + + DefaultReplacementCache(const CacheConfig config); + bool check_hit_(size_t paddr); + void clear_(); +}; + +struct LRUCache : Cache { + std::map> data; + + LRUCache(const CacheConfig config); + bool check_hit_(size_t paddr); + void clear_(); +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/profiler/simple_profiler.h b/python/jittor/src/profiler/simple_profiler.h new file mode 100644 index 00000000..9a92f23c --- /dev/null +++ b/python/jittor/src/profiler/simple_profiler.h @@ -0,0 +1,104 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include +#include "common.h" +#include "misc/intrin.h" + +namespace jittor { + +struct SimpleProfiler { + string name; + int64 cnt; + int64 total_ns; + int64 sum; + int64 pcnt[7] = {0}; + int64 pns[7] = {0}; + int64 last[7] = {0}; + + void report() { + std::cerr << "=============================\nSimpleProfiler [" << name << "] cnt: " << cnt + << " sum: " << sum << " speed: " << std::setprecision(3) << (sum*1.0/total_ns) + << " total: " ; + if (total_ns < 1.e3) + std::cerr << total_ns << " ns" << std::endl; + else if (total_ns < 1.e6) + std::cerr << std::setprecision(3) << total_ns/1.e3 << " us" << std::endl; + else if (total_ns < 1.e9) + std::cerr << std::setprecision(3) << total_ns/1.e6 << " ms" << std::endl; + else + std::cerr << std::setprecision(3) << total_ns/1.e9 << " s" << std::endl; + std::cerr << " <32ns <1us <32us <1ms <32ms <1s >1s\n"; + std::cerr << "cnt: "; + for (int i=0; i<7; i++) std::cerr << std::setw(9) << pcnt[i]; + std::cerr << "\n "; + for (int i=0; i<7; i++) std::cerr << std::setw(9) << std::setprecision(3) << pcnt[i]*1.0/cnt; + std::cerr << "\ntime:"; + for (int i=0; i<7; i++) std::cerr << std::setw(9) << std::setprecision(3) << pns[i]*1.0/total_ns; + std::cerr << "\nlast:"; + for (int i=0; i<7; i++) std::cerr << std::setw(9) << last[i]; + std::cerr << std::endl; + } + + inline SimpleProfiler(string&& name): name(move(name)), cnt(0), total_ns(0), sum(0) {} + inline ~SimpleProfiler() { report(); } + inline void add(int64 time, int64 s) { + auto nbit = 64 - lzcnt(time); + auto i = (nbit-1) / 5; + if (i>6) i=6; + cnt ++; + sum += s; + total_ns += time; + pcnt[i] ++; + pns[i] += time; + last[i] = cnt; + } + + inline void reset() { + cnt = 0; + total_ns = 0; + sum = 0; + for (int i=0; i<7; i++) { + pcnt[i] = 0; + pns[i] = 0; + last[i] = 0; + } + } +}; + +/* +example: + { + static SimpleProfiler _("array"); + SimpleProfilerGuard __(_); + ...... + } + */ +struct SimpleProfilerGuard { + SimpleProfiler* p; + int64 s; + std::chrono::high_resolution_clock::time_point start; + inline SimpleProfilerGuard(SimpleProfiler& p, int64 s=1) : p(&p), s(s) { + start = std::chrono::high_resolution_clock::now(); + } + void finish() { + this->~SimpleProfilerGuard(); + s = 0; + } + inline ~SimpleProfilerGuard() { + if (!s) return; + auto finish = std::chrono::high_resolution_clock::now(); + auto total_ns = (int64_t)std::chrono::duration_cast(finish-start).count(); + p->add(total_ns, s); + } +}; + + +DECLARE_FLAG(int, profiler_enable); + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/profiler/vtop.cc b/python/jittor/src/profiler/vtop.cc new file mode 100644 index 00000000..6d8250a8 --- /dev/null +++ b/python/jittor/src/profiler/vtop.cc @@ -0,0 +1,89 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include /* uint64_t */ +#include /* printf */ +#include /* size_t */ +#ifndef _WIN32 +#define _XOPEN_SOURCE 700 +#include /* open */ +#include /* pread, sysconf */ + +typedef struct { + uint64_t pfn : 54; + unsigned int soft_dirty : 1; + unsigned int file_page : 1; + unsigned int swapped : 1; + unsigned int present : 1; +} PagemapEntry; + +/* Parse the pagemap entry for the given virtual address. + * + * * [out] entry: the parsed entry + * * [in] pagemap_fd: file descriptor to an open /proc/pid/pagemap file + * * [in] vaddr: virtual address to get entry for + * @return 0 for success, 1 for failure + */ +int pagemap_get_entry(PagemapEntry* entry, int pagemap_fd, uintptr_t vaddr) +{ + size_t nread; + int64_t ret; + uint64_t data; + uintptr_t vpn; + + vpn = vaddr / sysconf(_SC_PAGE_SIZE); + nread = 0; + while (nread < sizeof(data)) { + ret = pread(pagemap_fd, &data, sizeof(data) - nread, + vpn * sizeof(data) + nread); + nread += ret; + if (ret <= 0) { + return 1; + } + } + entry->pfn = data & (((uint64_t)1 << 54) - 1); + entry->soft_dirty = (data >> 54) & 1; + entry->file_page = (data >> 61) & 1; + entry->swapped = (data >> 62) & 1; + entry->present = (data >> 63) & 1; + return 0; +} + +/* Convert the given virtual address to physical using /proc/self/pagemap. + * + * * [out] paddr: physical address + * * [in] vaddr: virtual address to get entry for + * @return 0 for success, 1 for failure + */ +int virt_to_phys_user(uintptr_t* paddr, uintptr_t vaddr) +{ + int pagemap_fd; + pagemap_fd = open("/proc/self/pagemap", O_RDONLY); + if (pagemap_fd < 0) { + return 1; + } + PagemapEntry entry; + if (pagemap_get_entry(&entry, pagemap_fd, vaddr)) { + return 1; + } + close(pagemap_fd); + if (entry.pfn == 0) + return 1; + *paddr = (entry.pfn * sysconf(_SC_PAGE_SIZE)) + (vaddr % sysconf(_SC_PAGE_SIZE)); + return 0; +} + +#else +int virt_to_phys_user(uintptr_t* paddr, uintptr_t vaddr) +{ + paddr[0] = vaddr; + return 1; +} + +#endif diff --git a/python/jittor/src/pybind/core.cc b/python/jittor/src/pybind/core.cc new file mode 100644 index 00000000..42dae808 --- /dev/null +++ b/python/jittor/src/pybind/core.cc @@ -0,0 +1,41 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "core.h" +#include "grad.h" +#include "pyjt/py_obj_holder.h" +#include "init.h" +#include "pyjt/numpy.h" +#include "utils/seh.h" + +namespace jittor { + +SEH_HOOK; + +// Those function is generated by python +EXTERN_LIB void pyjt_def_all(PyObject* m); + +vector _grad(VarHolder* loss, const vector& targets, bool retain_graph) { + vector vs; + vs.reserve(targets.size()); + for (auto* v : targets) vs.push_back(v->var); + auto grads = grad(loss->var, vs, retain_graph); + vector grads_hold; + grads_hold.reserve(targets.size()); + for (auto& grad : grads) + grads_hold.push_back(new VarHolder(move(grad))); + return grads_hold; +} + +} // jittor + +static void init_module(PyModuleDef* mdef, PyObject* m) { + mdef->m_doc = "Inner c++ core of jittor"; + jittor::init(); + jittor::numpy_init(); + jittor::pyjt_def_all(m); +} +PYJT_MODULE_INIT(jittor_core); diff --git a/python/jittor/src/pybind/py_var_tracer.cc b/python/jittor/src/pybind/py_var_tracer.cc new file mode 100644 index 00000000..5cc51596 --- /dev/null +++ b/python/jittor/src/pybind/py_var_tracer.cc @@ -0,0 +1,445 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang . +// Guoye Yang <498731903@qq.com> +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#include "pyjt/py_obj_holder.h" +#include "pyjt/py_converter.h" +#include "pybind/py_var_tracer.h" +#include "utils/str_utils.h" +#include "op.h" +#include "var.h" +#include "fused_op.h" + +namespace jittor { + +DEFINE_FLAG(int, trace_py_var, 0, "Trace py stack max depth for debug."); +DEFINE_FLAG(int, trace_var_data, 0, "Trace py stack max depth for debug."); +Op* trace_grad_op = nullptr; + +TraceData trace_data; +int64 cnt = 0; + +static PyObject* my_import(const char* module_name, const char* attr) { + // LOGir << module_name << attr; + PyObjHolder a(PyImport_ImportModule(module_name)); + PyObjHolder b(PyObject_GetAttrString(a.obj, attr)); + // LOGir << "Done"; + return b.obj; +} + +static PyObject* find_obj_name(PyFrameObject* f, PyObject* obj, const char* default_name="_model") { + #if PY_MINOR_VERSION>=11 + #pragma message( "PY_MAJOR_VERSION333 " PY_VERSION ) + LOGf << "python3.11 not supported yet"; + return nullptr; + #else + auto co = f->f_code; + auto map = co->co_varnames; + + auto fast = f->f_localsplus; + auto j = PyTuple_GET_SIZE(map); + if (j > co->co_nlocals) + j = co->co_nlocals; + if (co->co_nlocals) { + for (int i=0; ico_cellvars); + auto nfreevars = PyTuple_GET_SIZE(co->co_freevars); + if (ncells || nfreevars) { + for (int i=0; ico_nlocals] == obj) { + auto s = PyTuple_GET_ITEM(co->co_cellvars, i); + Py_INCREF(s); + return s; + } + } + for (int i=0; ico_nlocals+ncells] == obj) { + auto s = PyTuple_GET_ITEM(co->co_freevars, i); + Py_INCREF(s); + return s; + } + } + } + // LOGw << "not found name" << map << co->co_cellvars << co->co_freevars << (PyObject*)f; + return PyUnicode_FromString(default_name); + #endif +} + +static string to_string(PyObject* obj) { + Py_ssize_t size; + const char* s = PyUnicode_AsUTF8AndSize(obj, &size); + return string(s, size); +} + +static vector get_stack_info() { + #if PY_MINOR_VERSION>=11 + LOGf << "python3.11 not supported yet"; + return {}; + #else + // cnt ++; + // if (cnt % 100 != 0) return {}; + vector stacks; + static auto getframe = my_import("sys", "_getframe"); + static auto jt_module = my_import("jittor", "Module"); + static auto jt_optimizer = my_import("jittor.optim", "Optimizer"); + static auto fill_module_name = my_import("jittor.utils.tracer", "fill_module_name"); + static auto _trace_name = PyUnicode_FromString("_trace_name"); + + PyObjHolder ret(PyObject_CallFunctionObjArgs(getframe, nullptr)); + + auto frame = (PyFrameObject*)ret.obj; + int n=0; + while (frame) n++, frame = frame->f_back; + STACK_ALLOC(PyFrameObject*, frames, n); + frame = (PyFrameObject*)ret.obj; + int i=n; + while (i) frames[--i] = frame, frame = frame->f_back; + PyObject* prev_obj = nullptr; + if (trace_py_var >= 3) { + // trace raw stack + // auto start = std::max(0, n-5); + auto start = 0; + for (int i=start; if_code->co_filename); + auto lineno = (int)PyFrame_GetLineNumber(f); + stacks.emplace_back(Stack{ + filename+":"+S(lineno), + to_string(f->f_code->co_name), + filename, + lineno}); + } + return stacks; + } + for (int i=0; if_code->co_varnames)) { + auto fast = f->f_localsplus; + auto obj = fast[0]; + if (obj == prev_obj) continue; + prev_obj = obj; + if (obj == nullptr) + // normal function first argument is null + continue; + auto tp_mro = obj->ob_type->tp_mro; + auto base_type = PyTuple_GET_ITEM(tp_mro, Py_SIZE(tp_mro)-2); + auto prev_f = i? frames[i-1] : f; + if (base_type == jt_optimizer) { + string init_name = string(obj->ob_type->tp_name) + "_init"; + PyObjHolder ret(find_obj_name(f->f_back, obj, init_name.c_str())); + stacks.emplace_back(Stack{ + to_string(ret.obj), + string(obj->ob_type->tp_name), + to_string(prev_f->f_code->co_filename), + (int)PyFrame_GetLineNumber(prev_f)}); + break; + } + if (base_type != jt_module) + continue; + PyObjHolder ret; + _PyObject_LookupAttr(obj, _trace_name, &ret.obj); + string scope_name; + if (!ret.obj) { + // find base name + auto co_name = to_string(f->f_code->co_name); + if (co_name == "__init__") { + scope_name = string(obj->ob_type->tp_name) + "_init"; + } else + if (co_name == "__call__") { + if (i) { + ret.assign(find_obj_name(f->f_back, obj)); + scope_name = to_string(ret.obj); + } else { + ret.assign(PyUnicode_FromString("_model")); + scope_name = "_model"; + } + PyObjHolder _(PyObject_CallFunctionObjArgs( + fill_module_name, obj, ret.obj, nullptr)); + } + } else { + scope_name = to_string(ret.obj); + } + stacks.emplace_back(Stack{ + move(scope_name), + string(obj->ob_type->tp_name), + to_string(prev_f->f_code->co_filename), + (int)PyFrame_GetLineNumber(prev_f)}); + } + } + if (stacks.size() == 0) { + auto m = std::min(3,n); + for (int i=0; if_code->co_filename); + auto num = (int)PyFrame_GetLineNumber(f); + stacks.emplace_back(Stack{ + s+":"+S(num), + "", + s, + num}); + } + } + return stacks; + #endif +} + +template +string get_str(T* t, int64 num) { + string s = ""; + for (int64 i=0; idtype() == ns_int8) + return get_str(v->ptr(), v->num); + if (v->dtype() == ns_int16) + return get_str(v->ptr(), v->num); + if (v->dtype() == ns_int32) + return get_str(v->ptr(), v->num); + if (v->dtype() == ns_int64) + return get_str(v->ptr(), v->num); + + + if (v->dtype() == ns_uint8) + return get_str(v->ptr(), v->num); + if (v->dtype() == ns_uint16) + return get_str(v->ptr(), v->num); + if (v->dtype() == ns_uint32) + return get_str(v->ptr(), v->num); + if (v->dtype() == ns_uint64) + return get_str(v->ptr(), v->num); + + if (v->dtype() == ns_float32) + return get_str(v->ptr(), v->num); + if (v->dtype() == ns_float64) + return get_str(v->ptr(), v->num); + return ""; +} + +void TraceData::record_node(Node* node, bool record_stack) { + if (get_thread_name().size()) return; + NodeData data; + data.id = node_data_cnt++; + id_map[node] = data.id; + if (trace_py_var) { + if (record_stack) { + if (trace_grad_op) { + auto iter = trace_data.id_map.find(trace_grad_op); + data.stacks.emplace_back(Stack{"grad", "Grad", "", 0}); + if (iter != trace_data.id_map.end()) { + data.attrs["grad_op_id"] = S(iter->second); + auto& prev_stack = trace_data.node_data[iter->second].stacks; + for (auto& s : prev_stack) + data.stacks.push_back(s); + } + } else + data.stacks = get_stack_info(); + } + } else { + } + data.attrs["__id"] = S(node->id); + data.attrs["is_var"] = node->is_var() ? "1" : "0"; + data.attrs["name"] = "unname"; + node_data[data.id] = move(data); +} + +static int64 get_node_id(Node* node) { + auto iter = trace_data.id_map.find(node); + if (iter != trace_data.id_map.end()) + return iter->second; + trace_data.record_node(node, false); + return trace_data.node_data_cnt - 1; +} + +void TraceData::release_node(Node* node) { + if (get_thread_name().size()) return; + auto iter = trace_data.id_map.find(node); + if (iter == trace_data.id_map.end()) + return; + auto node_id = iter->second; + id_map.erase(node); + if (trace_py_var < 2 || execute_op_info.size() > 100000) { + node_data.erase(node_id); + } +} + +void TraceData::record_exe_node(Node* node) { + auto node_id = get_node_id(node); + auto& data = node_data[node_id]; + auto name_iter = data.attrs.find("name"); + if (data.inputs.size() != node->inputs().size() || data.attrs.size() == 0 || name_iter == data.attrs.end() || name_iter->second == "unname") { + data.inputs.clear(); + data.inputs.reserve(node->inputs().size()); + for (auto i : node->inputs()) { + auto iid = get_node_id(i); + data.inputs.push_back(iid); + node_data[iid].outputs.push_back(node_id); + } + if (node->is_var()) { + auto v = node->var(); + std::stringstream ss; + ss << v->shape; + data.attrs["shape"] = ss.str(); + data.attrs["ndim"] = S(v->shape.size()); + data.attrs["dtype"] = v->dtype().to_cstring(); + data.attrs["dsize"] = S(v->dtype().dsize()); + data.attrs["name"] = v->name.c_str(); + data.attrs["is_var"] = "1"; + if (trace_var_data && v->mem_ptr) + data.attrs["data"] = get_var_data_str(v); + } else { + auto op = node->op(); + data.attrs["name"] = op->name_ex(); + data.attrs["is_var"] = "0"; + // TODO: add other op attrs + } + } +} + +void TraceData::record_op(Op* op) { + record_exe_node(op); + for (auto o : op->outputs()) + record_exe_node(o); +} + +void TraceData::record_execution(Op* op, bool is_fused_op, JK& jk) { + if (execute_op_info.size() > 100000) return; + ExecuteOpInfo& einfo = execute_op_info[execute_op_info_cnt++]; + if (is_fused_op) { + FusedOp* fop = (FusedOp*)op; + for (auto op : fop->ops) { + record_op(op); + einfo.fused_ops.push_back(get_node_id(op)); + } + } else { + record_op(op); + einfo.fused_ops.push_back(get_node_id(op)); + } + op->do_prepare(jk); + if (jk.empty()) return; + const char* jit_key = jk.to_cstring(); + auto iter = jit_key_mapper.find(jit_key); + if (iter == jit_key_mapper.end()) + einfo.jit_key = jit_key; + else + einfo.jit_key = iter->second; + jit_key_map[einfo.jit_key].push_back(execute_op_info_cnt-1); + einfo.file_path = Op::get_filename_from_jit_key(jk.to_cstring(), ".cc"); +} + +template +static void fill_dict(PyObject* dict, T key, PyObject* value) { + PyObjHolder k(to_py_object(key)); + PyObjHolder v(value); + PyDict_SetItem(dict, k.obj, value); +} + +// template<> +// PyObject* to_py_object(const Stack& stack) { +// return nullptr; +// } + +DEF_IS(Stack, PyObject*) to_py_object(const T& a) { + PyObjHolder dict(PyDict_New()); + fill_dict(dict.obj, string("name"), to_py_object(a.module_name)); + fill_dict(dict.obj, string("type"), to_py_object(a.module_type)); + fill_dict(dict.obj, string("file_path"), to_py_object(a.file_path)); + fill_dict(dict.obj, string("lineno"), to_py_object(a.lineno)); + return dict.release(); +} + +PyObject* dump_trace_data() { + PyObjHolder dict(PyDict_New()); + PyObjHolder node_data(PyDict_New()); + PyObjHolder execute_op_info(PyDict_New()); + for (auto& kv : trace_data.node_data) { + if (kv.second.attrs.size() == 0) + continue; + auto name_iter = kv.second.attrs.find("name"); + // if don't have name, this node is not executed + if (name_iter == kv.second.attrs.end() || name_iter->second == "unname") + continue; + PyObjHolder dict(PyDict_New()); + fill_dict(dict.obj, string("id"), to_py_object(kv.second.id)); + fill_dict(dict.obj, string("inputs"), to_py_object(kv.second.inputs)); + fill_dict(dict.obj, string("outputs"), to_py_object(kv.second.outputs)); + fill_dict(dict.obj, string("stacks"), to_py_object(kv.second.stacks)); + fill_dict(dict.obj, string("attrs"), to_py_object(kv.second.attrs)); + fill_dict(node_data.obj, kv.first, dict.release()); + } + for (auto& kv : trace_data.execute_op_info) { + PyObjHolder dict(PyDict_New()); + fill_dict(dict.obj, string("fused_ops"), to_py_object(kv.second.fused_ops)); + fill_dict(dict.obj, string("jit_key"), to_py_object(kv.second.jit_key)); + fill_dict(dict.obj, string("file_path"), to_py_object(kv.second.file_path)); + fill_dict(dict.obj, string("attrs"), to_py_object(kv.second.attrs)); + fill_dict(execute_op_info.obj, kv.first, dict.release()); + } + fill_dict(dict.obj, string("node_data"), node_data.release()); + fill_dict(dict.obj, string("execute_op_info"), execute_op_info.release()); + return dict.release(); +} + +void clear_trace_data() { + trace_data.execute_op_info.clear(); + trace_data.jit_key_map.clear(); + trace_data.id_map.clear(); + trace_data.node_data.clear(); +} + +string _get_stack_info(Node* node, const char* change_line) { + string stack_info = ""; + auto iter = trace_data.id_map.find(node); + if (iter == trace_data.id_map.end()) + return stack_info; + auto node_id = iter->second; + auto iter2 = trace_data.node_data.find(node_id); + if (iter2 == trace_data.node_data.end()) + return stack_info; + for (auto& stack : iter2->second.stacks) { + stack_info += stack.module_name; + stack_info += '('; + stack_info += stack.module_type; + stack_info += ')'; + stack_info += " -> "; + stack_info += change_line; + } + return stack_info; +} + +void print_node_trace(const Node* node, std::ostream& os) { + os << _get_stack_info((Node*)node, "\n"); +} + +vector get_node_trace(Node* node) { + auto iter = trace_data.id_map.find(node); + if (iter == trace_data.id_map.end()) + return vector(); + auto node_id = iter->second; + auto iter2 = trace_data.node_data.find(node_id); + if (iter2 == trace_data.node_data.end()) + return vector(); + return iter2->second.stacks; +} + + +} // jittor diff --git a/python/jittor/src/pybind/py_var_tracer.h b/python/jittor/src/pybind/py_var_tracer.h new file mode 100644 index 00000000..732bde7f --- /dev/null +++ b/python/jittor/src/pybind/py_var_tracer.h @@ -0,0 +1,71 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +namespace jittor { + +DECLARE_FLAG(int, trace_py_var); +EXTERN_LIB Op* trace_grad_op; +struct JitKey; + +struct Stack { + string module_name; + string module_type; + string file_path; + int lineno; +}; + +struct NodeData { + int64 id; + vector inputs; + vector outputs; + vector stacks; + /* + if is var, then contain: + is_var: 1 + shape: [a,b,c,d] + ndim: x + dtype: floatxx + dsize: 4 or 8 + name: xxx + if is op, then contain: + is_var: 0 + name: xxx + other op attr + */ + unordered_map attrs; +}; + +struct ExecuteOpInfo { + vector fused_ops; + string jit_key; + string file_path; + unordered_map attrs; +}; + +struct TraceData { + int64 node_data_cnt; + int64 execute_op_info_cnt; + unordered_map node_data; + unordered_map execute_op_info; + // jit_key map to id of execute_op_info + unordered_map> jit_key_map; + unordered_map id_map; + + void record_node(Node* node, bool record_stack=true); + void release_node(Node*); + void record_op(Op* op); + void record_exe_node(Node* node); + void record_execution(Op* op, bool is_fused_op, JitKey& jk); +}; + +EXTERN_LIB TraceData trace_data; + +void print_node_trace(const Node* node, std::ostream& os); +vector get_node_trace(Node* node); +} // jittor diff --git a/python/jittor/src/pybind/py_var_tracer_interface.h b/python/jittor/src/pybind/py_var_tracer_interface.h new file mode 100644 index 00000000..c1dcaa26 --- /dev/null +++ b/python/jittor/src/pybind/py_var_tracer_interface.h @@ -0,0 +1,18 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include + +namespace jittor { + +// @pyjt(dump_trace_data) +PyObject* dump_trace_data(); + +// @pyjt(clear_trace_data) +void clear_trace_data(); + +} // jittor diff --git a/python/jittor/src/pyjt/numpy.cc b/python/jittor/src/pyjt/numpy.cc new file mode 100644 index 00000000..bd5bca2b --- /dev/null +++ b/python/jittor/src/pyjt/numpy.cc @@ -0,0 +1,77 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "pyjt/numpy.h" + +namespace jittor { + +NanoString npy2ns[] = { + ns_bool, + ns_int8, ns_uint8, + ns_int16, ns_uint16, + ns_int32, ns_uint32, + #ifdef _WIN32 + ns_int32, ns_uint32, + #else + ns_int64, ns_uint64, + #endif + ns_int64, ns_uint64, + ns_float32, ns_float64, ns_float64, + ns_void, ns_void, ns_void, + ns_void, // 17 + ns_void, ns_void, ns_void, ns_void, ns_void, // 22 + ns_float16, // 23 +}; + +NPY_TYPES ns2npy[] = { + NPY_OBJECT, // placeholder for ns_void + NPY_BOOL, + #ifdef _WIN32 + NPY_BYTE, NPY_SHORT, NPY_LONG, NPY_LONGLONG, + NPY_UBYTE, NPY_USHORT, NPY_ULONG, NPY_ULONGLONG, + #else + NPY_BYTE, NPY_SHORT, NPY_INT, NPY_LONGLONG, + NPY_UBYTE, NPY_USHORT, NPY_UINT, NPY_ULONGLONG, + #endif + NPY_HALF, NPY_FLOAT, NPY_DOUBLE, + NPY_USHORT // fake half +}; + +void** PyArray_API; +PyTypeObject *PyArray_Type; +PyTypeObject *PyNumberArrType_Type; +PyTypeObject *PyArrayDescr_Type; +PyObject* (*PyArray_New)(PyTypeObject *, int, npy_intp const *, int, npy_intp const *, void *, int, int, PyObject *); +PyObject* (*PyArray_FromAny)(PyObject *, PyArrayDescr_Proxy *, int, int, int, PyObject *); +unsigned int (*PyArray_GetNDArrayCFeatureVersion)(); +int (*PyArray_SetBaseObject)(PyObject *arr, PyObject *obj); +PyObject* (*PyArray_NewCopy)(PyObject *, int); +int (*PyArray_CopyInto)(PyObject *, PyObject *); +void (*PyArray_CastScalarToCtype)(PyObject* scalar, void* ctypeptr, PyArrayDescr_Proxy* outcode); + +tmp_data_t tmp_data; + +void numpy_init() { + PyObjHolder np(PyImport_ImportModule("numpy.core.multiarray"), "numpy is not installed"); + PyObjHolder api(PyObject_GetAttrString(np.obj, "_ARRAY_API"), "numpy _ARRAY_API not found, you may need to reinstall numpy"); + PyArray_API = (void **) PyCapsule_GetPointer(api.obj, NULL); + + #define fill(name, i) name = (decltype(name))PyArray_API[i] + fill(PyArray_Type, 2); + fill(PyArrayDescr_Type, 3); + fill(PyNumberArrType_Type, 11); + fill(PyArray_FromAny, 69); + fill(PyArray_New, 93); + fill(PyArray_GetNDArrayCFeatureVersion, 211); + fill(PyArray_SetBaseObject, 282); + fill(PyArray_NewCopy, 85); + fill(PyArray_CopyInto, 82); + fill(PyArray_CastScalarToCtype, 63); + + ASSERT(PyArray_GetNDArrayCFeatureVersion()>=7); +} + +} // jittor diff --git a/python/jittor/src/pyjt/numpy.h b/python/jittor/src/pyjt/numpy.h new file mode 100644 index 00000000..ce06a1f3 --- /dev/null +++ b/python/jittor/src/pyjt/numpy.h @@ -0,0 +1,130 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "pyjt/py_obj_holder.h" +#include "common.h" +#include "misc/nano_string.h" +#include "ops/array_op.h" + +namespace jittor { + +struct PyArrayDescr_Proxy { + PyObject_HEAD + PyObject* typeobj; + char kind; + char type; + char byteorder; + char flags; + int type_num; + int elsize; + int alignment; + char* subarray; + PyObject *fields; + PyObject *names; +}; + +struct PyArray_Proxy { + PyObject_HEAD + char* data; + int nd; + Py_ssize_t* dimensions; + Py_ssize_t* strides; + PyObject *base; + PyArrayDescr_Proxy *descr; + int flags; +}; + +enum NPY_TYPES { + NPY_BOOL=0, + NPY_BYTE, NPY_UBYTE, + NPY_SHORT, NPY_USHORT, + NPY_INT, NPY_UINT, + NPY_LONG, NPY_ULONG, + NPY_LONGLONG, NPY_ULONGLONG, + NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE, + NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE, + NPY_OBJECT=17, + NPY_HALF=23, + NPY_END=24, +}; + +EXTERN_LIB NanoString npy2ns[]; +EXTERN_LIB NPY_TYPES ns2npy[]; + +#define NPY_ARRAY_C_CONTIGUOUS 0x0001 +#define NPY_ARRAY_ALIGNED 0x0100 +#define NPY_ARRAY_WRITEABLE 0x0400 +// NPY_ARRAY_C_CONTIGUOUS=1 +inline bool is_c_style(PyArray_Proxy* obj) { return obj->flags & 1; } +inline NanoString get_type_str(PyArray_Proxy* obj) { + NanoString type = ns_void; + if (obj->descr->type_num < NPY_END) + type = npy2ns[obj->descr->type_num]; + CHECK(type != ns_void) << "Numpy type not support, type_num:" + << obj->descr->type_num + << "type_char:" << obj->descr->type << NPY_END << npy2ns[obj->descr->type_num]; + return type; +} + +inline int get_typenum(NanoString ns) { + return ns2npy[ns.index()]; +} + +typedef Py_intptr_t npy_intp; + +EXTERN_LIB unordered_map np_typenum_map; + +EXTERN_LIB void** PyArray_API; +EXTERN_LIB PyTypeObject *PyArray_Type; +EXTERN_LIB PyTypeObject *PyNumberArrType_Type; +EXTERN_LIB PyTypeObject *PyArrayDescr_Type; +EXTERN_LIB PyObject* (*PyArray_New)(PyTypeObject *, int, npy_intp const *, int, npy_intp const *, void *, int, int, PyObject *); +EXTERN_LIB PyObject* (*PyArray_FromAny)(PyObject *, PyArrayDescr_Proxy *, int, int, int, PyObject *); +EXTERN_LIB unsigned int (*PyArray_GetNDArrayCFeatureVersion)(); +EXTERN_LIB int (*PyArray_SetBaseObject)(PyObject *arr, PyObject *obj); +EXTERN_LIB PyObject* (*PyArray_NewCopy)(PyObject *, int); +EXTERN_LIB int (*PyArray_CopyInto)(PyObject *, PyObject *); +EXTERN_LIB void (*PyArray_CastScalarToCtype)(PyObject* scalar, void* ctypeptr, PyArrayDescr_Proxy* outcode); + +#define PyArray_Copy(obj) PyArray_NewCopy(obj, 0) + +#define NPY_ARRAY_ALIGNED 0x0100 +#define NPY_ARRAY_WRITEABLE 0x0400 +#define NPY_ARRAY_BEHAVED (NPY_ARRAY_ALIGNED | \ + NPY_ARRAY_WRITEABLE) + +#define NPY_ARRAY_CARRAY (NPY_ARRAY_C_CONTIGUOUS | \ + NPY_ARRAY_BEHAVED) + +#define PyArray_SimpleNew(nd, dims, typenum) \ + PyArray_New(PyArray_Type, nd, dims, typenum, NULL, NULL, 0, 0, NULL) + +#define PyArray_SimpleNewFromData(nd, dims, typenum, data) \ + PyArray_New(&PyArray_Type, nd, dims, typenum, NULL, \ + data, 0, NPY_ARRAY_CARRAY, NULL) + +#define PyArray_FROM_O(m) PyArray_FromAny(m, NULL, 0, 0, 0, NULL) + +inline int64 PyArray_Size(PyArray_Proxy* arr) { + int64 size = 1; + for (int i=0; ind; i++) + size *= arr->dimensions[i]; + size *= arr->descr->elsize; + return size; +} + +union tmp_data_t { + int32 i32; + float32 f32; + int8 i8; +}; + +EXTERN_LIB tmp_data_t tmp_data; + +void numpy_init(); + +} // jittor diff --git a/python/jittor/src/pyjt/py_arg_printer.cc b/python/jittor/src/pyjt/py_arg_printer.cc new file mode 100644 index 00000000..34a129e2 --- /dev/null +++ b/python/jittor/src/pyjt/py_arg_printer.cc @@ -0,0 +1,64 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "pyjt/py_arg_printer.h" +#include "pyjt/py_obj_holder.h" +#include "pyjt/py_converter.h" + +namespace jittor { + +std::ostream& operator<<(std::ostream& os, const PyArgPrinter& arg) { + os << " " << arg.name << "\t= "; + if (!arg.obj) return os << "null,"; + return os << _PyType_Name(Py_TYPE(arg.obj)) << ",\n"; +} + +std::ostream& operator<<(std::ostream& os, const PyTupleArgPrinter& args) { + os << " " << args.name << "\t= ("; + auto size = Py_SIZE(args.obj); + auto arr = PySequence_Fast_ITEMS(args.obj); + for (int i=0; i(key) << "=" << + _PyType_Name(Py_TYPE(value)) << ", "; + } + return os << "},\n"; +} + +std::ostream& operator<<(std::ostream& os, const PyFastCallArgPrinter& args) { + os << " args\t= ("; + auto size = args.n; + auto arr = args.obj; + for (int i=0; i. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include "common.h" + +namespace jittor { + +struct PyArgPrinter { + PyObject* obj; + const char* name; +}; +std::ostream& operator<<(std::ostream& os, const PyArgPrinter& arg); + +struct PyTupleArgPrinter { + PyObject* obj; + const char* name; +}; +std::ostream& operator<<(std::ostream& os, const PyTupleArgPrinter& args); + +struct PyKwArgPrinter { + PyObject* obj; +}; +std::ostream& operator<<(std::ostream& os, const PyKwArgPrinter& args); + +struct PyFastCallArgPrinter { + PyObject** obj; + int64 n; + PyObject* kw; +}; +std::ostream& operator<<(std::ostream& os, const PyFastCallArgPrinter& args); + +} diff --git a/python/jittor/src/pyjt/py_array_op.cc b/python/jittor/src/pyjt/py_array_op.cc new file mode 100644 index 00000000..3638834f --- /dev/null +++ b/python/jittor/src/pyjt/py_array_op.cc @@ -0,0 +1,218 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#ifdef HAS_CUDA +#include +#include "helper_cuda.h" +#include "mem/allocator.h" +#include "mem/allocator/cuda_dual_allocator.h" +#include "event_queue.h" +#endif +#include "mem/allocator/foreign_allocator.h" +#include +#include "pyjt/py_obj_holder.h" +#include "pyjt/py_converter.h" +#include "pyjt/numpy.h" +#include "ops/array_op.h" +#include "var.h" +#include "ops/op_register.h" +#include "var_holder.h" +#include "mem/swap.h" + +namespace jittor { + + +DEFINE_FLAG(int, auto_convert_64_to_32, 1, "auto convert 64bit numpy array into 32bit jittor array"); +DEFINE_FLAG(uint8, reuse_array, 0, "try reuse np.array memory into jt.array"); +DECLARE_FLAG(int, use_cuda); +DECLARE_FLAG(int, use_cuda_host_allocator); + + +static auto make_array = get_op_info("array") + .get_constructor(); + +PyObject* make_pyjt_array(const vector& shape, const string& dtype, const void* data) { + // return nullptr; + auto vh = new VarHolder(make_array(data, shape, dtype)); + return to_py_object(vh); +} + +void get_pyjt_array(PyObject* obj, vector& shape, string& dtype, void*& data) { + CHECK(Py_TYPE(obj) == &PyjtVarHolder.ht_type) << "Not a jittor array" << Py_TYPE(obj); + auto vh = GET_RAW_PTR(VarHolder, obj); + if (!vh->var->mem_ptr) + vh->sync(); + ASSERT(vh->var->mem_ptr); + shape = vh->shape().to_vector(); + dtype = vh->dtype().to_cstring(); + data = vh->var->mem_ptr; +} + +VarHolder* reuse_np_array(PyObject* obj) { + CHECK(Py_TYPE(obj) == PyArray_Type); + auto arr = (PyArray_Proxy*)obj; + NanoVector shape; + NanoString dtype; + if (arr->nd) + shape = NanoVector::make(arr->dimensions, arr->nd); + else + shape.push_back(1); + dtype = get_type_str(arr); + CHECK(is_c_style(arr)); + + VarPtr vp(shape, dtype); + vp->finish_pending_liveness(); + vp->mem_ptr = arr->data; + + Allocation allocation; + make_foreign_allocation(allocation, + vp->mem_ptr, vp->size, + [obj]() { + Py_DECREF(obj); + }); + Py_INCREF(obj); + vp->allocator = allocation.allocator; + vp->allocation = allocation.allocation; + allocation.ptr = nullptr; + allocation.allocator = nullptr; + allocation.allocation = 0; + + return new VarHolder(std::move(vp)); +} + +ArrayOp::ArrayOp(PyObject* obj) { + ArrayArgs args; + PyObjHolder holder; + args.ptr = nullptr; + allocation.ptr = nullptr; + void* ori_ptr = nullptr; + if (PyFloat_CheckExact(obj)) { + tmp_data.f32 = PyFloat_AS_DOUBLE(obj); + args = {&tmp_data, 1, ns_float32}; + } else + if (PyLong_CheckExact(obj)) { + tmp_data.i32 = PyLong_AsLong(obj); + args = {&tmp_data, 1, ns_int32}; + } else + if (PyBool_Check(obj)) { + tmp_data.i8 = obj == Py_True; + args = {&tmp_data, 1, ns_bool}; + } else + if (Py_TYPE(obj) == &PyjtVarHolder.ht_type) { + auto ptr = GET_RAW_PTR(VarHolder, obj); + args = move(fetch_sync({ptr}).at(0)); + } else + if (Py_TYPE(obj) == PyArray_Type || + PyList_CheckExact(obj) || PyTuple_CheckExact(obj) || + PyObject_TypeCheck(obj, PyNumberArrType_Type) + ) { + if (Py_TYPE(obj) != PyArray_Type) { + holder.assign(PyArray_FROM_O(obj)); + obj = holder.obj; + } + auto arr = (PyArray_Proxy*)obj; + if (arr->nd) + args.shape = NanoVector::make(arr->dimensions, arr->nd); + else + args.shape.push_back(1); + args.dtype = get_type_str(arr); + if (is_c_style(arr)) { + args.ptr = arr->data; + ori_ptr = arr->data; + } + + // use 32-bit by default + if ((auto_convert_64_to_32 || holder.obj) + && args.dtype.dsize() == 8 && args.ptr) { + auto size = PyArray_Size(arr); + args.buffer.reset(new char[size]); + auto pre_data = args.ptr; + args.ptr = args.buffer.get(); + auto num = size/8; + if (args.dtype.is_int()) { + auto* __restrict__ i64 = (int64*)pre_data; + auto* __restrict__ i32 = (int32*)args.ptr; + for (int i=0; i> Py_TYPE(obj)->tp_name >> "> not support for jittor array"; + } + NanoVector shape = args.shape; + output = create_output(shape, args.dtype); + int64 size = output->size; + if (shape.size() == 1 && shape[0] == 1) { + output->flags.set(NodeFlags::_force_fuse); + output->flags.set(NodeFlags::_is_scalar); + set_type(OpType::element); + } + void* host_ptr = nullptr; + #ifdef HAS_CUDA + if (use_cuda && !save_mem && !use_cuda_host_allocator) { + flags.set(NodeFlags::_cpu, 0); + flags.set(NodeFlags::_cuda, 1); + if (!output->flags.get(NodeFlags::_force_fuse)) { + // free prev allocation first + event_queue.flush(); + // alloc new allocation + auto size = output->size; + new (&allocation) Allocation(&cuda_dual_allocator, size); + host_ptr = cuda_dual_allocator.get_dual_allocation(allocation.allocation).host_ptr; + } + } + #endif + if (!host_ptr) { + new (&allocation) Allocation(cpu_allocator, output->size); + host_ptr = allocation.ptr; + } + + if (args.ptr) { + // if has ptr, copy from ptr + if (reuse_array && !use_cuda && args.ptr == ori_ptr) { + allocation.~Allocation(); + make_foreign_allocation(allocation, + ori_ptr, output->size, + [obj]() { + Py_DECREF(obj); + }); + Py_INCREF(obj); + } else { + std::memcpy(host_ptr, args.ptr, size); + } + } else { + // this is non-continue numpy array +#if defined(__linux__) || defined(_WIN32) + STACK_ALLOC(int64_t, dims, args.shape.size()); +#elif defined(__APPLE__) + long dims[args.shape.size()]; +#endif + for (int i=0; i. +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "pyjt/py_obj_holder.h" +#include "pyjt/py_converter.h" +#include "pyjt/py_caller.h" + +namespace jittor { + +string py_caller(const string& mod_func, const vector& args, const map& kw) { + PyObjHolder mod(PyImport_ImportModule("jittor")); + PyObjHolder func(PyObject_GetAttrString(mod.obj, "python_pass_wrapper")); + PyObjHolder py_name(to_py_object(mod_func)); + PyObjHolder py_args(to_py_tuple(args)); + PyObjHolder py_kw(to_py_object(kw)); + PyObjHolder ret(PyObject_CallFunctionObjArgs(func.obj, py_name.obj, py_args.obj, py_kw.obj, nullptr)); + CHECK(is_type(ret.obj)) << "expect return type string."; + return from_py_object(ret.obj); +} + +} diff --git a/python/jittor/src/pyjt/py_caller.h b/python/jittor/src/pyjt/py_caller.h new file mode 100644 index 00000000..bcadc4e8 --- /dev/null +++ b/python/jittor/src/pyjt/py_caller.h @@ -0,0 +1,16 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +namespace jittor { + +string py_caller(const string& mod_func, const vector& args, const map& kw); + +} diff --git a/python/jittor/src/pyjt/py_converter.h b/python/jittor/src/pyjt/py_converter.h new file mode 100644 index 00000000..5101b411 --- /dev/null +++ b/python/jittor/src/pyjt/py_converter.h @@ -0,0 +1,919 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang . +// Guowei Yang <471184555@qq.com> +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "pyjt/py_obj_holder.h" +#include "pyjt/numpy.h" +#include "common.h" +#include "misc/hash.h" +#include "misc/nano_string.h" +#include "misc/fast_shared_ptr.h" +#include "profiler/simple_profiler.h" +#ifdef IS_CUDA +#include "misc/cuda_flags.h" +#endif + +namespace jittor { + +template +struct vector_to_tuple { + typedef T value_type; + vector x; + vector_to_tuple(vector&& _) :x(move(_)) {} +}; + +#define DEF_IS(check_type, return_type) \ + template \ + typename std::enable_if::value, return_type>::type + +#define GET_PY_NONE(code) ((code), Py_INCREF(Py_None), Py_None) + +// string +DEF_IS(string, bool) is_type(PyObject* obj) { + return PyUnicode_CheckExact(obj); +} + +DEF_IS(string, PyObject*) to_py_object(const string& a) { + return PyUnicode_FromStringAndSize(a.c_str(), a.size()); +} + +DEF_IS(string, string) from_py_object(PyObject* obj) { + Py_ssize_t size; + const char* s = PyUnicode_AsUTF8AndSize(obj, &size); + CHECK(s); + return string(s, size); +} + +// bool +DEF_IS(bool, bool) is_type(PyObject* obj) { + return PyBool_Check(obj) || PyLong_CheckExact(obj); +} + +DEF_IS(bool, PyObject*) to_py_object(const T& a) { + if (a) Py_RETURN_TRUE; + Py_RETURN_FALSE; +} + +DEF_IS(bool, T) from_py_object(PyObject* obj) { + if (PyBool_Check(obj)) + return obj == Py_True; + return PyLong_AsLong(obj); +} + +// int +DEF_IS(int, bool) is_type(PyObject* obj) { + return PyLong_CheckExact(obj); +} + +DEF_IS(int, PyObject*) to_py_object(const T& a) { + return PyLong_FromLong(a); +} + +DEF_IS(int, T) from_py_object(PyObject* obj) { + return PyLong_AsLong(obj); +} + +// size_t +DEF_IS(size_t, bool) is_type(PyObject* obj) { + return PyLong_CheckExact(obj); +} + +DEF_IS(size_t, PyObject*) to_py_object(const T& a) { + return PyLong_FromUnsignedLongLong(a); +} + +DEF_IS(size_t, T) from_py_object(PyObject* obj) { + return PyLong_AsUnsignedLongLong(obj); +} + +// int64 +DEF_IS(int64, bool) is_type(PyObject* obj) { + return PyLong_CheckExact(obj); +} + +DEF_IS(int64, PyObject*) to_py_object(const T& a) { + return PyLong_FromLongLong(a); +} + +DEF_IS(int64, T) from_py_object(PyObject* obj) { + return PyLong_AsLongLong(obj); +} + +#ifdef __linux__ +// int64_t +DEF_IS(int64_t, bool) is_type(PyObject* obj) { + return PyLong_CheckExact(obj); +} + +DEF_IS(int64_t, PyObject*) to_py_object(const T& a) { + return PyLong_FromLongLong(a); +} + +DEF_IS(int64_t, T) from_py_object(PyObject* obj) { + return PyLong_AsLongLong(obj); +} +#endif + +#ifdef __APPLE__ +// uint64 +DEF_IS(uint64, bool) is_type(PyObject* obj) { + return PyLong_CheckExact(obj); +} + +DEF_IS(uint64, PyObject*) to_py_object(const T& a) { + return PyLong_FromUnsignedLongLong(a); +} + +DEF_IS(uint64, T) from_py_object(PyObject* obj) { + return PyLong_AsUnsignedLongLong(obj); +} +#endif + +// float64 +DEF_IS(float64, bool) is_type(PyObject* obj) { + return PyFloat_CheckExact(obj) || PyLong_CheckExact(obj); +} + +DEF_IS(float64, PyObject*) to_py_object(const T& a) { + return PyFloat_FromDouble(a); +} + +DEF_IS(float64, T) from_py_object(PyObject* obj) { + if (PyFloat_CheckExact(obj)) + return PyFloat_AS_DOUBLE(obj); + return PyLong_AsDouble(obj); +} + +struct Slice; +// Slice +DEF_IS(Slice, bool) is_type(PyObject* obj) { + return PySlice_Check(obj); +} +DEF_IS(Slice, T) from_py_object(PyObject* obj) { + Py_ssize_t start, stop, step; + auto slice = (PySliceObject*)obj; + + PySlice_Unpack(obj, &start, &stop, &step); + return {start, stop, step, + (slice->start == Py_None) | + ((slice->stop == Py_None) << 1) | + ((slice->step == Py_None) << 2)}; +} + +#define GET_RAW_PTR(T, obj) ((T*)(((char*)obj) + sizeof(PyObject))) +#define GET_OBJ_FROM_RAW_PTR(obj) ((PyObject*)(((char*)obj) - sizeof(PyObject))) +#define GET_OBJ_SIZE(T) (sizeof(PyObject)+sizeof(T)) + +// DumpGraphs +struct DumpGraphs; +EXTERN_LIB PyTypeObject PyjtDumpGraphs; +DEF_IS(DumpGraphs, bool) is_type(PyObject* obj) { + return Py_TYPE(obj) == &PyjtDumpGraphs; +} + + +DEF_IS(DumpGraphs, PyObject*) to_py_object(T&& a) { + PyObjHolder obj(_PyObject_New(&PyjtDumpGraphs)); + auto ptr = GET_RAW_PTR(T, obj.obj); + new (ptr) T(); + ptr->hold_vars = std::move(a.hold_vars); + ptr->nodes_info = std::move(a.nodes_info); + ptr->inputs = std::move(a.inputs); + ptr->outputs = std::move(a.outputs); + return obj.release(); +} + +DEF_IS(DumpGraphs, const T&) from_py_object(PyObject* obj) { + return GET_RAW_PTR(T, obj); +} + +// MemInfo +struct MemInfo; +EXTERN_LIB PyTypeObject PyjtMemInfo; +DEF_IS(MemInfo, bool) is_type(PyObject* obj) { + return Py_TYPE(obj) == &PyjtMemInfo; +} + + +DEF_IS(MemInfo, PyObject*) to_py_object(const T& a) { + PyObjHolder obj(_PyObject_New(&PyjtMemInfo)); + auto ptr = GET_RAW_PTR(T, obj.obj); + new (ptr) T(a); + return obj.release(); +} + +DEF_IS(MemInfo, const T&) from_py_object(PyObject* obj) { + return GET_RAW_PTR(T, obj); +} + +// MemInfo +struct ZipFile; +EXTERN_LIB PyTypeObject PyjtZipFile; +DEF_IS(ZipFile, bool) is_type(PyObject* obj) { + return Py_TYPE(obj) == &PyjtZipFile; +} + + +DEF_IS(ZipFile, PyObject*) to_py_object(const T& a) { + PyObjHolder obj(_PyObject_New(&PyjtZipFile)); + auto ptr = GET_RAW_PTR(T, obj.obj); + new (ptr) T(a); + return obj.release(); +} + +DEF_IS(ZipFile, const T&) from_py_object(PyObject* obj) { + return GET_RAW_PTR(T, obj); +} + + +// NanoString +struct NanoString; +EXTERN_LIB PyTypeObject PyjtNanoString; +DEF_IS(NanoString, bool) is_type(PyObject* obj) { + return Py_TYPE(obj) == &PyjtNanoString || + PyUnicode_CheckExact(obj) || + PyType_CheckExact(obj) || + // jt.float.__name__ + PyCallable_Check(obj) || + // numpy.dtype.type + PyObject_HasAttrString(obj, "type"); +} + +DEF_IS(NanoString, PyObject*) to_py_object(T a) { + PyObjHolder obj(_PyObject_New(&PyjtNanoString)); + auto ptr = GET_RAW_PTR(T, obj.obj); + new (ptr) T(a); + return obj.release(); +} + +DEF_IS(NanoString, T) from_py_object(PyObject* obj) { + if (Py_TYPE(obj) == &PyjtNanoString) + return *GET_RAW_PTR(T, obj); + if (PyUnicode_CheckExact(obj)) + return T(PyUnicode_AsUTF8(obj)); + // PyType + if (PyType_CheckExact(obj)) + return T(_PyType_Name((PyTypeObject *)obj)); + // jt.float.__name__ + if (PyCallable_Check(obj)) { + PyObjHolder t(PyObject_GetAttrString(obj, "__name__")); + return T(PyUnicode_AsUTF8(t.obj)); + } + PyObjHolder t(PyObject_GetAttrString(obj, "type")); + CHECK(PyType_CheckExact(t.obj)) << "Not a valid type:" << t.obj; + return T(_PyType_Name((PyTypeObject *)t.obj)); +} + +// NanoVector +struct NanoVector; +EXTERN_LIB PyTypeObject PyjtNanoVector; +DEF_IS(NanoVector, bool) is_type(PyObject* obj) { + return Py_TYPE(obj) == &PyjtNanoVector || + PyList_CheckExact(obj) || PyTuple_CheckExact(obj); +} +DEF_IS(NanoVector*, bool) is_type(PyObject* obj) { + return Py_TYPE(obj) == &PyjtNanoVector; +} + +DEF_IS(NanoVector, PyObject*) to_py_object(T a) { + PyObjHolder obj(_PyObject_New(&PyjtNanoVector)); + auto ptr = GET_RAW_PTR(T, obj.obj); + new (ptr) T(a); + return obj.release(); +} + +DEF_IS(NanoVector*, T) from_py_object(PyObject* obj) { + return GET_RAW_PTR(typename std::remove_pointer::type, obj); +} + +DEF_IS(NanoVector, T) from_py_object(PyObject* obj) { + if (Py_TYPE(obj) == &PyjtNanoVector) + return *GET_RAW_PTR(T, obj); + auto size = Py_SIZE(obj); + T a; + auto arr = PySequence_Fast_ITEMS(obj); + for (int64 i=0; i(oi)); + a.push_back_check_overflow(from_py_object(oi)); + } + return a; +} + +// ArrayArgs +struct ArrayArgs; +struct VarHolder; +vector fetch_sync(const vector& vh); +EXTERN_LIB PyHeapTypeObject PyjtVarHolder; +DEF_IS(ArrayArgs, bool) is_type(PyObject* obj) { + return + Py_TYPE(obj) == &PyjtVarHolder.ht_type || + Py_TYPE(obj) == PyArray_Type || + PyFloat_CheckExact(obj) || + PyLong_CheckExact(obj) || + PyBool_Check(obj) || + PyList_CheckExact(obj) || + PyObject_TypeCheck(obj, PyNumberArrType_Type); +} + +DEF_IS(ArrayArgs, PyObject*) to_py_object(const T& a) { +#if defined(__linux__) || defined(_WIN32) + STACK_ALLOC(int64_t, dims, a.shape.size()); +#elif defined(__APPLE__) + long dims[a.shape.size()]; +#endif + for (int i=0; idata; + int64 num = size/4; + for (int64 i=0; idata, (void*)a.ptr, size); + } + return obj.release(); +} + +DEF_IS(ArrayArgs, T) from_py_object(PyObject* obj) { + if (PyFloat_CheckExact(obj)) { + tmp_data.f32 = PyFloat_AS_DOUBLE(obj); + return {&tmp_data, 1, ns_float32}; + } + if (PyLong_CheckExact(obj)) { + tmp_data.i32 = PyLong_AsLong(obj); + return {&tmp_data, 1, ns_int32}; + } + if (PyBool_Check(obj)) { + tmp_data.i8 = obj == Py_True; + return {&tmp_data, 1, ns_bool}; + } + if (Py_TYPE(obj) == &PyjtVarHolder.ht_type) { + auto ptr = GET_RAW_PTR(VarHolder, obj); + return move(fetch_sync({ptr}).at(0)); + } + // PyArray_Type + auto arr = (PyArray_Proxy*)obj; + if (Py_TYPE(obj) != PyArray_Type || !is_c_style(arr)) { + PyObjHolder holder( + Py_TYPE(obj) != PyArray_Type ? + PyArray_FROM_O(obj) : + PyArray_Copy(obj)); + auto arr = (PyArray_Proxy*)holder.obj; + int64 size = PyArray_Size(arr); + T args; + if (arr->nd) + args.shape = NanoVector::make(arr->dimensions, arr->nd); + else + args.shape.push_back(1); + args.dtype = get_type_str(arr); + args.buffer.reset(new char[size]); + args.ptr = (void*)args.buffer.get(); + memcpy((void*)args.buffer.get(), (void*)arr->data, size); + if (Py_TYPE(obj) != PyArray_Type && args.dtype.dsize()==8) { + // convert to 32bit + auto num = size/8; + if (args.dtype.is_int()) { + auto* __restrict__ i64 = (int64*)args.ptr; + auto* __restrict__ i32 = (int32*)args.ptr; + for (int i=0; idata; + if (arr->dimensions) + for (int i=0; ind; i++) + args.shape.push_back(arr->dimensions[i]); + else + args.shape = 1; + args.dtype = get_type_str(arr); + return args; +} + +// VarHolder +struct VarHolder; +EXTERN_LIB PyHeapTypeObject PyjtVarHolder; +namespace jit_op_maker { +EXTERN_LIB VarHolder* array_(ArrayArgs&&); +EXTERN_LIB VarHolder* array__(PyObject* obj); +} +DEF_IS(VarHolder*, bool) is_type(PyObject* obj) { + return Py_TYPE(obj) == &PyjtVarHolder.ht_type || + is_type(obj); +} + +DEF_IS(VarHolder*, PyObject*) to_py_object(T a) { + PyObjHolder obj(_PyObject_New(&PyjtVarHolder.ht_type)); + auto ptr = GET_RAW_PTR(T, obj.obj); + ((PyObject**)(((char*)obj.obj) + sizeof(PyObject) + sizeof(typename std::remove_pointer::type)))[0] = PyDict_New(); + // new attr_dict + // will move and delete a + new (ptr) typename std::remove_pointer::type (a); + return obj.release(); +} + + +DEF_IS(VarHolder*, T) from_py_object(PyObject* obj) { + CHECK(Py_TYPE(obj) == &PyjtVarHolder.ht_type); + return GET_RAW_PTR(VarHolder, obj); +} + +DEF_IS(VarHolder*, T) from_py_object(PyObject* obj, unique_ptr& holder) { + if (Py_TYPE(obj) == &PyjtVarHolder.ht_type) + return GET_RAW_PTR(VarHolder, obj); + holder.reset(jit_op_maker::array__(obj)); + return holder.get(); +} + +struct DataView; +DEF_IS(DataView, PyObject*) to_py_object(T a) { +#if defined(__linux__) || defined(_WIN32) + STACK_ALLOC(int64_t, dims, a.shape.size()); +#elif defined(__APPLE__) + long dims[a.shape.size()]; +#endif + for (int i=0; i struct is_##check_type : public std::false_type {}; \ + template \ + struct is_##check_type> : public std::true_type {}; + +#define DEF_IS_1(check_type, return_type) \ + template \ + typename std::enable_if::value, return_type>::type + + +#define CHECK_IS_2(check_type) \ + template struct is_##check_type : public std::false_type {}; \ + template \ + struct is_##check_type> : public std::true_type {}; + +#define DEF_IS_2(check_type, return_type) \ + template \ + typename std::enable_if::value, return_type>::type + +CHECK_IS_1(vector); +CHECK_IS_1(vector_to_tuple); + +CHECK_IS_2(map); +DEF_IS_2(map, bool) is_type(PyObject* obj); +DEF_IS_2(map, PyObject*) to_py_object(const T& a); + +DEF_IS_1(vector, bool) is_type(PyObject* obj) { + if (!(PyList_CheckExact(obj) || PyTuple_CheckExact(obj))) + return false; + auto size = Py_SIZE(obj); + if (!size) + return true; + auto arr = PySequence_Fast_ITEMS(obj); + return is_type(arr[0]); +} + +DEF_IS_1(vector, PyObject*) to_py_object(const T& a) { + PyObjHolder list(PyList_New(a.size())); + for (uint i=0; i(a[i]); + CHECK(o); + // PyList_SET_ITEM borrow ownership, we do not hold this + PyList_SET_ITEM(list.obj, i, o); + } + return list.release(); +} + +DEF_IS_1(vector, PyObject*) to_py_tuple(const T& a) { + PyObjHolder list(PyTuple_New(a.size())); + for (uint i=0; i(a[i]); + CHECK(o); + // PyTuple_SET_ITEM borrow ownership, we do not hold this + PyTuple_SET_ITEM(list.obj, i, o); + } + return list.release(); +} + +DEF_IS_1(vector_to_tuple, PyObject*) to_py_object(const T& a) { + PyObjHolder list(PyTuple_New(a.x.size())); + for (uint i=0; i(a.x[i]); + CHECK(o); + // PyTuple_SET_ITEM borrow ownership, we do not hold this + PyTuple_SET_ITEM(list.obj, i, o); + } + return list.release(); +} + +DEF_IS_1(vector, PyObject*) to_py_object(T&& a) { + PyObjHolder list(PyList_New(a.size())); + for (uint i=0; i(std::move(a[i])); + CHECK(o); + // PyList_SET_ITEM borrow ownership, we do not hold this + PyList_SET_ITEM(list.obj, i, o); + } + return list.release(); +} + +DEF_IS_1(vector, T) from_py_object(PyObject* obj) { + auto size = Py_SIZE(obj); + T a(size); + auto arr = PySequence_Fast_ITEMS(obj); + for (int64 i=0; i(oi)); + a[i] = from_py_object(oi); + } + return a; +} + +struct FetchFunc; + +DEF_IS(FetchFunc, bool) is_type(PyObject* obj) { + return PyCallable_Check(obj); +} + +DEF_IS(FetchFunc, T) from_py_object(PyObject* obj) { + // PyObject_Call + Py_INCREF(obj); + T func( + // callback + [obj](typename T::R* result) { + PyObjHolder arrays(to_py_tuple>(result->arrays)); + PyObjHolder ret(PyObject_Call(obj, arrays.obj, nullptr)); + }, + // deleter + [obj]() { Py_DECREF(obj); } + ); + return func; +} + +struct SimpleFunc; + +DEF_IS(SimpleFunc, bool) is_type(PyObject* obj) { + return PyCallable_Check(obj); +} + +DEF_IS(SimpleFunc, T) from_py_object(PyObject* obj) { + // PyObject_Call + Py_INCREF(obj); + T func( + // callback + [obj](int64 result) { + // check python version macro >= 3.9 + #if PY_VERSION_HEX >= 0x03090000 + PyObjHolder args(to_py_object(result)); + PyObjHolder ret(PyObject_CallOneArg(obj, args.obj)); + #else + LOGf << "Not supported python version"; + #endif + }, + // deleter + [obj]() { Py_DECREF(obj); } + ); + return func; +} + +CHECK_IS_2(unordered_map); + +DEF_IS_2(unordered_map, bool) is_type(PyObject* obj) { + return PyDict_CheckExact(obj); +} + +DEF_IS_2(unordered_map, PyObject*) to_py_object(const T& a) { + PyObjHolder dict(PyDict_New()); + for (const auto& kv : a) { + PyObjHolder key(to_py_object(kv.first)); + PyObjHolder value(to_py_object(kv.second)); + PyDict_SetItem(dict.obj, key.obj, value.obj); + } + return dict.release(); +} + +DEF_IS_2(unordered_map, T) from_py_object(PyObject* obj) { + auto size = Py_SIZE(obj); + T a; + a.reserve(size); + PyObject *key, *value; + Py_ssize_t pos = 0; + while (PyDict_Next(obj, &pos, &key, &value)) { + CHECK(is_type(key) + && is_type(value)); + a.emplace( + from_py_object(key), + from_py_object(value) + ); + } + return a; +} + +// copy from unordered_map +// CHECK_IS_2(map); + +DEF_IS_2(map, bool) is_type(PyObject* obj) { + return PyDict_CheckExact(obj); +} + +DEF_IS_2(map, PyObject*) to_py_object(const T& a) { + PyObjHolder dict(PyDict_New()); + for (const auto& kv : a) { + PyObjHolder key(to_py_object(kv.first)); + PyObjHolder value(to_py_object(kv.second)); + PyDict_SetItem(dict.obj, key.obj, value.obj); + } + return dict.release(); +} + +DEF_IS_2(map, T) from_py_object(PyObject* obj) { + T a; + PyObject *key, *value; + Py_ssize_t pos = 0; + while (PyDict_Next(obj, &pos, &key, &value)) { + CHECK(is_type(key) + && is_type(value)); + a.emplace( + from_py_object(key), + from_py_object(value) + ); + } + return a; +} + + +CHECK_IS_1(fast_shared_ptr); + +DEF_IS_1(fast_shared_ptr, bool) is_type(PyObject* obj) { + return is_type(obj); +} + +DEF_IS_1(fast_shared_ptr, PyObject*) to_py_object(const T& a) { + if (a) + return to_py_object(a.data()); + return to_py_object(a); +} + +DEF_IS_1(fast_shared_ptr, T) from_py_object(PyObject* obj) { + return from_py_object(obj); +} + +CHECK_IS_1(Maybe); + +DEF_IS_1(Maybe, bool) is_type(PyObject* obj) { + return obj == Py_None || + is_type(obj); +} + +DEF_IS_1(Maybe, PyObject*) to_py_object(T a) { + if (a) + return to_py_object(a.ptr); + Py_INCREF(Py_None); + return Py_None; +} + +DEF_IS_1(Maybe, T) from_py_object(PyObject* obj) { + if (obj == Py_None) return T(); + return T(from_py_object(obj)); +} + +DEF_IS(NumpyFunc, T) from_py_object(PyObject* obj) { + // PyObject_Call + Py_INCREF(obj); + T func( + // callback + [obj](typename T::R* result) { + // import numpy + string npstr="numpy"; + #ifdef IS_CUDA + if (use_cuda) npstr="cupy"; + #endif + + PyObjHolder np(PyImport_ImportModule(npstr.data())); + // data = {} + PyObjHolder data(to_py_object(result->varrays)); + PyObjHolder data2(to_py_object(result->ints)); + PyObjHolder data3(to_py_object(result->arrays)); + PyDict_Update(data.obj, data2.obj); + PyDict_Update(data.obj, data3.obj); + + // args = [] + PyObjHolder args(PyTuple_New(2)); + PyTuple_SET_ITEM(args.obj, 0, np.release()); + PyTuple_SET_ITEM(args.obj, 1, data.release()); + + #ifdef IS_CUDA + if (npstr=="cupy") { + PyObjHolder jt(PyImport_ImportModule("jittor")); + PyObjHolder pFunc(PyObject_GetAttrString(jt.obj,"numpy2cupy")); + PyObjHolder ret1(PyObject_Call(pFunc.obj, args.obj, nullptr)); + } + #endif + + PyObjHolder ret2(PyObject_Call(obj, args.obj, nullptr)); + }, + // deleter + [obj]() { Py_DECREF(obj); }, + // inc_ref + [obj]() { Py_INCREF(obj); } + ); + return func; +} + + +struct GradCallback; + +DEF_IS(GradCallback, bool) is_type(PyObject* obj) { + return PyCallable_Check(obj); +} + +DEF_IS(GradCallback, T) from_py_object(PyObject* obj) { + // PyObject_Call + Py_INCREF(obj); + T func( + // callback + [obj](int n_o, typename T::Var** douts, int n_i, typename T::VarPtr* dins) { + PyObjHolder list(PyTuple_New(n_o)); + for (int i=0; itp_name<<") is not jittor variable"; + auto vh = from_py_object(obj); + dins[i] = vh->var; + } + }; + if (!is_seq) { + CHECKop(n_i,==,1) << n_i >> " returned grad required, but 1 given."; + check(0, ret.obj); + } else { + auto size = Py_SIZE(ret.obj); + CHECKop(n_i,==,size) << n_i >> " returned grad required, but " >> size >> " given."; + auto arr = PySequence_Fast_ITEMS(ret.obj); + for (int i=0; i(obj); +} + +template +void load_var_slice(PyObject* obj, T* var_slice, vector>& holders) { + if (PyLong_CheckExact(obj)) { + var_slice->set_int(PyLong_AsLong(obj)); + } else + if (PySlice_Check(obj)) { + var_slice->slice = from_py_objectslice)>(obj); + } else + if (Py_TYPE(obj) == &PyEllipsis_Type) { + var_slice->set_ellipsis(); + } else + if (PyUnicode_CheckExact(obj)) { + var_slice->set_str(from_py_object(obj)); + } else + if (obj == Py_None) { + var_slice->set_none(); + } else + if (PyObject_TypeCheck(obj, PyNumberArrType_Type)) { + PyArrayDescr_Proxy array_descr; + array_descr.type_num = 5; // 5: int32 + int value; + PyArray_CastScalarToCtype(obj, &value, &array_descr); + var_slice->set_int(value); + } else { + holders.emplace_back(); + auto* vh = from_py_object(obj, holders.back()); + auto vv = (decltype(var_slice->var)*)vh; + CHECK(vv[0]->dtype() != ns_bool) << "Please convert bool slice into jt.array, example:\n" + "a[[True,False,False]] ---> a[jt.array([True,False,False])"; + var_slice->set_var(vv[0]); + } +} + +DEF_IS(VarSlices, T) from_py_object(PyObject* obj, vector>& holders) { + if (PyTuple_CheckExact(obj)) { + auto size = Py_SIZE(obj); + T vs(size); + auto arr = PySequence_Fast_ITEMS(obj); + for (int i=0; i. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include "common.h" + +namespace jittor { + +struct PyObjHolder { + PyObject* obj; + inline PyObjHolder() : obj(nullptr) { + } + inline void assign(PyObject* obj) { + if (!obj) { + LOGf << "Python error occur"; + } + this->obj = obj; + } + inline PyObjHolder(PyObject* obj) : obj(obj) { + if (!obj) { + LOGf << "Python error occur"; + } + } + inline void assign(PyObject* obj, const char* err_msg) { + if (!obj) { + LOGf << err_msg; + } + this->obj = obj; + } + inline PyObjHolder(PyObject* obj, const char* err_msg) : obj(obj) { + if (!obj) { + LOGf << err_msg; + } + } + inline ~PyObjHolder() { + if (obj) Py_DECREF(obj); + } + inline PyObject* release() { + auto tmp = obj; + obj = nullptr; + return tmp; + } +}; + + +inline Log& operator<<(Log& os, PyObject* objp) { + PyObjHolder repr_obj(PyObject_Repr(objp)); + + if (PyUnicode_CheckExact(repr_obj.obj)) { + return os << Py_TYPE(objp)->tp_name << + PyUnicode_AsUTF8(repr_obj.obj); + } else { + return os << "unknown(" >> (void*)objp >> ")"; + } +} + +} + +#define PYJT_MODULE_INIT(name) \ +PyMODINIT_FUNC PyInit_##name() { \ + PyObject *m; \ + try { \ + PyModuleDef *def = new PyModuleDef(); \ + memset(def, 0, sizeof(PyModuleDef)); \ + def->m_name = #name; \ + def->m_doc = ""; \ + def->m_size = -1; \ + Py_INCREF(def); \ + jittor::PyObjHolder holder(m = PyModule_Create(def)); \ + init_module(def, m); \ + holder.release(); \ + } catch(const std::exception& e) { \ + PyErr_SetString(PyExc_RuntimeError, e.what()); \ + return nullptr; \ + } \ + return m; \ +} + diff --git a/python/jittor/src/pyjt/py_ring_buffer.cc b/python/jittor/src/pyjt/py_ring_buffer.cc new file mode 100644 index 00000000..e7354788 --- /dev/null +++ b/python/jittor/src/pyjt/py_ring_buffer.cc @@ -0,0 +1,245 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "pyjt/py_ring_buffer.h" +#include "pyjt/py_obj_holder.h" +#include "pyjt/py_converter.h" +#include "ops/array_op.h" +#include "var_holder.h" + +namespace jittor { + +static void push_py_object_pickle(RingBuffer* rb, PyObject* obj, uint64& __restrict__ offset) { + PyObjHolder pickle(PyImport_ImportModule("pickle")); + PyObjHolder dumps(PyObject_GetAttrString(pickle.obj, "dumps")); + PyObjHolder proto(PyObject_GetAttrString(pickle.obj, "HIGHEST_PROTOCOL")); + rb->push_t(6, offset); + PyObjHolder ret(PyObject_CallFunctionObjArgs(dumps.obj, obj, proto.obj, nullptr)); + obj = ret.obj; + Py_ssize_t size; + char* s; + ASSERT(0 == PyBytes_AsStringAndSize(ret.obj, &s, &size)); + rb->push_t(size, offset); + rb->push(size, offset); + // LOGir << string(rb->get_ptr(size, offset), size); + std::memcpy(rb->get_ptr(size, offset), s, size); + return; +} + +static PyObject* pop_py_object_pickle(RingBuffer* rb, uint64& __restrict__ offset) { + PyObjHolder pickle(PyImport_ImportModule("pickle")); + PyObjHolder loads(PyObject_GetAttrString(pickle.obj, "loads")); + + auto size = rb->pop_t(offset); + rb->pop(size, offset); + PyObjHolder s(PyBytes_FromStringAndSize(rb->get_ptr(size, offset), size)); + + PyObjHolder ret(PyObject_CallFunctionObjArgs(loads.obj, s.obj, nullptr)); + return ret.release(); +} + + +static void push_py_object(RingBuffer* rb, PyObject* obj, uint64& __restrict__ offset) { + if (PyLong_CheckExact(obj)) { + int64 x = PyLong_AsLongLong(obj); + rb->push_t(0, offset); + rb->push_t(x, offset); + return; + } + if (PyFloat_CheckExact(obj)) { + float64 x = PyFloat_AS_DOUBLE(obj); + rb->push_t(1, offset); + rb->push_t(x, offset); + return; + } + if (PyUnicode_CheckExact(obj)) { + Py_ssize_t size; + const char* s = PyUnicode_AsUTF8AndSize(obj, &size); + rb->push_t(2, offset); + rb->push_t(size, offset); + rb->push(size, offset); + std::memcpy(rb->get_ptr(size, offset), s, size); + return; + } + if (PyList_CheckExact(obj) || PyTuple_CheckExact(obj)) { + rb->push_t(3, offset); + auto size = Py_SIZE(obj); + auto arr = PySequence_Fast_ITEMS(obj); + rb->push_t(size, offset); + for (int64 i=0; ipush_t(4, offset); + auto size = Py_SIZE(obj); + rb->push_t(size, offset); + PyObject *key, *value; + Py_ssize_t pos = 0; + while (PyDict_Next(obj, &pos, &key, &value)) { + push_py_object(rb, key, offset); + push_py_object(rb, value, offset); + } + return; + } + if (Py_TYPE(obj) == &PyjtVarHolder.ht_type || + Py_TYPE(obj) == PyArray_Type) { + ArrayArgs args; + int64 size=0; + uint8 protocol = Py_TYPE(obj) == PyArray_Type ? 5 : 7; + rb->push_t(protocol, offset); + if (Py_TYPE(obj) == &PyjtVarHolder.ht_type) { + auto ptr = GET_RAW_PTR(VarHolder, obj); + args = move(fetch_sync({ptr}).at(0)); + size = ptr->var->size; + } else { + auto arr = (PyArray_Proxy*)obj; + if (arr->nd) + args.shape = NanoVector::make(arr->dimensions, arr->nd); + else + args.shape.push_back(1); + args.dtype = get_type_str(arr); + size = PyArray_Size(arr); + if (!is_c_style(arr)) { + rb->push_t(args.shape, offset); + rb->push_t(args.dtype, offset); + rb->push(size, offset); + args.ptr = rb->get_ptr(size, offset); +#if defined(__linux__) || defined(_WIN32) + STACK_ALLOC(int64_t, dims, args.shape.size()); +#elif defined(__APPLE__) + long dims[args.shape.size()]; +#endif + for (int i=0; idata; + } + } + rb->push_t(args.shape, offset); + rb->push_t(args.dtype, offset); + rb->push(size, offset); + std::memcpy(rb->get_ptr(size, offset), args.ptr, size); + return; + } + push_py_object_pickle(rb, obj, offset); +} + + +static PyObject* to_py_object3(ArrayArgs&& a) { + return to_py_object(jit_op_maker::array_(move(a))); +} + +static PyObject* pop_py_object(RingBuffer* rb, uint64& __restrict__ offset, bool keep_numpy_array) { + auto t = rb->pop_t(offset); + if (t==0) { + auto x = rb->pop_t(offset); + return PyLong_FromLongLong(x); + } + if (t==1) { + auto x = rb->pop_t(offset); + return PyFloat_FromDouble(x); + } + if (t==2) { + auto size = rb->pop_t(offset); + rb->pop(size, offset); + return PyUnicode_FromStringAndSize(rb->get_ptr(size, offset), size); + } + if (t==3) { + auto size = rb->pop_t(offset); + PyObjHolder list(PyList_New(size)); + for (uint i=0; ipop_t(offset); + PyObjHolder dict(PyDict_New()); + for (int64 i=0; ipop_t(offset); + args.dtype = rb->pop_t(offset); + int64 size = args.dtype.dsize(); + for (int i=0; ipop(size, offset); + args.ptr = rb->get_ptr(size, offset); + if (!keep_numpy_array || t==7) + // become jittor var + return to_py_object3(move(args)); + else + return to_py_object(args); + } + if (t==6) { + return pop_py_object_pickle(rb, offset); + } + if (t == 255) { + LOGf << "WorkerError:" << rb->pop_string(offset); + } else + LOGf << "unsupport type:" << (int)t; + return nullptr; +} + +void PyMultiprocessRingBuffer::push(PyObject* obj) { + auto offset = rb->r; + auto offset_bk = offset; + try { + push_py_object(rb, obj, offset); + } catch (const std::exception& e) { + offset = offset_bk; + rb->push_t(255, offset); + rb->push_string(string(e.what()), offset); + } + rb->commit_push(offset); +} + +PyObject* PyMultiprocessRingBuffer::pop() { + auto offset = rb->l; + auto obj = pop_py_object(rb, offset, _keep_numpy_array); + rb->commit_pop(offset); + return obj; +} + +PyMultiprocessRingBuffer::PyMultiprocessRingBuffer(uint64 size, uint64 buffer, bool init) { + this->buffer = buffer; + this->init = init; + if (buffer) { + auto mobj = (PyObject*)buffer; + auto buf = PyMemoryView_GET_BUFFER(mobj); + buffer = (uint64)buf->buf; + } + rb = RingBuffer::make_ring_buffer(size, 1, buffer, init); +} + +PyMultiprocessRingBuffer::~PyMultiprocessRingBuffer() { + RingBuffer::free_ring_buffer(rb, buffer, init); +} + +} diff --git a/python/jittor/src/pyjt/py_ring_buffer.h b/python/jittor/src/pyjt/py_ring_buffer.h new file mode 100644 index 00000000..f6b7b5d6 --- /dev/null +++ b/python/jittor/src/pyjt/py_ring_buffer.h @@ -0,0 +1,57 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include "misc/ring_buffer.h" + +namespace jittor { + +// @pyjt(RingBuffer) +struct PyMultiprocessRingBuffer { + RingBuffer* rb; + uint64 buffer; + bool _keep_numpy_array = false; + bool init; + // @pyjt(__init__) + PyMultiprocessRingBuffer(uint64 size, uint64 buffer=0, bool init=true); + // @pyjt(__dealloc__) + ~PyMultiprocessRingBuffer(); + // @pyjt(push,send) + void push(PyObject* obj); + // @pyjt(pop,recv) + PyObject* pop(); + // @pyjt(clear) + inline void clear() { rb->clear(); } + // @pyjt(keep_numpy_array) + inline void keep_numpy_array(bool keep) { _keep_numpy_array = keep; } + // @pyjt(stop) + inline void stop() { rb->stop(); } + // @pyjt(is_stop) + inline bool is_stop() { return rb->is_stop; } + + // @pyjt(total_pop) + inline uint64 total_pop() { return rb->l; } + // @pyjt(total_push) + inline uint64 total_push() { return rb->r; } + // @pyjt(__repr__) + inline string to_string() { + string s="Buffer(free="; + auto size = rb->size; + auto used = rb->r - rb->l; + s += S(100 - used*100.0/size); + s += "% size="; + s += S(size); + s += ")"; + return s; + } + + // @pyjt(__get__size) + inline uint64 size() { return rb->size; } +}; + + +} diff --git a/python/jittor/src/pyjt/pyjt_console.h b/python/jittor/src/pyjt/pyjt_console.h new file mode 100644 index 00000000..fb239492 --- /dev/null +++ b/python/jittor/src/pyjt/pyjt_console.h @@ -0,0 +1,500 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#ifdef _WIN32 +#include +#else +#include +#endif +#include +#include +#include +#include +#include +#include +#include +#include + +namespace jittor { + +typedef int8_t int8; +typedef int16_t int16; +typedef int int32; +typedef int64_t int64; +typedef uint8_t uint8; +typedef uint16_t uint16; +typedef uint32_t uint32; +typedef uint64_t uint64; +typedef float float32; +typedef double float64; +typedef uint32_t uint; + +using string = std::string; +using std::move; +template using vector = std::vector; +template using list = std::list; +template using set = std::set; +template using shared_ptr = std::shared_ptr; +template using unique_ptr = std::unique_ptr; +template using unordered_set = std::unordered_set; +template using pair = std::pair; +template using map = std::map; +template using unordered_map = std::unordered_map; + +#define JT_CHECK(cond) \ + if (!(cond)) throw std::runtime_error("JT_CHECK failed: " #cond " "); + +struct PyObjHolder { + +PyObject* obj; +inline PyObjHolder() : obj(nullptr) { +} +inline void assign(PyObject* obj) { + if (!obj) { + PyErr_Print(); + throw std::runtime_error("Python Error Occurred."); + } + this->obj = obj; +} +inline PyObjHolder(PyObject* obj) : obj(obj) { + if (!obj) { + PyErr_Print(); + throw std::runtime_error("Python Error Occurred."); + } +} +inline ~PyObjHolder() { + if (obj) Py_DECREF(obj); +} +inline PyObject* release() { + auto tmp = obj; + obj = nullptr; + return tmp; +} + +inline void free() { + if (obj) Py_DECREF(obj); + obj = nullptr; +} + +}; + +inline std::ostream& operator<<(std::ostream& os, PyObjHolder& objp) { + PyObjHolder repr_obj(PyObject_Repr(objp.obj)); + + if (PyUnicode_CheckExact(repr_obj.obj)) { + return os << Py_TYPE(objp.obj)->tp_name << ' ' << + PyUnicode_AsUTF8(repr_obj.obj); + } else { + return os << "unknown(" << (void*)objp.obj << ")"; + } +} + + +#define DEF_IS(check_type, return_type) \ + template \ + typename std::enable_if::value, return_type>::type + +#define GET_PY_NONE(code) ((code), Py_INCREF(Py_None), Py_None) + +// string +DEF_IS(string, bool) is_type(PyObject* obj) { + return PyUnicode_CheckExact(obj); +} + +DEF_IS(string, PyObject*) to_py_object(const string& a) { + return PyUnicode_FromStringAndSize(a.c_str(), a.size()); +} + +DEF_IS(string, string) from_py_object(PyObject* obj) { + Py_ssize_t size; + const char* s = PyUnicode_AsUTF8AndSize(obj, &size); + JT_CHECK(s); + return string(s, size); +} + + +// size_t +DEF_IS(size_t, bool) is_type(PyObject* obj) { + return PyLong_CheckExact(obj); +} + +DEF_IS(size_t, PyObject*) to_py_object(const T& a) { + return PyLong_FromUnsignedLongLong(a); +} + +DEF_IS(size_t, T) from_py_object(PyObject* obj) { + return PyLong_AsUnsignedLongLong(obj); +} + +DEF_IS(uint32, bool) is_type(PyObject* obj) { + return PyLong_CheckExact(obj); +} + +DEF_IS(uint32, PyObject*) to_py_object(const T& a) { + return PyLong_FromUnsignedLong(a); +} + +DEF_IS(uint32, T) from_py_object(PyObject* obj) { + return PyLong_AsUnsignedLong(obj); +} + +// int64 +DEF_IS(int64, bool) is_type(PyObject* obj) { + return PyLong_CheckExact(obj); +} + +DEF_IS(int64, PyObject*) to_py_object(const T& a) { + return PyLong_FromLongLong(a); +} + +DEF_IS(int64, T) from_py_object(PyObject* obj) { + return PyLong_AsLongLong(obj); +} +DEF_IS(int32, bool) is_type(PyObject* obj) { + return PyLong_CheckExact(obj); +} + +DEF_IS(int32, PyObject*) to_py_object(const T& a) { + return PyLong_FromLong(a); +} + +DEF_IS(int32, T) from_py_object(PyObject* obj) { + return PyLong_AsLong(obj); +} + +// float64 +DEF_IS(float64, bool) is_type(PyObject* obj) { + return PyFloat_CheckExact(obj) || PyLong_CheckExact(obj); +} + +DEF_IS(float64, PyObject*) to_py_object(const T& a) { + return PyFloat_FromDouble(a); +} + +DEF_IS(float64, T) from_py_object(PyObject* obj) { + if (PyFloat_CheckExact(obj)) + return PyFloat_AS_DOUBLE(obj); + return PyLong_AsDouble(obj); +} +DEF_IS(float32, bool) is_type(PyObject* obj) { + return PyFloat_CheckExact(obj) || PyLong_CheckExact(obj); +} + +DEF_IS(float32, PyObject*) to_py_object(const T& a) { + return PyFloat_FromFloat(a); +} + +DEF_IS(float32, T) from_py_object(PyObject* obj) { + if (PyFloat_CheckExact(obj)) + return PyFloat_AS_DOUBLE(obj); + return PyFloat_AS_DOUBLE(obj); +} + + +#define CHECK_IS_1(check_type) \ + template struct is_##check_type : public std::false_type {}; \ + template \ + struct is_##check_type> : public std::true_type {}; + +#define DEF_IS_1(check_type, return_type) \ + template \ + typename std::enable_if::value, return_type>::type + +CHECK_IS_1(vector); + +DEF_IS_1(vector, bool) is_type(PyObject* obj) { + if (!(PyList_CheckExact(obj) || PyTuple_CheckExact(obj))) + return false; + auto size = Py_SIZE(obj); + if (!size) + return true; + auto arr = PySequence_Fast_ITEMS(obj); + return is_type(arr[0]); +} + +DEF_IS_1(vector, PyObject*) to_py_object(const T& a) { + PyObjHolder list(PyList_New(a.size())); + for (uint i=0; i(a[i]); + JT_CHECK(o); + // PyList_SET_ITEM borrow ownership, we do not hold this + PyList_SET_ITEM(list.obj, i, o); + } + return list.release(); +} + +DEF_IS_1(vector, PyObject*) to_py_tuple(const T& a) { + PyObjHolder list(PyTuple_New(a.size())); + for (uint i=0; i(a[i]); + JT_CHECK(o); + // PyTuple_SET_ITEM borrow ownership, we do not hold this + PyTuple_SET_ITEM(list.obj, i, o); + } + return list.release(); +} + +DEF_IS_1(vector, PyObject*) to_py_object(T&& a) { + PyObjHolder list(PyList_New(a.size())); + for (uint i=0; i(std::move(a[i])); + JT_CHECK(o); + // PyList_SET_ITEM borrow ownership, we do not hold this + PyList_SET_ITEM(list.obj, i, o); + } + return list.release(); +} + +DEF_IS_1(vector, T) from_py_object(PyObject* obj) { + auto size = Py_SIZE(obj); + T a(size); + auto arr = PySequence_Fast_ITEMS(obj); + for (int64 i=0; i(oi)); + a[i] = from_py_object(oi); + } + return a; +} + + +#define CHECK_IS_2(check_type) \ + template struct is_##check_type : public std::false_type {}; \ + template \ + struct is_##check_type> : public std::true_type {}; + +#define DEF_IS_2(check_type, return_type) \ + template \ + typename std::enable_if::value, return_type>::type + +CHECK_IS_2(unordered_map); + +DEF_IS_2(unordered_map, bool) is_type(PyObject* obj) { + return PyDict_CheckExact(obj); +} + +DEF_IS_2(unordered_map, PyObject*) to_py_object(const T& a) { + PyObjHolder dict(PyDict_New()); + for (const auto& kv : a) { + PyObjHolder key(to_py_object(kv.first)); + PyObjHolder value(to_py_object(kv.second)); + PyDict_SetItem(dict.obj, key.obj, value.obj); + } + return dict.release(); +} + +DEF_IS_2(unordered_map, T) from_py_object(PyObject* obj) { + auto size = Py_SIZE(obj); + T a; + a.reserve(size); + PyObject *key, *value; + Py_ssize_t pos = 0; + while (PyDict_Next(obj, &pos, &key, &value)) { + JT_CHECK(is_type(key) + && is_type(value)); + a.emplace( + from_py_object(key), + from_py_object(value) + ); + } + return a; +} + +// copy from unordered_map +CHECK_IS_2(map); + +DEF_IS_2(map, bool) is_type(PyObject* obj) { + return PyDict_CheckExact(obj); +} + +DEF_IS_2(map, PyObject*) to_py_object(const T& a) { + PyObjHolder dict(PyDict_New()); + for (const auto& kv : a) { + PyObjHolder key(to_py_object(kv.first)); + PyObjHolder value(to_py_object(kv.second)); + PyDict_SetItem(dict.obj, key.obj, value.obj); + } + return dict.release(); +} + +DEF_IS_2(map, T) from_py_object(PyObject* obj) { + T a; + PyObject *key, *value; + Py_ssize_t pos = 0; + while (PyDict_Next(obj, &pos, &key, &value)) { + JT_CHECK(is_type(key) + && is_type(value)); + a.emplace( + from_py_object(key), + from_py_object(value) + ); + } + return a; +} + +template +struct array { + +typedef T _type; +static constexpr int _ndim = N; + +int64 shape[N]; +unique_ptr data; + +inline bool is_float() const { return std::is_floating_point::value; } +inline bool is_unsigned() const { return std::is_unsigned::value; } +inline int64 size() const { + int64 s=1; + for (auto x : shape) s *= x; + return s; +} +inline int64 nbyte() const { return size()*sizeof(T); } +inline string dtype() const { + return DTYPE(); +} +inline int ndim() const { return N; } + +inline static string DTYPE() { + string dtype(std::is_floating_point::value ? "float" : + std::is_unsigned::value ? "uint" : "int"); + if (sizeof(T)==1) dtype += "8"; else + if (sizeof(T)==2) dtype += "16"; else + if (sizeof(T)==4) dtype += "32"; else + if (sizeof(T)==8) dtype += "64"; else + throw std::runtime_error("Not support type"); + return dtype; +} + +inline array(const vector& shape) { + if (shape.size() != N) throw std::runtime_error("Dim not match"); + for (int i=0; ishape[i] = shape[i]; + data.reset(new T[size()]); +} + +inline array(const vector& shape, const T* data) : array(shape) { + memcpy(this->data.get(), data, nbyte()); +} + +inline array(const vector& shape, const vector& data) : array(shape, &data[0]) { +} + +template +inline int64 get_offset(int64 offset, Ti i, Targs... Fargs) { + if constexpr (I+1==N) + return offset*shape[I]+i; + else + return get_offset(offset*shape[I]+i, Fargs...); +} + +template +T& operator()(Targs... Fargs) { + return data[get_offset<0>(0, Fargs...)]; +} + +}; + +struct Console { + +PyObjHolder globals, locals; +PyObject* (*make_pyjt_array)(const vector& shape, const string& dtype, const void* data); +void (*get_pyjt_array)(PyObject* obj, vector& shape, string& dtype, void*& data); + +inline Console() { + Py_Initialize(); + globals.assign(PyDict_New()); + locals.assign(PyDict_New()); + + #if PY_VERSION_HEX < 0x03080000 + PyObjHolder builtins(PyImport_ImportModule("builtins")); + PyDict_SetItemString(globals.obj, "__builtins__", builtins.obj); + #endif + + run("import jittor as jt"); + #ifdef __APPLE__ + auto symbol_make_pyjt_array = "__ZN6jittor15make_pyjt_arrayERKNSt3__16vectorIxNS0_9allocatorIxEEEERKNS0_12basic_stringIcNS0_11char_traitsIcEENS2_IcEEEEPKv"; + auto symbol_gen_pyjt_array = "__ZN6jittor14get_pyjt_arrayEP7_objectRNSt3__16vectorIxNS2_9allocatorIxEEEERNS2_12basic_stringIcNS2_11char_traitsIcEENS4_IcEEEERPv"; + #else + auto symbol_make_pyjt_array = "_ZN6jittor15make_pyjt_arrayERKSt6vectorIxSaIxEERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEEPKv"; + auto symbol_gen_pyjt_array = "_ZN6jittor14get_pyjt_arrayEP7_objectRSt6vectorIxSaIxEERNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEERPv"; + #endif + make_pyjt_array = (PyObject* (*)(const vector& shape, const string& dtype, const void* data))dlsym(RTLD_DEFAULT, symbol_make_pyjt_array); + get_pyjt_array = (void (*)(PyObject* obj, vector& shape, string& dtype, void*& data))dlsym(RTLD_DEFAULT, symbol_gen_pyjt_array); + if (!make_pyjt_array || !get_pyjt_array) { + std::cerr << "get symbol failed." << std::endl; + exit(1); + } +} + +inline ~Console() { + globals.free(); + locals.free(); + Py_FinalizeEx(); +} + +inline void run(const char* src) { + PyObjHolder ret(PyRun_String(src, Py_file_input, globals.obj, nullptr)); +} + +inline void run(const string& src) { run(src.c_str()); } + +template +inline void set(const char* s, const T& data) { + PyObjHolder py_data(to_py_object(data)); + PyDict_SetItemString(globals.obj, s, py_data.obj); +} + +template +inline void set(const string& s, const T& data) { + set(s.c_str(), data); +} + +template +inline T get(const char* s) { + auto obj = PyDict_GetItemString(globals.obj, s); + if (!obj) obj = PyDict_GetItemString(globals.obj, s); + if (!obj) throw std::runtime_error(string("KeyError: ")+s); + if (!is_type(obj)) throw std::runtime_error(string("TypeError: key<")+s+"> is "+Py_TYPE(obj)->tp_name); + return from_py_object(obj); +}; + +template +inline T get(const string& s) { + return get(s.c_str()); +} + + + +template +inline void set_array(const string& s, const array& data) { + PyObjHolder obj(make_pyjt_array( + vector(data.shape, data.shape+N), + data.dtype(), + data.data.get())); + PyDict_SetItemString(globals.obj, s.c_str(), obj.obj); +} + +template +inline array get_array(const string& s) { + auto obj = PyDict_GetItemString(globals.obj, s.c_str()); + if (!obj) obj = PyDict_GetItemString(globals.obj, s.c_str()); + if (!obj) throw std::runtime_error(string("KeyError: ")+s); + vector shape; + string dtype; + void* data; + get_pyjt_array(obj, shape, dtype, data); + string dtype2 = array::DTYPE(); + if (dtype2 != dtype) + throw new std::runtime_error(string("array dtype not match: ")+dtype+"!="+dtype2); + if (shape.size() != N) + throw new std::runtime_error(string("array ndim not match: ")+std::to_string(shape.size())+"!="+std::to_string(N)); + return array(shape, (T*)data); +} + +}; + +} \ No newline at end of file diff --git a/python/jittor/src/test/test_expr.cc b/python/jittor/src/test/test_expr.cc new file mode 100644 index 00000000..a3bdde34 --- /dev/null +++ b/python/jittor/src/test/test_expr.cc @@ -0,0 +1,213 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "opt/expr.h" + +namespace jittor { + +using namespace expr; + +JIT_TEST(expr) { + auto check = [&](const string& src, const string& exp, int debug=1, int try_reduce_braces=0) { + LOGv << "test" << src << exp; + unique_ptr expr(new Expr(src)); + ASSERTop(expr->to_string(try_reduce_braces, debug),==,exp); + + string nexp = expr->to_string(try_reduce_braces); + expr.reset(new Expr(nexp)); + ASSERTop(expr->to_string(try_reduce_braces),==,nexp); + if (try_reduce_braces) return; + + string exp2 = expr->to_string(1); + expr.reset(new Expr(exp2)); + ASSERTop(expr->to_string(0),==,nexp) << "Test reduce braces failed" << "\n" + << src << "\n" << exp2; + }; + auto check_error = [&](const string& src) { + expect_error([&]() { + unique_ptr expr(new Expr(src)); + }); + }; + check("asd", "asd"); + check("a+b", "(/*f:is_binary_op,is_op,;s:+;c:2*/a+b)"); + check("a+b*c", "(/*f:is_binary_op,is_op,;s:+;c:2*/a+(/*f:is_binary_op,is_op,;s:*;c:2*/b*c))"); + check("a * (b+c )/d", "((a*(b+c))/d)", 0); + check("-a * b", "(-(a*b))", 0); + check("a+(b+c)", "(a+b+c)", 0); + check("a(b+c)", "(a((b+c)))", 0); + check("a[b+c]", "(a[(b+c)])", 0); + check("a{b+c}", "(a{(b+c)})", 0); + check_error("a{b+c)"); + check_error("a(b+c)a"); + check("a+x(b+c)", "(a+(x((b+c))))", 0); + check("a::x(b+c)", "((a::x)((b+c)))", 0); + check("a+b? c & d: x && y", "((a+b)?(c&d):(x&&y))", 0); + check("a=b?x:y", "(a=(b?x:y))", 0); + check_error("a=b?x,c:y"); + check("a=b?c:y,b=c;d+=p", "((a=(b?c:y)),((b=c);(d+=p)))", 0); + check("a,b+=c", "(a,(b+=c))", 0); + check("a(x,y,z)", "(a(x,y,z))", 0); + check("(a+b){x,y,z}+k", "(((a+b){x,y,z})+k)", 0); + check("++a_b", "(++a_b)", 0); + check("++a++", "((++a)++)", 0); + check("*a", "(*a)", 0); + check("a*", "(a*)", 0); + check("a***", "(((a*)*)*)", 0); + check_error("*"); + check("a***", "a***", 0, 1); + // this test can not pass + // check("***a", "***a", 0, 1); + check("a((x),(y,z))", "a(x,(y,z))", 0, 1); +} + +JIT_TEST(expr_bug) { + // unique_ptr expr(new Expr("op0_yp[op0_i]=((float32)(std::tanh((op0_xp[op0_i]))))")); + unique_ptr expr(new Expr("op0_yp[op0_i]=((float32)std::tanh((op0_xp[op0_i])))")); +} + +JIT_TEST(expr_get_tokens) { + auto check = [&]( + const string& src, + const vector>& tokens, + const vector& flags + ) { + vector> t; + vector f; + get_tokens(src, t, f); + ASSERTop(t,==,tokens); + ASSERTop(f,==,flags); + }; + check("a+b",{{0,1},{1,2},{2,3}},{0, _op, 0}); + check(" a + b ",{{1,2},{3,4},{5,6}},{0, _op, 0}); + check("'a'",{{0,3}},{_char}); + check("\"asdasd\"",{{0,8}},{_string}); + check("1 0x1a 0b1 1u 1ull", + {{0,1},{2,6},{7,10},{11,13},{14,18}}, + {_int,_int,_int,_int,_int}); + check("1. 1.0f 1e3", + {{0,2},{3,7},{8,11}}, + {_float,_float,_float}); + auto a = std::make_unique("0xaa"); + ASSERTop(a->as_int(),==,0xaa); + a = std::make_unique("0b11"); + ASSERT(a->as_int()==0b11); + a = std::make_unique("123"); + ASSERT(a->as_int()==123); + a = std::make_unique("1.5"); + ASSERT(a->as_float()==1.5); + a = std::make_unique("2."); + ASSERT(a->as_float()==2.); + a = std::make_unique("1e2"); + ASSERTop(a->as_float(),==,1e2); +} + +JIT_TEST(expr_simplify) { + auto check = [&](const string& a, const string& b) { + LOGv << "test" << a << b; + auto x = std::make_unique(a); + LOGv << *x; + x = x->simplify(); + ASSERTop(x->to_string(1),==,b); + }; + check("1+1","2"); + check("1+1*3+1.5","5.5"); + check("0?1+2:3+4","7"); + check("100/2*a", "50*a"); + check("100/2*a + 1/3.0*b", "50*a+0.333333*b"); + check("1+a+1+1+b+(1+c+1)+1+d+1+1","1+a+2+b+1+c+2+d+2"); + check("1*a*1", "a"); + check("a/1", "a"); + check("1/a", "1/a"); + check("a+0", "a"); + check("a*0", "0"); + // TODO: pass this test + // check("a+1-1", "a"); + + check("0+0+0", "0"); + check("(((0+(loop_cnt*1))*2)-0)","loop_cnt*2"); +} + +JIT_TEST(expr_expand) { + auto check = [&](const string& a, const string& b) { + LOGv << "test" << a << b; + auto x = std::make_unique(a); + x = expand(x.get()); + ASSERTop(x->to_string(1),==,b); + }; + check("-a", "(-1)*a"); + check("a-b", "a+(-1)*b"); + check("(a+b)*c", "a*c+b*c"); + check("(a-b)*c", "a*c+(-1)*b*c"); + check("(a-b)*(c-d)", "a*c+a*(-1)*d+(-1)*b*c+(-1)*b*(-1)*d"); + check("a&&b", "a&&b"); + check("!(a&&b)", "!a||!b"); + check("!(a&&b&&c&&d)", "!a||!b||!c||!d"); + check("!(a||b||c||d)", "!a&&!b&&!c&&!d"); + check("!!a", "a"); + check("!!!!a", "a"); + check("!!!a", "!a"); + check("!(!a&&b)", "a||!b"); + check("!(a>b && c<=d && e!=f)", "a<=b||c>d||e>=f&&e<=f"); + check("a@>b", "!a||b"); + check("a@=b&&a<=b"); + check("!(a!=b)", "a>=b&&a<=b"); + check("0<=i0 && i00 && m>0 \ + && (i0!=i1) @> (i0*m+j0 != i1*m+j0)", + "0>i0||i0>=n||0>i1||i1>=n||0>j1||j1>=m||0>j1||j1>=m||n<=0||m<=0||i0>=i1&&i0<=i1||i0*m+j0i1*m+j0"); + // vector> v{std::make_unique("asd"), std::make_unique("asxd")}; + // check("-a", "-1*a"); + // for i in n: + // for j in m + // write(i*m+j) + // 0<=i0 && i00 && m>0 + // && (i0i1) -> (i0*m+j0 < i1*m+j0 || i0*m+j0 > i1*m+j0) + + // for i in n: + // for j in m: + // (i*m+j)*a+b + // get_coeff([i,j], &coeff, &b) + // match(m*i+j+k, a*i+b*j+c, [a,b,c], [i,j], results) -> bool +} + +JIT_TEST(expr_match) { + auto check = [&](string src, string target, + vector solve_symbols, + vector exclude_symbols, + vector results, + bool ret = true + ) { + auto _src = make(src); + auto _target = make(target); + LOGv << "test" << src << target; + vector> _results; + bool _ret = match(_src.get(), _target.get(), solve_symbols, exclude_symbols, _results); + ASSERT(ret==_ret) << src << target; + if (ret) { + ASSERT(results.size()==_results.size()); + for (uint i=0; ito_string(1)); + } + }; + check("1", "a", {"a"}, {}, {"1"}); + check("1+x", "a", {"a"}, {}, {"1+x"}); + check("y", "x", {}, {}, {}, false); + check("x", "x", {}, {}, {}, true); + check("1-2+x", "a", {"a"}, {}, {"(-1)+x"}); + check("3*i+j-1", "stride*i+j+pad", {"stride", "pad"}, {"i", "j"}, {"3", "(-1)"}); + check("i*2+j*2", "(i+j)*2", {}, {}, {}); + check("i*2+j*2", "(i+j)*a", {"a"}, {}, {"2"}); + check("i*2+j*3", "(i+j)*a", {"a"}, {}, {"2"}, false); + check("3*i+j-1", "i*stride+pad+j", {"stride", "pad"}, {"i", "j"}, {"3", "(-1)"}); + check("1*i+j-1", "i*stride+pad+j", {"stride", "pad"}, {"i", "j"}, {"1", "(-1)"}); + check("1*i+j-1", "i*stride*stride+pad+j", {"stride", "pad"}, {"i", "j"}, {"1", "(-1)"}); + check("1*i+j", "i*stride*stride+pad+j", {"stride", "pad"}, {"i", "j"}, {"1", "0"}); +} + +} // jittor diff --git a/python/jittor/src/test/test_fast_shared_ptr.cc b/python/jittor/src/test/test_fast_shared_ptr.cc new file mode 100644 index 00000000..082528d4 --- /dev/null +++ b/python/jittor/src/test/test_fast_shared_ptr.cc @@ -0,0 +1,24 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "misc/fast_shared_ptr.h" + +namespace jittor { + +JIT_TEST(fast_shared_ptr) { + unordered_map a; + fast_shared_ptr> ap(move(a)); + ASSERT(ap.ptr==0); + ap = {{"a",1}}; + auto bp = ap; + ASSERT(bp.ptr==ap.ptr && bp.ref_cnt()==2); + ap = nullptr; + ASSERT(ap.ptr==nullptr && bp.ref_cnt()==1); + ap = clone(bp.data()); + ASSERT(ap.data().size()==1 && bp.ref_cnt()==1); +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/test/test_jit_key.cc b/python/jittor/src/test/test_jit_key.cc new file mode 100644 index 00000000..4f2be904 --- /dev/null +++ b/python/jittor/src/test/test_jit_key.cc @@ -0,0 +1,65 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "jit_key.h" + +namespace jittor { + +JIT_TEST(jit_key) { + JK& jk = get_jk(); + jk.clear(); + for (int i=0; i> k2 = + {{"a","11"},{"b","22"},{"a[3]","b::[x]"},{"x","17"},{"f","0"}}; + ASSERTop(keys,==,k2); + jk.clear();jk << 0x0; + ASSERT(jk.to_string()=="0"); + for (int i=1; i<63; i++) { + jk.clear(); + jk << ((1ll << i)-1); + ASSERT(jk.size==(i-1)/4+1); + jk.clear(); + jk << -((1ll << i)-1); + ASSERT(jk.size==(i-1)/4+2); + } + + jk.clear(); + add_jit_define(jk, "f", 0.01); + add_jit_define(jk, "f", 0.5); + #ifndef _MSC_VER + add_jit_define(jk, "f", 1.0/0); + add_jit_define(jk, "f", -1.0/0); + add_jit_define(jk, "f", 0.0/0); + #endif + keys = parse_jit_keys(jk.to_string()); + k2 = {{"f","0x1.47ae147ae147bp-7"}, + {"f","0x1p-1"}, + {"f","(1.0/0)"}, + {"f","(-1.0/0)"}, + {"f","(0.0/0)"}, + }; + ASSERTop(keys,==,k2); + +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/test/test_kernel_ir.cc b/python/jittor/src/test/test_kernel_ir.cc new file mode 100644 index 00000000..d1b37884 --- /dev/null +++ b/python/jittor/src/test/test_kernel_ir.cc @@ -0,0 +1,367 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "opt/kernel_ir.h" + +namespace jittor { + +string str_diff_detail(const string& a, const string& b) { + if (a.size() != b.size()) { + return "size not match " + S(a.size()) + " vs " + S(b.size()); + } + for (int i=0; i + #define aaa bbb + using namespace std; + void main(float* __restrict__ c) { + // test commment + int n = 1024; + int m = 1024; + int k = 1024; + float* __restrict__ a = new float[n*m]; + float* __restrict__ b = new float[m*k]; + for (int i=0; i(1), aaa(1), main(1), + +// C macro code:"#include " lvalue:"" +#include +// C macro code:"#define aaa bbb" lvalue:"aaa" rvalue:" bbb" +#define aaa bbb +// C code:"using namespace std;" raw:"1" +using namespace std; +// C func dtype:"void" lvalue:"main" +// scope: a(1), b(1), c(1), k(1), m(1), n(1), +void main(float* __restrict__ c) { + // C comment code:"// test commment" + // test commment + // C define dtype:"int" lvalue:"n" raw:"1" rvalue:"1024" + int n = 1024; + // C define dtype:"int" lvalue:"m" raw:"1" rvalue:"1024" + int m = 1024; + // C define dtype:"int" lvalue:"k" raw:"1" rvalue:"1024" + int k = 1024; + // C define dtype:"float* __restrict__" lvalue:"a" raw:"1" rvalue:"new float[n*m]" + float* __restrict__ a = new float[n*m]; + // C define dtype:"float* __restrict__" lvalue:"b" raw:"1" rvalue:"new float[m*k]" + float* __restrict__ b = new float[m*k]; + // C loop raw:"1" + // scope: i(1), + for (int i = 0; i({"a","b","c"})); + CHECK(split("a,b,c,,", ",") == vector({"a","b","c","",""})); + CHECK(split("a,b,c,,", ",", 3) == vector({"a","b","c,,"})); +} + +JIT_TEST(kernel_ir_manip) { + KernelIR ir(R"(for (int i=0; ierase(); + c[0]->erase(); + c.back()->erase(); + c.back()->erase(); + CHECKop(c.size(),==,1); + ir.push_back(ir.clone()); + CHECK(c.back()->type == "loop"); + auto& in_loop = *c.back(); + // test replace + in_loop.replace({{"i", "j"}, {"n", "range0"}}, true, false); + // test swap + ir.swap(in_loop); + CHECK(ir.attrs["lvalue"]=="j"); + in_loop.replace({{"n", "range1"}}, true, false); + // test rename_loop_index + ir.rename_loop_index(); + CHECK(ir.attrs["lvalue"]=="id0"); + CHECK(in_loop.attrs["lvalue"]=="id1"); + // test find_loops + auto a = ir.find_loops("1"); + CHECK(a.size()==1 && a[0]==&in_loop); + // test find_define + auto* b = in_loop.find_define("id0"); + CHECK(b == ir.inner[0].get()) << b; + // test move_loop_back + ir.push_back("a[3]++;"); + ir.push_back("for (int i=0; ichildren.size()==1); + CHECK(c[0]->children[0]->children.size()==2); + // test expand_empty_block + ir.move_out_children(); + ir.push_back("{ T xx=1; xx++; a[xx]=0; }"); + CHECK(ir.scope.count("xx")==0 && c.back()->scope.count("xx")); + ir.expand_empty_block(); + CHECK(c.size()==3 && ir.scope.count("xx")); + // test solve_conflict_define + ir.move_out_children(); + ir.push_back("{ T xx=1; xx++; a[xx]=0; }"); + ir.push_back("{ T xx=1; xx++; a[xx]=0; }"); + ir.expand_empty_block(); + ir.solve_conflict_define(); + CHECK(c.size()==6 && + c[0]->attrs["lvalue"] == "xx" && + c[1]->attrs["code"] == "xx++;" && + c[2]->attrs["code"] == "a[xx]=0;" && + c[3]->attrs["lvalue"] == "xx_" && + c[4]->attrs["code"] == "xx_++;" && + c[5]->attrs["code"] == "a[xx_]=0;" + ); + // test remove_unused + // a <-+- c <-- d (unused) + // +-- b (used) + ir.move_out_children(); + ir.push_back("T a=0;"); + ir.push_back("T b=a;"); + ir.push_back("T c=a;"); + ir.push_back("b++;"); + ir.push_back("T d=c;"); + ir.check_unused(); + CHECK(c.size()==5 && + c[0]->check_attr("used", "1") && + c[1]->check_attr("used", "1") && + c[2]->check_attr("used", "1") && + !c[4]->check_attr("used", "1") + ); + ir.remove_all_unused(); + CHECK(c.size()==3); + // test split_loop 1 + ir.move_out_children(); + ir.push_back("for (int i=0; ito_string() == code); + // test split_loop 3 + ir.move_out_children(); + ir.push_back("for (int i=0; ito_string(),==,code); + // test get_number + ir.move_out_children(); + ir.push_back("T x=1;"); + ir.push_back("T y=n;"); + int num=0; + CHECK(ir.get_number("x", num) && num==1); + CHECK(!ir.get_number("z", num) && num==-1); + CHECK(!ir.get_number("y", num) && num==-2); + // test resplit + ir.move_out_children(); + ir.push_back("for (int i=0; iresplit(); + code = R"(int id1 = 0; +for (id1=0; id1+stride1<=range1; id1+=stride1) { + int range2 = stride1; + for (int id2 = 0; id2to_string(),==,code); +} + +JIT_TEST(kernel_ir_func) { + KernelIR ir(""); + ir.push_back("void func1() {func0(0, 1);}"); + auto func1 = ir.children.back().get(); + func1->push_back("void func0(int a, int b) {}", &func1->before); + auto func0 = func1->before.back().get(); + CHECK(func0->inner.size()==2); + ir.remove_all_unused(); + CHECK(func0->inner.size()==0); + CHECK(func1->children.back()->get_attr("code") == "func0();"); + // test remove_func_call_arg + string s = "func(0,1,2,(1,2),3);"; + expect_error([&]() {remove_func_call_arg(s, 5);}); + remove_func_call_arg(s, 4); + CHECKop(s,==,"func(0,1,2,(1,2));"); + remove_func_call_arg(s, 2); + CHECKop(s,==,"func(0,1,(1,2));"); + remove_func_call_arg(s, 2); + CHECKop(s,==,"func(0,1);"); + remove_func_call_arg(s, 0); + CHECKop(s,==,"func(1);"); + remove_func_call_arg(s, 0); + CHECKop(s,==,"func();"); +} + +JIT_TEST(kernel_ir_swap_scope) { + KernelIR ir(R"( + void func() { + for (int i=0; ichildren.back(); + loop1->swap(*loop2); + CHECK(loop1->scope.count("j")); + CHECK(loop2->scope.count("i")); + CHECK(loop1->scope["j"].size()==1); + CHECK(loop2->scope["i"].size()==1); +} + +JIT_TEST(kernel_ir_remove_intermediate) { + KernelIR ir(R"( + void func() { + int* xp = input; + int* yp = output; + for (int i=0; i<100; i++) { + yp[i] = xp[i]+1; + } + } + )"); + ir.remove_intermediate({"y"}); + string expect = "auto yd = xp[i]+1;\n"; + CHECK(ir.children.at(1)->children.at(0)->to_string()==expect); + ir.remove_all_unused(); + CHECK(ir.children.at(0)->children.size()==0); +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/test/test_nano_vector.cc b/python/jittor/src/test/test_nano_vector.cc new file mode 100644 index 00000000..a4edc7d1 --- /dev/null +++ b/python/jittor/src/test/test_nano_vector.cc @@ -0,0 +1,63 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "misc/nano_vector.h" + +namespace jittor { + +JIT_TEST(nano_vector) { + NanoVector nv; + ASSERTop(nv.get_nbits(0),==,1); + ASSERTop(nv.get_nbits(-1),==,1); + + ASSERTop(nv.get_nbits(1),==,2); + ASSERTop(nv.get_nbits(-2),==,2); + + ASSERTop(nv.get_nbits(3),==,3); + ASSERTop(nv.get_nbits(-4),==,3); + + nv.push_back(0); + ASSERT(nv.size()==1 && nv[0]==0 && nv.total_bits()==1) << nv << nv.total_bits() + << nv.size() << nv.offset; + nv.push_back(0); + ASSERT(nv.size()==2 && nv[1]==0 && nv.total_bits()==2); + nv.push_back(-1); + ASSERT(nv.size()==3 && nv[2]==-1 && nv.total_bits()==3) << nv; + nv.push_back(2); + ASSERT(nv.size()==4 && nv[3]==2 && nv.total_bits()==6) << nv << nv.total_bits(); + nv.push_back(3); + ASSERT(nv.size()==5 && nv[4]==3 && nv.total_bits()==9) << nv << nv.total_bits(); + nv.push_back(-3); + ASSERT(nv.size()==6 && nv[5]==-3 && nv.total_bits()==12) + << nv << nv.total_bits() << nv[5] << nv.size(); + nv.push_back(1ull<<40); + ASSERT(nv.size()==7 && nv[6]==(1ull<<40) && nv.total_bits()==54) + << nv << nv.total_bits(); + nv.push_back(-(1<<5)); + ASSERT(nv.size()==8 && nv[7]==(-(1<<5)) && nv.total_bits()==60) + << nv << nv.total_bits(); + nv.push_back(1); + ASSERT(nv.size()==9 && nv[8]==1 && nv.total_bits()==62) + << nv << nv.total_bits(); + nv.push_back(-2); + ASSERT(nv.size()==10 && nv[9]==-2); + + nv.clear(); + nv.reserve(10, 10); + nv.set_data(0, 10); + nv.set_data(9, -10); + nv.set_data(5, 4); + ASSERT(nv.to_string()=="[10,0,0,0,0,4,0,0,0,-10,]") << nv; + + nv.clear(); + nv.reserve(10*8, 8); + nv.set_data(0, 10); + nv.set_data(7, -10); + nv.set_data(5, 4); + ASSERT(nv.to_string()=="[10,0,0,0,0,4,0,-10,]") << nv; +} + +} // jittor diff --git a/python/jittor/src/test/test_op_compiler.cc b/python/jittor/src/test/test_op_compiler.cc new file mode 100644 index 00000000..3fe35cbf --- /dev/null +++ b/python/jittor/src/test/test_op_compiler.cc @@ -0,0 +1,29 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "op_compiler.h" + +namespace jittor { + +JIT_TEST(regex) { + std::string s(R"( + asdas +void adasd +asdads XxxXxxOp::jit_run() { + xxxx +})"); + std::regex e(R"([^]*\s(\S*Op)::jit_run[^]*)"); + std::smatch cm; + + // std::regex_match ( s, cm, e, std::regex_constants::match_default ); + std::regex_match ( s, cm, e); + + CHECK(cm.size()==2); + CHECK(cm[1]=="XxxXxxOp"); +} + +} // jittor diff --git a/python/jittor/src/test/test_op_relay.cc b/python/jittor/src/test/test_op_relay.cc new file mode 100644 index 00000000..717ed72b --- /dev/null +++ b/python/jittor/src/test/test_op_relay.cc @@ -0,0 +1,105 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "op.h" +#include "var.h" +#include "opt/var_relay.h" +#include "ops/op_register.h" +#include "fused_op.h" +#include "graph.h" +#include "op_compiler.h" +#include "mem/allocator.h" + +namespace jittor { + +static auto make_binary_op = get_op_info("binary") + .get_constructor(); +static auto make_broadcast_to_op = get_op_info("broadcast_to") + .get_constructor(); +static auto make_reduce = get_op_info("reduce") + .get_constructor(); + +JIT_TEST(op_register) { + VarPtr a({10,10,1}, "float32"); + VarPtr b({1,10,10}, "float32"); + auto c = make_binary_op(a, b, ns_add); + CHECK(c->size==1000*4); + CHECK(c->input()->name_ex()=="binary.add"); +} + +JIT_TEST(fused_op_relay_matmul) { + JK& jk = get_jk(); + VarPtr a({10,10}, "float32"); + VarPtr b({10,10}, "float32"); + auto aa = make_broadcast_to_op(a, {10,10,10}, {2}); + auto bb = make_broadcast_to_op(b, {10,10,10}, {0}); + auto c = make_binary_op(aa, bb, ns_add); + auto d = make_reduce(c, ns_add, 1, false); + vector s({d->node()}), q; + vector ops; + bfs_backward(s, q, [&](Node *node) -> bool { + node->custom_data=0; + if (!node->is_var()) ops.push_back(node->op()); + return true; + }); + CHECKop(q.size(),==,10); + CHECKop(ops.size(),==,4); + for (auto op : ops) op->do_jit_prepare(jk); + FusedOp fop; + FusedOpContext context; + fop.context = &context; + context.vrm.set_fused_op(&fop); + for (uint i=0; icustom_data = b->custom_data = d->custom_data = 1; + fop.update_ops(); + context.setup(&fop); + if (!has_op("mkl_matmul")) return; + auto make_matmul = get_op_info("mkl_matmul") + .get_constructor(); + auto rvar = make_matmul(a, b, 0, 0); + + fop.context->vrm.add_relay_group({{rvar, d}}); + CHECKop(context.vrm.relay_groups[0].removed_input_vars.size(),==,2); + auto is_op_relayed = context.vrm.get_op_relay_info({1}); + for (auto v : is_op_relayed) CHECK(v.first==0 && v.second==0); + + // test2 + for (Node* node : q) node->custom_data = 0; + // a, b, d can not fuse + a->custom_data = b->custom_data = d->custom_data = 1; + // broadcast(a) can not fused + fop.vars[1].var->custom_data = 1; + fop.update_ops(); + context.setup(&fop); + is_op_relayed = context.vrm.get_op_relay_info({1}); + vector> ans{{-1,-1},{0,0},{0,0},{0,0}}; + CHECKop(is_op_relayed,==,ans); + auto& oprc = context.vrm.relay_groups[0].oprcs[0]; + CHECKop(oprc.op,==,rvar->input()); + // matmul op.x --> a, op.y --> b, op.z --> d + CHECK(oprc.relayed_members[0]==(a->custom_data>>2)); + CHECK(oprc.relayed_members[1]==(b->custom_data>>2)); + CHECK(oprc.relayed_members[2]==(d->custom_data>>2)); + auto src = context.vrm.get_relay_src(0,0); + + auto& loop_options = fop.get_loop_options_tuned(); + loop_options["relay0"] = 1; + OpCompiler oc(&fop); + + auto allocator = get_allocator(); + for (auto& v : fop.vars) + if (v.type!=1) v.var->alloc(allocator); + auto entry = oc.compile("«OP:_fused_op_relay_matmul", oc.src); + for (uint i=0; inum; i++) + a->ptr()[i] = b->ptr()[i] = 1; + entry(&fop); + for (uint i=0; inum; i++) + CHECK(d->ptr()[i]==10); +} + +} // jittor diff --git a/python/jittor/src/test/test_setitem_op.cc b/python/jittor/src/test/test_setitem_op.cc new file mode 100644 index 00000000..f3bf1c68 --- /dev/null +++ b/python/jittor/src/test/test_setitem_op.cc @@ -0,0 +1,41 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "common.h" +#include "misc/nano_vector.h" + +namespace jittor { + +void cuda_loop_schedule(NanoVector o_shape, int* masks, int* tdims); + +JIT_TEST(cuda_loop_schedule) { + auto check = [&](const vector& shape, const vector& masks, vector tdims={}) { + STACK_ALLOC(int, masks2, shape.size()); + int tdims2[6]; + cuda_loop_schedule(shape, masks2, tdims2); + while (tdims.size() < 6) tdims.push_back(1); + for (int i=0; i(masks2, masks2+shape.size()); + for (int i=0; i<6; i++) + ASSERT(tdims.at(i)==tdims2[i]) << tdims << vector(tdims2, tdims2+6); + }; + check({0}, {129}, {0,1,1,1,1,1}); + check({2,2,2,2}, {8, 4, 2, 1}, {2,2,2,2,1,1}); + check({2048,1024}, {8, 1}, {1024,1,1,2048,1,1}); + check({2048,1025}, {8, 1+(1<<6)}, {1024,1,1,2048,1,1}); + check({2048,3025}, {8, 1+(1<<6)}, {1024,1,1,2048,1,1}); + check({2048,4425}, {16, 1+8+(1<<6)}, {1024,1,1,5,2048,1}); + check({2048, 2048,4425}, {0, 16, 1+8+(1<<6)}, {1024,1,1,5,2048,1}); + check({3,3,3,4425}, {0, 32, 16, 1+8+(1<<6)}, {1024,1,1,5,3,3}); + check({3,3,3,4425, 3,3}, {0, 32, 16, 8+4+(1<<6), 2, 1}, {3,3,64,70,3,3}); + check({3,3,3,12, 9,9}, {32, 16, 8, 4, 2, 1}, {9,9,12,3,3,3}); + check({3,3,3,13, 9,9}, {32, 16, 8, 4+64, 2, 1}, {9,9,12,3,3,3}); + check({3,3,3,13*4, 9,9}, {0, 32, 16, 8+4+64, 2, 1}, {9,9,12,5,3,3}); + check({3,3,3,100, 3,3}, {32, 16, 8, 4+64, 2, 1}, {3,3,64,3,3,3}); + check({3,3,3,400, 3,3}, {0, 32, 16, 8+4+64, 2, 1}, {3,3,64,7,3,3}); +} + +} // jittor diff --git a/python/jittor/src/test/test_sfrl_allocator.cc b/python/jittor/src/test/test_sfrl_allocator.cc new file mode 100644 index 00000000..b60fd5c1 --- /dev/null +++ b/python/jittor/src/test/test_sfrl_allocator.cc @@ -0,0 +1,118 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** + +#include "opt/kernel_ir.h" +#include "mem/allocator/sfrl_allocator.h" +#include +#include + +namespace jittor { + +struct TestTask { + //alloc [size] for [times2] times and free them all, do this [times1] times + size_t size, times1, times2; + float time_limit; //ms + TestTask(size_t size, size_t times1, size_t times2, float time_limit) : size(size), times1(times1), times2(times2), time_limit(time_limit) {} +}; + +JIT_TEST(sfrl_allocator_time) { + Allocator* allocator = get_allocator(); + constexpr int max_allc_num = 10000; + size_t id[max_allc_num]; + size_t temp[max_allc_num]; + std::vector tasks; + tasks.push_back(TestTask(20000000, 1000, 1000, 400.0)); + tasks.push_back(TestTask(10000, 1000, 1000, 600.0)); + + for (size_t i = 0; i < tasks.size(); ++i) { + auto begin = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()).count(); + for (size_t k = 0; k < tasks[i].times1; ++k) { + for (size_t j = 0; j < tasks[i].times2; ++j) { + temp[j] = j; + allocator->alloc(tasks[i].size, id[j]); + if (j > 0) + std::swap(temp[j], temp[rand() % j]); + } + for (size_t j = 0; j < tasks[i].times2; ++j) { + allocator->free(0, tasks[i].size, id[temp[j]]); + } + } + auto end = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()).count(); + + LOGvv << "Use time " << float(end - begin) / 1000 << "ms\n"; + ASSERTop(float(end - begin) / 1000, <, tasks[i].time_limit); + } +} + +JIT_TEST(sfrl_allocator_share) { + Allocator* allocator = get_allocator(); + constexpr int max_allc_num = 10000; + size_t id[max_allc_num]; + size_t temp[max_allc_num]; + std::vector tasks; + tasks.push_back(TestTask(20000000, 1000, 1000, 400.0)); + tasks.push_back(TestTask(10000, 1000, 1000, 600.0)); + + for (size_t i = 0; i < tasks.size(); ++i) { + auto begin = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()).count(); + for (size_t k = 0; k < tasks[i].times1; ++k) { + for (size_t j = 0; j < tasks[i].times2; ++j) { + temp[j] = j; + if (j > 0) + std::swap(temp[j], temp[rand() % j]); + if (rand() % 10 != 0 && j > 0) { + id[j] = id[rand() % j]; + allocator->share_with(tasks[i].size, id[j]); + } else { + allocator->alloc(tasks[i].size, id[j]); + } + } + for (size_t j = 0; j < tasks[i].times2; ++j) { + allocator->free(0, tasks[i].size, id[temp[j]]); + } + } + auto end = std::chrono::duration_cast( + std::chrono::steady_clock::now().time_since_epoch()).count(); + + LOGvvv << "Use time " << float(end - begin) / 1000 << "ms\n"; + ASSERTop(float(end - begin) / 1000, <, tasks[i].time_limit); + } +} + +JIT_TEST(sfrl_allocator_share_without_size_and_ptr) { + Allocator* allocator = get_allocator(); + constexpr int max_allc_num = 1000; + size_t id[max_allc_num]; + size_t temp[max_allc_num]; + std::vector tasks; + tasks.push_back(TestTask(20000000, 100, 100, 400.0)); + tasks.push_back(TestTask(10000, 100, 100, 600.0)); + + for (size_t i = 0; i < tasks.size(); ++i) { + for (size_t k = 0; k < tasks[i].times1; ++k) { + for (size_t j = 0; j < tasks[i].times2; ++j) { + temp[j] = j; + if (j > 0) + std::swap(temp[j], temp[rand() % j]); + if (rand() % 10 != 0 && j > 0) { + id[j] = id[rand() % j]; + allocator->share_with(0, id[j]); + } else { + allocator->alloc(tasks[i].size, id[j]); + } + } + for (size_t j = 0; j < tasks[i].times2; ++j) { + allocator->free(0, 0, id[temp[j]]); + } + } + } +} + +} // jittor diff --git a/python/jittor/src/type/common_op_type.cc b/python/jittor/src/type/common_op_type.cc new file mode 100644 index 00000000..7d7e8c64 --- /dev/null +++ b/python/jittor/src/type/common_op_type.cc @@ -0,0 +1,173 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "common.h" +#include "utils/str_utils.h" +#include "ops/op_register.h" + +namespace jittor { + +extern int use_cuda; + +unordered_map common_op_type_cuda_map = { + {"logical_not", "(!($2))"}, + {"bitwise_not", "(~($2))"}, + {"negative", "(-($2))"}, + {"abs", "::abs($2)"}, + {"log", "::logf(($1)($2))"}, + {"exp", "::expf(($1)($2))"}, + {"sqrt", "::sqrtf(($1)($2))"}, + {"round", "(($1) ::roundf(($2)))"}, + {"floor", "(($1) ::floorf(($2)))"}, + {"ceil", "(($1) ::ceilf(($2)))"}, + {"round_int", "(($1) ::roundf(($2)))"}, + {"floor_int", "(($1) ::floorf(($2)))"}, + {"ceil_int", "(($1) ::ceilf(($2)))"}, + {"sin", "(($1) ::sinf(($2)))"}, + {"asin", "(($1) ::asinf(($2)))"}, + {"sinh", "(($1) ::sinhf(($2)))"}, + {"asinh", "(($1) ::asinhf(($2)))"}, + {"cos", "(($1) ::cosf(($2)))"}, + {"acos", "(($1) ::acosf(($2)))"}, + {"cosh", "(($1) ::coshf(($2)))"}, + {"acosh", "(($1) ::acoshf(($2)))"}, + {"tan", "(($1) ::tanf(($2)))"}, + {"atan", "(($1) ::atanf(($2)))"}, + {"tanh", "(($1) ::tanhf(($2)))"}, + {"atanh", "(($1) ::atanhf(($2)))"}, + {"sigmoid", "(($1) (1.0f/(1.0f+::expf((::min($1(-($2)), $1(@if(@strcmp($1,float32)==0,30,300))))))))"}, + {"erf", "(($1) ::erff(($2)))"}, + {"erfinv", "(($1) ::erfinvf(($1)($2)))"}, + {"cast", "(($1)($2))"}, +#ifdef _WIN32 + // windows don't have pow(float,int), cause undefined reference, fix it + {"pow", "::pow(($1)($2),($1)($4))"}, +#else + {"pow", "::pow(($2),($4))"}, +#endif + {"maximum", "::max($1($2), $1($4))"}, + {"minimum", "::min($1($2), $1($4))"}, + {"mod", "@if(@strcmp($1,float32)==0,(($2)-::floorf(($2)/($4))*($4)),@if(@strcmp(@Tx,float64)==0,(($2)-::floor(($2)/($4))*($4)),(($2)%($4))))"}, + {"init_maximum", "::numeric_min<$1>()"}, + {"init_minimum", "::numeric_max<$1>()"}, +}; + +struct CommonOpType : OpByType { + CommonOpType() { + types = { + "bool", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float32", + "float64", + }; + } + + string expand_op(const vector& args) { + for (int i=1; i cpu_map = { + {"logical_not", "(!($2))"}, + {"bitwise_not", "(~($2))"}, + {"negative", "(-($2))"}, + {"abs", "std::abs($2)"}, + {"log", "std::log(($1)($2))"}, + {"exp", "std::exp(($1)($2))"}, + {"sqrt", "std::sqrt(($1)($2))"}, + {"round", "(($1)std::round(($2)))"}, + {"floor", "(($1)std::floor(($2)))"}, + {"ceil", "(($1)std::ceil(($2)))"}, + {"round_int", "(($1)std::round(($2)))"}, + {"floor_int", "(($1)std::floor(($2)))"}, + {"ceil_int", "(($1)std::ceil(($2)))"}, + {"sin", "(($1) std::sin(($2)))"}, + {"asin", "(($1) std::asin(($2)))"}, + {"sinh", "(($1) std::sinh(($2)))"}, + {"asinh", "(($1) std::asinh(($2)))"}, + {"cos", "(($1) std::cos(($2)))"}, + {"acos", "(($1) std::acos(($2)))"}, + {"cosh", "(($1) std::cosh(($2)))"}, + {"acosh", "(($1) std::acosh(($2)))"}, + {"tan", "(($1) std::tan(($2)))"}, + {"atan", "(($1) std::atan(($2)))"}, + {"tanh", "(($1) std::tanh(($2)))"}, + {"atanh", "(($1) std::atanh(($2)))"}, + {"sigmoid", "(($1) (1.0f/(1.0f+std::exp(std::min($1(-($2)), $1(@if(@strcmp($1,float32)==0,30,300)))))))"}, + {"erf", "(($1) std::erf(($2)))"}, + {"erfinv", "(jittor::_erfinv($2))"}, + {"cast", "(($1)($2))"}, + {"pow", "std::pow(($2),($4))"}, + {"maximum", "std::max($1($2), $1($4))"}, + {"minimum", "std::min($1($2), $1($4))"}, + {"mod", "@if(@strcmp($1,float32)==0,(($2)-std::floor(($2)/($4))*($4)),@if(@strcmp(@Tx,float64)==0,(($2)-std::floor(($2)/($4))*($4)),(($2)%($4))))"}, + {"init_maximum", "std::numeric_limits<$1>::lowest()"}, + {"init_minimum", "std::numeric_limits<$1>::max()"}, + }; + + static unordered_map both_map { + {"void", "($4)"}, + {"add", "(($2)+($4))"}, + {"subtract", "(($2)-($4))"}, + {"multiply", "(($2)*($4))"}, + {"divide", "($1(($1($2))/($1($4))))"}, + {"floor_divide", "($1(($1($2))/($1($4))))"}, + {"less", "(($2)<($4))"}, + {"less_equal", "(($2)<=($4))"}, + {"greater", "(($2)>($4))"}, + {"greater_equal", "(($2)>=($4))"}, + {"equal", "(($2)==($4))"}, + {"not_equal", "(($2)!=($4))"}, + {"left_shift", "(($2)<<($4))"}, + {"right_shift", "(($2)>>($4))"}, + {"logical_and", "(($2)&&($4))"}, + {"logical_or", "(($2)||($4))"}, + {"logical_xor", "((bool($2))!=(bool($4)))"}, + {"bitwise_and", "(($2)&($4))"}, + {"bitwise_or", "(($2)|($4))"}, + {"bitwise_xor", "(($2)^($4))"}, + {"mean", "(($2)+$1($4)*($1(rcount)))"}, + {"init_void", "$1(0)"}, + {"init_add", "$1(0)"}, + {"init_multiply", "$1(1)"}, + {"init_logical_and", "true"}, + {"init_logical_or", "false"}, + {"init_logical_xor", "false"}, + {"init_bitwise_and", "$1(-1)"}, + {"init_bitwise_or", "$1(0)"}, + {"init_bitwise_xor", "$1(0)"}, + {"init_mean", "$1(0)"}, + }; + + string ret; + if (both_map.count(args.at(0))) + ret = both_map[args.at(0)]; + else if (use_cuda) + ret = cuda_map[args.at(0)]; + else + ret = cpu_map[args.at(0)]; + if (args.at(1) == "bool") ret = "((bool)"+ret+")"; + return format(ret, args); + } + + void post_pass(OpCompiler*) { + return; + } +}; + + +static int _ = registe_op_type(new CommonOpType()); + +} \ No newline at end of file diff --git a/python/jittor/src/type/fp16_compute.h b/python/jittor/src/type/fp16_compute.h new file mode 100644 index 00000000..f5c85d3b --- /dev/null +++ b/python/jittor/src/type/fp16_compute.h @@ -0,0 +1,257 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +#if defined(JIT_cuda) && !defined(IS_ACL) + +#include +#include +#ifndef IS_ROCM +#include +#endif + +namespace jittor { + +typedef __half float16; +#ifndef IS_ROCM +typedef __nv_bfloat16 bfloat16; +#endif + + +#if CUDA_ARCH >= 800 +inline __device__ float16 max(float16 a, float16 b) { return __hmax(a, b); } +inline __device__ float16 min(float16 a, float16 b) { return __hmin(a, b); } +#elif CUDA_ARCH >= 610 +inline __device__ float16 max(float16 a, float16 b) { return a= 800 +inline __device__ bfloat16 max(bfloat16 a, bfloat16 b) { return __hmax(a, b); } +inline __device__ bfloat16 min(bfloat16 a, bfloat16 b) { return __hmin(a, b); } +#elif CUDA_ARCH >= 610 +inline __device__ bfloat16 max(bfloat16 a, bfloat16 b) { return a +__device__ inline +typename std::enable_if::type +vload(T* __restrict__ a, T* __restrict__ b) {} + +template +__device__ inline +typename std::enable_if<0::type +vload(T* __restrict__ a, T* __restrict__ b) { + if (nbyte<=0) return; + if (nbyte>=16) { + auto* __restrict__ aa = (float4* __restrict__)a; + auto* __restrict__ bb = (float4* __restrict__)b; + aa[0] = bb[0]; + return vload(aa+1, bb+1); + } + if (nbyte>=8) { + auto* __restrict__ aa = (float2* __restrict__)a; + auto* __restrict__ bb = (float2* __restrict__)b; + aa[0] = bb[0]; + return vload(aa+1, bb+1); + } + if (nbyte>=4) { + auto* __restrict__ aa = (float* __restrict__)a; + auto* __restrict__ bb = (float* __restrict__)b; + aa[0] = bb[0]; + return vload(aa+1, bb+1); + } + if (nbyte>=2) { + auto* __restrict__ aa = (__half* __restrict__)a; + auto* __restrict__ bb = (__half* __restrict__)b; + aa[0] = bb[0]; + return vload(aa+1, bb+1); + } + if (nbyte>=1) { + auto* __restrict__ aa = (int8_t* __restrict__)a; + auto* __restrict__ bb = (int8_t* __restrict__)b; + aa[0] = bb[0]; + return vload(aa+1, bb+1); + } +} + +template +__device__ inline +typename std::enable_if::type +vfill(T* __restrict__ a) {} + +template +__device__ inline +typename std::enable_if<0::type +vfill(T* __restrict__ a) { + if (nbyte<=0) return; + if (nbyte>=16) { + auto* __restrict__ aa = (int4* __restrict__)a; + aa[0].x = aa[0].y = aa[0].z = aa[0].w = 0; + return vfill(aa+1); + } + if (nbyte>=8) { + auto* __restrict__ aa = (int2* __restrict__)a; + aa[0].x = aa[0].y = 0; + return vfill(aa+1); + } + if (nbyte>=4) { + auto* __restrict__ aa = (int* __restrict__)a; + aa[0] = 0; + return vfill(aa+1); + } + if (nbyte>=2) { + auto* __restrict__ aa = (int16_t* __restrict__)a; + aa[0] = 0; + return vfill(aa+1); + } + if (nbyte>=1) { + auto* __restrict__ aa = (int8_t* __restrict__)a; + aa[0] = 0; + return vfill(aa+1); + } +} + + +} + +using jittor::max; +using jittor::min; +using jittor::pow; + +#else + +namespace jittor { + +struct float16 { + uint16 x; + + inline float16(float32 f) { + unsigned x = *((int*)(void*)(&f)); + unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; + unsigned sign, exponent, mantissa; + + + // Get rid of +NaN/-NaN case first. + if (u > 0x7f800000) { + this->x = 0x7fffU; + return; + } + + sign = ((x >> 16) & 0x8000); + + // Get rid of +Inf/-Inf, +0/-0. + if (u > 0x477fefff) { + this->x = sign | 0x7c00U; + return; + } + if (u < 0x33000001) { + this->x = sign | 0x0000U; + return; + } + + exponent = ((u >> 23) & 0xff); + mantissa = (u & 0x7fffff); + + if (exponent > 0x70) { + shift = 13; + exponent -= 0x70; + } else { + shift = 0x7e - exponent; + exponent = 0; + mantissa |= 0x800000; + } + lsb = (1 << shift); + lsb_s1 = (lsb >> 1); + lsb_m1 = (lsb - 1); + + // Round to nearest even. + remainder = (mantissa & lsb_m1); + mantissa >>= shift; + if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { + ++mantissa; + if (!(mantissa & 0x3ff)) { + ++exponent; + mantissa = 0; + } + } + + this->x = (sign | (exponent << 10) | mantissa); + } + + inline operator float() const { + + unsigned sign = ((x >> 15) & 1); + unsigned exponent = ((x >> 10) & 0x1f); + unsigned mantissa = ((x & 0x3ff) << 13); + + if (exponent == 0x1f) { /* NaN or Inf */ + mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0); + exponent = 0xff; + } else if (!exponent) { /* Denorm or Zero */ + if (mantissa) { + unsigned int msb; + exponent = 0x71; + do { + msb = (mantissa & 0x400000); + mantissa <<= 1; /* normalize */ + --exponent; + } while (!msb); + mantissa &= 0x7fffff; /* 1.mantissa is implicit */ + } + } else { + exponent += 0x70; + } + + int temp = ((sign << 31) | (exponent << 23) | mantissa); + + return reinterpret_cast(temp); + } +}; + +bool operator<(float16 x, float16 y) { return float32(x)(float16 x, float16 y) { return float32(x)>float32(y); } +bool operator==(float16 x, float16 y) { return float32(x)==float32(y); } + + +struct bfloat16 { + uint16 x; + + inline bfloat16(float32 f) { + unsigned x = *((int*)(void*)(&f)); + this->x = x>>16; + } + + inline operator float() const { + int temp = x<<16; + + return reinterpret_cast(temp); + } +}; + +bool operator<(bfloat16 x, bfloat16 y) { return float32(x)(bfloat16 x, bfloat16 y) { return float32(x)>float32(y); } +bool operator==(bfloat16 x, bfloat16 y) { return float32(x)==float32(y); } + + +} + +#endif \ No newline at end of file diff --git a/python/jittor/src/type/fp16_op_type.cc b/python/jittor/src/type/fp16_op_type.cc new file mode 100644 index 00000000..b331f456 --- /dev/null +++ b/python/jittor/src/type/fp16_op_type.cc @@ -0,0 +1,200 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "common.h" +#include "utils/str_utils.h" +#include "ops/op_register.h" +#include "op_compiler.h" + +namespace jittor { + +extern int use_cuda; + +extern unordered_map common_op_type_cuda_map; + +static bool isvar(char x) { return isalnum(x) || x == '_' || x == ':'; } + +struct FP16OpType : OpByType { + FP16OpType() { + types = { + "float16", + "bfloat16", + }; + } + + string expand_op(const vector& args) { + bool found_fp16 = 0; + bool found_bf16 = 0; + for (int i=1; i cuda_map = { + {"logical_not", "(!($2))"}, + {"bitwise_not", "(~($2))"}, + {"negative", "(-($2))"}, + {"abs", "::__habs($2)"}, + {"log", "::hlog(($1)($2))"}, + {"exp", "::hexp(($1)($2))"}, + {"sqrt", "::hsqrt(($1)($2))"}, + {"round", "(($1) ::roundf(($2)))"}, + {"floor", "(($1) ::floorf(($2)))"}, + {"ceil", "(($1) ::ceilf(($2)))"}, + {"round_int", "(($1) ::roundf(($2)))"}, + {"floor_int", "(($1) ::floorf(($2)))"}, + {"ceil_int", "(($1) ::ceilf(($2)))"}, + {"sin", "(($1) ::sinf(($2)))"}, + {"asin", "(($1) ::asinf(($2)))"}, + {"sinh", "(($1) ::sinhf(($2)))"}, + {"asinh", "(($1) ::asinhf(($2)))"}, + {"cos", "(($1) ::cosf(($2)))"}, + {"acos", "(($1) ::acosf(($2)))"}, + {"cosh", "(($1) ::coshf(($2)))"}, + {"acosh", "(($1) ::acoshf(($2)))"}, + {"tan", "(($1) ::tanf(($2)))"}, + {"atan", "(($1) ::atanf(($2)))"}, + {"tanh", "(($1) ::tanhf(($2)))"}, + {"atanh", "(($1) ::atanhf(($2)))"}, + {"sigmoid", "(($1) (1.0f/(1.0f+::expf((::min($1(-($2)), $1(@if(@strcmp($1,float16)==0,30,300))))))))"}, + {"erf", "(($1) ::erff(($2)))"}, + {"erfinv", "(($1) ::erfinvf(($1)($2)))"}, + {"cast", "(($1)($2))"}, + {"pow", "::pow(($2),($4))"}, + {"maximum", "::max($1($2), $1($4))"}, + {"minimum", "::min($1($2), $1($4))"}, + {"mod", "$1(($2)-::hfloor(($2)/($4))*($4))"}, + {"init_maximum", "@if(@strcmp($1,float16)==0,-65000.0f,-1e38)"}, + {"init_minimum", "@if(@strcmp($1,float16)==0,65000.0f,1e38)"}, + {"equal", "(($2)==($4))"}, + }; + + static unordered_map cpu_map = { + {"logical_not", "(!($2))"}, + {"bitwise_not", "(~($2))"}, + {"negative", "(-($2))"}, + {"abs", "std::abs($2)"}, + {"log", "std::log(($1)($2))"}, + {"exp", "std::exp(($1)($2))"}, + {"sqrt", "std::sqrt(($1)($2))"}, + {"round", "(($1)std::round(($2)))"}, + {"floor", "(($1)std::floor(($2)))"}, + {"ceil", "(($1)std::ceil(($2)))"}, + {"round_int", "(($1)std::round(($2)))"}, + {"floor_int", "(($1)std::floor(($2)))"}, + {"ceil_int", "(($1)std::ceil(($2)))"}, + {"sin", "(($1) std::sin(($2)))"}, + {"asin", "(($1) std::asin(($2)))"}, + {"sinh", "(($1) std::sinh(($2)))"}, + {"asinh", "(($1) std::asinh(($2)))"}, + {"cos", "(($1) std::cos(($2)))"}, + {"acos", "(($1) std::acos(($2)))"}, + {"cosh", "(($1) std::cosh(($2)))"}, + {"acosh", "(($1) std::acosh(($2)))"}, + {"tan", "(($1) std::tan(($2)))"}, + {"atan", "(($1) std::atan(($2)))"}, + {"tanh", "(($1) std::tanh(($2)))"}, + {"atanh", "(($1) std::atanh(($2)))"}, + {"sigmoid", "(($1) (1.0f/(1.0f+std::exp(std::min($1(-($2)), $1(@if(@strcmp($1,float32)==0,30,300)))))))"}, + {"erf", "(($1) std::erf(($2)))"}, + {"erfinv", "(jittor::_erfinv($2))"}, + {"cast", "(($1)($2))"}, + {"pow", "std::pow(($2),($4))"}, + {"maximum", "std::max($1($2), $1($4))"}, + {"minimum", "std::min($1($2), $1($4))"}, + {"mod", "$1(($2)-std::floor(($2)/($4))*($4))"}, + {"init_maximum", "-32768.0f"}, + {"init_minimum", "32768.0f"}, + {"equal", "(float($2)==float($4))"}, + }; + + static unordered_map both_map { + {"void", "($4)"}, + {"add", "(($2)+($4))"}, + {"subtract", "(($2)-($4))"}, + {"multiply", "(($2)*($4))"}, + {"divide", "($1(($1($2))/($1($4))))"}, + {"floor_divide", "($1(($1($2))/($1($4))))"}, + {"less", "(($2)<($4))"}, + {"less_equal", "(($2)<=($4))"}, + {"greater", "(($2)>($4))"}, + {"greater_equal", "(($2)>=($4))"}, + {"not_equal", "(($2)!=($4))"}, + {"left_shift", "(($2)<<($4))"}, + {"right_shift", "(($2)>>($4))"}, + {"logical_and", "(($2)&&($4))"}, + {"logical_or", "(($2)||($4))"}, + {"logical_xor", "((bool($2))!=(bool($4)))"}, + {"bitwise_and", "(($2)&($4))"}, + {"bitwise_or", "(($2)|($4))"}, + {"bitwise_xor", "(($2)^($4))"}, + {"mean", "(($2)+($4)*($1(rcount)))"}, + {"init_void", "$1(0)"}, + {"init_add", "$1(0)"}, + {"init_multiply", "$1(1)"}, + {"init_logical_and", "true"}, + {"init_logical_or", "false"}, + {"init_logical_xor", "false"}, + {"init_bitwise_and", "$1(-1)"}, + {"init_bitwise_or", "$1(0)"}, + {"init_bitwise_xor", "$1(0)"}, + {"init_mean", "$1(0)"}, + }; + + string ret; + if (both_map.count(args.at(0))) + ret = both_map[args.at(0)]; + else if (use_cuda) + ret = cuda_map[args.at(0)]; + else + ret = cpu_map[args.at(0)]; + if (use_cuda) { + if (args[1] == "float32" && !both_map.count(args.at(0))) { + ret = common_op_type_cuda_map[args.at(0)]; + } + if (args[1] == "float16" || + args[1] == "bfloat16" || + args[1] == "float32") + { + for (int i=3; isrc; + if (src.find("float16") == string::npos) + return; + int i = src.rfind("#include"); + if (i<0) i=0; + i = src.find('\n', i) + 1; + src = src.substr(0, i) + "#include \"type/fp16_compute.h\"\n" + + src.substr(i); + return; + } +}; + + +static int _ = registe_op_type(new FP16OpType()); + +} \ No newline at end of file diff --git a/python/jittor/src/types.h b/python/jittor/src/types.h new file mode 100644 index 00000000..576a8ce2 --- /dev/null +++ b/python/jittor/src/types.h @@ -0,0 +1,244 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include +#include +#include +#include +#include +#include + +namespace jittor { + +typedef int8_t int8; +typedef int16_t int16; +typedef int int32; +typedef long long int64; +typedef uint8_t uint8; +typedef uint16_t uint16; +typedef uint32_t uint32; +typedef uint64_t uint64; +typedef float float32; +typedef double float64; +typedef uint32_t uint; + +using string = std::string; +using std::move; +template using vector = std::vector; +template using list = std::list; +template using set = std::set; +template using shared_ptr = std::shared_ptr; +template using unique_ptr = std::unique_ptr; +template using unordered_set = std::unordered_set; +template using pair = std::pair; +template using map = std::map; +template using unordered_map = std::unordered_map; + +struct Node; +struct Var; +struct Op; +struct Allocator; +struct Executor; +struct VarHolder; +struct VarPtr; +struct FusedOp; +struct OpCompiler; +struct PassManager; +struct Pass; +struct TunerManager; +struct Tuner; +struct NanoString; + +typedef map map_string; +typedef map loop_options_t; +typedef map> loop_option_candidates_t; +typedef void (*jit_op_entry_t)(Op* op); + +template +T clone(const T& a) { return a; } + +#define function_alias(A, B) \ +template \ +auto B(Args&&... args) -> decltype(A(std::forward(args)...)) { \ + return A(std::forward(args)...); \ +} + +function_alias(std::to_string, S); + +template +std::ostream& operator<<(std::ostream& os, const pair& p) { + return os << '(' << p.first << ',' << p.second << ')'; +} + +// print tuple function +namespace aux{ +template struct seq{}; + +template +struct gen_seq : gen_seq{}; + +template +struct gen_seq<0, Is...> : seq{}; + +template +void print_tuple(std::basic_ostream& os, Tuple const& t, seq){ + using swallow = int[]; + (void)swallow{0, (void(os << (Is == 0? "" : ",") << std::get(t)), 0)...}; +} +} // aux:: + +template +auto operator<<(std::basic_ostream& os, std::tuple const& t) + -> std::basic_ostream& +{ + os << "["; + aux::print_tuple(os, t, aux::gen_seq()); + return os << "]"; +} + + +template +std::ostream& operator<<(std::ostream& os, unique_ptr& ptr) { + return os << *ptr; +} + +template +std::ostream& operator<<(std::ostream& os, shared_ptr& ptr) { + return os << *ptr; +} + +template +std::ostream& operator<<(std::ostream& os, const unique_ptr& ptr) { + return os << *ptr; +} + +template +std::ostream& operator<<(std::ostream& os, const shared_ptr& ptr) { + return os << *ptr; +} + +template +std::ostream& operator<<(std::ostream& os, vector& input) { + os << '['; + for (auto& i: input) os << i << ","; + return os << ']'; +} + +template +std::ostream& operator<<(std::ostream& os, list& input) { + os << '['; + for (auto& i: input) os << i << ","; + return os << ']'; +} + +template +std::ostream& operator<<(std::ostream& os, map& input) { + os << '{'; + for (auto& i: input) os << i.first << ':' << i.second << ", "; + return os << '}'; +} + +template +std::ostream& operator<<(std::ostream& os, const vector& input) { + os << '['; + for (auto const& i: input) os << i << ","; + return os << ']'; +} + +template +std::ostream& operator<<(std::ostream& os, const list& input) { + os << '['; + for (auto const& i: input) os << i << ","; + return os << ']'; +} + +template +std::ostream& operator<<(std::ostream& os, const set& input) { + os << '['; + for (auto const& i: input) os << i << ","; + return os << ']'; +} + +template +std::istream& operator>>(std::istream& is, vector& out) { + T value; + while (is >> value) + out.push_back(value); + return is; +} + +template +std::ostream& operator<<(std::ostream& os, const map& input) { + os << '{'; + for (auto const& i: input) os << i.first << ':' << i.second << ", "; + return os << '}'; +} + +template +std::istream& operator>>(std::istream& is, map& out) { + Ta key; + Tb value; + while (is >> key >> value) + out[key] = value; + return is; +} + +template +std::istream& operator>>(std::istream& is, unordered_map& out) { + Ta key; + Tb value; + while (is >> key >> value) + out[key] = value; + return is; +} + + +template +std::ostream& operator<<(std::ostream& os, const unordered_map& input) { + os << '{'; + for (auto const& i: input) os << i.first << ':' << i.second << ", "; + return os << '}'; +} + +template +std::ostream& operator<<(std::ostream& os, const unordered_set& input) { + os << '{'; + for (auto const& i: input) os << i << ", "; + return os << '}'; +} + +template +struct Caster { + list *ptr; + Caster(list* ptr) : ptr(ptr) {}; + struct Iter { + typename list::iterator iter, next; + Iter(typename list::iterator iter) + : iter(iter), next(std::next(iter)) {} + T operator*() { return iter->operator T(); } + Iter& operator++() { iter = next++; return *this; } + Iter operator++(int) { auto tmp = *this; ++(*this); return tmp; } + bool operator!=(Iter& other) { return iter != other.iter; } + }; + Iter begin() const { return Iter(ptr->begin()); } + Iter end() const { return Iter(ptr->end()); } + size_t size() { return ptr->size(); } + T front() { return ptr->front().operator T(); } + T back() { return ptr->back().operator T(); } +}; + +template +std::ostream& operator<<(std::ostream& os, const Caster& input) { + os << '['; + for (const T i: input) os << i << ","; + return os << ']'; +} + +#define JPU(x) ; + +} // jittor diff --git a/python/jittor/src/utils/cache_compile.cc b/python/jittor/src/utils/cache_compile.cc new file mode 100644 index 00000000..b2d741d6 --- /dev/null +++ b/python/jittor/src/utils/cache_compile.cc @@ -0,0 +1,439 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#ifdef _WIN32 +#include +#endif +#include "misc/hash.h" +#include "utils/cache_compile.h" +#include "utils/str_utils.h" + +namespace jittor { +namespace jit_compiler { + +#ifndef TEST +string read_all(const string& fname) { + std::ifstream ifs(fname); + if (ifs && ifs.good()) + return string((std::istreambuf_iterator(ifs)), + (std::istreambuf_iterator())); + return ""; +} + +void write(const string& fname, const string& src) { + std::ofstream(fname) << src; +} + +bool file_exist(const string& fname) { + std::ifstream f(fname); + return f && f.good(); +} +#endif + +string join(string a, string b) { + const char sep = '/'; + if (!b.empty() && b.front() == sep) return b; + a.reserve(a.size() + b.size() + 1); + if (!a.empty() && a.back() != sep) a += sep; + a += b; + return a; +} + +void find_names(string cmd, vector& input_names, string& output_name, map>& extra) { + // find space not in str + #define is_quate(x) ((x)=='\'' || (x)=='\"') + auto pass = [&](size_t& j) { + while (j string { + string s; + for (size_t k=i; k& input_names, string& cmd) { + for (size_t i=0; i=src.size()) break; + if (src[i] == '#') { + // #include "a.h" + // i jk l + auto j=i+1; + while (j=src.size()) return; + if (j-i != 8 && j-i != 6 && j-i != 3) continue; + auto k=src[j] == '\"' ? j : j+1; + while (k=src.size()) return; + auto l=k+1; + while (l2 && src[k] == 'J' && src[k+1] == 'T' && (src.substr(i,j-i) == "#ifdef" || src.substr(i,j-i) == "#if")) { + auto inc = strip(src.substr(k, l-k)); + auto env = getenv(inc.c_str()); + if (env && string(env)!="0") { + auto senv = string(env); + string dflag = " -D"+inc+"="+senv; + if (cmd.find(dflag) == string::npos) { + // -D flags should insert before -o flag + #ifdef _MSC_VER + string patt = " -Fo: "; + #else + string patt = " -o "; + #endif + auto cmds = split(cmd, patt, 2); + if (cmds.size() == 2) { + cmd = cmds[0] + dflag + patt + cmds[1]; + } + } + } + } + i=l; + } + } +} + +static inline void check_win_file(const string& name) { +#ifdef _WIN32 + // win32 not allowed so file change when load + // but we can rename it + if (!file_exist(name)) return; + if (!(endswith(name, ".pyd") || endswith(name, ".dll"))) + return; + string new_name = name+".bk"; + LOGv << "move file" << name << "-> " << new_name; + if (file_exist(new_name)) + std::filesystem::remove(new_name); + std::filesystem::rename(name, new_name); +#endif +} + +static inline bool is_full_path(const string& name) { +#ifdef _WIN32 + return name.size()>=2 && (name[1]==':' || (name[0]=='\\' && name[1]=='\\')); +#else + return name.size() && name[0]=='/'; +#endif +} + +bool cache_compile(string cmd, const string& cache_path_, const string& jittor_path_) { + #ifdef _WIN32 + cmd = _to_winstr(cmd); + string cache_path = _to_winstr(cache_path_); + string jittor_path = _to_winstr(jittor_path_); + #else + const string& cache_path = cache_path_; + const string& jittor_path = jittor_path_; + #endif + vector input_names; + map> extra; + string output_name; + find_names(cmd, input_names, output_name, extra); + string output_cache_key; + bool ran = false; + if (file_exist(output_name)) + output_cache_key = read_all(output_name+".key"); + string cache_key; + unordered_set processed; + auto src_path = join(jittor_path, "src"); + const auto& extra_include = extra["I"]; + string tmp_dir =join(cache_path, "obj_files"); + for (size_t i=0; i new_names; + // *.obj, *.o, *.pyd + if (back != 'j' && back != 'o' && back != 'd') + process(src, new_names, cmd); + for (auto& name : new_names) { + string full_name; + if (name.substr(0, 4) == "jit/" || name.substr(0, 4) == "gen/") + full_name = join(cache_path, name); + else if (is_full_path(name)) + full_name = name; + else + full_name = join(src_path, name); + if (!file_exist(full_name)) { + bool found = 0; + for (const auto& inc : extra_include) { + full_name = join(inc, name); + if (file_exist(full_name)) { + found = 1; + break; + } + } + ASSERT(found) << "Include file" << name << "not found in" << extra_include + >> "\nCommands:" << cmd; + LOGvvvv << "Include file found:" << full_name; + } + input_names.push_back(full_name); + } + cache_key += "# "; + cache_key += input_names[i]; + cache_key += ": "; + cache_key += hash; + cache_key += "\n"; + } + cache_key = cmd + "\n" + cache_key; + if (output_cache_key.size() == 0) { + LOGvv << "Cache key of" << output_name << "not found."; + LOGvvv << "Run cmd:" << cmd; + check_win_file(output_name); + system_with_check(cmd.c_str(), tmp_dir.c_str()); + ran = true; + } + if (output_cache_key.size() != 0 && output_cache_key != cache_key) { + LOGvv << "Cache key of" << output_name << "changed."; + LOGvvv << "Run cmd:" << cmd; + check_win_file(output_name); + system_with_check(cmd.c_str(), tmp_dir.c_str()); + ran = true; + } + if (output_cache_key != cache_key) { + LOGvvvv << "Prev cache key" << output_cache_key; + LOGvvvv << "Write cache key" << output_name+".key:\n" >> cache_key; + write(output_name+".key", cache_key); + } + if (!ran) + LOGvvvv << "Command cached:" << cmd; + #ifdef TEST + if (ran) + write(output_name, "..."); + #endif + return ran; +} + +} // jit_compiler +} // jittor + +#ifdef TEST + +#include "test.h" + +static unordered_map files; + +namespace jittor { +namespace jit_compiler { + +string read_all(const string& fname) { + if (files.count(fname)) return files[fname]; + return ""; +} + +void write(const string& fname, const string& src) { + files[fname] = src; +} + +bool file_exist(const string& fname) { + return files.count(fname); +} + +} +} + +void test_find_names(string cmd, vector input_names, string output_name, map> extra={}) { + LOGvv << cmd; + vector inames; + string oname; + map> ename; + jittor::jit_compiler::find_names(cmd, inames, oname, ename); + CHECKop(oname,==,output_name); + CHECKop(inames.size(),==,input_names.size()); + for (size_t i=0; i inames; + string oname; + map> ename; + jittor::jit_compiler::find_names(cmd, inames, oname, ename); + }); +} + +void test_process(string src, vector files) { + vector ifiles; + string cmd; + jittor::jit_compiler::process(src, ifiles, cmd); + CHECK(files.size() == ifiles.size()); + for (size_t i=0; i", {}); + test_process("#include ", {}); + test_process("#include \"asd\"", {"asd"}); + test_process("//#include \"asd\"", {}); + test_process("/*#include \"asd\"*/", {}); + test_process("#include \"asd\"\n#include \"zxc\"", {"asd", "zxc"}); + + files = {{"src/a.h", "xxx"}, {"src/a.cc", "#include \"a.h\"\nxxx"}}; + CHECK(cache_compile("echo src/a.cc -o a.o")); + CHECK(files.count("a.o.key")); + CHECK(!cache_compile("echo src/a.cc -o a.o")); + files["src/a.h"] ="xxxx"; + CHECK(cache_compile("echo src/a.cc -o a.o")); + files["src/a.cc"] ="xxxx"; + CHECK(cache_compile("echo src/a.cc -o a.o")); + CHECK(cache_compile("echo src/a.cc -ff -o a.o")); + + // test include + files = {{"ex/a.h", "xxx"}, {"src/a.cc", "#include \"a.h\"\nxxx"}}; + CHECK(cache_compile("echo src/a.cc -Iex -o a.o")); + CHECK(files.count("a.o.key")); + CHECK(files["a.o.key"].find("ex/a.h") >= 0); + expect_error([&]() { + cache_compile("echo src/a.cc -o a.o"); + }); +} + +#endif diff --git a/python/jittor/src/utils/cache_compile.h b/python/jittor/src/utils/cache_compile.h new file mode 100644 index 00000000..adc6655a --- /dev/null +++ b/python/jittor/src/utils/cache_compile.h @@ -0,0 +1,20 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +namespace jittor { +namespace jit_compiler { + +string read_all(const string& fname); +void write(const string& fname, const string& src); +bool file_exist(const string& fname); +string join(string a, string b); +bool cache_compile(string cmd, const string& cache_path="", const string& jittor_path=""); + +} // jit_compiler +} // jittor \ No newline at end of file diff --git a/python/jittor/src/utils/cross_platform.h b/python/jittor/src/utils/cross_platform.h new file mode 100644 index 00000000..c75ed8a5 --- /dev/null +++ b/python/jittor/src/utils/cross_platform.h @@ -0,0 +1,58 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** + +#ifndef _WIN32 +#include +#ifdef __linux__ +#include +#endif +#include +#include +#include +#include +#else +#include +#include +#endif +#ifdef _MSC_VER +#include +#include +#define getpid _getpid +inline void sleep(int s) { Sleep(s*1000); } +#else +#include +#endif + + +#ifdef _MSC_VER + +// typedef struct timeval { +// long tv_sec; +// long tv_usec; +// } timeval; + +inline int gettimeofday(struct timeval * tp, struct timezone * tzp) +{ + // Note: some broken versions only have 8 trailing zero's, the correct epoch has 9 trailing zero's + // This magic number is the number of 100 nanosecond intervals since January 1, 1601 (UTC) + // until 00:00:00 January 1, 1970 + static const uint64_t EPOCH = ((uint64_t) 116444736000000000ULL); + + SYSTEMTIME system_time; + FILETIME file_time; + uint64_t time; + + GetSystemTime( &system_time ); + SystemTimeToFileTime( &system_time, &file_time ); + time = ((uint64_t)file_time.dwLowDateTime ) ; + time += ((uint64_t)file_time.dwHighDateTime) << 32; + + tp->tv_sec = (long) ((time - EPOCH) / 10000000L); + tp->tv_usec = (long) (system_time.wMilliseconds * 1000); + return 0; +} +#endif \ No newline at end of file diff --git a/python/jittor/src/utils/flags.cc b/python/jittor/src/utils/flags.cc new file mode 100644 index 00000000..cdc72069 --- /dev/null +++ b/python/jittor/src/utils/flags.cc @@ -0,0 +1,28 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** + +/* + +DEFINE_FLAG(string, jittor_path, "", "Source path of jittor"); +DEFINE_FLAG(string, cc_path, "", "Path of C++ compiler"); +DEFINE_FLAG(string, cc_type, "", "Type of C++ compiler(clang, icc, g++)"); +DEFINE_FLAG(string, cc_flags, "", "Flags of C++ compiler"); +DEFINE_FLAG(string, nvcc_path, "", "Path of CUDA C++ compiler"); +DEFINE_FLAG(string, nvcc_flags, "", "Flags of CUDA C++ compiler"); +DEFINE_FLAG(string, python_path, "", "Path of python interpreter"); +DEFINE_FLAG(string, cache_path, "", "Cache path of jittor"); +DEFINE_FLAG(int, rewrite_op, 1, "Rewrite source file of jit operator or not"); + +DEFINE_FLAG(int, check_graph, 0, "Unify graph sanity check."); + +DEFINE_FLAG(int, try_use_32bit_index, 0, + "If not overflow, try to use 32 bit type as index type."); + +DEFINE_FLAG(fast_shared_ptr, compile_options, {}, + "Override the default loop transfrom options"); + +*/ \ No newline at end of file diff --git a/python/jittor/src/utils/flags.h b/python/jittor/src/utils/flags.h new file mode 100644 index 00000000..2e6c5eba --- /dev/null +++ b/python/jittor/src/utils/flags.h @@ -0,0 +1,8 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "utils/log.h" \ No newline at end of file diff --git a/python/jittor/src/utils/jit_utils.cc b/python/jittor/src/utils/jit_utils.cc new file mode 100644 index 00000000..5ca64ddd --- /dev/null +++ b/python/jittor/src/utils/jit_utils.cc @@ -0,0 +1,599 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "utils/cache_compile.h" +#include "pyjt/py_converter.h" +#include "pyjt/py_arg_printer.h" +#ifdef __clang__ +// #pragma clang diagnostic ignored "-Wdefaulted-function-deleted" +#endif +#ifdef __GNUC__ +#endif +#ifdef __linux__ +#include +#endif +#include +#include +#include +#include +#include "utils/seh.h" + +namespace jittor { + +bool check_async_executor_error(const std::exception& e, std::ostream& os) { + if (!e.what()) return false; + auto s = string(e.what()); + if (s.find("executor.cc:") == string::npos) + return false; + os << s; + if (getenv("JT_SYNC") && getenv("trace_py_var")) + return true; + if (s.find("[Async Backtrace]: ---") != string::npos) + return true; + os << "\n**********\nAsync error was detected. " + "To locate the async backtrace and get better error report, please rerun your code with " + "two enviroment variables set:\n" + #ifdef _WIN32 + "cmd: \n" + ">>> set JT_SYNC=1\n" + ">>> set trace_py_var=3\n" + "powershell: \n" + ">>> $env:JT_SYNC=1\n" + ">>> $env:trace_py_var=3\n" + #else + ">>> export JT_SYNC=1\n" + ">>> export trace_py_var=3\n" + #endif + ; + return true; +} + +SEH_HOOK; + +void init_subprocess() { +#if defined(__linux__) && defined(PR_SET_PDEATHSIG) + prctl(PR_SET_PDEATHSIG, SIGKILL); +#endif +} + +static void __log( + const std::string& fileline, + char level, + int verbose, + const std::string& s) +{ + // return if verbose level not match + if (level=='i' && !( + jittor::log_vprefix.size() ? + jittor::check_vlog(fileline.c_str(), verbose) : + verbose <= jittor::log_v)) + return; + if (level != 'f') + jittor::LogVoidify() && + jittor::Log(fileline.c_str(), level, verbose) << s; + else + jittor::LogFatalVoidify() && + jittor::Log(fileline.c_str(), level, verbose) << s; +} + +// Buffer that writes to Python instead of C++ +class pythonbuf : public std::streambuf { +private: + using traits_type = std::streambuf::traits_type; + + const size_t buf_size; + std::unique_ptr d_buffer; + PyObject* _pywrite; + PyObject* _pyflush; + + int overflow(int c) override { + if (!traits_type::eq_int_type(c, traits_type::eof())) { + *pptr() = traits_type::to_char_type(c); + pbump(1); + } + return sync() == 0 ? traits_type::not_eof(c) : traits_type::eof(); + } + + // Computes how many bytes at the end of the buffer are part of an + // incomplete sequence of UTF-8 bytes. + // Precondition: pbase() < pptr() + size_t utf8_remainder() const { + const auto rbase = std::reverse_iterator(pbase()); + const auto rpptr = std::reverse_iterator(pptr()); + auto is_ascii = [](char c) { + return (static_cast(c) & 0x80) == 0x00; + }; + auto is_leading = [](char c) { + return (static_cast(c) & 0xC0) == 0xC0; + }; + auto is_leading_2b = [](char c) { + return static_cast(c) <= 0xDF; + }; + auto is_leading_3b = [](char c) { + return static_cast(c) <= 0xEF; + }; + // If the last character is ASCII, there are no incomplete code points + if (is_ascii(*rpptr)) + return 0; + // Otherwise, work back from the end of the buffer and find the first + // UTF-8 leading byte + const auto rpend = rbase - rpptr >= 3 ? rpptr + 3 : rbase; + const auto leading = std::find_if(rpptr, rpend, is_leading); + if (leading == rbase) + return 0; + const auto dist = static_cast(leading - rpptr); + size_t remainder = 0; + + if (dist == 0) + remainder = 1; // 1-byte code point is impossible + else if (dist == 1) + remainder = is_leading_2b(*leading) ? 0 : dist + 1; + else if (dist == 2) + remainder = is_leading_3b(*leading) ? 0 : dist + 1; + // else if (dist >= 3), at least 4 bytes before encountering an UTF-8 + // leading byte, either no remainder or invalid UTF-8. + // Invalid UTF-8 will cause an exception later when converting + // to a Python string, so that's not handled here. + return remainder; + } + + // This function must be non-virtual to be called in a destructor. If the + // rare MSVC test failure shows up with this version, then this should be + // simplified to a fully qualified call. + int _sync() { + if (pbase() != pptr()) { // If buffer is not empty + if (pbase() != pptr()) { // Check again under the lock + // This subtraction cannot be negative, so dropping the sign. + auto size = static_cast(pptr() - pbase()); + size_t remainder = utf8_remainder(); + + if (size > remainder) { + string line(pbase(), size - remainder); + pywrite(line); + pyflush(); + } + + // Copy the remainder at the end of the buffer to the beginning: + if (remainder > 0) + std::memmove(pbase(), pptr() - remainder, remainder); + setp(pbase(), epptr()); + pbump(static_cast(remainder)); + } + } + return 0; + } + + int sync() override { + return _sync(); + } + + void pywrite(const string& s) { + PyObjHolder pys(to_py_object(s)); + PyObjHolder args(PyTuple_New(1)); + PyTuple_SET_ITEM(args.obj, 0, pys.release()); + PyObjHolder ret(PyObject_Call(_pywrite, args.obj, nullptr)); + } + + void pyflush() { + PyObjHolder args(PyTuple_New(0)); + PyObjHolder ret(PyObject_Call(_pyflush, args.obj, nullptr)); + } + +public: + + pythonbuf(PyObject* pyostream, size_t buffer_size = 1024) + : buf_size(buffer_size), + d_buffer(new char[buf_size]) { + + PyObjHolder pywrite(PyObject_GetAttrString(pyostream, "write")); + _pywrite = pywrite.release(); + PyObjHolder pyflush(PyObject_GetAttrString(pyostream, "flush")); + _pyflush = pyflush.release(); + setp(d_buffer.get(), d_buffer.get() + buf_size - 1); + + } + + pythonbuf(pythonbuf&&) = default; + + /// Sync before destroy + ~pythonbuf() override { + _sync(); + } +}; + +static void ostream_redirect(bool _stdout, bool _stderr) { + if (_stdout) { + PyObjHolder a(PyImport_ImportModule("sys")); + PyObjHolder b(PyObject_GetAttrString(a.obj,"stdout")); + auto buf = new pythonbuf(b.obj); + std::cout.rdbuf(buf); + } + if (_stderr) { + PyObjHolder a(PyImport_ImportModule("sys")); + PyObjHolder b(PyObject_GetAttrString(a.obj,"stderr")); + auto buf = new pythonbuf(b.obj); + std::cerr.rdbuf(buf); + } +} + +static void pyjt_def_core(PyObject* m) { + static PyMethodDef defs[] = { + { R""(cache_compile)"", + + (PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* { + try { + ; + uint64 arg_filled=0; + (void)arg_filled; + + if (n+(kw?Py_SIZE(kw):0)<=3 && n+(kw?Py_SIZE(kw):0)>=1 && is_type(args[0])) { + + ; + string arg0 = from_py_object(args[0]); + + ; + string arg1; + if (n>1) { + CHECK((is_type(args[1]))); + arg1 = from_py_object(args[1]); + arg_filled |= 1ull << 1; + } + + ; + string arg2; + if (n>2) { + CHECK((is_type(args[2]))); + arg2 = from_py_object(args[2]); + arg_filled |= 1ull << 2; + } + + CHECK(!PyErr_Occurred()); + ; + + if (kw) { + auto kw_n = Py_SIZE(kw); + for (int i=0; i(vo))); + arg0 = from_py_object(vo); + arg_filled |= 1ull << 0; + continue; + } + + if (khash == 370544278u) { + // hash match cache_path + CHECK((is_type(vo))); + arg1 = from_py_object(vo); + arg_filled |= 1ull << 1; + continue; + } + + if (khash == 1219769050u) { + // hash match jittor_path + CHECK((is_type(vo))); + arg2 = from_py_object(vo); + arg_filled |= 1ull << 2; + continue; + } + + LOGf << "Not a valid keyword:" << ks; + } + } + + if (!(arg_filled & (1ull<<1))) { + arg1 = ""; + } + + if (!(arg_filled & (1ull<<2))) { + arg2 = ""; + } + ; + return to_py_object((jit_compiler::cache_compile(arg0,arg1,arg2))); + } + + LOGf << "Not a valid call."; + } catch (const std::exception& e) { + if (!PyErr_Occurred()) { + PyErr_Format(PyExc_RuntimeError, e.what()); + } + } + return nullptr; + } + , + METH_FASTCALL | METH_KEYWORDS, + R""(Declaration: +bool cache_compile(const string& cmd, const string& cache_path="", const string& jittor_path="") + +)"" + }, + { R""(log)"", + + (PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* { + try { + ; + uint64 arg_filled=0; + (void)arg_filled; + + if (n+(kw?Py_SIZE(kw):0)<=4 && n+(kw?Py_SIZE(kw):0)>=4 && is_type(args[0]) && PyUnicode_CheckExact(args[1]) && PyLong_CheckExact(args[2]) && is_type(args[3])) { + + ; + std::string arg0 = from_py_object(args[0]); + + ; + const char* arg1 = PyUnicode_AsUTF8(args[1]); + + ; + int arg2 = PyLong_AsLong(args[2]); + + ; + std::string arg3 = from_py_object(args[3]); + + CHECK(!PyErr_Occurred()); + ; + + if (kw) { + auto kw_n = Py_SIZE(kw); + for (int i=0; i(vo))); + arg0 = from_py_object(vo); + arg_filled |= 1ull << 0; + continue; + } + + if (khash == 1005433988u) { + // hash match level + CHECK((PyUnicode_CheckExact(vo))); + arg1 = PyUnicode_AsUTF8(vo); + arg_filled |= 1ull << 1; + continue; + } + + if (khash == 2796496354u) { + // hash match verbose + CHECK((PyLong_CheckExact(vo))); + arg2 = PyLong_AsLong(vo); + arg_filled |= 1ull << 2; + continue; + } + + if (khash == 115u) { + // hash match s + CHECK((is_type(vo))); + arg3 = from_py_object(vo); + arg_filled |= 1ull << 3; + continue; + } + + LOGf << "Not a valid keyword:" << ks; + } + } + ; + return GET_PY_NONE((__log(arg0,arg1[0],arg2,arg3))); + } + + LOGf << "Not a valid call."; + } catch (const std::exception& e) { + if (!PyErr_Occurred()) { + PyErr_Format(PyExc_RuntimeError, e.what()); + } + } + return nullptr; + } + , + METH_FASTCALL | METH_KEYWORDS, + R""(Declaration: +void log(const std::string& fileline, const char* level, int verbose, const std::string& s) + +)"" + }, + { R""(init_subprocess)"", + + (PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* { + try { + ; + uint64 arg_filled=0; + (void)arg_filled; + + if (n<=0 && n>=0) { + ; + ; + return GET_PY_NONE((init_subprocess())); + } + + LOGf << "Not a valid call."; + } catch (const std::exception& e) { + if (!PyErr_Occurred()) { + PyErr_Format(PyExc_RuntimeError, e.what()); + } + } + return nullptr; + } + , + METH_FASTCALL | METH_KEYWORDS, + R""(Declaration: +void init_subprocess() + +)"" + }, + { R""(log_capture_start)"", + + (PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* { + try { + ; + uint64 arg_filled=0; + (void)arg_filled; + + if (n<=0 && n>=0) { + ; + ; + return GET_PY_NONE((log_capture_start())); + } + + LOGf << "Not a valid call."; + } catch (const std::exception& e) { + if (!PyErr_Occurred()) { + PyErr_Format(PyExc_RuntimeError, e.what()); + } + } + return nullptr; + } + , + METH_FASTCALL | METH_KEYWORDS, + R""(Declaration: +void log_capture_start() + +)"" + }, + { R""(log_capture_stop)"", + + (PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* { + try { + ; + uint64 arg_filled=0; + (void)arg_filled; + + if (n<=0 && n>=0) { + ; + ; + return GET_PY_NONE((log_capture_stop())); + } + + LOGf << "Not a valid call."; + } catch (const std::exception& e) { + if (!PyErr_Occurred()) { + PyErr_Format(PyExc_RuntimeError, e.what()); + } + } + return nullptr; + } + , + METH_FASTCALL | METH_KEYWORDS, + R""(Declaration: +void log_capture_stop() + +)"" + }, + { R""(log_capture_read)"", + + (PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* { + try { + ; + uint64 arg_filled=0; + (void)arg_filled; + + if (n<=0 && n>=0) { + ; + ; + // return GET_PY_NONE((log_capture_stop())); + auto ret = log_capture_read(); + return to_py_object(move(ret)); + } + + LOGf << "Not a valid call."; + } catch (const std::exception& e) { + if (!PyErr_Occurred()) { + PyErr_Format(PyExc_RuntimeError, e.what()); + } + } + return nullptr; + } + , + METH_FASTCALL | METH_KEYWORDS, + R""(Declaration: +void log_capture_read() + +)"" + }, + { R""(ostream_redirect)"", + + (PyCFunction)(PyObject* (*)(PyObject*,PyObject**,int64,PyObject*))[](PyObject* self, PyObject** args, int64 n, PyObject* kw) -> PyObject* { + try { + ; + uint64 arg_filled=0; + (void)arg_filled; + + if (n+(kw?Py_SIZE(kw):0)<=2 && n+(kw?Py_SIZE(kw):0)>=2 && is_type(args[0]) && is_type(args[1])) { + + ; + bool arg0 = from_py_object(args[0]); + + ; + bool arg1 = from_py_object(args[1]); + + CHECK(!PyErr_Occurred()); + ; + + if (kw) { + auto kw_n = Py_SIZE(kw); + for (int i=0; i(vo))); + arg0 = from_py_object(vo); + arg_filled |= 1ull << 0; + continue; + } + + if (khash == 2600128022u) { + // hash match stderr + CHECK((is_type(vo))); + arg1 = from_py_object(vo); + arg_filled |= 1ull << 1; + continue; + } + + LOGf << "Not a valid keyword:" << ks; + } + } + ; + return GET_PY_NONE((ostream_redirect(arg0,arg1))); + } + + LOGf << "Not a valid call."; + } catch (const std::exception& e) { + if (!PyErr_Occurred()) { + PyErr_Format(PyExc_RuntimeError, e.what()); + } + } + return nullptr; + } + , + METH_FASTCALL | METH_KEYWORDS, + R""(Declaration: +void ostream_redirect(bool stdout, bool stderr) + +)"" + },{0,0,0,0} + }; + ASSERT(PyModule_AddFunctions(m, defs)==0); +} + +} + + +static void init_module(PyModuleDef* mdef, PyObject* m) { + mdef->m_doc = "Inner c++ core of jittor_utils"; + jittor::pyjt_def_core(m); +} +PYJT_MODULE_INIT(jit_utils_core); diff --git a/python/jittor/src/utils/log.cc b/python/jittor/src/utils/log.cc new file mode 100644 index 00000000..2a800c19 --- /dev/null +++ b/python/jittor/src/utils/log.cc @@ -0,0 +1,682 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#include +#include +#include +#include +#include "utils/cross_platform.h" +#include "utils/log.h" +#include "utils/mwsr_list.h" +#include "utils/str_utils.h" + +namespace jittor { + +bool peek_logged = 0; +typedef uint32_t uint; +using string = std::string; +using stringstream = std::stringstream; +using std::move; +template using unordered_map = std::unordered_map; + +template<> string get_from_env(const char* name, const string& _default) { + auto s = getenv(name); + if (s == NULL) return _default; + return string(s); +} + +uint32_t get_tid() { + stringstream ss; + ss << std::this_thread::get_id(); + uint32_t id = static_cast(std::stoull(ss.str())); + return id; +} + +static bool supports_color() { + #ifdef _WIN32 + HANDLE hOut = GetStdHandle(STD_OUTPUT_HANDLE); + if (hOut == INVALID_HANDLE_VALUE) return 0; + + DWORD dwMode = 0; + if (!GetConsoleMode(hOut, &dwMode)) return 0; + + dwMode |= ENABLE_VIRTUAL_TERMINAL_PROCESSING; + if (!SetConsoleMode(hOut, dwMode)) return 0; + return 1; + + #endif + bool term_supports_color = false; + const char* const term = getenv("TERM"); + if (term != NULL && term[0] != '\0') { + term_supports_color = + !strcmp(term, "xterm") || + !strcmp(term, "xterm-color") || + !strcmp(term, "xterm-256color") || + !strcmp(term, "screen-256color") || + !strcmp(term, "konsole") || + !strcmp(term, "konsole-16color") || + !strcmp(term, "konsole-256color") || + !strcmp(term, "screen") || + !strcmp(term, "linux") || + !strcmp(term, "cygwin"); + } + return term_supports_color; +} +bool g_supports_color = supports_color(); +string thread_local thread_name; + +struct timeval start_tv; + +struct tm get_start_tm() { + gettimeofday (&start_tv, NULL); + time_t t = start_tv.tv_sec; + return *localtime(&t); +} + +struct tm start_tm = get_start_tm(); + +void print_prefix(std::ostream* out) { + struct timeval tv; + gettimeofday (&tv, NULL); + struct tm lt = start_tm; + auto dt = tv.tv_sec - start_tv.tv_sec; + lt.tm_sec += dt; + lt.tm_min += lt.tm_sec / 60; lt.tm_sec %= 60; + lt.tm_hour += lt.tm_min / 60; lt.tm_min %= 60; + // localtime is slow, cache time call + if (lt.tm_hour >= 24) { + start_tm = get_start_tm(); + tv = start_tv; + lt = start_tm; + } + + auto usecs = tv.tv_usec; + + thread_local uint32_t tid = get_tid()%100; + + #define PRINT_W2(x) \ + char('0'+(x)/10%10) << char('0'+(x)%10) + #define PRINT_W6(x) \ + PRINT_W2((x)/10000) << PRINT_W2((x)/100) << PRINT_W2(x) + + *out << PRINT_W2(1+lt.tm_mon) + << PRINT_W2(lt.tm_mday) << ' ' + << PRINT_W2(lt.tm_hour) << ':' + << PRINT_W2(lt.tm_min) << ':' + << PRINT_W2(lt.tm_sec) << "." + << PRINT_W6(usecs) << ' ' + << PRINT_W2(tid); + if (thread_name.size()) + *out << ":" << thread_name; + *out << ' '; +} + +#ifdef LOG_ASYNC +MWSR_LIST(log, std::ostringstream); +#endif +DECLARE_FLAG(int, log_sync); + +std::mutex sync_log_m; +std::mutex sync_log_capture; +std::vector> logs; +int log_capture_enabled = 0; + +void log_capture(const string& s) { + // find [ and ] + uint i=0; + while (i+2 spaces; + spaces.reserve(5); + for (uint k=i; k log; + log["level"] = s.substr(i+1, spaces[0]-i-1); + log["verbose"] = spaces.size()==4 ? "0" : s.substr(spaces[3]+2, spaces[4]-spaces[3]-2); + // split asdad.cc:asd + uint l = spaces.back(); + while (l lg(sync_log_capture); + logs.emplace_back(std::move(log)); + } +} + +DECLARE_FLAG(int, log_silent); + +void send_log(std::ostringstream&& out, char level, int verbose) { + if (log_capture_enabled) + log_capture(out.str()); + if ((level=='i' || level=='w') && log_silent) return; + if (!log_sync) { + #if LOG_ASYNC + mwsr_list_log::push(move(out)); + #endif + } else { + std::lock_guard lk(sync_log_m); + // std::cerr << "[SYNC]"; + std::cerr << _to_winstr(out.str()); + std::cerr.flush(); + } +} + +void flush_log() { + if (!log_sync) { + #if LOG_ASYNC + mwsr_list_log::flush(); + #endif + } else { + std::cerr.flush(); + } +} + +void log_capture_start() { log_capture_enabled=1; } +void log_capture_stop() { log_capture_enabled=0; } +std::vector> log_capture_read() { + return move(logs); +} + +void log_exiting(); + +bool exited = false; +size_t thread_local protected_page = 0; +int segfault_happen = 0; +static int _pid = getpid(); +vector cleanup_callback; +vector sigquit_callback; +int64 last_q_time; + +string& get_thread_name() { + return thread_name; +} + +#ifdef _WIN32 +void handle_signal(int signal) { + std::cerr << "Caught SIGNAL " << signal << ", quick exit"; + std::cerr.flush(); + abort(); +} +#else +static inline void do_exit() { + #ifdef __APPLE__ + _Exit(1); + #else + std::quick_exit(1); + #endif +} + +void segfault_sigaction(int signal, siginfo_t *si, void *arg) { + if (signal == SIGQUIT) { + if (_pid == getpid()) { + std::cerr << "Caught SIGQUIT" << std::endl; + int64 now = clock(); + if (now > last_q_time && last_q_time+CLOCKS_PER_SEC/10 > now) { + last_q_time = now; + std::cerr << "GDB attach..." << std::endl; + breakpoint(); + } else { + last_q_time = now; + for (auto f : sigquit_callback) + f(); + } + } + return; + } + if (signal == SIGCHLD) { + if (si->si_code != CLD_EXITED && si->si_status != SIGTERM && _pid == getpid()) { + LOGe << "Caught SIGCHLD. Maybe out of memory, please reduce your worker size." + << "si_errno:" << si->si_errno + << "si_code:" << si->si_code + << "si_status:" << si->si_status + << ", quick exit"; + exited = true; + do_exit(); + } + return; + } + if (signal == SIGINT) { + if (_pid == getpid()) { + LOGe << "Caught SIGINT, quick exit"; + } + exited = true; + do_exit(); + } + if (exited) do_exit(); + std::cerr << "Caught segfault at address " << si->si_addr << ", " + << "thread_name: '" << thread_name << "', flush log..." << std::endl; + std::cerr.flush(); + if (protected_page && + si->si_addr>=(void*)protected_page && + si->si_addr<(void*)(protected_page+4*1024)) { + LOGf << "Accessing protect pages, maybe jit_key too long"; + } + if (!exited) { + exited = true; + if (signal == SIGSEGV) { + // only print trace in main thread + if (thread_name.size() == 0) + print_trace(); + std::cerr << "Segfault, exit" << std::endl; + } else { + std::cerr << "Get signal " << signal << ", exit" << std::endl; + } + } + segfault_happen = 1; + exit(1); +} +#endif + +int register_sigaction() { +#ifdef _WIN32 + signal(SIGINT, handle_signal); + signal(SIGTERM, handle_signal); + // signal(SIGABRT, handle_signal); + signal(SIGSEGV, handle_signal); + signal(SIGFPE, handle_signal); +#else + struct sigaction sa; + + memset(&sa, 0, sizeof(struct sigaction)); + sigemptyset(&sa.sa_mask); + sa.sa_sigaction = segfault_sigaction; + sa.sa_flags = SA_SIGINFO; + + sigaction(SIGSEGV, &sa, NULL); + sigaction(SIGKILL, &sa, NULL); + sigaction(SIGSTOP, &sa, NULL); + sigaction(SIGFPE, &sa, NULL); + // jupyter use sigint to interp + if (getenv("JPY_PARENT_PID") == nullptr) + sigaction(SIGINT, &sa, NULL); + sigaction(SIGCHLD, &sa, NULL); + sigaction(SIGILL, &sa, NULL); + sigaction(SIGBUS, &sa, NULL); + sigaction(SIGQUIT, &sa, NULL); + // sigaction(SIGABRT, &sa, NULL); +#endif + return 0; +} + +static int log_init() { + #ifdef _WIN32 + // SetConsoleCP(CP_UTF8); + // SetConsoleOutputCP(CP_UTF8); + #endif + register_sigaction(); + std::atexit(log_exiting); + return 1; +} + +int _log_init = log_init(); + +void log_main() { + #ifdef LOG_ASYNC + mwsr_list_log::reduce([&](const std::ostringstream& out) { + #ifdef TEST_LOG + string s = out.str(); + if (s[8] == 'm') std::cerr << s; + #else + std::cerr << out.str(); + #endif + }, [&]() { + std::cerr.flush(); + }); + #endif +} + +unordered_map vprefix_map; +void stream_hash(uint64_t& hash, char c) { + hash = hash * 257 + (uint8_t)c; +} + +DEFINE_FLAG(int, log_sync, 1, "Set log printed synchronously."); +DEFINE_FLAG(int, log_silent, 0, "The log will be completely silent."); +DEFINE_FLAG(int, log_v, 0, "Verbose level of logging"); +DEFINE_FLAG_WITH_SETTER(string, log_vprefix, "", + "Verbose level of logging prefix\n" + "example: log_vprefix='op=1,node=2,executor.cc:38$=1000'"); +void setter_log_vprefix(string value) { + unordered_map new_map; + auto& s = value; + for (uint i=0; i" << vnum; + uint64_t phash=0; + for (char c : prefix) stream_hash(phash, c); + new_map[phash] = vnum; + i = k; + } + vprefix_map = move(new_map); +} +DEFINE_FLAG_WITH_SETTER(string, log_file, "", + "log to file, mpi env will add $OMPI_COMM_WORLD_RANK suffix\n"); +void setter_log_file(string value) { + if (value.size() == 0) + return; + auto c = getenv("OMPI_COMM_WORLD_RANK"); + if (c) value += string("_") + c; + static std::ofstream out; + out = std::ofstream(value); + std::cerr.rdbuf(out.rdbuf()); +} + +bool check_vlog(const char* fileline, int verbose) { + uint64_t phash=0; + for (int i=0;; i++) { + char c = fileline[i]; + if (!c) c = '$'; + stream_hash(phash, c); + auto iter = vprefix_map.find(phash); + if (iter != vprefix_map.end()) + return verbose <= iter->second; + if (c=='$') break; + } + return verbose <= log_v; +} + +static inline void check_cuda_unsupport_version(const string& output) { + // check error like: + // /usr/include/crt/host_config.h:121:2: error: #error -- unsupported GNU version! gcc versions later than 6 are not supported! + // #error -- unsupported GNU version! gcc versions later than 6 are not supported! + string pat = "crt/host_config.h"; + auto id = output.find(pat); + if (id == string::npos) return; + auto end = id + pat.size(); + while (id>=0 && !(output[id]==' ' || output[id]=='\t' || output[id]=='\n')) + id--; + id ++; + auto fname = output.substr(id, end-id); + LOGw << R"( +!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +Dear user, your nvcc and gcc version are not match, +but you can hot fix it by this command: +>>> sudo python3 -c 's=open(")" >> fname >> R"(","r").read().replace("#error", "//#error");open(")" >> fname >> R"(","w").write(s)' + )"; +} + +static inline void check_cuda_gcc_version(const string& output) { + /* if such error occur: + error: identifier "__is_assignable" is undefined + this means your gcc version is not match with nvcc, + for example, nvcc 10 support gcc<=7, nvcc 11 support gcc<=9, + + https://gist.github.com/ax3l/9489132 + */ + string pat = "__is_assignable"; + auto id = output.find(pat); + if (id == string::npos) return; + LOGf << output << R"( +!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +Dear user, your nvcc and gcc version are still not match +after dirty hack, your should install the correct version of g++ +or nvcc, for example, nvcc 10 support g++<=7, nvcc 11 support g++<=9, +here is the NVCC Compatibility Matrix: + https://gist.github.com/ax3l/9489132 +Please install correct version of gcc, for example: + >>> sudo apt install g++-7 +After your g++ is installed, using enviroment variable `cc_path` to +tell jittor use the correct version of g++, for example: + >>> cc_path='g++-7' python3.7 -m jittor.test.test_core +If you still have problems, please contact us: + https://github.com/Jittor/jittor/issues + )"; +} + +#ifdef _WIN32 + +string GbkToUtf8(const char *src_str) +{ + int len = MultiByteToWideChar(CP_ACP, 0, src_str, -1, NULL, 0); + wchar_t* wstr = new wchar_t[len + 1]; + memset(wstr, 0, len + 1); + MultiByteToWideChar(CP_ACP, 0, src_str, -1, wstr, len); + len = WideCharToMultiByte(CP_UTF8, 0, wstr, -1, NULL, 0, NULL, NULL); + char* str = new char[len + 1]; + memset(str, 0, len + 1); + WideCharToMultiByte(CP_UTF8, 0, wstr, -1, str, len, NULL, NULL); + string strTemp = str; + if (wstr) delete[] wstr; + if (str) delete[] str; + return strTemp; +} + +string Utf8ToGbk(const char *src_str) +{ + int len = MultiByteToWideChar(CP_UTF8, 0, src_str, -1, NULL, 0); + wchar_t* wszGBK = new wchar_t[len + 1]; + memset(wszGBK, 0, len * 2 + 2); + MultiByteToWideChar(CP_UTF8, 0, src_str, -1, wszGBK, len); + len = WideCharToMultiByte(CP_ACP, 0, wszGBK, -1, NULL, 0, NULL, NULL); + char* szGBK = new char[len + 1]; + memset(szGBK, 0, len + 1); + WideCharToMultiByte(CP_ACP, 0, wszGBK, -1, szGBK, len, NULL, NULL); + string strTemp(szGBK); + if (wszGBK) delete[] wszGBK; + if (szGBK) delete[] szGBK; + return strTemp; +} + +int system_popen(const char *cmd, const char* cwd) { + HANDLE g_hChildStd_OUT_Rd = NULL; + HANDLE g_hChildStd_OUT_Wr = NULL; + SECURITY_ATTRIBUTES saAttr; + // Set the bInheritHandle flag so pipe handles are inherited. + + saAttr.nLength = sizeof(SECURITY_ATTRIBUTES); + saAttr.bInheritHandle = TRUE; + saAttr.lpSecurityDescriptor = NULL; + + // Create a pipe for the child process's STDOUT. + if (!CreatePipe(&g_hChildStd_OUT_Rd, &g_hChildStd_OUT_Wr, &saAttr, 0)) + LOGf << "StdoutRd CreatePipe error"; + // Ensure the read handle to the pipe for STDOUT is not inherited. + if (!SetHandleInformation(g_hChildStd_OUT_Rd, HANDLE_FLAG_INHERIT, 0)) + LOGf << "Stdout SetHandleInformation error"; + + // Create the child process. + PROCESS_INFORMATION piProcInfo; + STARTUPINFO siStartInfo; + BOOL bSuccess = FALSE; + // Set up members of the PROCESS_INFORMATION structure. + ZeroMemory(&piProcInfo, sizeof(PROCESS_INFORMATION)); + + // Set up members of the STARTUPINFO structure. + // This structure specifies the STDIN and STDOUT handles for redirection. + ZeroMemory(&siStartInfo, sizeof(STARTUPINFO)); + siStartInfo.cb = sizeof(STARTUPINFO); + siStartInfo.hStdError = g_hChildStd_OUT_Wr; + siStartInfo.hStdOutput = g_hChildStd_OUT_Wr; + siStartInfo.dwFlags |= STARTF_USESTDHANDLES; + + // Create the child process. + bSuccess = CreateProcess(NULL, + (char *)cmd, // command line + NULL, // process security attributes + NULL, // primary thread security attributes + TRUE, // handles are inherited + 0, // creation flags + NULL, // use parent's environment + cwd, // use cwd directory + &siStartInfo, // STARTUPINFO pointer + &piProcInfo); // receives PROCESS_INFORMATION + + // If an error occurs, exit the application. + if (!bSuccess) + LOGf << "CreateProcess error"; + // Close handles to the stdin and stdout pipes no longer needed by the child process. + // If they are not explicitly closed, there is no way to recognize that the child process has ended. + CloseHandle(g_hChildStd_OUT_Wr); + + DWORD dwRead, dwWritten; + CHAR chBuf[BUFSIZ]; + HANDLE hParentStdOut = GetStdHandle(STD_OUTPUT_HANDLE); + + + string output; + for (;;) + { + bSuccess = ReadFile(g_hChildStd_OUT_Rd, chBuf, BUFSIZ, &dwRead, NULL); + if (!bSuccess || dwRead == 0) + break; + output += string(chBuf, dwRead); + + if (log_v) + bSuccess = WriteFile(hParentStdOut, chBuf, + dwRead, &dwWritten, NULL); + if (!bSuccess) + break; + } + WaitForSingleObject(piProcInfo.hProcess, INFINITE); + DWORD ec; + GetExitCodeProcess(piProcInfo.hProcess, &ec); + // Close handles to the child process and its primary thread. + // Some applications might keep these handles to monitor the status + // of the child process, for example. + CloseHandle(piProcInfo.hProcess); + CloseHandle(piProcInfo.hThread); + if (ec && !log_v) + LOGe << output; + + if (ec) { + check_cuda_unsupport_version(output); + check_cuda_gcc_version(output); + } + return ec; +} +#else +int system_popen(const char* cmd, const char* cwd) { + char buf[BUFSIZ]; + string cmd2; + cmd2 = cmd; + cmd2 += " 2>&1 "; + FILE *ptr = popen(cmd2.c_str(), "r"); + if (!ptr) return -1; + string output; + while (fgets(buf, BUFSIZ, ptr) != NULL) { + output += buf; + if (log_v) + std::cerr << buf; + } + if (output.size()) std::cerr.flush(); + auto ret = pclose(ptr); + if (ret && !log_v) + std::cerr << output; + if (output.size()<10 && ret) { + // maybe overcommit + return -1; + } + if (ret) { + check_cuda_unsupport_version(output); + check_cuda_gcc_version(output); + } + return ret; +} +#endif + +void system_with_check(const char* cmd, const char* cwd) { + auto ret = system_popen(cmd, cwd); + CHECK(ret>=0 && ret<=256) << "Run cmd failed:" << cmd << + "\nreturn ">> ret >> ". This might be an overcommit issue or out of memory." + << "Try : sudo sysctl vm.overcommit_memory=1, or set enviroment variable `export DISABLE_MULTIPROCESSING=1`"; + CHECKop(ret,==,0) << "Run cmd failed:" << cmd; +} + +#ifdef LOG_ASYNC +std::thread log_thread(log_main); +#endif + +int log_exit = 0; + +void log_exiting() { + if (log_exit) return; + log_exit = true; + for (auto cb : cleanup_callback) + cb(); + cleanup_callback.clear(); +#ifdef LOG_ASYNC + mwsr_list_log::stop(); + log_thread.join(); +#endif +} + +} // jittor + + +void expect_error(std::function func) { + try { + func(); + } catch (...) { + return; + } + LOGf << "Missing error"; +} + +#ifdef TEST_LOG + +#include +#include +#include +#include "test.h" + + +DEFINE_FLAG (int, nthread, 4, "Number of thread"); + +void test_log_time(std::ostream* out) { + int n = 100000; + auto log_lot = [&]() { + auto start = std::chrono::high_resolution_clock::now(); + for (int i=0; i(finish-start).count(); + LOGi << "total_ns" << total_ns << "each_ns" << total_ns/n; + CHECKop(total_ns/n,<=,6500); + }; + std::list ts; + for (int i=0; i. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include +#include +#include +#include "types.h" + +namespace jittor { + +// define in tracer.cc +void print_trace(); +void breakpoint(); +#ifdef _WIN32 +string GbkToUtf8(const char *src_str); +string Utf8ToGbk(const char *src_str); +#define _to_winstr(x) Utf8ToGbk(x.c_str()) +#define _from_winstr(x) GbkToUtf8(x.c_str()) +#else +#define _to_winstr(x) (x) +#define _from_winstr(x) (x) +#endif + +constexpr int32_t basename_index(const char * const path, const int32_t index = 0, const int32_t slash_index = -1) { + return path[index] + ? ((path[index] == '/' || path[index] == '\\') + ? basename_index (path, index + 1, index) + : basename_index (path, index + 1, slash_index) + ) + : (slash_index + 1); +} + +#define STRINGIZE_DETAIL(x) #x +#define STRINGIZE(x) STRINGIZE_DETAIL(x) + +#define __FILELINE__ \ + (&((__FILE__ ":" STRINGIZE(__LINE__))[jittor::basename_index(__FILE__)])) + +#ifndef _WIN32 +#define PREDICT_BRANCH_NOT_TAKEN(x) (__builtin_expect(x, 0)) +#else +#define PREDICT_BRANCH_NOT_TAKEN(x) (x) +#endif + + +#ifdef _MSC_VER +#define STACK_ALLOC(T, a, n) T* a = (T*)_alloca(sizeof(T)*(n)) +#define EXTERN_LIB extern __declspec(dllimport) +#define EXPORT_LIB __declspec(dllimport) +#else +#define STACK_ALLOC(T, a, n) T a[n] +#define EXTERN_LIB extern +#define EXPORT_LIB +#endif + +EXTERN_LIB uint32_t get_tid(); +EXTERN_LIB bool g_supports_color; +EXTERN_LIB void print_prefix(std::ostream* out); + +#ifdef _WIN32 +constexpr char green[] = "\x1b[1;32m"; +constexpr char red[] = "\x1b[1;31m"; +constexpr char yellow[] = "\x1b[1;33m"; + + +inline static void get_color(char level, int verbose, const char*& color_begin, const char*& color_end) { + if (level == 'i' || level == 'I') { + if (verbose == 0) color_begin = "\x1b[1;32m"; else + if (verbose < 10) color_begin = "\x1b[1;32m"; else + if (verbose < 100) color_begin = "\x1b[1;32m"; else + if (verbose < 1000) color_begin = "\x1b[1;32m"; + else color_begin = "\x1b[1;32m"; + } else if (level == 'w') + color_begin = yellow; + else if (level == 'e') + color_begin = red; + else // level == 'f' + color_begin = red; + color_end = "\x1b[m"; +} + +#else +constexpr char green[] = "\033[38;5;2m"; +constexpr char red[] = "\033[38;5;1m"; +constexpr char yellow[] = "\033[38;5;3m"; + +inline static void get_color(char level, int verbose, const char*& color_begin, const char*& color_end) { + if (level == 'i' || level == 'I') { + if (verbose == 0) color_begin = "\033[38;5;2m"; else + if (verbose < 10) color_begin = "\033[38;5;250m"; else + if (verbose < 100) color_begin = "\033[38;5;244m"; else + if (verbose < 1000) color_begin = "\033[38;5;238m"; + else color_begin = "\033[38;5;232m"; + } else if (level == 'w') + color_begin = yellow; + else if (level == 'e') + color_begin = red; + else // level == 'f' + color_begin = red; + color_end = "\033[m"; +} + +#endif + +EXTERN_LIB void send_log(std::ostringstream&& out, char level, int verbose); +EXTERN_LIB void flush_log(); +EXTERN_LIB void log_capture_start(); +EXTERN_LIB void log_capture_stop(); +EXTERN_LIB std::vector> log_capture_read(); +EXTERN_LIB string& get_thread_name(); + +struct Log { + std::ostringstream out; + const char* color_end; + int verbose; + char level; + + inline Log(const char* const fileline, char level, int verbose) { + this->verbose = verbose; + this->level = level; + const char* color_begin; + get_color(level, verbose, color_begin, color_end); + if (g_supports_color) out << color_begin; + out << '[' << level << ' '; + print_prefix(&out); + if (verbose) out << 'v' << verbose << ' '; + out << fileline << ']'; + } + + inline void end() { + if (g_supports_color) out << color_end; + out << '\n'; + send_log(move(out), level, verbose); + } + inline void flush() { flush_log(); } + + template + Log& operator<<(const T& a) { out << ' ' << a; return *this; } + template + Log& operator>>(const T& a) { out << a; return *this; } +}; + +struct LogVoidify { + inline void operator&&(Log& log) { log.end(); } +}; + +struct LogFatalVoidify { + inline void operator&&(Log& log) { + log.flush(); + if (g_supports_color) log.out << log.color_end; + throw std::runtime_error(log.out.str()); + } +}; + +#define _LOGi(v) jittor::LogVoidify() && jittor::Log(__FILELINE__, 'i', v) +#define _LOGw(v) jittor::LogVoidify() && jittor::Log(__FILELINE__, 'w', v) +#define _LOGe(v) jittor::LogVoidify() && jittor::Log(__FILELINE__, 'e', v) +#define _LOGf(v) jittor::LogFatalVoidify() && jittor::Log(__FILELINE__, 'f', v) +#define LOGi _LOGi(0) +#define LOGw _LOGw(0) +#define LOGe _LOGe(0) +#define LOGf _LOGf(0) + +#define _LOG(level, v) _LOG ## level(v) +#define LOG(level) _LOG(level, 0) + +#define CHECK(cond) \ + LOG_IF(f, PREDICT_BRANCH_NOT_TAKEN(!(cond))) \ + << "Check failed: " #cond " " + +#define _LOG_IF(level, cond, v) \ + !(cond) ? (void) 0 : _LOG(level, v) +#define LOG_IF(level, cond) _LOG_IF(level, cond, 0) + +template T get_from_env(const char* name,const T& _default) { + auto ss = getenv(name); + if (ss == NULL) return _default; + string s = ss; + std::istringstream is(s); + T env; + if (is >> env) { + is.peek(); + if (!is) { + return env; + } + } + if (s.size() && is.eof()) + return env; + LOGw << "Load" << name << "from env(" << s << ") failed, use default" << _default; + return _default; +} + +template<> std::string get_from_env(const char* name, const std::string& _default); + +#define DECLARE_FLAG(type, name) \ +EXTERN_LIB type name; \ +EXTERN_LIB std::string doc_ ## name; \ +EXTERN_LIB void set_ ## name (const type&); + + +#ifdef JIT + +#define DEFINE_FLAG(type, name, default, doc) \ + DECLARE_FLAG(type, name) +#define DEFINE_FLAG_WITH_SETTER(type, name, default, doc, setter) \ + DECLARE_FLAG(type, name) + +#else + +#define DEFINE_FLAG(type, name, default, doc) \ + DECLARE_FLAG(type, name) \ + type name; \ + std::string doc_ ## name = doc; \ + void set_ ## name (const type& value) { \ + name = value; \ + }; \ + void init_ ## name (const type& value) { \ + name = value; \ + if (getenv(#name)) LOGi << "Load " #name":" << value; \ + }; \ + int caller_ ## name = (init_ ## name (jittor::get_from_env(#name, default)), 0); + +#define DEFINE_FLAG_WITH_SETTER(type, name, default, doc) \ + DECLARE_FLAG(type, name) \ + type name; \ + std::string doc_ ## name = doc; \ + void setter_ ## name (type value); \ + void set_ ## name (const type& value) { \ + setter_ ## name (value); \ + name = value; \ + }; \ + void init_ ## name (const type& value) { \ + setter_ ## name (value); \ + name = value; \ + if (getenv(#name)) LOGi << "Load " #name":" << value; \ + }; \ + int caller_ ## name = (init_ ## name (jittor::get_from_env(#name, default)), 0); + +#endif + +DECLARE_FLAG(int, log_v); +DECLARE_FLAG(std::string, log_vprefix); +bool check_vlog(const char* fileline, int verbose); + +#define V_ON(v) PREDICT_BRANCH_NOT_TAKEN(jittor::log_vprefix.size() ? \ + jittor::check_vlog(__FILELINE__, v) : \ + (v) <= jittor::log_v) + +#define LOGV(v) \ + _LOG_IF(i, jittor::log_vprefix.size() ? \ + jittor::check_vlog(__FILELINE__, v) : \ + (v) <= jittor::log_v, v) + +#define LOGv LOGV(1) +#define LOGvv LOGV(10) +#define LOGvvv LOGV(100) +#define LOGvvvv LOGV(1000) +#define CHECKop(a, op, b) LOG_IF(f, !((a) op (b))) \ + << "Check failed" \ + << #a "(" >> a >> ") " #op " " #b"(" >> b >> ")" + +#define ASSERT(s) CHECK(s) << "Something wrong... Could you please report this issue?\n" +#define ASSERTop(a, op, b) CHECKop(a, op, b) << "Something wrong ... Could you please report this issue?\n" + +#define LOGg LOGv >> jittor::green +#define LOGr LOGv >> jittor::red +#define LOGy LOGv >> jittor::yellow +#define LOGgg LOGvv >> jittor::green +#define LOGrr LOGvv >> jittor::red +#define LOGyy LOGvv >> jittor::yellow +#define LOGggg LOGvvv >> jittor::green +#define LOGrrr LOGvvv >> jittor::red +#define LOGyyy LOGvvv >> jittor::yellow +#define LOGgggg LOGvvvv >> jittor::green +#define LOGrrrr LOGvvvv >> jittor::red +#define LOGyyyy LOGvvvv >> jittor::yellow + +#define LOGI jittor::LogVoidify() && jittor::Log(__FILELINE__, 'I', 0) +#define LOGir LOGI >> jittor::red +#define LOGig LOGI >> jittor::green +#define LOGiy LOGI >> jittor::yellow + +void system_with_check(const char* cmd, const char* cwd=nullptr); + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/utils/mwsr_list.cc b/python/jittor/src/utils/mwsr_list.cc new file mode 100644 index 00000000..a53d45d0 --- /dev/null +++ b/python/jittor/src/utils/mwsr_list.cc @@ -0,0 +1,64 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** + +#include "utils/mwsr_list.h" + +#ifdef TEST + +#include +#include +#include + +using namespace std; + +MWSR_LIST(test, int64_t); + +int n, m, tnum; + +void reduce() { + int64_t sum=0; + mwsr_list_test::reduce([&](const int64_t& s) { + sum += s; + }, [](){}); + + int64_t expect = int64_t(m)*(m-1)/2*n*tnum; + cout << "get sum " << sum << ' ' << sum - expect << endl; + assert(expect == sum); +} + +void add() { + for (int i=0; i ts; + thread checker(reduce); + for (int i=0; i. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include +#include +#include +#include + +// fast multi writer single reader list +#define MWSR_LIST(name, T) \ +namespace mwsr_list_ ## name { \ + using std::list; \ + using std::vector; \ + using std::function; \ + \ + typedef T mylist_t; \ + list> glist; \ + list::iterator> glist_iter; \ + std::mutex glist_mutex; \ + std::condition_variable cv; \ + std::mutex mm; \ + bool _stop; \ + bool _flush; \ + \ + void clear() { \ + std::lock_guard lk(glist_mutex); \ + glist.clear(); \ + glist_iter.clear(); \ + _stop = false; \ + _flush = false; \ + } \ + \ + void flush() { \ + { \ + std::lock_guard lk(mm); \ + _flush = true; \ + } \ + cv.notify_one(); \ + } \ + \ + void stop() { \ + { \ + std::lock_guard lk(mm); \ + _stop = true; \ + } \ + cv.notify_one(); \ + } \ + \ + void init() { \ + std::lock_guard lk(glist_mutex); \ + _stop = false; \ + _flush = false; \ + auto titer = glist_iter.begin(); \ + for (auto& tlist : glist) { \ + tlist.clear(); \ + *titer = tlist.end(); \ + titer ++; \ + } \ + } \ + \ + list* create_tlist() { \ + std::lock_guard lk(glist_mutex); \ + glist.emplace_back(); \ + auto tlist = &glist.back(); \ + glist_iter.push_back(tlist->end()); \ + return tlist; \ + } \ + \ + thread_local list* tlist = create_tlist(); \ + \ + void push(mylist_t &&s) { \ + tlist->emplace_back(move(s)); \ + cv.notify_one(); \ + } \ + \ + void reduce(function func, function flush_func) { \ + thread_local vector*> gvlist; \ + thread_local vector::iterator*> gvlist_iter; \ + gvlist.clear(); \ + gvlist_iter.clear(); \ + int stop2=0; \ + int flush2=0; \ + while (1) { \ + int found = 0; \ + if (gvlist.size() != glist.size()) { \ + std::lock_guard lk(glist_mutex); \ + gvlist.clear(); \ + gvlist_iter.clear(); \ + for (auto &tlist : glist) \ + gvlist.push_back(&tlist); \ + for (auto &tlist_iter : glist_iter) \ + gvlist_iter.push_back(&tlist_iter); \ + } \ + \ + auto list_iter = gvlist_iter.begin(); \ + for (auto tlist : gvlist) { \ + auto& last = **list_iter; \ + if (last == tlist->end()) { \ + last = tlist->begin(); \ + if (last != tlist->end()) { \ + func(*last); \ + found++; \ + } \ + } \ + while (last != tlist->end()) { \ + auto nlast = next(last); \ + if (nlast != tlist->end()) { \ + func(*nlast); \ + last = nlast; \ + tlist->pop_front(); \ + found++; \ + } else break; \ + } \ + list_iter ++; \ + } \ + if (!found) { \ + std::unique_lock lk(mm); \ + if (_flush) { \ + _flush = false; \ + flush2 = 1; \ + lk.unlock(); \ + continue; \ + } \ + if (flush2) { \ + flush2 = 0; \ + flush_func(); \ + } \ + if (_stop) { \ + if (stop2>0) { \ + lk.unlock(); \ + break; \ + } else { \ + stop2 ++; \ + lk.unlock(); \ + continue; \ + } \ + } \ + cv.wait(lk); \ + lk.unlock(); \ + } \ + } \ + init(); \ + } \ +} // mwsr_list diff --git a/python/jittor/src/utils/seh.h b/python/jittor/src/utils/seh.h new file mode 100644 index 00000000..0e460596 --- /dev/null +++ b/python/jittor/src/utils/seh.h @@ -0,0 +1,194 @@ + +#pragma once +#ifdef _WIN32 +#include +#include +#include +#include +#include "common.h" + +namespace jittor { + +using std::stringstream; + +inline void raise_win_error(int ierr) { + DWORD err = (DWORD)ierr; + WCHAR *s_buf = NULL; /* Free via LocalFree */ + stringstream message; + + if (err==0) { + err = GetLastError(); + } + + auto len = FormatMessageW( + /* Error API error */ + FORMAT_MESSAGE_ALLOCATE_BUFFER | + FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, /* no message source */ + err, + MAKELANGID(LANG_NEUTRAL, + SUBLANG_DEFAULT), /* Default language */ + (LPWSTR) &s_buf, + 0, /* size not used */ + NULL); /* no args */ + + if (len==0) { + /* Only seen this in out of mem situations */ + message << "Windows Error " << err; + s_buf = NULL; + } else { + /* remove trailing cr/lf and dots */ + while (len > 0 && (s_buf[len-1] <= L' ' || s_buf[len-1] == L'.')) + s_buf[--len] = L'\0'; + message << s_buf; + } + if (s_buf) + LocalFree(s_buf); + throw std::runtime_error(message.str()); +} + +inline void raise_cxx_exception(unsigned int code, _EXCEPTION_POINTERS* pExp) { + std::cerr << "raise_cxx_exception " << code << std::endl; + EXCEPTION_RECORD* pr = pExp->ExceptionRecord; + + /* The 'code' is a normal win32 error code so it could be handled by + raise_win_error(). However, for some errors, we have additional + information not included in the error code. We handle those here and + delegate all others to the generic function. */ + stringstream message; + switch (code) { + case EXCEPTION_ACCESS_VIOLATION: + /* The thread attempted to read from or write + to a virtual address for which it does not + have the appropriate access. */ + if (pr->ExceptionInformation[0] == 0) + message << "exception: access violation reading " << (void*)pr->ExceptionInformation[1]; + else + message << "exception: access violation writing " << (void*)pr->ExceptionInformation[1]; + break; + + case EXCEPTION_BREAKPOINT: + /* A breakpoint was encountered. */ + message << "exception: breakpoint encountered"; + break; + + case EXCEPTION_DATATYPE_MISALIGNMENT: + /* The thread attempted to read or write data that is + misaligned on hardware that does not provide + alignment. For example, 16-bit values must be + aligned on 2-byte boundaries, 32-bit values on + 4-byte boundaries, and so on. */ + message << "exception: datatype misalignment"; + break; + + case EXCEPTION_SINGLE_STEP: + /* A trace trap or other single-instruction mechanism + signaled that one instruction has been executed. */ + message << "exception: single step"; + break; + + case EXCEPTION_ARRAY_BOUNDS_EXCEEDED: + /* The thread attempted to access an array element + that is out of bounds, and the underlying hardware + supports bounds checking. */ + message << "exception: array bounds exceeded"; + break; + + case EXCEPTION_FLT_DENORMAL_OPERAND: + /* One of the operands in a floating-point operation + is denormal. A denormal value is one that is too + small to represent as a standard floating-point + value. */ + message << "exception: floating-point operand denormal"; + break; + + case EXCEPTION_FLT_DIVIDE_BY_ZERO: + /* The thread attempted to divide a floating-point + value by a floating-point divisor of zero. */ + message << "exception: float divide by zero"; + break; + + case EXCEPTION_FLT_INEXACT_RESULT: + /* The result of a floating-point operation cannot be + represented exactly as a decimal fraction. */ + message << "exception: float inexact"; + break; + + case EXCEPTION_FLT_INVALID_OPERATION: + /* This exception represents any floating-point + exception not included in this list. */ + message << "exception: float invalid operation"; + break; + + case EXCEPTION_FLT_OVERFLOW: + /* The exponent of a floating-point operation is + greater than the magnitude allowed by the + corresponding type. */ + message << "exception: float overflow"; + break; + + case EXCEPTION_FLT_STACK_CHECK: + /* The stack overflowed or underflowed as the result + of a floating-point operation. */ + message << "exception: stack over/underflow"; + break; + + case EXCEPTION_STACK_OVERFLOW: + /* The stack overflowed or underflowed as the result + of a floating-point operation. */ + message << "exception: stack overflow"; + break; + + case EXCEPTION_FLT_UNDERFLOW: + /* The exponent of a floating-point operation is less + than the magnitude allowed by the corresponding + type. */ + message << "exception: float underflow"; + break; + + case EXCEPTION_INT_DIVIDE_BY_ZERO: + /* The thread attempted to divide an integer value by + an integer divisor of zero. */ + message << "exception: integer divide by zero"; + break; + + case EXCEPTION_INT_OVERFLOW: + /* The result of an integer operation caused a carry + out of the most significant bit of the result. */ + message << "exception: integer overflow"; + break; + + case EXCEPTION_PRIV_INSTRUCTION: + /* The thread attempted to execute an instruction + whose operation is not allowed in the current + machine mode. */ + message << "exception: privileged instruction"; + break; + + case EXCEPTION_NONCONTINUABLE_EXCEPTION: + /* The thread attempted to continue execution after a + noncontinuable exception occurred. */ + message << "exception: nocontinuable"; + break; + + case 0xE06D7363: + /* magic number(0xE06D7363) of c++ exception: + https://devblogs.microsoft.com/oldnewthing/20100730-00/?p=13273 + */ + message << "Error c++ exception"; + break; + + default: + raise_win_error(code); + break; + } + // std::cout << message.str() << std::endl; + throw std::runtime_error(message.str()); +} + +} +#define SEH_HOOK int _seh_hook = (_set_se_translator(raise_cxx_exception), 0) +#else +#define SEH_HOOK +#endif \ No newline at end of file diff --git a/python/jittor/src/utils/str_utils.cc b/python/jittor/src/utils/str_utils.cc new file mode 100644 index 00000000..34995168 --- /dev/null +++ b/python/jittor/src/utils/str_utils.cc @@ -0,0 +1,241 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include "utils/str_utils.h" + +namespace jittor { + + +bool startswith(const string& a, const string& b, uint start, bool equal, uint end) { + if (!end) end = a.size(); + if (b.size()+start > end) return false; + if (equal && b.size()+start != end) return false; + for (uint i=0; i split(const string& s, const string& sep, int max_split) { + vector ret; + int pos = 0, pos_next; + while (1) { + pos_next = s.find(sep, pos); + if (pos_next == (int)string::npos || (int)ret.size() == max_split-1) { + ret.push_back(s.substr(pos)); + return ret; + } + ret.push_back(s.substr(pos, pos_next-pos)); + pos = pos_next + sep.size(); + } + ASSERT(max_split==0); + return ret; +} + +string strip(const string& s) { + int i=0; + while (ii && (s[j-1]==' ' || s[j-1]=='\t' || s[j-1]=='\n' || s[j-1]=='\r')) j--; + return s.substr(i,j-i); +} + +string format(const string& s, const vector& v) { + string ss; + for (int i=0; i& vs, const string& x) { + string s; + for (int i=0; i token_split(const string& s, bool exclude_comments) { + vector ss; + if (!s.size()) return ss; + ss.push_back(""); + for (int i = 0; i < s.size(); i++) { + if (exclude_comments) { + if (s[i] == '/' && s[i+1] == '/') { + i = s.find('\n', i); + if (i == string::npos) + return ss; + } + if (s[i] == '/' && s[i+1] == '*') { + i = s.find("*/", i); + if (i == string::npos) + return ss; + i += 1; + continue; + } + } + if (i && (isvar(s[i]) != isvar(s[i-1]))) + ss.push_back(""); + ss.back() += s[i]; + } + return ss; +} + +static void parse_reg(const string& src, + vector& patterns, + vector& arg_id, + bool match_whitespace=true) { + patterns.clear(); + arg_id.clear(); + patterns.push_back(""); + for (int j=0; j& tokens, int i, const string& src, const string& dst, bool match_whitespace) { + if (!(src.at(0) != '$' && src.at(src.size()-1) != '$' && + src.at(src.size()-2) != '$')) { + LOGe << "illegal src:" << src; + LOGf << "illegal src:" << src; + } + ASSERT(src.at(0) != '$' && src.at(src.size()-1) != '$' && + src.at(src.size()-2) != '$') << "illegal src:" << src; + vector patterns; + vector arg_id; + vector patterns2; + vector arg_id2; + unordered_map args; + parse_reg(src, patterns, arg_id, match_whitespace); + parse_reg(dst, patterns2, arg_id2); + + int start_i, start_pos, end_i, end_pos; + int c_i = i, c_pos = 0; + int match_i, match_pos; + string c_arg; + + auto next = [&tokens](int &c_i, int &c_pos) { + c_pos ++; + if (c_pos >= tokens[c_i].size()) { + c_pos = 0; + c_i ++; + if (c_i >= tokens.size()) + return false; + } + return true; + }; + + auto match = [&](int c_i, int c_pos, const string& pat) -> bool { + for (int i=0; i ss{s}; + token_replace(ss, 0, src, dst, match_whitespace); + return join(ss, ""); +} + +string token_replace_all(const string& s, const string& src, const string& dst) { + auto ss = token_split(s); + int pos = 0; + while (pos < ss.size()) { + try { + pos = token_replace(ss, pos, src, dst) + 1; + } + catch(const std::exception& e) { + return join(ss, ""); + } + } + return join(ss, ""); +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/utils/str_utils.h b/python/jittor/src/utils/str_utils.h new file mode 100644 index 00000000..916750c4 --- /dev/null +++ b/python/jittor/src/utils/str_utils.h @@ -0,0 +1,43 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +namespace jittor { + +// a: main string +// b: pattern string +// start: start index(include) +// equal: match reach the end +// end: end index(exclude) +bool startswith(const string& a, const string& b, uint start=0, bool equal=false, uint end=0); + +// a: main string +// b: pattern string +bool endswith(const string& a, const string& b); + +// s: main string +// sep: pattern string for split +// max_split: maximun split number(include) +vector split(const string& s, const string& sep, int max_split=0); + +string strip(const string& s); + +string format(const string& s, const vector& v); + +string replace(const string& a, const string& b, const string& c); + +string join(const vector& vs, const string& x); + +vector token_split(const string& s, bool exclude_comments=false); + +int token_replace(vector& tokens, int i, const string& src, const string& dst, bool match_whitespace=true); + +string token_replace(const string& s, const string& src, const string& dst); + +string token_replace_all(const string& s, const string& src, const string& dst); +} // jittor \ No newline at end of file diff --git a/python/jittor/src/utils/tracer.cc b/python/jittor/src/utils/tracer.cc new file mode 100644 index 00000000..0aba2f9f --- /dev/null +++ b/python/jittor/src/utils/tracer.cc @@ -0,0 +1,223 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#include +#include "utils/cross_platform.h" +#include "utils/tracer.h" + +namespace jittor { + +DEFINE_FLAG_WITH_SETTER(string, gdb_path, "", "Path of GDB."); +DEFINE_FLAG(string, addr2line_path, "", "Path of addr2line."); +DEFINE_FLAG(string, extra_gdb_cmd, "", "Extra command pass to GDB, seperate by(;) ."); +DEFINE_FLAG(int, has_pybt, 0, "GDB has pybt or not."); +DEFINE_FLAG(int, trace_depth, 10, "trace depth for GDB."); +DEFINE_FLAG_WITH_SETTER(int, gdb_attach, 0, "gdb attach self process."); + +string _extra_gdb_cmd; + +int system_popen(const char* cmd, const char* cwd=nullptr); + +#ifdef _WIN32 +string get_cmds(const vector& argv) { + auto cmds = gdb_path; + for (auto p : argv) { + if (!p) continue; + string cmd = p; + cmds += " "; + if (cmd.find(' ') != string::npos && cmd[0] != '"') + cmds += '"' + cmd + '"'; + else + cmds += cmd; + } + return cmds; +} +#endif + +void setter_gdb_attach(int v) { + if (v && gdb_path.size()) { + static int gdb_attached = 0; + if (gdb_attached) return; + gdb_attached = 1; + // using gdb to print the stack trace + char pid_buf[30]; + sprintf(pid_buf, "%d", getpid()); + + vector argv{ + gdb_path.c_str(), + "-ex", "catch throw" + }; + if (auto n = extra_gdb_cmd.size()) { + _extra_gdb_cmd = extra_gdb_cmd; + _extra_gdb_cmd += '\0'; + argv.push_back("-ex"); + argv.push_back(&_extra_gdb_cmd[0]); + for (uint i=0; i. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" + +namespace jittor { + +void print_trace(); +void breakpoint(); + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/utils/vdp b/python/jittor/src/utils/vdp new file mode 100644 index 00000000..ebfa310b --- /dev/null +++ b/python/jittor/src/utils/vdp @@ -0,0 +1 @@ +#define _P(...) \ No newline at end of file diff --git a/python/jittor/src/var.cc b/python/jittor/src/var.cc new file mode 100644 index 00000000..133f00a2 --- /dev/null +++ b/python/jittor/src/var.cc @@ -0,0 +1,170 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include + +#include "var.h" +#include "op.h" +#include "mem/allocator.h" +#include "pybind/py_var_tracer.h" +#include "mem/swap.h" + +namespace jittor { + +int64 Var::number_of_lived_vars = 0; + +DEFINE_FLAG(fast_shared_ptr, compile_options, {}, + "Override the default loop transfrom options"); +DEFINE_FLAG(bool, no_grad, 0, + "No grad for all jittor Var creation"); +DEFINE_FLAG(bool, no_fuse, 0, + "No fusion optimization for all jittor Var creation"); +DEFINE_FLAG(uint8, node_order, 0, "id prior"); +DEFINE_FLAG(uint8, th_mode, 0, "th mode"); +// TODO: fuse multiple flags +DEFINE_FLAG(int, amp_reg, 0, "Auto mixed-precision control registers, bit 0: prefer 32; bit 1: prefer 16; bit 2: keep reduce type; bit 3 keep white list type; bit 4: array like op prefer too; bit 5, reduce16 intermediate not use 32"); + +DEFINE_FLAG_WITH_SETTER(int, auto_mixed_precision_level, 0, "Auto mixed-precision optimization level, 0: not use fp16, 1-3: preserve level, not use fp16 for now; 4: perfer fp16, but some ops use fp32 e.g. sum,exp; 5: simular with 4, and array op will automatically convert to fp16; 6: all ops prefer fp16"); + +void (*_var_free_hook)(Var*); + +void free_var(Var* v) { + if (PREDICT_BRANCH_NOT_TAKEN((bool)_var_free_hook)) _var_free_hook(v); + Var::number_of_lived_vars--; + if (save_mem) + free_with_swap(v); + else + if (v->mem_ptr != nullptr) { + auto mem_ptr = v->mem_ptr; + auto allocation = v->allocation; + auto allocator = v->allocator; + v->mem_ptr = nullptr; + v->allocator = nullptr; + v->allocation = 0; + allocator->free(mem_ptr, v->size, allocation); + } +} + +void free_var_mem(Var* v) { + if (save_mem) + free_with_swap(v); + else { + auto mem_ptr = v->mem_ptr; + auto allocation = v->allocation; + auto allocator = v->allocator; + v->mem_ptr = nullptr; + v->allocator = nullptr; + v->allocation = 0; + allocator->free(mem_ptr, v->size, allocation); + } +} + +void setter_auto_mixed_precision_level(int value) { + if (value <= 2) amp_reg = 0; else + if (value == 3) amp_reg = amp_keep_reduce | amp_keep_white; else + if (value == 4) amp_reg = amp_prefer16; else + if (value == 5) amp_reg = amp_prefer16 | amp_array_prefer; else + if (value == 6) amp_reg = amp_prefer16 | amp_array_prefer | amp_keep_reduce | amp_keep_white; +} + +Var::Var(NanoVector shape, NanoString dtype) + : shape(shape), + loop_options(compile_options) { + flags.set(NodeFlags::_var, 1); + flags.set(NodeFlags::_stop_grad, !dtype.is_float() || no_grad); + flags.set(NodeFlags::_stop_fuse, no_fuse); + ns = dtype; + ASSERT(ns.is_dtype()); + number_of_lived_vars++; + numel(); + if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.record_node(this); +} + +string Var::to_string() { + string s = dtype().to_cstring(); + s += shape.to_string(); + return s; +} + +int64 Var::numel() { + bool negtive = 0; + num=1; + for (auto k : shape) { + if (k<0) { + negtive = 1; + num *= -k; + } else { + num *= k; + } + } + size = num * dsize(); + if (negtive) num = -num; + if (shape.size() == 0) {shape.push_back(1);} + return num; +} + +void Var::set_shape(NanoVector shape) { + this->shape = shape; + numel(); +} + +bool Var::alloc(Allocator* allocator) { + if (mem_ptr) return true; + if (auto* x = (Var*)(this->allocator)) { + if (x->allocator->share_with(size, x->allocation)) { + mem_ptr = ((char*) x->mem_ptr) + allocation; + allocation = x->allocation; + this->allocator = x->allocator; + return true; + } + } + mem_ptr = allocator->alloc(size, allocation); + this->allocator = allocator; + return mem_ptr; +} + +VarPtr clone(Var* x); +void VarPtr::set_stop_grad(bool stop_grad) { + if (stop_grad == ptr->is_stop_grad()) return; + if (stop_grad) + ptr->set_stop_grad(); + else { + bool no_grad_bk = no_grad; + auto th_mode_bk = th_mode; + no_grad = 0; + th_mode = 0; + *this = clone(ptr); + no_grad = no_grad_bk; + th_mode = th_mode_bk; + } +} + +std::ostream& operator<<(std::ostream& os, const Var& var) { + os << "Var" << '(' << var.id + << ':' << var.forward_liveness + << ':' << var.backward_liveness + << ':' << var.pending_liveness + << ":i" << var._inputs.size() + << ":o" << var._outputs.size() + << ":s" << var.is_finished() + << ":n" << var.flags.get(NodeFlags::_needed_by_backward) + << ":g" << !var.is_stop_grad() + << ',' + << var.dtype().to_cstring() << ',' << var.name << ',' << std::hex <<(uint64)var.mem_ptr << std::dec + << ')' << var.shape; + if (trace_py_var) { + os << '{'; + print_node_trace(&var, os); + os << '}'; + } + return os; +} +std::ostream& operator<<(std::ostream& os, const Var* var) { + return os << *var; +} +std::ostream& operator<<(std::ostream& os, const VarPtr& v) { return os << v.ptr; } + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/var.h b/python/jittor/src/var.h new file mode 100644 index 00000000..cfd1c81a --- /dev/null +++ b/python/jittor/src/var.h @@ -0,0 +1,107 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include "node.h" +#include "misc/cstr.h" +#include "misc/fast_shared_ptr.h" + +namespace jittor { + +constexpr size_t alignment = 32; +struct VarHolder; + +struct Var : Node { + NanoVector shape; + cstr name; + fast_shared_ptr loop_options; + static int64 number_of_lived_vars; + + // this var will be generated after alloc. + void* mem_ptr = nullptr; + Allocator* allocator = nullptr; + size_t allocation; + int64 size, num; + VarHolder* holder = nullptr; + inline bool is_float() const { CHECK_EXIST; return ns.is_float(); } + inline int dsize() const { CHECK_EXIST; return ns.dsize(); } + inline NanoString dtype() const { CHECK_EXIST; return ns; } + inline NanoString& dtype() { CHECK_EXIST; return ns; } + template + inline T* ptr() { CHECK_EXIST; return (T*)mem_ptr; } + inline Op* input() { CHECK_EXIST; return _inputs.size() ? (Op*)_inputs.front() : (Op*)nullptr; } + inline Caster outputs() { CHECK_EXIST; return &_outputs; } + inline Caster outputs_with_index() { CHECK_EXIST; return &_outputs; } + inline Op* input(uint i) { return Node::input(i)->op(); } + inline Op* output(uint i) { return Node::output(i)->op(); } + + Var(NanoVector shape, NanoString dtype); + + string to_string(); + int64 numel(); + void set_shape(NanoVector shape); + bool alloc(Allocator* allocator); + inline void share_with(Var* x, size_t offset = 0) { CHECK_EXIST; allocator = (Allocator*)x; allocation = offset; } +}; + +struct VarPtr { + Var* ptr; + + inline + VarPtr(Var* ptr=nullptr) : ptr(ptr) { + if (ptr) { + ptr->own_both_liveness(); + } + } + + inline + VarPtr(VarPtr&& other) { + ptr = other.ptr; + other.ptr = nullptr; + } + + inline + VarPtr(const VarPtr& other) : VarPtr(other.ptr) { + } + + inline + VarPtr(NanoVector shape, NanoString dtype) { + ptr = new Var(shape, dtype); + ptr->own_both_liveness(); + } + + inline + ~VarPtr() { free_liveness(); } + + inline + void free_liveness() { + if (ptr) { + auto tmp = ptr; + ptr = nullptr; + tmp->release_both_liveness(); + } + } + + inline Var* operator->() { return ptr; } + inline operator Var*() { return ptr; } + inline operator bool() { return ptr; } + + inline VarPtr& operator=(VarPtr&& other) { + free_liveness(); + ptr = other.ptr; + other.ptr = nullptr; + return *this; + } + + void set_stop_grad(bool stop_grad); +}; + +std::ostream& operator<<(std::ostream& os, const Var& var); +std::ostream& operator<<(std::ostream& os, const Var* var); +std::ostream& operator<<(std::ostream& os, const VarPtr& v); + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/var_holder.cc b/python/jittor/src/var_holder.cc new file mode 100644 index 00000000..ed4d4c0a --- /dev/null +++ b/python/jittor/src/var_holder.cc @@ -0,0 +1,381 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#ifdef HAS_CUDA +#include +#include "helper_cuda.h" +#endif +#include "var_holder.h" +#include "var.h" +#include "executor.h" +#include "graph.h" +#include "mem/allocator/cuda_dual_allocator.h" +#include "ops/op_register.h" +#include "ops/getitem_op.h" +#include "ops/setitem_op.h" +#include "type/fp16_compute.h" +#include "mem/swap.h" + +namespace jittor { + +DEFINE_FLAG(int, lazy_execution, 1, "Default enabled, if disable, use immediately eager execution rather than lazy execution, This flag makes error message and traceback infomation better. But this flag will raise memory consumption and lower the performance."); + +list hold_vars; +list::iterator sync_ptr = hold_vars.end(); + +void add_hold_vars(VarHolder* self) { + hold_vars.push_front(self); + self->iter = hold_vars.begin(); + if (lazy_execution && Op::number_of_lived_ops < 100000) return; + auto v = self->var; + for (int i=0; i<5; i++) { + auto op = v->input(); + if (!op) break; + if (i==0 && op->name() == string("tape")) return; + if (op->type() == OpType::other) break; + if (op->type() == OpType::reduce) break; + if (op->inputs().size() == 0) + break; + if (op->type() == OpType::broadcast) + return; + v = op->inputs().front(); + } + self->sync(true); +} + +VarHolder::VarHolder(Var* v) : var(v) { + // Var holder has both forward and backward liveness + own_holder(); + var->own_both_liveness(); + add_hold_vars(this); +} + +VarHolder::VarHolder(VarPtr&& v) { + var = v.ptr; + v.ptr = nullptr; + own_holder(); + add_hold_vars(this); +} + +VarHolder::VarHolder(VarHolder* v) : var(v->var) { + own_holder(); + iter = v->iter; + *iter = this; + // free memory without calling deconstructor + operator delete(v); +} + +void VarHolder::release_from_holders() { + if (PREDICT_BRANCH_NOT_TAKEN(!var)) return; + if (iter == sync_ptr) + sync_ptr = std::next(sync_ptr); + if (iter != hold_vars.end()) { + hold_vars.erase(iter); + release_holder(); + } + iter = hold_vars.end(); +} + +static auto make_array_from_pyobj = get_op_info("array") + .get_constructor(); +static auto make_unary = get_op_info("unary") + .get_constructor(); + +VarHolder::VarHolder(PyObject* obj, NanoString dtype) { + auto vp = make_array_from_pyobj(obj); + if (dtype != ns_void) + vp = make_unary(vp, dtype); + var = vp.ptr; + vp.ptr = nullptr; + own_holder(); + add_hold_vars(this); +} + + +VarHolder::~VarHolder() { + if (PREDICT_BRANCH_NOT_TAKEN(!var)) return; + if (iter == sync_ptr) + sync_ptr = std::next(sync_ptr); + if (iter != hold_vars.end()) + hold_vars.erase(iter); + release_holder(); + var->release_both_liveness(); +} + +// assign attributes of b to a +static inline void assign_var(Var* a, Var* b) { + a->name = move(b->name); + if (b->is_stop_grad()) + a->set_stop_grad(); + if (b->flags.get(NodeFlags::_stop_fuse)) + a->flags.set(NodeFlags::_stop_fuse); + if (b->flags.get(NodeFlags::_th_require_grad)) + a->flags.set(NodeFlags::_th_require_grad); +} + +extern uint8 th_mode; +void VarHolder::operator=(VarPtr&& v) { + if (th_mode) { + if (var->is_stop_grad() != v->is_stop_grad()) + v.set_stop_grad(var->is_stop_grad()); + if (var->flags.get(NodeFlags::_th_require_grad)) + v.ptr->flags.set(NodeFlags::_th_require_grad); + } + assign_var(v.ptr, var); + release_holder(); + var->release_both_liveness(); + var = v.ptr; + own_holder(); + v.ptr = nullptr; +} + +extern bool no_grad; +void VarHolder::set_requires_grad(bool flag) { + if (flag != get_requires_grad()) { + if (flag) { + start_grad(); + } else + stop_grad(); + } + return; +} + +VarHolder* VarHolder::start_grad() { + if (!var->dtype().is_float()) + LOGw << "cannot enable grad of a non-float value:" << var; + bool no_grad_bk = no_grad; + auto th_mode_bk = th_mode; + no_grad = 0; + th_mode = 0; + auto dvar = jittor::detach(var); + std::swap(dvar.ptr, var); + no_grad = no_grad_bk; + th_mode = th_mode_bk; + var->flags.set(NodeFlags::_th_require_grad); + return this; +} + +string VarHolder::to_string() { + return var->to_string(); +} + +VarHolder* VarHolder::assign(VarHolder* v) { + if (th_mode) { + v->set_requires_grad(get_requires_grad()); + } + assign_var(v->var, var); + release_holder(); + v->var->own_both_liveness(); + var->release_both_liveness(); + var = v->var; + own_holder(); + return this; +} + +VarHolder* VarHolder::update(VarHolder* v) { + v->var->flags.set(NodeFlags::_out_hint); + return assign(v); +} + +VarHolder* VarHolder::_update(VarHolder* v) { + release_holder(); + v->var->own_both_liveness(); + var->release_both_liveness(); + var = v->var; + own_holder(); + var->flags.set(NodeFlags::_out_hint); + return this; +} + +EXTERN_LIB Executor exe; + +VarHolder* VarHolder::sync(bool device_sync, bool weak_sync) { + jittor::sync({this}, device_sync, weak_sync); + return this; +} + +ArrayArgs VarHolder::fetch_sync() { + if (!(var->mem_ptr && !var->allocator->is_cuda())) { + sync(true); + if (save_mem || _HAS_CUDA) + migrate_to_cpu(var, exe.allocator); + } + // this will casuse save wrong. + // if (var->flags.get(NodeFlags::_is_scalar)) + // return {var->mem_ptr, {}, var->dtype()}; + return {var->mem_ptr, var->shape, var->dtype()}; +} + +inline static void cast_item_data(ItemData& data) { + if (data.dtype == ns_float16) { + auto* fp16 = (float16*)&data; + auto* fp32 = (float32*)&data; + fp32[0] = float32(fp16[0]); + } + #ifndef IS_ROCM + else if (data.dtype == ns_bfloat16) { + auto* bf16 = (bfloat16*)&data; + auto* fp32 = (float32*)&data; + fp32[0] = float32(bf16[0]); + } + #endif + data.dtype = ns_float32; +} + +ItemData VarHolder::item() { + CHECK(var->num==1) << "Item var size should be 1, but got" << var->num; + ItemData data; + data.dtype = var->dtype(); + auto dsize = data.dtype.dsize(); + if (!(var->mem_ptr && !var->allocator->is_cuda())) { + sync(); + if (save_mem || _HAS_CUDA) + migrate_to_cpu(var, exe.allocator); + } + #ifdef HAS_CUDA + if (var->allocator->is_cuda()) { + checkCudaErrors(cudaMemcpy(&data.data, var->mem_ptr, dsize, cudaMemcpyDeviceToHost)); + } else + #endif + { + std::memcpy(&data.data, var->mem_ptr, dsize); + } + if (data.dtype == ns_float16 || data.dtype == ns_bfloat16) + cast_item_data(data); + return data; +} + +// from fetch_op.cc +EXTERN_LIB list fetcher; + +void sync_all(bool device_sync) { + vector vars; + vars.reserve(hold_vars.size()); + for (auto v : hold_vars) { + if (!v->var->_outputs.size()) + vars.push_back(v->var); + } + for (auto& v :fetcher) + vars.push_back(v.ptr); + graph_check(); + exe.run_sync(vars, device_sync); //need sync at last + graph_check(); +} + +void sync(const vector& vh, bool device_sync, bool weak_sync) { + vector vars; + vars.reserve(vh.size()); + for (auto v : vh) vars.push_back(v->var); + graph_check(); + exe.run_sync(vars, device_sync, weak_sync); //need sync at last + graph_check(); +} + +vector fetch_sync(const vector& vh) { + vector ret(vh.size()); + sync(vh, true); + for (uint i=0; ivar, exe.allocator); + ret[i].ptr = vh[i]->var->mem_ptr; + ret[i].shape = vh[i]->var->shape; + ret[i].dtype = vh[i]->var->dtype(); + } + return ret; +} + +string VarHolder::debug_msg() { + std::stringstream ss; + ss << var; + return ss.str(); +} + +int VarHolder::grad() { + LOGf << R""(Jittor Var doesn't have this interface, please change +your code as below:: + + model = Model() + optimizer = SGD(model.parameters()) + ... + optimizer.backward(loss) + + for p in model.parameters(): + # prev code: + # grad = p.grad + + # change to: + grad = p.opt_grad(optimizer) +)""; + return 0; +} + + +static auto make_ternary = get_op_info("ternary") + .get_constructor(); + +extern bool no_grad; + +VarHolder* ternary_out_hint(VarHolder* cond, VarHolder* x, VarHolder* y) { + if (!no_grad) + cond->var->flags.set(NodeFlags::_out_hint); + return new VarHolder(make_ternary(cond->var, x->var, y->var)); +} + +void migrate_all_to_cpu() { + sync_all(true); + if (save_mem || _HAS_CUDA) + for (auto vh : hold_vars) { + auto v = vh->var; + // if (v->_outputs.size()) continue; + if (v->allocator && v->mem_ptr && !v->allocator->is_cuda()) + migrate_to_cpu(v, cpu_allocator); + } +} + +static auto make_setitem = get_op_info("setitem") + .get_constructor(); + +inline static bool fast_strcmp(const char* a, const char* b) { + return ((const uint64*)a)[0] == ((const uint64*)b)[0]; +} + +VarHolder* VarHolder::check_cascade_setitem(VarHolder* out) { + // return this; + auto v = var; + int n=0; + int64 slices[10]; + while (n<10) { + Op* iop = v->input(); + if (!iop) break; + if (!fast_strcmp(iop->name(), "getitem")) break; + v = iop->inputs().front(); + GetitemOp* gop = (GetitemOp*)iop; + if (gop->vs.n == 1 && gop->vs.slices[0].is_int()) { + slices[n++] = gop->vs.slices[0].i; + } else break; + if (v->holder) { + // found holder var: v + // v[a][b][c][d] = y + // ^ + auto* prev_op = (SetitemOp*)out->var->input(); + VarSlices& old_slices = prev_op->vs; + Var* y = prev_op->input(1); + VarSlices new_slices(n+old_slices.n); + for (int i=n-1; i>=0; i--) + new_slices.slices[n-1-i].set_int(slices[i]); + for (int i=0; i v[a,b,c,d] = y + (*v->holder) = make_setitem(v, move(new_slices), y, ns_void); + break; + } + } + return assign(out); +} + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/var_holder.h b/python/jittor/src/var_holder.h new file mode 100644 index 00000000..662b8cc4 --- /dev/null +++ b/python/jittor/src/var_holder.h @@ -0,0 +1,433 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include "var.h" +#include "ops/array_op.h" +#include "executor.h" +#include "mem/allocator/cuda_dual_allocator.h" + +namespace jittor { + +struct VarHolder; +VarPtr detach(Var* x); + +struct DataView { + VarHolder* vh; + void* ptr; + NanoVector shape; + NanoString dtype; +}; + +struct ItemData { + int64 data; + NanoString dtype; +}; + +typedef struct _object PyObject; + +EXTERN_LIB list hold_vars; +EXTERN_LIB list::iterator sync_ptr; +extern uint8 th_mode; + +// @pyjt(Var) +// @attrs(heaptype) +struct VarHolder { + Var* var; + list::iterator iter; + VarHolder(Var* v); + VarHolder(VarPtr&& v); + // will move and delete v + VarHolder(VarHolder* v); + // @pyjt(__init__) + VarHolder(PyObject* v, NanoString dtype=ns_void); + // @pyjt(__dealloc__) + ~VarHolder(); + string to_string(); + // @pyjt(sync) + // @attrs(return_self) + VarHolder* sync(bool device_sync = false, bool weak_sync = true); + + /** + * Returns a numpy array copy of the Var. + */ + // @pyjt(fetch_sync,numpy) + ArrayArgs fetch_sync(); + + inline void release_holder() {var->holder = nullptr;} + inline void own_holder() {var->holder = this;} + + /** + * assign the data from another Var. + */ + // @pyjt(assign) + // @attrs(return_self) + VarHolder* assign(VarHolder* v); + + /** + * update parameter and global variable, + * different from assign, it will + * stop grad between origin var and assigned var, and + * will update in the background + */ + // @pyjt(update) + // @attrs(return_self) + VarHolder* update(VarHolder* v); + + /** + * update parameter without set attribute. + */ + // @pyjt(_update) + // @attrs(return_self) + VarHolder* _update(VarHolder* v); + + /** + * swap the data with another Var. + */ + // @pyjt(swap) + // @attrs(return_self) + inline VarHolder* swap(VarHolder* v) { + std::swap(var, v->var); + own_holder(); v->own_holder(); + return this; + }; + + // @pyjt(location) + inline string location() { + if (var->flags.get(NodeFlags::_is_swapped)) + return "disk"; + if (var->mem_ptr == nullptr) + return "none"; + if (var->allocator->is_cuda()) + return "device"; + return "cpu"; + } + + void operator=(VarPtr&& v); + + + /** + * set the name of the Var. + */ + // @pyjt(name) + // @attrs(return_self) + inline VarHolder* name(const char* s) { + var->name = s; + return this; + } + + /** + * return the name of the Var. + */ + // @pyjt(name) + inline const char* name() { + return var->name.c_str(); + } + + /** + * return the number of elements in the Var. + */ + // @pyjt(numel) + inline int64 numel() { + return var->num; + } + + /** + * return the number of bytes of this Var. + */ + // @pyjt(__get__nbytes) + inline int64 nbytes() { + return var->num * var->dsize(); + } + + /** + * return id of this Var. + */ + // @pyjt(__get__id) + inline int64 id() { + return var->id; + } + + // @pyjt(__get__var_ptr) + inline int64 var_ptr() { + return (int64)var; + } + + // @pyjt(__get__flags) + inline int32 flags() { + return (int32)(var->flags.flags); + } + + /** + * disable the gradient calculation for the Var. + */ + // @pyjt(stop_grad) + // @attrs(return_self) + inline VarHolder* stop_grad() { + var->set_stop_grad(); + return this; + } + + /** + * return True if the gradient is stopped. + */ + // @pyjt(is_stop_grad) + inline bool is_stop_grad() { + return var->is_stop_grad(); + } + + /* detach the grad */ + // @pyjt(detach) + inline VarHolder* detach() { + return new VarHolder(jittor::detach(var)); + } + + + /** + * stop operator fusion. + */ + // @pyjt(stop_fuse) + // @attrs(return_self) + inline VarHolder* stop_fuse() { + var->flags.set(NodeFlags::_stop_fuse); + return this; + } + + /** + * return True if operator fusion is stopped. + */ + // @pyjt(is_stop_fuse) + inline bool is_stop_fuse() { + return var->flags.get(NodeFlags::_stop_fuse); + } + + /** + * output hint for training optimization + */ + // @pyjt(out_hint) + // @attrs(return_self) + inline VarHolder* out_hint() { + var->flags.set(NodeFlags::_out_hint); + return this; + } + + /** + * return the shape of the Var. + */ + // @pyjt(__get__shape) + inline NanoVector shape() { + return var->shape; + } + + // @pyjt(release_from_holders) + void release_from_holders(); + + /** + * return True if the Var requires gradient calculation. + * @see is_stop_grad + */ + // @pyjt(__get__requires_grad) + inline bool get_requires_grad() { + return !var->is_stop_grad(); + } + + /** + * enable or disable gradient calculation. + * @see stop_grad + */ + // @pyjt(__set__requires_grad) + void set_requires_grad(bool flag); + + /** + * enable the gradient calculation for the Var. + */ + // @pyjt(start_grad) + // @attrs(return_self) + VarHolder* start_grad(); + + // @pyjt(__get__uncertain_shape) + inline NanoVector uncertain_shape() { + return var->shape; + } + + /** + * return the data type of the Var. + */ + // @pyjt(__get__dtype) + inline NanoString dtype() { + return var->dtype(); + } + + // @pyjt(__get__compile_options) + inline loop_options_t compile_options() { + return var->loop_options; + } + + // @pyjt(__set__compile_options) + inline void set_compile_options(loop_options_t&& options) { + var->loop_options = move(options); + } + + /** + * get a numpy array which shares the data with the Var. + */ + // @pyjt(__get__data) + inline DataView data() { + if (!(var->mem_ptr && !var->allocator->is_cuda())) { + sync(true, false); + #ifdef HAS_CUDA + migrate_to_cpu(var, exe.allocator); + #endif + } + // this will cause state_dict only has one element + // if (var->flags.get(NodeFlags::_is_scalar)) + // return {this, var->mem_ptr, {}, var->dtype()}; + return {this, var->mem_ptr, var->shape, var->dtype()}; + } + + // @pyjt(__get__raw_ptr) + inline uint64 raw_ptr() { + sync(true, false); + #ifdef HAS_CUDA + migrate_to_cpu(var, exe.allocator); + #endif + return (uint64)var->mem_ptr; + } + + /** + * returns the Python number if the Var contains only one element. + * For other cases, see data(). + */ + // @pyjt(item) + ItemData item(); + + /** + * return the number of dimensions. + */ + // @pyjt(__get__ndim, dim) + inline int ndim() { + return var->shape.size(); + } + + // @pyjt(__set__data) + inline void set_data(ArrayArgs&& array) { + sync(true); + CHECK(array.dtype.dsize() == var->dtype().dsize() + && array.dtype.is_int() == var->dtype().is_int()); + int64 size = array.dtype.dsize(); + for (int i=0; isize); + #ifdef HAS_CUDA + migrate_to_cpu(var, exe.allocator); + #endif + std::memcpy(var->mem_ptr, array.ptr, size); + } + + // @pyjt(share_with) + // @attrs(return_self) + inline VarHolder* share_with(VarHolder* other) { + CHECK(!var->allocator) << "This var is already executed or shared."; + var->allocator = (Allocator*)(other->var); + return this; + } + + /** + * print the information of the Var to debug. + */ + // @pyjt(debug_msg) + string debug_msg(); + + /* Jittor Var doesn't have this interface, please change your code as below:: + + model = Model() + optimizer = SGD(model.parameters()) + ... + optimizer.backward(loss) + + for p in model.parameters(): + # prev code: + # grad = p.grad + + # change to: + grad = p.opt_grad(optimizer) + */ + // @pyjt(__get__grad) + int grad(); + + // @pyjt(_input) + inline VarHolder* _input(int i) { + CHECK(!var->is_finished()); + return new VarHolder(var->input()->input(i)); + } + + /* Add dependency, make var computed after vars + */ + // @pyjt(_add_dependency) + // @attrs(return_self) + inline VarHolder* _add_dependency(vector&& vars) { + vector b(vars.size()); + for (int i=0; ivar; + CHECK(!var->is_finished()); + auto a = var->input(); + var->input()->add_inputs(b); + auto edge = a->_inputs.end(); + for (int i=0; iback->index = -1; + } + return this; + } + + /* check a[x][y] = c + */ + // @pyjt(check_cascade_setitem) + // @attrs(return_self) + VarHolder* check_cascade_setitem(VarHolder* out); +}; + +// @pyjt(sync) +void sync(const vector& vh=vector(), bool device_sync=false, bool weak_sync=true); +// @pyjt(fetch_sync) +vector fetch_sync(const vector& vh); + +// @pyjt(sync_all) +void sync_all(bool device_sync=false); + +inline vector convert(const vector& vhs) { + vector v; + v.reserve(vhs.size()); + for (uint i=0; ivar); + return v; +} + +inline vector make_vh_vector(vector&& vps) { + vector a; + a.reserve(vps.size()); + for (auto& vp : vps) + // a.emplace_back(move(vp)); + a.emplace_back(new VarHolder(move(vp))); + return a; +} + +// @pyjt(ternary_out_hint) +VarHolder* ternary_out_hint(VarHolder* cond, VarHolder* x, VarHolder* y); + +// @pyjt(migrate_all_to_cpu) +void migrate_all_to_cpu(); + +// @pyjt(wrap_var_addr) +inline VarHolder* wrap_var_addr(int64 addr) { + return new VarHolder((Var*)addr); +} + +// @pyjt(reuse_np_array) +VarHolder* reuse_np_array(PyObject* obj); + +} // jittor \ No newline at end of file diff --git a/python/jittor/src/var_slices.cc b/python/jittor/src/var_slices.cc new file mode 100644 index 00000000..d94c0468 --- /dev/null +++ b/python/jittor/src/var_slices.cc @@ -0,0 +1,37 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include "var_slices.h" +#include "var.h" + +namespace jittor { + +std::ostream& operator<<(std::ostream& os, const VarSlices& vs) { + os << '['; + for (int i=0; idtype() << s.var->shape; + if (s.is_ellipsis()) return os << "..."; + if (s.is_slice()) return os << s.slice; + if (s.is_int()) return os << s.i; + if (s.is_str()) return os << (const char*)&s; + return os << "-"; +} + +std::ostream& operator<<(std::ostream& os, const Slice& s) { + if (!(s.mask & 1)) os << s.start; + os << ':'; + if (!(s.mask & 2)) os << s.stop; + os << ':'; + if (!(s.mask & 4)) os << s.step; + return os; +} + +} // jittor diff --git a/python/jittor/src/var_slices.h b/python/jittor/src/var_slices.h new file mode 100644 index 00000000..09cbfc56 --- /dev/null +++ b/python/jittor/src/var_slices.h @@ -0,0 +1,85 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "common.h" +#include "misc/nano_vector.h" +#include "var.h" + +namespace jittor { + +struct Slice; + +union VarSlice { + Slice slice; + Var* var; + int64 i; + inline bool is_var() const { return slice.mask == -1; } + inline bool is_ellipsis() const { return slice.mask == -2; } + inline bool is_none() const { return slice.mask == -3; } + inline bool is_int() const { return slice.mask == -4; } + inline bool is_str() const { return slice.mask == -5; } + inline bool is_slice() const { return slice.mask >= 0; } + inline void set_var(Var* v) { slice.mask = -1; var = v; } + inline void set_ellipsis() { slice.mask = -2; } + inline void set_none() { slice.mask = -3; } + inline void set_int(int64 v) { slice.mask = -4; i = v; } + inline void set_str(const string& s) { + slice.mask = -5; + CHECK(s.size() < 16) << "String slice too long" << s; + auto v = (int64*)s.c_str(); + slice.start = v[0]; + slice.stop = v[1]; + slice.step = s.size(); + } + inline char* get_str() {return (char*)this;} +}; + +struct VarSlices { + VarSlice* slices; + int n; + inline VarSlices() : slices(nullptr) {} + inline VarSlices(int n) : slices(new VarSlice[n]), n(n) {} + inline ~VarSlices() {if (slices) delete[] slices;} + inline VarSlices(VarSlices&& other) : slices(other.slices), n(other.n) { + other.slices = nullptr; + } + inline VarSlices(const VarSlices& other, bool negtive_set_none=false) : slices(new VarSlice[other.n]), n(other.n) { + for (int i=0; i. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** + +if __name__ == "__main__": + import unittest, os + unittest.TestLoader.sortTestMethodsUsing = None + + suffix = "__main__.py" + assert __file__.endswith(suffix) + test_dir = __file__[:-len(suffix)] + + skip_l = int(os.environ.get("test_skip_l", "0")) + skip_r = int(os.environ.get("test_skip_r", "1000000")) + skip = os.environ.get("test_skip", "").split(",") + test_only = None + if "test_only" in os.environ: + test_only = set(os.environ.get("test_only").split(",")) + + test_files = os.listdir(test_dir) + test_files = sorted(test_files) + suite = unittest.TestSuite() + test_names = [] + seperate_test = os.environ.get("seperate_test", "1") == "1" + + for _, test_file in enumerate(test_files): + test_name = test_file.split(".")[0] + tests = unittest.defaultTestLoader.loadTestsFromName( + "jittor.test."+test_name) + + if not test_file.startswith("test_"): + continue + if _ < skip_l or _ > skip_r: + continue + if test_only and test_name not in test_only: + continue + for s in skip: + if s in test_name: + continue + + print("Add Test", _, test_name) + if seperate_test: + test_names.append("jittor.test."+test_name) + else: + suite.addTest(tests) + + if seperate_test: + import subprocess as sp + import sys + import time + import jittor_utils + start = time.time() + errors = "" + f = open(jittor_utils.home()+"/.cache/jittor/test.log", "w") + for i,test_name in enumerate(test_names): + progress = f"{i}/{len(test_names)}" + print(f"[RUN TEST {progress}]", test_name) + r = sp.run(" ".join([sys.executable, '-m', test_name, '-v']), stdout=sp.PIPE, stderr=sp.STDOUT, timeout=60*10, shell=True) + out = r.stdout.decode('utf8') + sys.stdout.write(out) + f.write(out) + msg = f"[RUN TEST {progress} OK]" + if r.returncode: + msg = f"[RUN TEST {progress} FAILED]" + msg = msg + f" {test_name} {time.time()-start:.1f}\n" + if r.returncode: + errors += msg + sys.stdout.write(msg) + f.write(msg) + sys.stdout.write(errors) + f.write(errors) + f.close() + + result = unittest.TextTestRunner(verbosity=3).run(suite) + if len(result.errors) or len(result.failures): + exit(1) diff --git a/python/jittor/test/misc/superglue.py b/python/jittor/test/misc/superglue.py new file mode 100644 index 00000000..44af14fd --- /dev/null +++ b/python/jittor/test/misc/superglue.py @@ -0,0 +1,374 @@ +from copy import deepcopy +from pathlib import Path +import jittor as jt +import jittor.nn as nn +import numpy as np +import os + +split_size = 1000000 + +conv_opt = int(os.environ.get("conv_opt", "0")) + +if conv_opt: + Conv1d_sp = nn.Conv1d_sp +else: + Conv1d_sp = nn.Conv1d + + +def MLP(channels: list, do_bn=True): + """ Multi-layer perceptron """ + n = len(channels) + layers = [] + for i in range(1, n): + layers.append(Conv1d_sp(channels[i - 1], channels[i], kernel_size=1, bias=True)) + if i < (n - 1): + if do_bn: + layers.append(nn.BatchNorm(channels[i])) + # layers.append(nn.InstanceNorm1d(channels[i])) + # layers.append(nn.LayerNorm(channels[i])) + layers.append(nn.ReLU()) + return nn.Sequential(*layers) + + +def normalize_keypoints(kpts, image_shape): + size = image_shape.flip(1) # shape=(b,2) ;h w -> w, h + center = size / 2 + scaling = size.float32().max(1, keepdims=True) * 0.7 + return (kpts - center[:, None, :]) / scaling[:, None, :] + + +class KeypointEncoder(nn.Module): + """ Joint encoding of visual appearance and location using MLPs""" + def __init__(self, feature_dim, layers, keypoint_position_dim=2): + super().__init__() + # self.keypoint_position_dim = keypoint_position_dim + self.encoder = MLP([keypoint_position_dim + 1] + layers + [feature_dim]) + nn.init.constant_(self.encoder[-1].bias, 0.0) + + def execute(self, kpts, scores): + inputs = jt.concat([kpts.t(), scores.unsqueeze(1)], dim=1) + return self.encoder(inputs) + +cnt = 0 + +def attention(query, key, value): + global cnt + cnt += 1 + b, d, h, n = query.shape + # print("attention", b,d,h,n, cnt) + dim_factor = (1.0 / d)**0.5 + query = query.transpose(0, 2, 3, 1).reshape(b * h, -1, d) * dim_factor + key = key.transpose(0, 2, 1, 3).reshape(b * h, d, -1) + value = value.transpose(0, 2, 3, 1).reshape(b * h, -1, d) + # print("attention", query.shape, key.shape, value.shape) + + data = [] + for i in range(0, query.shape[0], split_size): + end = min(i + split_size, query.shape[0]) + tmp1 = nn.bmm(query[i:end], key[i:end]) + tmp2 = nn.softmax(tmp1, dim=-1) + tmp3 = nn.bmm(tmp2, value[i:end]) + tmp3.sync() + data.append(tmp3) + tmp3 = jt.concat(data) + + # for i in range(0, query.shape[0], split_size): + # end = min(i + split_size, query.shape[0]) + # tmp1 = nn.bmm(query[:,i:end], key[:,i:end]) + # tmp2 = nn.softmax(tmp1, dim=-1) + # tmp3 = nn.bmm(tmp2, value[:,i:end]) + # tmp3.sync() + # data.append(tmp3) + # tmp3 = jt.concat(data, dim=1) + + # tmp1 = nn.bmm(query, key) + # print(tmp1.shape) + # tmp2 = nn.softmax(tmp1, dim=-1) + # print(tmp2.shape) + # tmp3 = nn.bmm(tmp2, value) + # print(tmp3.shape) + return tmp3.reshape(b, h, -1, d).transpose(0, 3, 1, 2) + return nn.bmm(nn.softmax(nn.bmm(query, key), dim=-1), value).reshape(b, h, -1, d).transpose(0, 3, 1, 2) + + +class MultiHeadedAttention(nn.Module): + """ Multi-head attention to increase model expressivitiy """ + def __init__(self, num_heads: int, d_model: int): + super().__init__() + assert d_model % num_heads == 0 + self.dim = d_model // num_heads + self.num_heads = num_heads + self.merge = Conv1d_sp(d_model, d_model, kernel_size=1) + self.proj = nn.ModuleList([deepcopy(self.merge) for _ in range(3)]) + + def execute(self, query, key, value): + batch_dim = query.size(0) + query, key, value = [l(x).reshape(batch_dim, self.dim, self.num_heads, -1) for l, x in zip(self.proj, (query, key, value))] + x = attention(query, key, value) + # x = attention_chunk(query, key, value) + return self.merge(x.reshape(batch_dim, self.dim * self.num_heads, -1)) + + +class AttentionalPropagation(nn.Module): + def __init__(self, feature_dim: int, num_heads: int): + super().__init__() + self.attn = MultiHeadedAttention(num_heads, feature_dim) + self.mlp = MLP([feature_dim * 2, feature_dim * 2, feature_dim]) + nn.init.constant_(self.mlp[-1].bias, 0.0) + + def execute(self, x, source): + message = self.attn(x, source, source) + return self.mlp(jt.concat([x, message], dim=1)) + + +class AttentionalGNN(nn.Module): + def __init__(self, feature_dim: int, layer_names: list): + super().__init__() + self.layers = nn.ModuleList([AttentionalPropagation(feature_dim, 4) for _ in range(len(layer_names))]) + self.is_cross = [x == 'cross' for x in layer_names] + + def execute(self, desc0, desc1): + for layer, is_cross in zip(self.layers, self.is_cross): + layer.attn.prob = [] + if is_cross: + src0, src1 = desc1, desc0 + else: # if name == 'self': + src0, src1 = desc0, desc1 + # delta0, delta1 = layer(desc0, src0), layer(desc1, src1) + + delta0 = layer(desc0, src0) + # print(delta0.numel()*4) + # breakpoint() + jt.sync_all() + delta1 = layer(desc1, src1) + jt.sync_all() + desc0, desc1 = (desc0 + delta0), (desc1 + delta1) + jt.sync_all() + return desc0, desc1 + + +def log_sinkhorn_iterations(Z, log_mu, log_nu, iters: int): + """ Perform Sinkhorn Normalization in Log-space for stability""" + u, v = jt.zeros_like(log_mu), jt.zeros_like(log_nu) + for _ in range(iters): + u = log_mu - (Z + v.unsqueeze(1)).exp().sum(dim=2).log() + v = log_nu - (Z + u.unsqueeze(2)).exp().sum(dim=1).log() + return Z + u.unsqueeze(2) + v.unsqueeze(1) + + +def log_optimal_transport(scores, alpha, iters: int): + """ Perform Differentiable Optimal Transport in Log-space for stability""" + b, m, n = scores.shape + ms, ns = jt.float(m, requires_grad=False), jt.float(n, requires_grad=False) + + bins0 = alpha.broadcast([b, m, 1]) + bins1 = alpha.broadcast([b, 1, n]) + alpha = alpha.broadcast([b, 1, 1]) + + couplings = jt.concat([jt.concat([scores, bins0], -1), jt.concat([bins1, alpha], -1)], 1) + + norm = -(ms + ns).log() + log_mu = jt.concat([norm.broadcast([m]), ns.log() + norm]) + log_nu = jt.concat([norm.broadcast([n]), ms.log() + norm]) + log_mu, log_nu = log_mu[None].broadcast([b, m + 1]), log_nu[None].broadcast([b, n + 1]) + + Z = log_sinkhorn_iterations(couplings, log_mu, log_nu, iters) + Z = Z - norm # multiply probabilities by M+N + return Z + + +def arange_like(x, dim: int): + return jt.ones(x.shape[dim], dtype=x.dtype)[None].cumsum()[0] - 1 # traceable in 1.1 + + +default_config = { + 'descriptor_dim': 256, # SuperPoint + 'weights': 'indoor', + 'keypoint_encoder': [32, 64, 128, 256], # SuperPoint + 'GNN_layers': ['self', 'cross'] * 9, + 'sinkhorn_iterations': 100, + 'match_threshold': 0.2, +} + + +def get_weighted_loss_batch(scores, all_matches): + matches0, matches1 = all_matches.chunk(chunks=2, dim=2) + batchIdx = jt.arange(all_matches.shape[0]).unsqueeze(1).repeat(1, all_matches.shape[1]) + batchIdx, matches0, matches1 = batchIdx.view(-1), matches0.view(-1), matches1.view(-1) + valid_index0, valid_index1 = matches0 >= 0, matches1 >= 0 + valid_match = jt.logical_and(valid_index0, valid_index1) + valid_unmatch = jt.logical_xor(valid_index0, valid_index1) + num_match = valid_match.sum().maximum(1e-9) + num_unmatch = valid_unmatch.sum().maximum(1e-9) + + + + score_ = scores[batchIdx, matches0, matches1] + score_match_ = (score_*valid_match).float32().sum() / num_match + score_umatch_ = (score_*valid_unmatch).float32().sum() / num_unmatch + return -(num_unmatch * score_match_ + num_match * score_umatch_) / (num_match + num_unmatch) + # print(score_umatch_, score_match_) + # return -(score_match + score_umatch) / (num_match + num_unmatch) + + score_match = scores[(batchIdx[valid_match], matches0[valid_match], matches1[valid_match])].float32().mean() if num_match > 0 else 0 + score_umatch = scores[(batchIdx[valid_unmatch], matches0[valid_unmatch], matches1[valid_unmatch])].float32().mean() if num_unmatch > 0 else 0 + # print(score_match, score_umatch) + return -(num_unmatch * score_match + num_match * score_umatch) / (num_match + num_unmatch) + + +def add_dustbin(scores, alpha): + b, m, n = scores.shape + bins0 = jt.broadcast(alpha, (b, m, 1)) + bins1 = jt.broadcast(alpha, (b, 1, n)) + alpha = jt.broadcast(alpha, (b, 1, 1)) + couplings = jt.concat([jt.concat([scores, bins0], -1), jt.concat([bins1, alpha], -1)], 1) + return couplings + + +class SuperGlue(nn.Module): + def __init__(self, config): + super().__init__() + config = {**default_config, **config} + self.descriptor_dim = config['descriptor_dim'] + self.keypoint_encoder = config['keypoint_encoder'] + self.GNN_layers = config['GNN_layers'] + self.sinkhorn_iterations = config['sinkhorn_iterations'] + self.match_threshold = config['match_threshold'] + self.keypoint_position_dim = config['keypoint_position_dim'] + self.use_dual_softmax = config['use_dual_softmax'] + self.scale = jt.float(self.descriptor_dim**-0.5).stop_grad() + # self.scale.requires_grad = False + + # self.des_extend = MLP([128, 256]) + + self.kenc = KeypointEncoder(self.descriptor_dim, self.keypoint_encoder, keypoint_position_dim=self.keypoint_position_dim) + + self.gnn = AttentionalGNN(self.descriptor_dim, self.GNN_layers) + + self.final_proj = Conv1d_sp(self.descriptor_dim, self.descriptor_dim, kernel_size=1, bias=True) + + self.bin_score = jt.float(1.0) + + def execute(self, data): + """Run SuperGlue on a pair of keypoints and descriptors""" + + kpts0, kpts1 = data['keypoints0'], data['keypoints1'] + desc0, desc1 = data['descriptors0'], data['descriptors1'] + all_matches = data['all_matches'] + # match_num = data['match_num'] + + if kpts0.shape[1] == 0 or kpts1.shape[1] == 0 or all_matches.shape[1] == 0: # no keypoints or no matches/unmatches + shape0, shape1 = kpts0.shape[:-1], kpts1.shape[:-1] + return { + 'matches0': jt.ones(shape0, dtype=jt.int), + 'matches1': jt.ones(shape1, dtype=jt.int), + 'matching_scores0': jt.zeros(shape0, dtype=jt.float), + 'matching_scores1': jt.zeros(shape1, dtype=jt.float), + 'skip_train': True + } + + # Keypoint normalization. + kpts0 = normalize_keypoints(kpts0, data['shape0']) + kpts1 = normalize_keypoints(kpts1, data['shape1']) + + # Keypoint MLP encoder. + # desc0 = self.des_extend(desc0) + self.kenc(kpts0, data['scores0']) + # desc1 = self.des_extend(desc1) + self.kenc(kpts1, data['scores1']) + desc0 = desc0 + self.kenc(kpts0, data['scores0']) + desc1 = desc1 + self.kenc(kpts1, data['scores1']) + + # Multi-layer Transformer network. + desc0, desc1 = self.gnn(desc0, desc1) + + # Final MLP projection. + desc0, desc1 = self.final_proj(desc0), self.final_proj(desc1) + desc0_t = desc0.t() + losses = [] + + for i in range(0, desc1.shape[0], split_size): + end = min(desc1.shape[0], i + split_size) + + # Compute matching descriptor distance. + scores = nn.bmm(desc0_t[i:end], desc1[i:end]) * self.scale # 457.76 MB + scores.sync() + + # Run the optimal transport. + if self.use_dual_softmax: + scores = add_dustbin(scores, self.bin_score) # 458.68 MB + scores.sync() + dual_softmax0, dual_softmax1 = nn.log_softmax(scores, 1), nn.log_softmax(scores, 2) + scores = dual_softmax0 + dual_softmax1 # 458.22 MB + scores.sync() + else: + scores = log_optimal_transport(scores, self.bin_score, iters=self.config['sinkhorn_iterations']) + + # loss = torch.stack([get_match_score(scores[b], all_matches[b]) for b in range(all_matches.shape[0])]) + + loss = get_weighted_loss_batch(scores, all_matches[i:end]) + loss.sync() + losses.append(loss) + loss = jt.concat(losses) + ''' + # Compute matching descriptor distance. + scores = nn.bmm(desc0.t(), desc1) * self.scale # 457.76 MB + scores.sync() + + # Run the optimal transport. + if self.use_dual_softmax: + scores = add_dustbin(scores, self.bin_score) # 458.68 MB + scores.sync() + dual_softmax0, dual_softmax1 = nn.log_softmax(scores, 1), nn.log_softmax(scores, 2) + scores = dual_softmax0 + dual_softmax1 # 458.22 MB + scores.sync() + else: + scores = log_optimal_transport(scores, self.bin_score, iters=self.config['sinkhorn_iterations']) + + # loss = torch.stack([get_match_score(scores[b], all_matches[b]) for b in range(all_matches.shape[0])]) + + loss = get_weighted_loss_batch(scores, all_matches) + # print(scores.shape, all_matches.shape, loss.shape) + ''' + + # matches0, matches1 = all_matches.chunk(chunks=2, dim=2) + # batchIdx = jt.arange(0, b).unsqueeze(1).repeat(1, num) + # batchIdx, matches0, matches1 = batchIdx.view(-1), matches0.view(-1), matches1.view(-1) + # validmatch = (matches0 >= 0) | (matches1 >= 0) + # batchIdx, matches0, matches1 = batchIdx[validmatch], matches0[validmatch], matches1[validmatch] + # matches0[matches0 == -1] = n + # matches1[matches1 == -1] = m + # loss_mean = -scores[(batchIdx, matches0, matches1)].mean() + # loss_mean = nn.l1_loss(loss_mean, jt.float(0.0)) + + if not data['return_match']: + return {'loss': loss} + + with jt.no_grad(): + b, n, m = scores.shape + # Get the matches with score above "match_threshold". + indices0, max0 = scores[:, :-1, :-1].argmax(2) + indices1, max1 = scores[:, :-1, :-1].argmax(1) + mutual0 = jt.arange(0, n)[None] == indices1.gather(1, indices0) + mutual1 = jt.arange(0, m)[None] == indices0.gather(1, indices1) + # zero = scores.new_tensor(0) + # mscores0 = torch.where(mutual0, max0.values.exp(), zero) + mscores0 = max0.exp() + mscores0[mutual0.logical_not()] = 0 + # mscores1 = torch.where(mutual1, mscores0.gather(1, indices1), zero) + mscores1 = mscores0.gather(1, indices1) + mscores1[mutual1.logical_not()] = 0 + valid0 = mutual0 & (mscores0 > self.match_threshold) + valid1 = mutual1 & valid0.gather(1, indices1) + # indices0 = torch.where(valid0, indices0, indices0.new_tensor(-1)) + # indices1 = torch.where(valid1, indices1, indices1.new_tensor(-1)) + indices0[valid0.logical_not()] = -1 + indices1[valid1.logical_not()] = -1 + + return { + 'matches0': indices0, # use -1 for invalid match + 'matches1': indices1, # use -1 for invalid match + 'matching_scores0': mscores0, + 'matching_scores1': mscores1, + 'loss': loss, + } + + # scores big value or small value means confidence? log can't take neg value \ No newline at end of file diff --git a/python/jittor/test/perf/perf.py b/python/jittor/test/perf/perf.py new file mode 100644 index 00000000..93e77ef1 --- /dev/null +++ b/python/jittor/test/perf/perf.py @@ -0,0 +1,225 @@ +import sys, os + +suffix = "" + +import jittor as jt +import time +import jittor_utils as jit_utils +home_path = jit_utils.home() +perf_path = os.path.join(home_path, ".cache", "jittor_perf") + +def main(): + os.makedirs(perf_path+"/src/jittor", exist_ok=True) + os.makedirs(perf_path+"/src/jittor_utils", exist_ok=True) + os.system(f"cp -rL {jt.flags.jittor_path} {perf_path+'/src/'}") + os.system(f"cp -rL {jt.flags.jittor_path}/../jittor_utils {perf_path+'/src/'}") + use_torch_1_4 = os.environ.get("use_torch_1_4", "0") == "1" + dockerfile_src = r""" +FROM nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04 + +RUN echo \ +"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security main restricted universe multiverse" > /etc/apt/sources.list + +# RUN rm -rf /var/lib/apt/lists/* +RUN apt update || true + +RUN apt install wget \ + python3.7 python3.7-dev \ + g++ build-essential -y + +WORKDIR /usr/src + +RUN apt download python3-distutils && dpkg-deb -x ./python3-distutils* / \ + && wget -O - https://bootstrap.pypa.io/get-pip.py | python3.7 + +# change tsinghua mirror +RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple + +RUN pip3 install + numpy \ + tqdm \ + pillow \ + astunparse + +RUN pip3 install torch torchvision +""" + global suffix + if use_torch_1_4: + suffix = "_1_4" + dockerfile_src = dockerfile_src.replace("torch ", "torch==1.4.0 ") + dockerfile_src = dockerfile_src.replace("torchvision", "torchvision==0.5.0") + with open("/tmp/perf_dockerfile", 'w') as f: + f.write(dockerfile_src) + assert os.system("sudo nvidia-smi -lgc 1500") == 0 + + # if the docker image is not built + if os.system(f"sudo docker image inspect jittor/jittor-perf{suffix}"): + assert os.system(f"sudo docker build --tag jittor/jittor-perf{suffix} -f /tmp/perf_dockerfile .") == 0 + + # run once for compile source + jt_fps = test_main("jittor", "resnet50", 1) + + logs = "" + # resnext50_32x4d with bs=8 cannot pass this test + #### inference test + for model_name in ["resnet50", "wide_resnet50_2", # "resnext50_32x4d", + "resnet152", "wide_resnet101_2", "resnext101_32x8d", + "alexnet", "vgg11", "squeezenet1_1", "mobilenet_v2", + "densenet121", "densenet169", "densenet201", + "res2net50", "res2net101"]: + for bs in [1, 2, 4, 8, 16, 32, 64, 128]: + jt_fps = test_main("jittor", model_name, bs) + logs += f"jittor-{model_name}-{bs} {jt_fps}\n" + tc_fps = test_main("torch", model_name, bs) + logs += f"torch-{model_name}-{bs} {tc_fps}\n" + logs += f"compare-{model_name}-{bs} {jt_fps/tc_fps}\n" + print(logs) + #### train test + for model_name in ["train_resnet50", "train_resnet101" + ]: + for bs in [1, 2, 4, 8, 16, 32, 64, 128]: + jt_fps = test_main("jittor", model_name, bs) + logs += f"jittor-{model_name}-{bs} {jt_fps}\n" + tc_fps = test_main("torch", model_name, bs) + logs += f"torch-{model_name}-{bs} {tc_fps}\n" + logs += f"compare-{model_name}-{bs} {jt_fps/tc_fps}\n" + print(logs) + with open(f"{perf_path}/jittor-perf{suffix}-latest.txt", "w") as f: + f.write(logs) + from datetime import datetime + with open(f"{perf_path}/jittor-perf{suffix}-{datetime.now()}.txt", "w") as f: + f.write(logs) + +def test_main(name, model_name, bs): + cmd = f"sudo docker run --gpus all --rm -v {perf_path}:/root/.cache/jittor --network host jittor/jittor-perf{suffix} bash -c 'PYTHONPATH=/root/.cache/jittor/src python3.7 /root/.cache/jittor/src/jittor/test/perf/perf.py {name} {model_name} {bs}'" + fps = -1 + try: + print("run cmd:", cmd) + if os.system(cmd) == 0: + with open(f"{perf_path}/{name}-{model_name}-{bs}.txt", 'r') as f: + fps = float(f.read().split()[3]) + except: + pass + return fps + +def time_iter(duration=2, min_iter=5): + start = time.time() + for i in range(10000000): + yield i + end = time.time() + if end-start>duration and i>=min_iter: + return + +def test(name, model_name, bs): + print("hello", name, model_name, bs) + import numpy as np + import time + is_train = False + _model_name = model_name + if model_name.startswith("train_"): + is_train = True + model_name = model_name[6:] + if name == "torch": + import torch + import torchvision.models as tcmodels + from torch import optim + from torch import nn + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = True + model = tcmodels.__dict__[model_name]() + model = model.cuda() + else: + import jittor as jt + from jittor import optim + from jittor import nn + jt.flags.use_cuda = 1 + jt.cudnn.set_algorithm_cache_size(10000) + import jittor.models as jtmodels + model = jtmodels.__dict__[model_name]() + if (model == "resnet152" or model == "resnet101") and bs == 128 and is_train: + jt.cudnn.set_max_workspace_ratio(0.05) + if is_train: + model.train() + else: + model.eval() + img_size = 224 + if model_name == "inception_v3": + img_size = 300 + test_img = np.random.random((bs, 3, img_size, img_size)).astype("float32") + if is_train: + label = (np.random.random((bs,)) * 1000).astype("int32") + if name == "torch": + test_img = torch.Tensor(test_img).cuda() + if is_train: + label = torch.LongTensor(label).cuda() + opt = optim.SGD(model.parameters(), 0.001) + sync = lambda: torch.cuda.synchronize() + jt = torch + else: + test_img = jt.array(test_img).stop_grad() + if is_train: + label = jt.array(label).stop_grad() + opt = optim.SGD(model.parameters(), 0.001) + sync = lambda: jt.sync_all(True) + + sync() + use_profiler = os.environ.get("use_profiler", "0") == "1" + if hasattr(jt, "nograd"): + ng = jt.no_grad() + ng.__enter__() + def iter(): + x = model(test_img) + if isinstance(x, tuple): + x = x[0] + if is_train: + loss = nn.CrossEntropyLoss()(x, label) + if name == "jittor": + opt.step(loss) + else: + opt.zero_grad() + loss.backward() + opt.step() + else: + if name == "jittor": + x.sync() + sync() + for i in time_iter(): + iter() + sync() + for i in time_iter(): + iter() + sync() + if use_profiler: + if name == "torch": + prof = torch.autograd.profiler.profile(use_cuda=True) + else: + prof = jt.profile_scope() + prof.__enter__() + if name == "jittor": + if hasattr(jt.flags, "use_parallel_op_compiler"): + jt.flags.use_parallel_op_compiler = 0 + start = time.time() + for i in time_iter(10): + iter() + sync() + end = time.time() + if use_profiler: + prof.__exit__(None,None,None) + if name == "torch": + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=30)) + total_iter = i+1 + print("duration:", end-start, "FPS:", total_iter*bs/(end-start)) + fpath = f"{home_path}/.cache/jittor/{name}-{_model_name}-{bs}.txt" + with open(fpath, 'w') as f: + f.write(f"duration: {end-start} FPS: {total_iter*bs/(end-start)}") + os.chmod(fpath, 0x666) + +if len(sys.argv) <= 1: + main() +else: + name, model, bs = sys.argv[1:] + bs = int(bs) + test(name, model, bs) \ No newline at end of file diff --git a/python/jittor/test/system/test_all.sh b/python/jittor/test/system/test_all.sh new file mode 100644 index 00000000..b4a2ddec --- /dev/null +++ b/python/jittor/test/system/test_all.sh @@ -0,0 +1,6 @@ +bash python/jittor/test/system/test_cuda10.0_ubuntu16.04.sh +bash python/jittor/test/system/test_cuda10.0_ubuntu18.04.sh +bash python/jittor/test/system/test_cuda11.1_ubuntu16.04.sh +bash python/jittor/test/system/test_cuda11.1_ubuntu18.04.sh +bash python/jittor/test/system/test_cuda11.1_ubuntu20.04.sh +bash python/jittor/test/system/test_nocuda_ubuntu18.04.sh diff --git a/python/jittor/test/system/test_cuda10.0_ubuntu16.04.sh b/python/jittor/test/system/test_cuda10.0_ubuntu16.04.sh new file mode 100644 index 00000000..bb1de201 --- /dev/null +++ b/python/jittor/test/system/test_cuda10.0_ubuntu16.04.sh @@ -0,0 +1,41 @@ +cat > /tmp/cuda10.0-ubuntu16.04.dockerfile <<\EOF +FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04 + +RUN apt update && apt install ca-certificates -y + +RUN echo \ +"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security main restricted universe multiverse" > /etc/apt/sources.list + +# RUN rm -rf /var/lib/apt/lists/* +RUN apt update || true + +RUN apt install wget \ + python3.7 python3.7-dev \ + g++ build-essential -y + +WORKDIR /usr/src + +RUN apt download python3-distutils && dpkg-deb -x ./python3-distutils* / \ + && wget -O - https://bootstrap.pypa.io/get-pip.py | python3.7 + +# change tsinghua mirror +RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple + +RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example +RUN pip3 uninstall jittor -y + +COPY . jittor +RUN python3.7 -m pip install ./jittor +RUN python3.7 -m jittor.test.test_core +EOF + +sudo docker build --tag jittor/jittor-cuda:10.0-16.04 -f /tmp/cuda10.0-ubuntu16.04.dockerfile . +sudo docker run --gpus all --rm jittor/jittor-cuda:10.0-18.04 bash -c \ +"python3.7 -m jittor.test.test_example && \ +python3.7 -m jittor.test.test_resnet && \ +python3.7 -m jittor.test.test_parallel_pass && \ +python3.7 -m jittor.test.test_atomic_tuner && \ +python3.7 -m jittor.test.test_where_op" \ No newline at end of file diff --git a/python/jittor/test/system/test_cuda10.0_ubuntu18.04.sh b/python/jittor/test/system/test_cuda10.0_ubuntu18.04.sh new file mode 100644 index 00000000..cde38606 --- /dev/null +++ b/python/jittor/test/system/test_cuda10.0_ubuntu18.04.sh @@ -0,0 +1,41 @@ +cat > /tmp/cuda10.0-ubuntu18.04.dockerfile <<\EOF +FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04 + +RUN apt update && apt install ca-certificates -y + +RUN echo \ +"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security main restricted universe multiverse" > /etc/apt/sources.list + +# RUN rm -rf /var/lib/apt/lists/* +RUN apt update || true + +RUN apt install wget \ + python3.7 python3.7-dev \ + g++ build-essential -y + +WORKDIR /usr/src + +RUN apt download python3-distutils && dpkg-deb -x ./python3-distutils* / \ + && wget -O - https://bootstrap.pypa.io/get-pip.py | python3.7 + +# change tsinghua mirror +RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple + +RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example +RUN pip3 uninstall jittor -y + +COPY . jittor +RUN python3.7 -m pip install ./jittor +RUN python3.7 -m jittor.test.test_core +EOF + +sudo docker build --tag jittor/jittor-cuda:10.0-18.04 -f /tmp/cuda10.0-ubuntu18.04.dockerfile . +sudo docker run --gpus all --rm jittor/jittor-cuda:10.0-18.04 bash -c \ +"python3.7 -m jittor.test.test_example && \ +python3.7 -m jittor.test.test_resnet && \ +python3.7 -m jittor.test.test_parallel_pass && \ +python3.7 -m jittor.test.test_atomic_tuner && \ +python3.7 -m jittor.test.test_where_op" \ No newline at end of file diff --git a/python/jittor/test/system/test_cuda11.1_ubuntu16.04.sh b/python/jittor/test/system/test_cuda11.1_ubuntu16.04.sh new file mode 100644 index 00000000..0cdbf2f6 --- /dev/null +++ b/python/jittor/test/system/test_cuda11.1_ubuntu16.04.sh @@ -0,0 +1,41 @@ +cat > /tmp/cuda11.1-ubuntu16.04.dockerfile <<\EOF +FROM nvidia/cuda:11.1-cudnn8-devel-ubuntu16.04 + +RUN apt update && apt install ca-certificates -y + +RUN echo \ +"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security main restricted universe multiverse" > /etc/apt/sources.list + +# RUN rm -rf /var/lib/apt/lists/* +RUN apt update || true + +RUN apt install wget \ + python3.7 python3.7-dev \ + g++ build-essential -y + +WORKDIR /usr/src + +RUN apt download python3-distutils && dpkg-deb -x ./python3-distutils* / \ + && wget -O - https://bootstrap.pypa.io/get-pip.py | python3.7 + +# change tsinghua mirror +RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple + +RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example +RUN pip3 uninstall jittor -y + +COPY . jittor +RUN python3.7 -m pip install ./jittor +RUN python3.7 -m jittor.test.test_core +EOF + +sudo docker build --tag jittor/jittor-cuda:11.1-16.04 -f /tmp/cuda11.1-ubuntu16.04.dockerfile . +sudo docker run --gpus all --rm jittor/jittor-cuda:11.1-16.04 bash -c \ +"python3.7 -m jittor.test.test_example && \ +python3.7 -m jittor.test.test_resnet && \ +python3.7 -m jittor.test.test_parallel_pass && \ +python3.7 -m jittor.test.test_atomic_tuner && \ +python3.7 -m jittor.test.test_where_op" \ No newline at end of file diff --git a/python/jittor/test/system/test_cuda11.1_ubuntu18.04.sh b/python/jittor/test/system/test_cuda11.1_ubuntu18.04.sh new file mode 100644 index 00000000..6c8c409f --- /dev/null +++ b/python/jittor/test/system/test_cuda11.1_ubuntu18.04.sh @@ -0,0 +1,41 @@ +cat > /tmp/cuda11.1-ubuntu18.04.dockerfile <<\EOF +FROM nvidia/cuda:11.1-cudnn8-devel-ubuntu18.04 + +RUN apt update && apt install ca-certificates -y + +RUN echo \ +"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security main restricted universe multiverse" > /etc/apt/sources.list + +# RUN rm -rf /var/lib/apt/lists/* +RUN apt update || true + +RUN apt install wget \ + python3.7 python3.7-dev \ + g++ build-essential -y + +WORKDIR /usr/src + +RUN apt download python3-distutils && dpkg-deb -x ./python3-distutils* / \ + && wget -O - https://bootstrap.pypa.io/get-pip.py | python3.7 + +# change tsinghua mirror +RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple + +RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example +RUN pip3 uninstall jittor -y + +COPY . jittor +RUN python3.7 -m pip install ./jittor +RUN python3.7 -m jittor.test.test_core +EOF + +sudo docker build --tag jittor/jittor-cuda:11.1-18.04 -f /tmp/cuda11.1-ubuntu18.04.dockerfile . +sudo docker run --gpus all --rm jittor/jittor-cuda:11.1-18.04 bash -c \ +"python3.7 -m jittor.test.test_example && \ +python3.7 -m jittor.test.test_resnet && \ +python3.7 -m jittor.test.test_parallel_pass && \ +python3.7 -m jittor.test.test_atomic_tuner && \ +python3.7 -m jittor.test.test_where_op" \ No newline at end of file diff --git a/python/jittor/test/system/test_cuda11.1_ubuntu20.04.sh b/python/jittor/test/system/test_cuda11.1_ubuntu20.04.sh new file mode 100644 index 00000000..6cc05742 --- /dev/null +++ b/python/jittor/test/system/test_cuda11.1_ubuntu20.04.sh @@ -0,0 +1,39 @@ +cat > /tmp/cuda11.1-ubuntu20.04.dockerfile <<\EOF +FROM nvidia/cuda:11.1-devel-ubuntu20.04 + +RUN apt update && apt install ca-certificates -y + +RUN echo \ +"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-updates main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-backports main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ focal-security main restricted universe multiverse" > /etc/apt/sources.list + +# RUN rm -rf /var/lib/apt/lists/* +RUN apt update || true +RUN apt install g++ build-essential libomp-dev python3-dev python3-pip wget -y +RUN python3 -m pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple +WORKDIR /usr/src/ + +RUN wget https://developer.download.nvidia.cn/compute/cuda/repos/ubuntu2004/x86_64/libcudnn8_8.0.5.39-1+cuda11.1_amd64.deb && \ + wget https://developer.download.nvidia.cn/compute/cuda/repos/ubuntu2004/x86_64/libcudnn8-dev_8.0.5.39-1+cuda11.1_amd64.deb && \ + dpkg -i ./libcudnn8_8.0.5.39-1+cuda11.1_amd64.deb ./libcudnn8-dev_8.0.5.39-1+cuda11.1_amd64.deb && \ + rm *.deb +RUN ls + + +RUN pip3 install jittor --timeout 100 && python3 -m jittor.test.test_example +RUN pip3 uninstall jittor -y + +COPY . jittor +RUN python3 -m pip install ./jittor +RUN python3 -m jittor.test.test_core +EOF + +sudo docker build --tag jittor/jittor-cuda:11.1-20.04 -f /tmp/cuda11.1-ubuntu20.04.dockerfile . +sudo docker run --gpus all --rm jittor/jittor-cuda:11.1-20.04 bash -c \ +"python3 -m jittor.test.test_example && \ +python3 -m jittor.test.test_resnet && \ +python3 -m jittor.test.test_parallel_pass && \ +python3 -m jittor.test.test_atomic_tuner && \ +python3 -m jittor.test.test_where_op" \ No newline at end of file diff --git a/python/jittor/test/system/test_nocuda_ubuntu18.04.sh b/python/jittor/test/system/test_nocuda_ubuntu18.04.sh new file mode 100644 index 00000000..02b83332 --- /dev/null +++ b/python/jittor/test/system/test_nocuda_ubuntu18.04.sh @@ -0,0 +1,40 @@ +cat > /tmp/ubuntu18.04.dockerfile <<\EOF +FROM ubuntu:18.04 + +RUN apt update && apt install ca-certificates -y + +RUN echo \ +"deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-updates main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-backports main restricted universe multiverse\n\ +deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security main restricted universe multiverse" > /etc/apt/sources.list + +# RUN rm -rf /var/lib/apt/lists/* +RUN apt update + +RUN apt install wget \ + python3.7 python3.7-dev \ + g++ build-essential -y + +WORKDIR /usr/src + +RUN apt download python3-distutils && dpkg-deb -x ./python3-distutils* / \ + && wget -O - https://bootstrap.pypa.io/get-pip.py | python3.7 + +# change tsinghua mirror +RUN pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple + +RUN pip3 install jittor --timeout 100 && python3.7 -m jittor.test.test_example +RUN pip3 uninstall jittor -y + +COPY . jittor +RUN python3.7 -m pip install ./jittor +RUN python3.7 -m jittor.test.test_core +EOF + +sudo docker build --tag jittor/jittor:18.04 -f /tmp/ubuntu18.04.dockerfile . +sudo docker run --gpus all --rm jittor/jittor:18.04 bash -c \ +"python3.7 -m jittor.test.test_example && \ +python3.7 -m jittor.test.test_parallel_pass && \ +python3.7 -m jittor.test.test_atomic_tuner && \ +python3.7 -m jittor.test.test_where_op" \ No newline at end of file diff --git a/python/jittor/test/test.h b/python/jittor/test/test.h new file mode 100644 index 00000000..a03dfb26 --- /dev/null +++ b/python/jittor/test/test.h @@ -0,0 +1,23 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: Dun Liang . +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include + +using namespace std; + +void test_main(); + +void expect_error(function func); + +int main() { + try { + test_main(); + } catch (const std::exception& e) { + std::cout << e.what() << std::endl; + return 1; + } +} \ No newline at end of file diff --git a/python/jittor/test/test_acl.py b/python/jittor/test/test_acl.py new file mode 100644 index 00000000..6551ddba --- /dev/null +++ b/python/jittor/test/test_acl.py @@ -0,0 +1,225 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +from .test_core import expect_error +import numpy as np +from jittor import init, Module +import numpy as np + + +@unittest.skipIf(not jt.compiler.has_acl, "No ACL found") +class TestACL(unittest.TestCase): + + @jt.flag_scope(use_acl=1) + def test_array(self): + print("use_acl", jt.flags.use_acl) + a = jt.array([1,2,3]) + np.testing.assert_allclose(a.numpy(), [1,2,3]) + print('test_array pass') + + @jt.flag_scope(use_acl=1) + def test_add(self): + a = jt.array([1,2,3]) + b = a+a + np.testing.assert_allclose(b.numpy(), [2,4,6]) + print('test_add pass') + + @jt.flag_scope(use_acl=1) + def test_add_float(self): + a = jt.array([1.0,2.0,3.0]) + b = a+a + np.testing.assert_allclose(b.numpy(), [2,4,6]) + print('test_add_float pass') + + @jt.flag_scope(use_acl=1) + def test_array_cast(self): + # this test cannot pass because cast error + x = np.random.rand(10) + y = jt.float32(x) + np.testing.assert_allclose(x, y.numpy()) + print('test_array_cast pass') + + @jt.flag_scope(use_acl=1) + def test_array_cast_half(self): + # this test cannot pass because cast error + x = np.random.rand(10).astype("float32") + y = jt.float16(x) + np.testing.assert_allclose(x.astype("float16"), y.numpy()) + print('test_array_cast_half pass') + + @jt.flag_scope(use_acl=1) + def test_rand(self): + a = jt.rand(10) + b = a*10 + b.sync() + print(b) + + def test_meminfo(self): + jt.display_memory_info() + print('test_meminfo pass') + + @jt.flag_scope(use_acl=1) + def test_conv(self): + x = jt.rand(10, 3, 50, 50) + w = jt.rand(4,3,3,3) + # x = jt.rand(2, 2, 1, 1) + # w = jt.rand(2,2,1,1) + y = jt.nn.conv2d(x, w) + y.sync(True) + y1 = y.data + mask = jt.rand_like(y) + dx, dw = jt.grad((y*mask).sum(), [x, w]) + dx1, dw1 = dx.data, dw.data + # dw, = jt.grad((y*mask).sum(), [w]) + # dw1 = dw.data + with jt.flag_scope(use_acl=0): + y = jt.nn.conv2d(x, w) + y2 = y.data + dx, dw = jt.grad((y*mask).sum(), [x, w]) + dx2, dw2 = dx.data, dw.data + # dw, = jt.grad((y*mask).sum(), [w]) + # dw2 = dw.data + np.testing.assert_allclose(y1, y2) + np.testing.assert_allclose(dx1, dx2) + np.testing.assert_allclose(dw1, dw2) + print('test_conv pass') + + @jt.flag_scope(use_acl=1) + def test_matmul(self): + # x = jt.rand(10, 3, 50, 50) + # w = jt.rand(4,3,3,3) + x = jt.rand(10,10) + w = jt.rand(10,10) + y = jt.matmul(x, w) + ny = np.matmul(x.numpy(), w.numpy()) + np.testing.assert_allclose(y.numpy(), ny, atol=1e-3, rtol=1e-3) + print('test_matmul pass') + + @jt.flag_scope(use_acl=1) + def test_max(self): + x = jt.rand(3,3) + y = x.max(1).data + ny = x.data.max(1) + np.testing.assert_allclose(y, ny) + print('test_max pass') + + @jt.flag_scope(use_acl=1) + def test_sum(self): + x = jt.rand(3,3).float16() + print(x) + # return + y = x.sum(1).data + print(y) + print(x) + ny = x.data.sum(1) + np.testing.assert_allclose(y, ny) + print('test_sum pass') + + @jt.flag_scope(use_acl=1) + def test_broadcast(self): + x = jt.rand(3) + # print(x) + y = x.broadcast([3,3]).data + ny = np.broadcast_arrays(x.data, y)[0] + np.testing.assert_allclose(y, ny) + print(x, y) + # y = x.broadcast([3,3], dims=[1]).data + y = jt.broadcast(x, shape=(3,3), dims=[1]).data + with jt.flag_scope(use_acl=0): + ny = jt.broadcast(x, shape=(3,3), dims=[1]).data + # ny = np.broadcast_arrays(x.data, y)[0] + np.testing.assert_allclose(y, ny) + print(x, y) + print('test_broadcast pass') + + @jt.flag_scope(use_acl=1) + def test_resnet(self): + from jittor.models import resnet50 + net = resnet50() + x = jt.rand(2,3,224,224) + y = net(x) + y.sync() + + + +def matmul(a, b): + (n, m), k = a.shape, b.shape[-1] + a = a.broadcast([n,m,k], dims=[2]) + b = b.broadcast([n,m,k], dims=[0]) + return (a*b).sum(dim=1) + +class Linear(Module): + def __init__(self, in_features, out_features, bias=True): + self.w = (jt.random((in_features, out_features))-0.5) / in_features**0.5 + self.b = jt.random((out_features,))-0.5 if bias else None + def execute(self, x): + x = matmul(x, self.w) + if self.b is not None: + return x+self.b + return x + +def relu(x): + return jt.maximum(x, 0.0) +Relu = jt.make_module(relu) + +class Model(Module): + def __init__(self, input_size): + self.linear1 = Linear(input_size, 10) + self.relu1 = Relu() + self.linear2 = Linear(10, 1) + def execute(self, x): + x = self.linear1(x) + x = self.relu1(x) + return self.linear2(x) + +@unittest.skipIf(not jt.compiler.has_acl, "No ACL found") +class TestExample(unittest.TestCase): + @jt.flag_scope(use_acl=1) + def test1(self): + np.random.seed(0) + jt.set_seed(3) + n = 1000 + batch_size = 50 + lr = 0.05 + + def get_data(n): + for i in range(n): + x = np.random.rand(batch_size, 1).astype("float32") + y = x*x + yield jt.float32(x), jt.float32(y) + + model = Model(input_size=1) + ps = model.parameters() + + for i,(x,y) in enumerate(get_data(n)): + jt.sync_all(True) + pred_y = model(x).name("pred_y") + loss = ((pred_y - y).sqr()).name("loss") + loss_mean = loss.mean() + + gs = jt.grad(loss_mean, ps) + for p, g in zip(ps, gs): + p -= g * lr + + if i>2: + assert prev == jt.liveness_info(), f"memory leak {prev} {jt.liveness_info()}" + prev = jt.liveness_info() + print(f"step {i}, loss = {loss_mean.data.sum()} {jt.liveness_info()}") + + possible_results = [ + 0.0009948202641680837, + 0.001381353591568768, + 0.00110957445576787, + ] + loss_mean = loss_mean.data + assert any(abs(loss_mean - r) < 1e-6 for r in possible_results) + + jt.clean() + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_aclop.py b/python/jittor/test/test_aclop.py new file mode 100644 index 00000000..196d2576 --- /dev/null +++ b/python/jittor/test/test_aclop.py @@ -0,0 +1,314 @@ +import unittest +import jittor as jt +from .test_core import expect_error +import numpy as np +from jittor import init, Module +import numpy as np + + +@unittest.skipIf(not jt.compiler.has_acl, "No ACL found") +class TestACL(unittest.TestCase): + + @jt.flag_scope(use_acl=1) + def test_getitem(self): + a = jt.ones(100, 2) + b = a[0:2, 0:2] + np.testing.assert_allclose(b.numpy(), [[1, 1], [1, 1]]) + print("test getitem success") + + @jt.flag_scope(use_acl=1) + def test_setitem(self): + a = jt.ones(2, 2) + b = jt.Var(0) + a[0:1, 0:1] = b + np.testing.assert_allclose(a.numpy(), [[0, 1], [1, 1]]) + print("test setitem success") + + @jt.flag_scope(use_acl=1) + def test_getitem_grad(self): + a = jt.ones(2, 2) + b = a[0:1, 0:1] + optimizer = jt.optim.SGD([a], 0.1) + loss = b.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res = a.opt_grad(optimizer) + np.testing.assert_allclose(res.numpy(), [[1, 0], [0, 0]]) + print("test getitem grad success") + + @jt.flag_scope(use_acl=1) + def test_setitem_grad(self): + a = jt.ones(3, 3) + b = jt.ones(2, 2) + a[0:2, 0:2] = b * 2 + optimizer = jt.optim.SGD([a, b], 0.1) + loss = a.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res_a = a.opt_grad(optimizer) + res_b = b.opt_grad(optimizer) + np.testing.assert_allclose(res_a.numpy(), + [[1, 1, 1], [1, 1, 1], [1, 1, 1]]) + np.testing.assert_allclose(res_b.numpy(), [[2, 2], [2, 2]]) + print("test setitem grad success") + + @jt.flag_scope(use_acl=1) + def test_concat(self): + a = jt.ones(2, 2) + b = jt.ones(2, 2) + c = jt.concat([a, b], 0) + np.testing.assert_allclose(c.numpy(), [[1, 1], [1, 1], [1, 1], [1, 1]]) + print("test concat success") + + @jt.flag_scope(use_acl=1) + def test_maxpool_grad(self): + a = jt.ones(1, 1, 4, 4) + max_pool = jt.nn.Pool(2, op='maximum') + optimizer = jt.optim.SGD([a], 0.1) + b = max_pool(a) + loss = b.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res = a.opt_grad(optimizer) + np.testing.assert_allclose( + res.numpy(), + [[[[1, 0, 1, 0], [0, 0, 0, 0], [1, 0, 1, 0], [0, 0, 0, 0]]]]) + print("test maxpool grad success") + + @jt.flag_scope(use_acl=1) + def test_triu(self): + a = jt.ones(3, 3) + b = jt.triu_(a, 0) + c = jt.triu_(a, 1) + np.testing.assert_allclose(b.numpy(), + [[1, 1, 1], [0, 1, 1], [0, 0, 1]]) + np.testing.assert_allclose(c.numpy(), + [[0, 1, 1], [0, 0, 1], [0, 0, 0]]) + print("test triu success") + + @jt.flag_scope(use_acl=1) + def test_bmm(self): + a = jt.ones(3, 2, 2).float32() + b = jt.bmm(a, a) + np.testing.assert_allclose( + b.numpy(), [[[2, 2], [2, 2]], [[2, 2], [2, 2]], [[2, 2], [2, 2]]]) + print("test bmm success") + + @jt.flag_scope(use_acl=1) + def test_matmul(self): + a = jt.ones(1, 4, 4) + b = jt.ones(4, 2) + c = jt.matmul(a, b) + np.testing.assert_allclose(c.numpy(), + [[[4, 4], [4, 4], [4, 4], [4, 4]]]) + print("test matmul success") + + @jt.flag_scope(use_acl=1) + def test_maxpool(self): + a = jt.ones(1, 1, 4, 4) + max_pool = jt.nn.Pool(2, op='maximum') + np.testing.assert_allclose(max_pool(a).numpy(), [[[[1, 1], [1, 1]]]]) + print("test maxpool success") + + @jt.flag_scope(use_acl=1) + def test_transpose(self): + a = jt.ones(1, 2, 2) + b = a.transpose(0, 2) + np.testing.assert_allclose(b.numpy(), [[[1], [1]], [[1], [1]]]) + print("test transpose success") + + @jt.flag_scope(use_acl=1) + def test_matmul_grad(self): + a = jt.ones(1, 2, 2) + b = jt.ones(2, 2) + optimizer = jt.optim.SGD([a, b], 0.1) + loss = jt.matmul(a, b).sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res_a = a.opt_grad(optimizer) + res_b = b.opt_grad(optimizer) + np.testing.assert_allclose(res_a.numpy(), [[[2, 2], [2, 2]]]) + np.testing.assert_allclose(res_b.numpy(), [[2, 2], [2, 2]]) + print("test matmul grad success") + + @jt.flag_scope(use_acl=1) + def test_bmm_grad(self): + a = jt.ones(3, 2, 2).float32() + optimizer = jt.optim.SGD([a], 0.1) + c = jt.bmm(a, a) + loss = c.sum() + + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + + res = a.opt_grad(optimizer) + np.testing.assert_allclose( + res.numpy(), + [[[4, 4], [4, 4]], [[4, 4], [4, 4]], [[4, 4], [4, 4]]]) + print("test bmm grad success") + + @jt.flag_scope(use_acl=1) + def test_avgpool(self): + a = jt.ones(1, 1, 4, 4) + avg_pool = jt.nn.Pool(2, op='mean') + b = avg_pool(a) + np.testing.assert_allclose(b.numpy(), [[[[1, 1], [1, 1]]]]) + print("test avgpool success") + + @jt.flag_scope(use_acl=1) + def test_adaptive_maxpool2d(self): + a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], + [13, 14, 15, 16]]]]) + pool = jt.nn.AdaptiveMaxPool2d((2, 2)) + b = pool(a) + np.testing.assert_allclose(b.numpy(), [[[[6, 8], [14, 16]]]]) + print("test adaptive_maxpool2d success") + + @jt.flag_scope(use_acl=1) + def test_adaptive_maxpool2d_grad(self): + a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], + [13, 14, 15, 16]]]]) + max_pool = jt.nn.AdaptiveMaxPool2d((2, 2)) + optimizer = jt.optim.SGD([a], 0.1) + b = max_pool(a) + loss = b.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res = a.opt_grad(optimizer) + np.testing.assert_allclose( + res.numpy(), + [[[[0, 0, 0, 0], [0, 1, 0, 1], [0, 0, 0, 0], [0, 1, 0, 1]]]]) + print("test adaptive_maxpool2d grad success") + + @jt.flag_scope(use_acl=1) + def test_adaptive_avgpool2d(self): + a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], + [13, 14, 15, 16]]]]) + pool = jt.nn.AdaptiveAvgPool2d((2, 2)) + b = pool(a) + np.testing.assert_allclose(b.numpy(), [[[[3.5, 5.5], [11.5, 13.5]]]]) + print("test adaptive_avgpool2d success") + + @jt.flag_scope(use_acl=1) + def test_adaptive_avgpool2d_grad(self): + a = jt.float32([[[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], + [13, 14, 15, 16]]]]) + avg_pool = jt.nn.AdaptiveAvgPool2d((2, 2)) + optimizer = jt.optim.SGD([a], 0.1) + b = avg_pool(a) + loss = b.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res = a.opt_grad(optimizer) + np.testing.assert_allclose( + res.numpy(), + [[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25], + [0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]]]]) + print("test adaptive_avgpool2d grad success") + + @jt.flag_scope(use_acl=1) + def test_index(self): + a = jt.ones(2, 3) + [s1, s2] = jt.index(a.shape) + np.testing.assert_allclose(s1.numpy(), [[0, 0, 0], [1, 1, 1]]) + np.testing.assert_allclose(s2.numpy(), [[0, 1, 2], [0, 1, 2]]) + print("test index success") + + @jt.flag_scope(use_acl=1) + def test_gather(self): + a = jt.array([[1, 2], [3, 4]]) + b = jt.gather(a, 1, jt.array([[0, 0], [1, 0]])) + np.testing.assert_allclose(b.numpy(), [[1, 1], [4, 3]]) + print("test gather success") + + @jt.flag_scope(use_acl=1) + def test_gather_grad(self): + a = jt.float32([[1, 2], [3, 4]]) + optimizer = jt.optim.SGD([a], 0.1) + b = jt.gather(a, 1, jt.array([[0, 0], [1, 0]])) + loss = b.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res = a.opt_grad(optimizer) + np.testing.assert_allclose(res.numpy(), [[2, 0], [1, 1]]) + print("test gather grad success") + + @jt.flag_scope(use_acl=1) + def test_scatter(self): + a = jt.array([[1, 2], [3, 4]]) + b = jt.array([[0, 0], [0, 0]]) + b = jt.scatter(b, 1, jt.array([[0, 0], [1, 0]]), a, reduce="add") + np.testing.assert_allclose(b.numpy(), [[3, 0], [4, 3]]) + print("test scatter success") + + @jt.flag_scope(use_acl=1) + def test_scatter_grad(self): + a = jt.float32([[1, 2], [3, 4]]) + b = jt.float32([[0, 0], [0, 0]]) + optimizer = jt.optim.SGD([a, b], 0.1) + c = jt.scatter(b, 1, jt.array([[0, 0], [1, 0]]), a, reduce="add") + loss = c.max() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res_a = a.opt_grad(optimizer) + res_b = b.opt_grad(optimizer) + np.testing.assert_allclose(res_a.numpy(), [[0, 0], [0, 1]]) + np.testing.assert_allclose(res_b.numpy(), [[0, 0], [1, 0]]) + print("test scatter grad success") + + @jt.flag_scope(use_acl=1) + def test_where(self): + a = jt.array([[1, 2], [3, 4]]) + b = jt.ones(2, 2) + c = jt.where(a > 2, a, b) + np.testing.assert_allclose(c.numpy(), [[1, 1], [3, 4]]) + print("test where success") + + @jt.flag_scope(use_acl=1) + def test_where_grad(self): + a = jt.float32([[1, 2], [3, 4]]) + b = jt.array([[2., 2.], [2., 2.]]) + c = jt.where(a > 2, a, b) + optimizer = jt.optim.SGD([a, b], 0.1) + loss = c.sum() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res_a = a.opt_grad(optimizer) + res_b = b.opt_grad(optimizer) + np.testing.assert_allclose(res_a.numpy(), [[0, 0], [1, 1]]) + np.testing.assert_allclose(res_b.numpy(), [[1, 1], [0, 0]]) + print("test where grad success") + + @jt.flag_scope(use_acl=1) + def test_flip(self): + a = jt.array([[1., 2.], [3., 4.]]) + b = a.flip((0, 1)) + np.testing.assert_allclose(b.numpy(), [[4, 3], [2, 1]]) + print("test flip success") + + @jt.flag_scope(use_acl=1) + def test_flip_grad(self): + a = jt.float32([[1, 2], [3, 4]]) + optimizer = jt.optim.SGD([a], 0.1) + b = a.flip((0, 1)) + loss = b.max() + optimizer.zero_grad() + optimizer.backward(loss) + optimizer.step() + res = a.opt_grad(optimizer) + np.testing.assert_allclose(res.numpy(), [[0, 0], [0, 1]]) + print("test flip grad success") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_adamw.py b/python/jittor/test/test_adamw.py new file mode 100644 index 00000000..78269acf --- /dev/null +++ b/python/jittor/test/test_adamw.py @@ -0,0 +1,89 @@ + +import jittor as jt +import random +import numpy as np +import unittest + + +class TestAdamw(unittest.TestCase): + def test(self): + import torch + + LR = 0.01 + BATCH_SIZE = 32 + EPOCH = 12 + WD = 0.1 + N = 1024 + + # data + x = [] + y = [] + for i in range(N): + x.append(-1 + i * 2 / N) + random.shuffle(x) + x = np.array(x) + y = x * x + np.random.randn(N) * 0.1 + + class NetTorch(torch.nn.Module): + def __init__(self): + super(NetTorch, self).__init__() + self.hidden = torch.nn.Linear(1, 20) # hidden layer + self.predict = torch.nn.Linear(20, 1) # output layer + + def forward(self, x): + x = torch.nn.functional.relu(self.hidden(x)) # activation function for hidden layer + x = self.predict(x) # linear output + return x + + class NetJittor(jt.Module): + def __init__(self): + super(NetJittor, self).__init__() + self.hidden = jt.nn.Linear(1, 20) # hidden layer + self.predict = jt.nn.Linear(20, 1) # output layer + + def execute(self, x): + x = jt.nn.relu(self.hidden(x)) # activation function for hidden layer + x = self.predict(x) # linear output + return x + + net_torch = NetTorch() + optim_torch = torch.optim.AdamW(net_torch.parameters(), lr=LR, betas=(0.9, 0.99), weight_decay = WD) + Loss_torch = torch.nn.MSELoss() + + net_jittor = NetJittor() + net_jittor.hidden.weight = jt.array(net_torch.hidden.weight.detach().numpy()) + net_jittor.hidden.bias = jt.array(net_torch.hidden.bias.detach().numpy()) + net_jittor.predict.weight = jt.array(net_torch.predict.weight.detach().numpy()) + net_jittor.predict.bias = jt.array(net_torch.predict.bias.detach().numpy()) + optim_jittor = jt.optim.AdamW(net_jittor.parameters(), lr=LR, betas=(0.9, 0.99), weight_decay = WD) + Loss_jittor = jt.nn.MSELoss() + + for epoch in range(EPOCH): + # print('Epoch: ', epoch) + + for i in range(N // BATCH_SIZE): + bx = x[i * BATCH_SIZE : (i + 1) * BATCH_SIZE, np.newaxis] + by = y[i * BATCH_SIZE : (i + 1) * BATCH_SIZE, np.newaxis] + + bx_torch = torch.Tensor(bx) + by_torch = torch.Tensor(by) + output_torch = net_torch(bx_torch) + loss_torch = Loss_torch(output_torch, by_torch) + optim_torch.zero_grad() + loss_torch.backward() + optim_torch.step() + + bx_jittor = jt.array(bx) + by_jittor = jt.array(by) + output_jittor = net_jittor(bx_jittor) + loss_jittor = Loss_jittor(output_jittor, by_jittor) + optim_jittor.step(loss_jittor) + + lt = float(loss_torch.detach().numpy()) + lj = float(loss_jittor.data) + # print(abs(lt - lj)) + assert abs(lt - lj) < 1e-5 + +if __name__ == "__main__": + unittest.main() + \ No newline at end of file diff --git a/python/jittor/test/test_affine_grid.py b/python/jittor/test/test_affine_grid.py new file mode 100644 index 00000000..291651f6 --- /dev/null +++ b/python/jittor/test/test_affine_grid.py @@ -0,0 +1,74 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from jittor.nn import affine_grid,grid_sample + + +skip_this_test = False + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch.nn.functional as F + import torch +except: + skip_this_test = True + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestAffineGrid(unittest.TestCase): + def test_affine_grid_2d(self): + N = 8 + C = 3 + H = 256 + W = 128 + theta = np.random.randn(N,2,3).astype(np.float32) + features = np.random.randint(256,size=(N,C,H,W)).astype(np.float32) + + torch_theta = torch.Tensor(theta) + torch_features = torch.Tensor(features) + torch_grid = F.affine_grid(torch_theta,size=(N,C,H,W),align_corners=False) + torch_sample = F.grid_sample(torch_features,torch_grid,mode='bilinear',padding_mode='zeros',align_corners=False) + + jt_theta = jt.array(theta) + jt_features = jt.array(features) + jt_grid = affine_grid(jt_theta,size=(N,C,H,W),align_corners=False) + jt_sample = grid_sample(jt_features,jt_grid,mode='bilinear',padding_mode='zeros',align_corners=False) + + assert np.allclose(jt_theta.numpy(),torch_theta.numpy()) + assert np.allclose(jt_features.numpy(),torch_features.numpy()) + assert np.allclose(jt_grid.numpy(),torch_grid.numpy(),atol=1e-05) + assert np.allclose(torch_sample.numpy(),jt_sample.numpy(),atol=1e-01) + + + def test_affine_grid_3d(self): + N = 8 + C = 3 + D = 64 + H = 256 + W = 128 + theta = np.random.randn(N,3,4).astype(np.float32) + features = np.random.randint(256,size=(N,C,D,H,W)).astype(np.float32) + + torch_theta = torch.Tensor(theta) + torch_features = torch.Tensor(features) + torch_grid = F.affine_grid(torch_theta,size=(N,C,D,H,W),align_corners=False) + torch_sample = F.grid_sample(torch_features,torch_grid,mode='bilinear',padding_mode='zeros',align_corners=False) + + jt_theta = jt.array(theta) + jt_features = jt.array(features) + jt_grid = affine_grid(jt_theta,size=(N,C,D,H,W),align_corners=False) + jt_sample = grid_sample(jt_features,jt_grid,mode='bilinear',padding_mode='zeros',align_corners=False) + + assert np.allclose(jt_theta.numpy(),torch_theta.numpy()) + assert np.allclose(jt_features.numpy(),torch_features.numpy()) + assert np.allclose(jt_grid.numpy(),torch_grid.numpy(),atol=1e-05) + assert np.allclose(torch_sample.numpy(),jt_sample.numpy(),atol=1e-01) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_allocator.py b/python/jittor/test/test_allocator.py new file mode 100644 index 00000000..9bbc9f4d --- /dev/null +++ b/python/jittor/test/test_allocator.py @@ -0,0 +1,38 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import gc + +class TestAllocator(unittest.TestCase): + def test_stat(self): + jt.clean() + with jt.flag_scope(use_stat_allocator=1, use_sfrl_allocator = 0): + a = jt.random([10,10]) + b = a+a + c = a*b + c.data + del a,b,c + gc.collect() + assert jt.flags.stat_allocator_total_alloc_call == 2 + assert jt.flags.stat_allocator_total_alloc_byte == 800 + assert jt.flags.stat_allocator_total_free_call == 2 + assert jt.flags.stat_allocator_total_free_byte == 800 + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1, use_cuda_managed_allocator=0) + def test_device_allocator(self): + a = jt.array([1,2,3,4,5]) + b = a + 1 + c = jt.code(a.shape, a.dtype, [b], cpu_src=""" + for (int i=0; i +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import gc + +def test(h, w, total_alloc_call, total_alloc_byte, total_free_call = 0, total_free_byte = 0): + jt.clean() + jt.gc() + with jt.flag_scope(use_stat_allocator=1): + a = jt.random([h,w]) + b = a+a + c = a*b + c.data + del a,b,c + gc.collect() + x = ( + jt.flags.stat_allocator_total_alloc_call, + jt.flags.stat_allocator_total_alloc_byte, + jt.flags.stat_allocator_total_free_call, + jt.flags.stat_allocator_total_free_byte + ) + y = (total_alloc_call, total_alloc_byte, total_free_call, total_free_byte) + assert x==y, (x, y) + + +class TestAllocator2(unittest.TestCase): + def test_stat(self): + #small_block + test(10, 10, 1, 1048576) #800 + #small_block + test(100, 100, 1, 1048576) #80000 + #large_block + test(1000, 1000, 1, 20971520) #8000000 + #large_block2 + test(8000, 1000, 2, 67108864) #64000000 + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_arg_pool_op.py b/python/jittor/test/test_arg_pool_op.py new file mode 100644 index 00000000..4ccb692f --- /dev/null +++ b/python/jittor/test/test_arg_pool_op.py @@ -0,0 +1,290 @@ +# *************************************************************** +# Copyright (c) 2019 Dun Liang . All Rights Reserved. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +from jittor.nn import Pool, pool, AvgPool2d, avg_pool2d +from jittor.nn import MaxPool2d as j_MaxPool2d +from jittor.nn import max_pool2d as j_max_pool2d +import numpy as np +from .test_core import expect_error +from .test_grad import ngrad +from itertools import permutations +from jittor import compile_extern, Module +from .test_log import find_log_with_re +import random +import pickle as pk + +skip_this_test = False + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + from torch.nn import MaxPool2d, Sequential +except: + skip_this_test = True + +class OldPool(Module): + def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, count_include_pad=True, op="maximum"): + assert dilation == None + assert return_indices == None + self.kernel_size = kernel_size + self.op = op + self.stride = stride if stride else kernel_size + self.padding = padding + self.ceil_mode = ceil_mode + self.count_include_pad = count_include_pad and padding != 0 + + def execute(self, x): + N,C,H,W = x.shape + if self.ceil_mode == False: + h = (H+self.padding*2-self.kernel_size)//self.stride+1 + w = (W+self.padding*2-self.kernel_size)//self.stride+1 + else: + h = (H+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1 + w = (W+self.padding*2-self.kernel_size + self.stride - 1)//self.stride+1 + + # TODO: backward + xx = x.reindex([N,C,h,w,self.kernel_size,self.kernel_size], [ + "i0", # Nid + "i1", # Cid + f"i2*{self.stride}-{self.padding}+i4", # Hid + f"i3*{self.stride}-{self.padding}+i5", # Wid + ]) + return xx.reduce(self.op, [4,5]) + + +def check(jt_model, torch_model, shape, near_data): + if (near_data): + assert shape[0] * shape[1] * shape[2] * shape[3] % 8 == 0 + data = list(range(8)) * int((shape[0] * shape[1] * shape[2] * shape[3]) / 8) + random.shuffle(data) + x = jt.array(data).float32().reshape(shape) + else: + x = jt.random(shape) + y = jt_model(x) + g = jt.grad(y.sum(), x) + + x_ = torch.Tensor(x.data) + x_.requires_grad = True + y_ = torch_model(x_) + y_.sum().backward() + y__ = y_.detach().numpy() + g__ = x_.grad.detach().numpy() + assert np.allclose(y.data, y__) + assert np.allclose(g.data, g__) + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestArgPoolOp(unittest.TestCase): + @unittest.skipIf(not jt.compiler.has_cuda, "No cuda found") + @jt.flag_scope(use_cuda=1) + def test_cuda(self): + jt_model = jt.nn.Sequential(Pool(2, 2, 0), Pool(2, 2, 0), Pool(2, 2, 0, ceil_mode=True), Pool(2, 2, 0), Pool(2, 2, 0), Pool(3, 1, 1)) + torch_model = Sequential(MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0, ceil_mode=True), MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0), MaxPool2d(3, 1, 1)) + shape = [2, 3, 300, 300] + check(jt_model, torch_model, shape, False) + shape = [2, 3, 157, 300] + check(jt_model, torch_model, shape, False) + for i in range(10): + check(jt_model, torch_model, [1,1,300,300], True) + + @unittest.skipIf(not jt.compiler.has_cuda, "No cuda found") + @jt.flag_scope(use_cuda=1) + def test_cuda_tuple(self): + jt_model = jt.nn.Sequential(Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1), ceil_mode=True), Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1)), Pool(3, 1, 1)) + torch_model = Sequential(MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1), ceil_mode=True), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d(3, 1, 1)) + shape = [2, 3, 300, 300] + check(jt_model, torch_model, shape, False) + shape = [2, 3, 157, 300] + check(jt_model, torch_model, shape, False) + for i in range(10): + check(jt_model, torch_model, [1,1,300,300], True) + + @unittest.skipIf(True, "TODO: cannot pass this test, fix me") + @unittest.skipIf(not jt.compiler.has_cuda, "No cuda found") + @jt.flag_scope(use_cuda=1) + def test_cuda_old_pool(self): + from torch.nn import AvgPool2d + jt_model = OldPool(3, 1, 1, op="mean") + torch_model = AvgPool2d(3, 1, 1) + shape = [64, 64, 300, 300] + check(jt_model, torch_model, shape, False) + shape = [32, 128, 157, 300] + check(jt_model, torch_model, shape, False) + for i in range(10): + check(jt_model, torch_model, [1,1,300,300], True) + + def test_cpu_(self): + # x = jt.random([32, 128, 157, 300]) + x = jt.random([4, 128, 157, 300]) + x = jt.nn.pool(x, 2, "maximum", 0, 2) + + def test_cpu(self): + jt_model = jt.nn.Sequential(Pool(2, 2, 0), Pool(2, 2, 0), Pool(2, 2, 0, ceil_mode=True), Pool(2, 2, 0), Pool(2, 2, 0), Pool(3, 1, 1)) + torch_model = Sequential(MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0, ceil_mode=True), MaxPool2d(2, 2, 0), MaxPool2d(2, 2, 0), MaxPool2d(3, 1, 1)) + # shape = [64, 64, 300, 300] + shape = [4, 64, 300, 300] + check(jt_model, torch_model, shape, False) + # shape = [32, 128, 157, 300] + shape = [4, 128, 157, 300] + check(jt_model, torch_model, shape, False) + for i in range(10): + check(jt_model, torch_model, [1,1,300,300], True) + + def test_cpu_tuple(self): + jt_model = jt.nn.Sequential(Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1), ceil_mode=True), Pool((2,3), (2,3), (1,1)), Pool((2,3), (2,3), (1,1)), Pool(3, 1, 1)) + torch_model = Sequential(MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1), ceil_mode=True), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d((2,3), (2,3), (1,1)), MaxPool2d(3, 1, 1)) + shape = [2, 3, 300, 300] + check(jt_model, torch_model, shape, False) + shape = [2, 3, 157, 300] + check(jt_model, torch_model, shape, False) + for i in range(10): + check(jt_model, torch_model, [1,1,300,300], True) + + def test_index_pool(self): + pool = jt.nn.Pool(2, return_indices=True) + a = jt.randn([10,3,100,100]) + b, idx = pool(a) + idx.sync() + + def test_index_pool2(self): + pool = jt.nn.Pool(2, return_indices=True) + a = jt.array([1,0,0,1, + 0,0,0,0, + 0,0,0,0, + 1,0,0,1]).reshape((1,1,4,4)) + b, idx = pool(a) + assert (idx.data.reshape((4,)) == [0,3,12,15]).all() + + def test_unpool(self): + from jittor import nn + pool = nn.MaxPool2d(2, stride=2, return_indices=True) + unpool = nn.MaxUnpool2d(2, stride=2) + input = jt.array([[[[ 1., 2, 3, 4,0], + [ 5, 6, 7, 8,0], + [ 9, 10, 11, 12,0], + [13, 14, 15, 16,0], + [0, 0, 0, 0, 0]]]]) + output, indices = pool(input) + assert (indices == jt.array([[6,8],[16,18]])).all() + out = unpool(output, indices, output_size=input.shape) + assert (out == jt.array([[[[ 0., 0., 0., 0., 0.], + [ 0., 6., 0., 8., 0.], + [ 0., 0., 0., 0., 0.], + [ 0., 14., 0., 16., 0.], + [ 0., 0., 0., 0., 0.]]]])).all() + + def test_unpool_diff_kernel_stride(self): + from jittor import nn + pool = nn.MaxPool2d(3, stride=2, return_indices=True) + unpool = nn.MaxUnpool2d(3, stride=2) + input = jt.array([[[[ 1., 2, 3, 4, 0], + [ 5, 6, 7, 8, 0], + [ 9, 10, 11, 12,0], + [13, 14, 16, 15,0], + [0, 0, 0, 0, 0]]]]) + output, indices = pool(input) + out = unpool(output, indices, output_size=input.shape) + assert (out == jt.array([[[ + [ 0., 0., 0., 0., 0.,], + [ 0., 0., 0., 0., 0.,], + [ 0., 0., 11., 12., 0.,], + [ 0., 0., 32., 0., 0.,], + [ 0., 0., 0., 0., 0.,]]]])).all() + + + + @unittest.skipIf(not jt.compiler.has_cuda, "No cuda found") + @jt.flag_scope(use_cuda=1) + def test_cuda_avg_pool(self): + self.test_cpu_avg_pool() + + def test_cpu_avg_pool(self): + from torch.nn import AvgPool2d + jt_model = Pool(2, 2, 0, op="mean", ceil_mode=True) + torch_model = AvgPool2d(2, 2, 0, ceil_mode=True) + shape = (2, 16, 33, 33) + check(jt_model, torch_model, shape, False) + + def test_cpu_avg_pool2(self): + from torch.nn import AvgPool2d + jt_model = Pool(3, 1, 1, op="mean", ceil_mode=True) + torch_model = AvgPool2d(3, 1, 1, ceil_mode=True) + shape = (2, 16, 33, 33) + check(jt_model, torch_model, shape, False) + + def test_AvgPool2d(self): + from torch.nn import AvgPool2d as t_AvgPool2d + jt_model = AvgPool2d(3, 1, 1, ceil_mode=True) + torch_model = t_AvgPool2d(3, 1, 1, ceil_mode=True) + shape = (2, 16, 33, 33) + check(jt_model, torch_model, shape, False) + + jt_model = AvgPool2d(3, 1, 1, ceil_mode=True, count_include_pad=False) + torch_model = t_AvgPool2d(3, 1, 1, ceil_mode=True, count_include_pad=False) + shape = (2, 16, 100, 100) + check(jt_model, torch_model, shape, False) + print('finish') + + def test_avg_pool2d(self): + from torch.nn.functional import avg_pool2d as t_avg_pool2d + arr = np.random.random((2, 16, 33, 33)) + jt_model = avg_pool2d(jt.array(arr), 3, 1, 1, ceil_mode=True) + torch_model = t_avg_pool2d(torch.Tensor(arr), 3, 1, 1, ceil_mode=True) + assert np.allclose(jt_model.numpy(), torch_model.numpy()) + + jt_model = avg_pool2d(jt.array(arr), 3, 1, 1, ceil_mode=True, count_include_pad=False) + torch_model = t_avg_pool2d(torch.Tensor(arr), 3, 1, 1, ceil_mode=True, count_include_pad=False) + assert np.allclose(jt_model.numpy(), torch_model.numpy()) + print('finish') + + def test_MaxPool2d(self): + from torch.nn import MaxPool2d + jt_model = j_MaxPool2d(3, 1, 1, ceil_mode=True) + torch_model = MaxPool2d(3, 1, 1, ceil_mode=True) + shape = (2, 16, 33, 33) + check(jt_model, torch_model, shape, False) + print('finish') + + def test_max_pool2d(self): + from torch.nn.functional import max_pool2d + arr = np.random.random((2, 16, 33, 33)) + jt_model = j_max_pool2d(jt.array(arr), 3, 1, 1, ceil_mode=True) + torch_model = max_pool2d(torch.Tensor(arr), 3, 1, 1, ceil_mode=True) + assert np.allclose(jt_model.numpy(), torch_model.numpy()) + + jt_model = j_max_pool2d(jt.array(arr), 3, 1, 1) + torch_model = max_pool2d(torch.Tensor(arr), 3, 1, 1) + assert np.allclose(jt_model.numpy(), torch_model.numpy()) + + def test_pool_3d(self): + from torch.nn.functional import max_pool2d + arr = np.random.random((2, 16, 20, 20, 20)).astype("float32") + # arr = np.random.random((1, 1, 1, 5, 5)).astype("float32") + jin = jt.array(arr) + tin = torch.Tensor(arr) + tin.requires_grad = True + jt_model = jt.nn.Pool3d(3,1,1)(jin) + torch_model = torch.nn.MaxPool3d(3,1,1)(tin) + assert np.allclose(jt_model.numpy(), torch_model.detach().numpy()) + + + nout = np.random.random(tuple(jt_model.shape)).astype("float32") + jout = jt_model * nout + tout = torch_model * torch.Tensor(nout) + dj = jt.grad(jout, jin) + + tout.sum().backward() + dt = tin.grad + assert np.allclose(dj.numpy(), dt.numpy()) + + @unittest.skipIf(not jt.compiler.has_cuda, "No cuda found") + @jt.flag_scope(use_cuda=1) + def test_cuda_pool_3d(self): + self.test_pool_3d() + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_arg_reduce_op.py b/python/jittor/test/test_arg_reduce_op.py new file mode 100644 index 00000000..357557f1 --- /dev/null +++ b/python/jittor/test/test_arg_reduce_op.py @@ -0,0 +1,131 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from jittor import compile_extern +from .test_log import find_log_with_re +import copy +if jt.has_cuda: + from jittor.compile_extern import cublas_ops, cudnn_ops, cub_ops +else: + cublas_ops = cudnn_ops = cub_ops = None + +def check_reduce(shape, op, dim, keepdims, is_cuda = False): + with jt.log_capture_scope( + log_silent=1, + log_v=0, log_vprefix="op.cc=100" + ) as raw_log: + x = jt.random(shape) + key, v = jt.arg_reduce(x, op, dim, keepdims) + x_ = x.data + key_ = key.data + v_ = v.data + if (is_cuda): + logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + "cub_arg_reduce" + ".*)") + assert len(logs)==1 + if op == 'max': + key__ = np.argmax(x_, axis=dim) + v__ = np.max(x_, axis=dim) + else: + key__ = np.argmin(x_, axis=dim) + v__ = np.min(x_, axis=dim) + + if keepdims: + key__ = np.expand_dims(key__, axis=dim) + v__ = np.expand_dims(v__, axis=dim) + assert np.allclose(key_, key__) + assert np.allclose(v_, v__) + +def check_backward(shape, op, dim, keepdims): + x = jt.random(shape) + v, key = jt.arg_reduce(x, op, dim, keepdims) + loss = (key * key).sum() + gs = jt.grad(loss, x) / 2 + assert np.allclose((gs * x).data, (gs * gs).data) + +class TestArgReduceOp(unittest.TestCase): + def test_backward(self): + check_backward([5,5,5], 'min', 0, True) + check_backward([5,5,5], 'min', 2, True) + check_backward([5,5,5], 'min', 1, True) + check_backward([5,], 'min', 0, True) + check_backward([20,20,20,20], 'max', 0, True) + check_backward([20,20,20,20], 'max', 2, True) + check_backward([20,20,20,20], 'max', 1, True) + check_backward([20,20,20,20], 'max', 3, True) + check_backward([5,5,5], 'min', 0, False) + check_backward([5,5,5], 'min', 2, False) + check_backward([5,5,5], 'min', 1, False) + check_backward([5,], 'min', 0, False) + check_backward([20,20,20,20], 'max', 0, False) + check_backward([20,20,20,20], 'max', 2, False) + check_backward([20,20,20,20], 'max', 1, False) + check_backward([20,20,20,20], 'max', 3, False) + + @unittest.skipIf(cub_ops==None, "Not use cub, Skip") + @jt.flag_scope(use_cuda=1) + def test_backward_cuda(self): + check_backward([5,5,5], 'min', 0, True) + check_backward([5,5,5], 'min', 2, True) + check_backward([5,5,5], 'min', 1, True) + check_backward([5,], 'min', 0, True) + check_backward([20,20,20,20], 'max', 0, True) + check_backward([20,20,20,20], 'max', 2, True) + check_backward([20,20,20,20], 'max', 1, True) + check_backward([20,20,20,20], 'max', 3, True) + check_backward([5,5,5], 'min', 0, False) + check_backward([5,5,5], 'min', 2, False) + check_backward([5,5,5], 'min', 1, False) + check_backward([5,], 'min', 0, False) + check_backward([20,20,20,20], 'max', 0, False) + check_backward([20,20,20,20], 'max', 2, False) + check_backward([20,20,20,20], 'max', 1, False) + check_backward([20,20,20,20], 'max', 3, False) + + def test(self): + check_reduce([5,5,5], 'min', 0, True) + check_reduce([5,5,5], 'min', 2, True) + check_reduce([5,5,5], 'min', 1, True) + check_reduce([5], 'min', 0, True) + check_reduce([20,20,20,20], 'max', 0, True) + check_reduce([20,20,20,20], 'max', 2, True) + check_reduce([20,20,20,20], 'max', 1, True) + check_reduce([20,20,20,20], 'max', 3, True) + check_reduce([5,5,5], 'min', 0, False) + check_reduce([5,5,5], 'min', 2, False) + check_reduce([5,5,5], 'min', 1, False) + check_reduce([5], 'min', 0, False) + check_reduce([20,20,20,20], 'max', 0, False) + check_reduce([20,20,20,20], 'max', 2, False) + check_reduce([20,20,20,20], 'max', 1, False) + check_reduce([20,20,20,20], 'max', 3, False) + + @unittest.skipIf(cub_ops==None, "Not use cub, Skip") + @jt.flag_scope(use_cuda=1) + def test_cuda(self): + check_reduce([5,5,5], 'min', 0, True, True) + check_reduce([5,5,5], 'min', 2, True, True) + check_reduce([5,5,5], 'min', 1, True, True) + check_reduce([5], 'min', 0, True) + check_reduce([20,20,20,20], 'max', 0, True, True) + check_reduce([20,20,20,20], 'max', 2, True, True) + check_reduce([20,20,20,20], 'max', 1, True, True) + check_reduce([20,20,20,20], 'max', 3, True, True) + check_reduce([5,5], 'min', 0, False, True) + check_reduce([5,5,5], 'min', 2, False, True) + check_reduce([5,5,5], 'min', 1, False, True) + check_reduce([5], 'min', 0, False) + check_reduce([20,20,20,20], 'max', 0, False, True) + check_reduce([20,20,20,20], 'max', 2, False, True) + check_reduce([20,20,20,20], 'max', 1, False, True) + check_reduce([20,20,20,20], 'max', 3, False, True) +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_argsort_op.py b/python/jittor/test/test_argsort_op.py new file mode 100644 index 00000000..34a1f0ec --- /dev/null +++ b/python/jittor/test/test_argsort_op.py @@ -0,0 +1,124 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from jittor import compile_extern +from .test_log import find_log_with_re +if jt.has_cuda: + from jittor.compile_extern import cublas_ops, cudnn_ops, cub_ops +else: + cublas_ops = cudnn_ops = cub_ops = None + +def check_argsort(shape, dim, descending = False): + x = jt.random(shape) + y, y_key = jt.argsort(x, dim=dim, descending=descending) + v = [] + for i in range(len(shape)): + if (i == dim): + v.append(y) + else: + v.append(jt.index(shape, dim=i)) + yk = jt.reindex(x, v) + yk_ = yk.data + y_key_ = y_key.data + x__ = x.data + if descending: + x__ = -x__ + yk__ = np.sort(x__, axis=dim) + if descending: + yk__ = -yk__ + assert np.allclose(y_key_, yk__) + assert np.allclose(yk_, yk__) + +def check_cub_argsort(shape, dim, descending = False): + with jt.log_capture_scope( + log_silent=1, + log_v=0, log_vprefix="op.cc=100" + ) as raw_log: + x = jt.random(shape) + y, y_key = jt.argsort(x, dim=dim, descending=descending) + v = [] + for i in range(len(shape)): + if (i == dim): + v.append(y) + else: + v.append(jt.index(shape, dim=i)) + yk = jt.reindex(x, v) + yk_ = yk.data + y_key_ = y_key.data + logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + "cub_argsort" + ".*)") + assert len(logs)==1 + x__ = x.data + if descending: + x__ = -x__ + yk__ = np.sort(x__, axis=dim) + if descending: + yk__ = -yk__ + assert np.allclose(y_key_, yk__) + assert np.allclose(yk_, yk__) + +def check_backward(shape, dim, descending = False): + x = jt.random(shape) + y, y_key = jt.argsort(x, dim=dim, descending=descending) + loss = (y_key * y_key).sum() + gs = jt.grad(loss, x) + assert np.allclose(x.data*2, gs.data) + +class TestArgsortOp(unittest.TestCase): + def test(self): + check_argsort([5,5], 0, False) + check_argsort([5,5], 0, True) + check_argsort([5,5], 1, False) + check_argsort([5,5], 1, True) + check_argsort([12, 34, 56, 78], 1, True) + check_argsort([12, 34, 56, 78], 3, True) + check_argsort([12, 34, 56, 78], 2, False) + check_argsort([12, 34, 56, 78], 0, False) + + def test_backward(self): + check_backward([5,5], 0, False) + check_backward([5,5], 0, True) + check_backward([5,5], 1, False) + check_backward([5,5], 1, True) + check_backward([12, 34, 56, 78], 1, True) + check_backward([12, 34, 56, 78], 3, True) + check_backward([12, 34, 56, 78], 2, False) + check_backward([12, 34, 56, 78], 0, False) + + def test_doc(self): + assert "Argsort Operator" in jt.argsort.__doc__ + + @unittest.skipIf(cub_ops==None, "Not use cub, Skip") + @jt.flag_scope(use_cuda=1) + def test_cub(self): + check_cub_argsort([5,5], 0, False) + check_cub_argsort([5,5], 0, True) + check_cub_argsort([5,5], 1, False) + check_cub_argsort([5,5], 1, True) + check_cub_argsort([12, 34, 56, 78], 1, True) + check_cub_argsort([12, 34, 56, 78], 3, True) + check_cub_argsort([12, 34, 56, 78], 2, False) + check_cub_argsort([12, 34, 56, 78], 0, False) + + @unittest.skipIf(cub_ops==None, "Not use cub, Skip") + @jt.flag_scope(use_cuda=1) + def test_cub_backward(self): + check_backward([5,5], 0, False) + check_backward([5,5], 0, True) + check_backward([5,5], 1, False) + check_backward([5,5], 1, True) + check_backward([12, 34, 56, 78], 1, True) + check_backward([12, 34, 56, 78], 3, True) + check_backward([12, 34, 56, 78], 2, False) + check_backward([12, 34, 56, 78], 0, False) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_array.py b/python/jittor/test/test_array.py new file mode 100644 index 00000000..788ce026 --- /dev/null +++ b/python/jittor/test/test_array.py @@ -0,0 +1,212 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from jittor import compile_extern +from jittor.test.test_core import expect_error + +class TestArray(unittest.TestCase): + def test_data(self): + a = jt.array([1,2,3]) + assert (a.data == [1,2,3]).all() + d = a.data + a.data[1] = -2 + assert (a.data == [1,-2,3]).all() + assert (a.fetch_sync()==[1,-2,3]).all() + li = jt.liveness_info() + del a + assert li == jt.liveness_info() + del d + assert li != jt.liveness_info() + + def test_set_data(self): + a = jt.array([1,2,3]) + assert (a.fetch_sync()==[1,2,3]).all() + a.data = [4,5,6] + assert (a.fetch_sync()==[4,5,6]).all() + a.data = jt.array([7,8,9]) + assert (a.fetch_sync()==[7,8,9]).all() + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1) + def test_memcopy_overlap(self): + import time + from jittor.models import resnet + + im=np.random.rand(100,3,224,224).astype(np.float32) + net = resnet.Resnet34() + net.eval() + # warm up + x = jt.array(im).stop_grad() + + for i in range(10): + a = net(x) + a.sync() + jt.sync(device_sync=True) + + # pure compute + time_start=time.time() + x = jt.array(im).stop_grad() + for i in range(10): + a = net(x) + a.sync() + jt.sync(device_sync=True) + t1 = time.time() - time_start + + # warm up + for i in range(3): + x = jt.array(im) + b = net(x) + b.fetch(lambda b: None) + b.sync() + jt.sync(device_sync=True) + + # overlap + time_start=time.time() + results = [] + for i in range(10): + x = jt.array(im) + b = net(x) + b.fetch(lambda b: results.append(b)) + b.sync() + # del c + jt.sync(device_sync=True) + t2 = time.time() - time_start + + assert t2-t1 < 0.010, (t2, t1, t2-t1) + assert np.allclose(a.data, b.data) + assert len(results) == 10 + for v in results: + assert np.allclose(a.data, v), (v.shape, a.data.shape) + jt.LOG.v(f"pure compute: {t1}, overlap: {t2}") + + def test_segfault(self): + a = jt.array([1.0,2.0,3.0]) + b = (jt.maximum(a, 0)).sum() * 2.0 + da = jt.grad(b, a) + jt.sync_all() + assert (a.data==[1,2,3]).all() + assert (da.data==[2,2,2]).all() + + def test_segfault2(self): + assert (jt.array([1,2,3]).reshape((1,3)).data==[1,2,3]).all() + if jt.has_cuda: + with jt.flag_scope(use_cuda=1): + assert (jt.array([1,2,3]).reshape((1,3)).data==[1,2,3]).all() + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + def test_array_dual(self): + with jt.flag_scope(use_cuda=1): + a = jt.array(np.float32([1,2,3])) + assert (a.data==[1,2,3]).all() + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + def test_array_migrate(self): + with jt.flag_scope(use_cuda=1): + a = jt.array(np.float32([1,2,3])) + b = jt.code(a.shape, a.dtype, [a], cpu_src=""" + for (int i=0; i +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +from jittor import LOG +import os +import re +import platform + +class TestAsmTuner(unittest.TestCase): + @classmethod + def setUpClass(self): + inline = "inline" + if jt.flags.cc_type == "clang": + inline = "__attribute__((always_inline))" + self.cc_content=''' +#include +#include +#include "var.h" +#include "ops/broadcast_to_op.h" +#include "ops/binary_op.h" +#include "fused_op.h" +#define op0_Tx float32 +#define op0_DIM 2 +#define op0_BCAST 1 +#define op0_index_t int32_t +#define op1_Tx float +#define op1_DIM 2 +#define op1_BCAST 0 +#define op1_index_t int32_t +#define op2_Tx float +#define op2_Ty float32 +#define op2_Tz float32 +#define op2_OP subtract +#define op2_index_t int32_t +using namespace jittor; +#define INLINE_FUNC '''+inline+''' void +INLINE_FUNC func0(op0_Tx* __restrict__ op0_xp, op1_Tx* __restrict__ op1_xp, op2_Tz* __restrict__ op2_zp) { + //@begin replace "vmova(.*,.*\(.*\))" "vmovnt\g<1>" + (void)(__builtin_assume_aligned(op0_xp, alignment)); + (void)(__builtin_assume_aligned(op1_xp, alignment)); + (void)(__builtin_assume_aligned(op2_zp, alignment)); + op2_index_t range0 = 1048576; + op2_index_t range1 = 32; + op0_index_t op0_xstride1 = 1; + auto op0_xstride0 = op0_xstride1 * range1; + op1_index_t op1_xstride1 = 1; + auto op1_xstride0 = op1_xstride1 * range1; + op2_index_t op2_zstride1 = 1; + auto op2_zstride0 = op2_zstride1 * range1; + for (op2_index_t id0 = 0; id0x; + auto op1_x = ((BroadcastToOp*)(ops[1]))->x; + auto op2_z = ((BinaryOp*)(ops[2]))->z; + auto* __restrict__ op0_xp = op0_x->ptr(); + auto* __restrict__ op1_xp = op1_x->ptr(); + auto* __restrict__ op2_zp = op2_z->ptr(); + func0(op0_xp,op1_xp,op2_zp); +} + ''' + + self.src_path=os.path.join(jt.flags.cache_path, 'jit', 'asm_test_op.cc') + self.asm_path = os.path.join(jt.flags.jittor_path, "utils/asm_tuner.py") + self.so_path=self.src_path.replace(".cc",".so") + + def run_cmd(self, cmd): + return jt.compiler.run_cmd(cmd) + + def check_cc(self, content, check_movnt): + LOG.vv("check_cc") + with open(self.src_path,"w") as f: + f.write(content) + + cmd = jt.flags.python_path + " " + \ + jt.flags.jittor_path+"/utils/asm_tuner.py --cc_path=" + jt.flags.cc_path + " '" + self.src_path + "'" + " -DJIT -DJIT_cpu " + jt.compiler.fix_cl_flags(jt.flags.cc_flags) + " -o '" + self.so_path + "'"; + self.run_cmd(cmd) + + s_path=self.so_path.replace(".so",".s") + bo=False + with open(s_path) as f: + for line in f: + if line.find("vmovnt")!=-1: + bo=True + break + if check_movnt and jt.flags.cc_type == "clang": + assert bo + + @unittest.skipIf(platform.system() == 'Darwin', 'will crash on macOS') + def test_asm_tuner(self): + self.check_cc(self.cc_content,True) + self.check_cc(self.cc_content.replace("@begin","233").replace("@end","666"), False) + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_atomic_tuner.py b/python/jittor/test/test_atomic_tuner.py new file mode 100644 index 00000000..dfab32f7 --- /dev/null +++ b/python/jittor/test/test_atomic_tuner.py @@ -0,0 +1,77 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +from jittor import nn, Module +import numpy as np +import sys, os +import random +import math +import unittest +from .test_reorder_tuner import simple_parser +from .test_log import find_log_with_re + +class testNet(Module): + def __init__(self, op): + self.op = op + return + + def execute(self, x): + N,H,W,C = x.shape + y1=x.reindex_reduce(self.op, [N,H], ["i0","i1",]) + y2=x.reindex_reduce(self.op, [H,W], ["i1","i2",]) + y1=y1.broadcast([N,H,W],[2]) + y2=y2.broadcast([N,H,W],[0]) + return y1+y2 + +class TestAtomicTunerClass(unittest.TestCase): + @classmethod + def setUpClass(self): + self.addNet = testNet("add") + self.maxNet = testNet("maximum") + self.minNet = testNet("minimum") + return + + def check(self, model, std_log): + x=jt.random([100,64,128,128]) + with jt.log_capture_scope( + # log_silent=1, + log_v=0, log_vprefix="atomic=100,data=100", + ) as logs: + y=model(x).numpy() + with jt.log_capture_scope( + log_v=0, + exclude_pass="atomic", + # new options to force recompile + compile_options = {"test_atomic_tuner":1} + ) as logs2: + y_std=model(x).numpy() + + err=np.max(y_std-y)/(np.mean(y_std)+1e-6) + assert err<1e-5, (err) + log_move = find_log_with_re(logs, "atomictuner: move .* to loop .*") + assert len(log_move)==len(std_log), (len(log_move), len(std_log)) + assert sorted(log_move) == sorted(std_log) + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1) + def test_atomic_tuner(self): + self.check(self.addNet, ['atomictuner: move atomicAdd to loop 1', 'atomictuner: move atomicAdd to loop 2']) + self.check(self.maxNet, ['atomictuner: move cuda_atomic_max to loop 1', 'atomictuner: move cuda_atomic_max to loop 2']) + self.check(self.minNet, ['atomictuner: move cuda_atomic_min to loop 1', 'atomictuner: move cuda_atomic_min to loop 2']) + + self.check(lambda x: x.sum()+x.sqr().mean(), [ + 'atomictuner: move atomicAdd to loop -1', + 'atomictuner: move atomicAdd to loop -1', + ]) + + self.check(lambda x: x.reindex_reduce("add", x.shape, ["i2","i3","i0","i1"]), []) + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_attention.py b/python/jittor/test/test_attention.py new file mode 100644 index 00000000..7f3cbf63 --- /dev/null +++ b/python/jittor/test/test_attention.py @@ -0,0 +1,66 @@ + +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import jittor.attention as jtatt +import numpy as np + +skip_this_test = False + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + import torch.nn as tnn + import fairseq +except: + torch = None + tnn = None + skip_this_test = True + +def check_equal(q,k,v,tatt,jatt): + tq=torch.from_numpy(q) + jq=jt.array(q) + tk=torch.from_numpy(k) + jk=jt.array(k) + tv=torch.from_numpy(v) + jv=jt.array(v) + + jatt.load_parameters(tatt.state_dict()) + ty, tw = tatt(tq, tk, tv) + jy, jw = jatt(jq, jk, jv) + assert np.allclose(ty.detach().numpy(), jy.numpy(), rtol=1e-3) + assert np.allclose(tw.detach().numpy(), jw.numpy(), rtol=1e-3) + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestAttention(unittest.TestCase): + def test_attention(self): + q=np.random.rand(4,8,16).astype(np.float32) + k=np.random.rand(4,8,16).astype(np.float32) + v=np.random.rand(4,8,16).astype(np.float32) + + tatt=fairseq.modules.multihead_attention.MultiheadAttention(16,1) + jatt=jt.attention.MultiheadAttention(16,1) + check_equal(q,k,v,tatt,jatt) + + tatt=fairseq.modules.multihead_attention.MultiheadAttention(16,4) + jatt=jt.attention.MultiheadAttention(16,4) + check_equal(q,k,v,tatt,jatt) + + tatt=fairseq.modules.multihead_attention.MultiheadAttention(16,1,self_attention=True) + jatt=jt.attention.MultiheadAttention(16,1,self_attention=True) + check_equal(q,q,q,tatt,jatt) + + tatt=fairseq.modules.multihead_attention.MultiheadAttention(16,4,self_attention=True) + jatt=jt.attention.MultiheadAttention(16,4,self_attention=True) + check_equal(q,q,q,tatt,jatt) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_auto_diff.py b/python/jittor/test/test_auto_diff.py new file mode 100644 index 00000000..f8a873c2 --- /dev/null +++ b/python/jittor/test/test_auto_diff.py @@ -0,0 +1,70 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import numpy as np +import os +import sys +import jittor as jt + +skip_this_test = False +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + import torchvision.models as tcmodels + from torch import nn +except: + torch = None + skip_this_test = True + +@unittest.skipIf(skip_this_test, "skip_this_test") +class TestAutoDiff(unittest.TestCase): + def test_pt_hook(self): + code = ''' +import numpy as np +from jittor_utils import auto_diff +import torch +import torchvision.models as tcmodels +net = tcmodels.resnet50() +net.train() +hook = auto_diff.Hook("resnet50") +hook.hook_module(net) + +np.random.seed(0) +data = np.random.random((2,3,224,224)).astype('float32') +data = torch.Tensor(data) +net(data) +# assert auto_diff.has_error == 0, auto_diff.has_error +''' + with open("/tmp/test_pt_hook.py", 'w') as f: + f.write(code) + print(jt.flags.cache_path) + os.system(f"rm -rf {jt.flags.cache_path}/../../auto_diff/resnet50") + assert os.system(sys.executable+" /tmp/test_pt_hook.py") == 0 + assert os.system(sys.executable+" /tmp/test_pt_hook.py") == 0 + code = ''' +import numpy as np +import jittor as jt +from jittor_utils import auto_diff +from jittor.models import resnet50 +net = resnet50() +net.train() +hook = auto_diff.Hook("resnet50") +hook.hook_module(net) + +np.random.seed(0) +data = np.random.random((2,3,224,224)).astype('float32') +data = jt.array(data) +net(data) +# assert auto_diff.has_error == 0, auto_diff.has_error +''' + with open("/tmp/test_jt_hook.py", 'w') as f: + f.write(code) + assert os.system(sys.executable+" /tmp/test_jt_hook.py") == 0 + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_batchnorm.py b/python/jittor/test/test_batchnorm.py new file mode 100644 index 00000000..1f5e271d --- /dev/null +++ b/python/jittor/test/test_batchnorm.py @@ -0,0 +1,135 @@ + +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Wenyang Zhou <576825820@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import jittor.nn as jnn + +skip_this_test = False + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + import torch.nn as tnn +except: + torch = None + tnn = None + skip_this_test = True + +def check_equal_with_istrain(arr, j_layer, p_layer, is_train=True, has_running=True, threshold=1e-5): + jittor_arr = jt.array(arr) + pytorch_arr = torch.Tensor(arr) + if has_running: + if is_train: + assert np.allclose(p_layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold) + else: + assert np.allclose(p_layer.layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold) + jittor_result = j_layer(jittor_arr) + pytorch_result = p_layer(pytorch_arr) + if has_running: + if is_train: + assert np.allclose(p_layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold) + else: + assert np.allclose(p_layer.layer.running_mean.detach().numpy(), j_layer.running_mean.numpy(), threshold) + assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy(), 1e-2, threshold), \ + ( np.abs(pytorch_result.detach().numpy() - jittor_result.numpy()).max() ) + +def check_equal_without_istrain(arr, j_layer, p_layer, threshold=1e-5): + jittor_arr = jt.array(arr) + pytorch_arr = torch.Tensor(arr) + jittor_result = j_layer(jittor_arr) + pytorch_result = p_layer(pytorch_arr) + assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy(), threshold) + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestBatchNorm(unittest.TestCase): + @jt.flag_scope(auto_convert_64_to_32=0) + def test_batchnorm(self): + # *************************************************************** + # Test BatchNorm Layer + # *************************************************************** + arr = np.random.randn(16,10,224,224) + check_equal_with_istrain(arr, jnn.BatchNorm(10, is_train=True), tnn.BatchNorm2d(10)) + + class Model(tnn.Module): + def __init__(self): + super(Model, self).__init__() + self.layer = tnn.BatchNorm2d(10) + def forward(self, x): + return self.layer(x) + model = Model() + model.eval() + check_equal_with_istrain(arr, jnn.BatchNorm(10, is_train=False), model, False) + + # *************************************************************** + # Test InstanceNorm2d Layer + # *************************************************************** + arr = np.random.randn(16,10,224,224) + check_equal_without_istrain(arr, jnn.InstanceNorm2d(10, is_train=True), tnn.InstanceNorm2d(10)) + + class Model(tnn.Module): + def __init__(self): + super(Model, self).__init__() + self.layer = tnn.InstanceNorm2d(10) + def forward(self, x): + return self.layer(x) + model = Model() + model.eval() + check_equal_without_istrain(arr, jnn.InstanceNorm2d(10, is_train=False), model) + + # *************************************************************** + # Test BatchNorm1d Layer + # *************************************************************** + arr = np.random.randn(16,10) + check_equal_with_istrain(arr, jnn.BatchNorm1d(10, is_train=True), tnn.BatchNorm1d(10), 1e-3) + + class Model(tnn.Module): + def __init__(self): + super(Model, self).__init__() + self.layer = tnn.BatchNorm1d(10) + def forward(self, x): + return self.layer(x) + model = Model() + model.eval() + check_equal_with_istrain(arr, jnn.BatchNorm1d(10, is_train=False), model, False) + + # *************************************************************** + # Test GroupNorm Layer + # *************************************************************** + arr = np.random.randn(16,10,224,224) + + class Model(tnn.Module): + def __init__(self): + super(Model, self).__init__() + self.layer = tnn.GroupNorm(2, 10) + def forward(self, x): + return self.layer(x) + model = Model() + model.eval() + check_equal_with_istrain(arr, jnn.GroupNorm(2, 10, is_train=False), model, False, False) + + # *************************************************************** + # Test LayerNorm Layer + # *************************************************************** + arr = np.random.randn(16,10,224,224) + + class Model(tnn.Module): + def __init__(self): + super(Model, self).__init__() + self.layer = tnn.LayerNorm(224) + def forward(self, x): + return self.layer(x) + model = Model() + model.eval() + check_equal_with_istrain(arr, jnn.LayerNorm(224), model, False, False) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_benchmark.py b/python/jittor/test/test_benchmark.py new file mode 100644 index 00000000..891d8a0a --- /dev/null +++ b/python/jittor/test/test_benchmark.py @@ -0,0 +1,344 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import os + +n = 400000000 +# n = 4000000 +n = 7680000 + +def get_mem_band(): + a = jt.rand((n)).float32() + for i in range(100): + a.copy().sync() + jt.sync_all(True) + import time + t = time.time() + for i in range(1000): + a.copy().sync() + jt.sync_all(True) + dt = time.time() - t + band = a.numel() * 4 * 2000 / dt / 1024**3 + print("Mem band: ", band) + return band + +def check_simple_add_band(): + # copy: 816 + # S=1 128,1024, ILP=1 634 + # S=0 128,1024, ILP=1 734 + # S=0 128,512, ILP=1 716 + # S=0 64,1024, ILP=1 706 + # S=0 256,1024, ILP=1 706 + def test(S=0, B=128, T=1024, ILP=1): + a = jt.rand((n)).float32() + jt.sync_all(True) + jt.flags.log_silent = 1 + with jt.profile_scope(100, 1000) as rep: + b = jt.code(a.shape, a.dtype, [a], + cuda_header="#include \"type/fp16_compute.h\"", + cuda_src=f""" + __global__ void kernel(in0_type * __restrict__ a, in0_type* __restrict__ b, int num) {{ + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int tnum = blockDim.x * gridDim.x; + #define ILP {ILP} + for (int i=tid*ILP; i(b+i, a+i); + {"__syncthreads();" if S else ""} + }} + }} + kernel<<<{B},{T}>>>(in0_p, out0_p, in0->num); + """) + b.sync() + bw = float(rep[-1][9]) / 1024**3 + s = f"S={S}, B={B}, T={T}, ILP={ILP} BW={bw}" + print(s) + return s, bw + + def test2(S=0, B=128, T=1024, ILP=1): + a = jt.rand((n)).float32() + jt.sync_all(True) + # jt.flags.log_silent = 0 + with jt.profile_scope(10, 1000) as rep: + b = jt.code(a.shape, a.dtype, [a], + cuda_header="#include \"type/fp16_compute.h\"", + cuda_src=f""" + __global__ void kernel(float2 * __restrict__ a, float2* __restrict__ b, int num) {{ + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int tnum = blockDim.x * gridDim.x; + #define ILP 1 + for (int i=tid*ILP; i(b+i, a+i); + {"__syncthreads();" if S else ""} + }} + }} + kernel<<<{B},{T}>>>((float2*)in0_p, (float2*)out0_p, in0->num/2); + """) + b.sync() + bw = float(rep[-1][9]) / 1024**3 + s = f"T2: S={S}, B={B}, T={T}, ILP={ILP} BW={bw}" + print(s) + return s, bw + + + def test3(S=0, B=128, T=1024, ILP=1, C=0): + a = jt.rand((n)).float32() + b = jt.rand(B) + jt.sync_all(True) + jt.flags.log_silent = 1 + with jt.profile_scope(100, 1000) as rep: + b = jt.code(a.shape, a.dtype, [a, b], + cuda_header="#include \"type/fp16_compute.h\"", + cuda_src=f""" + __global__ void kernel(in0_type * __restrict__ a, in0_type* __restrict__ b, int num) {{ + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int tnum = blockDim.x * gridDim.x; + #define ILP {ILP} + for (int i=tid*ILP; i(b+i, a+i); + {"__syncthreads();" if S else ""} + }} + {"__syncthreads();" if C else ""} + }} + kernel<<shape[0],{T}>>>(in0_p, out0_p, in0->num); + """) + b.compile_options = {"FLAGS: -Xptxas -dlcm=ca ": C} + # b.compile_options = {"FLAGS: –Xptxas –dlcm=ca ": 1} + b.sync() + + bw = float(rep[-1][9]) / 1024**3 + s = f"T3: S={S}, B={B}, T={T}, ILP={ILP} C={C} BW={bw}" + print(s) + return s, bw + + + def test4(S=0, B=128, T=1024, ILP=1, C=0, name="b.png"): + a = jt.rand((n)).float32() + b = jt.rand(B*4).uint32() + jt.sync_all(True) + # jt.flags.log_silent = 1 + with jt.profile_scope(100, 10000) as rep: + _ = jt.code(a.shape, a.dtype, [a, b], + cuda_header="#include \"type/fp16_compute.h\"", + cuda_src=f""" + __device__ uint get_smid(void) {{ + uint ret; + asm("mov.u32 %0, %smid;" : "=r"(ret) ); + return ret; + }} + __device__ uint get_time(void) {{ + uint ret; + asm volatile("mov.u32 %0, %%globaltimer_lo;" : "=r"(ret)); + return ret; + }} + + __global__ void kernel(in0_type * __restrict__ a, in0_type* __restrict__ b, int num, in1_type* __restrict__ c) {{ + uint t = get_time(); + int tid = threadIdx.x + blockIdx.x * blockDim.x; + int tnum = blockDim.x * gridDim.x; + #define ILP {ILP} + for (int i=tid*ILP; i(b+i, a+i); + {"__syncthreads();" if S else ""} + }} + {"__syncthreads();" if C else ""} + if (threadIdx.x == 0) + ((uint4* __restrict__)c)[blockIdx.x] = + uint4{{get_smid(), t, get_time(), 0}}; + }} + kernel<<shape[0]/4,{T}>>>(in0_p, out0_p, in0->num, in1_p); + """) + _.compile_options = {"FLAGS: -Xptxas -dlcm=ca ": C} + # b.compile_options = {"FLAGS: –Xptxas –dlcm=ca ": 1} + _.sync() + + bw = float(rep[-1][9]) / 1024**3 + b = b.data.reshape(-1, 4)[:,:3] + mint = b[:,1].min() + b[:,1:] -= mint + smmax = int(b[:,0].max()) + smmin = int(b[:,0].min()) + maxt = b.max() + + # print(b) + + s = f"T4: S={S}, B={B}, T={T}, ILP={ILP} C={C} BW={bw:.3f} sm={smmin},{smmax} maxt={maxt}" + print(s) + import pylab as pl + pl.figure(figsize=(16,16)) + texts = [] + pret = np.zeros(200, dtype="uint32") + for i in range(B): + smid, s, t = b[i] + pl.plot([s,t], [smid, smid], 'ro-') + texts.append((s, smid, i)) + texts.append((t, smid, i)) + + texts = sorted(texts) + for (s, smid, bid) in texts: + cpos = max(pret[smid], s) + pl.text(cpos, smid, str(bid)) + pret[smid] = cpos + maxt // 30 + + + # print("???") + # adjust_text(texts, arrowprops=dict(arrowstyle='->', color='blue')) + # print("???") + pl.savefig(name) + pl.close() + return s, bw + # test(S=0, B=128, T=1024, ILP=1) + # test(S=1, B=128, T=1024, ILP=1) + # test(S=0, B=64, T=1024, ILP=1) + # test(S=0, B=256, T=1024, ILP=1) + # test(S=1, B=128, T=512, ILP=1) + # test(S=1, B=128, T=256, ILP=1) + + # test(S=0, B=128, T=1024, ILP=2) + # test(S=0, B=128, T=1024, ILP=4) + # test(S=0, B=128, T=512, ILP=2) + # test(S=0, B=128, T=512, ILP=4) + + # test(S=1, B=128, T=1024, ILP=2) + # test(S=1, B=128, T=1024, ILP=4) + # test(S=1, B=128, T=1024, ILP=8) + # test(S=1, B=128, T=1024, ILP=16) + # test(S=1, B=128, T=512, ILP=2) + # test(S=1, B=128, T=512, ILP=4) + + # test(S=1, B=256, T=1024, ILP=2) + # test(S=1, B=512, T=1024, ILP=2) + # test(S=1, B=256, T=1024, ILP=4) + # test(S=1, B=256, T=1024, ILP=8) + # test(S=1, B=256, T=1024, ILP=16) + # test(S=1, B=256, T=512, ILP=2) + # test(S=1, B=256, T=512, ILP=4) + + # test(S=1, B=128, T=256, ILP=2) + # test(S=1, B=128, T=256, ILP=4) + # test(S=0, B=128, T=256, ILP=2) + # test(S=0, B=128, T=256, ILP=4) + + # for b in [1, 2, 4, 8, 16, 32, 64, 128,256]: + # test(S=1, B=b, T=512, ILP=2) + + import matplotlib as mpl + mpl.use('Agg') + import pylab as pl + import numpy as np + + # test4(S=1, B=82, T=1024, ILP=2, C=0, name="b.png") + # test4(S=1, B=83, T=1024, ILP=2, C=0, name="c.png") + # test4(S=1, B=82*3, T=512, ILP=2, C=0, name="d1.png") + # test4(S=1, B=82*3+1, T=512, ILP=2, C=0, name="d2.png") + # test4(S=1, B=82*6+1, T=512, ILP=2, C=0, name="d3.png") + # test4(S=0, B=82*6+1, T=512, ILP=2, C=0, name="d4.png") + + for b in range(70, 83): + test4(S=1, B=b, T=1024, ILP=2, C=0, name=f"b-{b}.png") + + # data = [] + # for b in range(32, 2000, 8): + # _, bw = test3(S=0, B=b, T=32, ILP=2) + # data.append([b, bw]) + # data = np.array(data) + # pl.plot(data[:,0], data[:,1]) + + # for t in [32, 64, 128, 256, 512, 1024]: + # data = [] + # for b in range(32, 2000, 8): + # _, bw = test3(S=1, B=b*(1024//t), T=t, ILP=2) + # data.append([b, bw]) + # data = np.array(data) + # pl.plot(data[:,0], data[:,1]) + + # for t in [1024]: + # for c in [0,1]: + # data = [] + # # for b in range(32, 1000, 8): + # for b in range(32, 33, 8): + # _, bw = test3(S=c, B=b*(1024//t), T=t, ILP=2, C=0) + # data.append([b, bw]) + # data = np.array(data) + # pl.plot(data[:,0], data[:,1]) + + # for ilp in [2]: + # for s in [1]: + # for t in [1024,512,256,128]: + # data = [] + # for b in range(32, 1100, 8): + # _, bw = test3(S=s, B=b*(1024//t), T=t, ILP=ilp) + # data.append([b, bw]) + # data = np.array(data) + # pl.plot(data[:,0], data[:,1]) + + # pl.savefig("a.png") + # pl.close() + # for b in range(80, 90, 1): + # _, bw = test3(S=1, B=b, T=1024, ILP=2) + # # 82 + # for b in range(240, 260, 1): + # _, bw = test3(S=1, B=b, T=512, ILP=2) + # # 82*3 = 246 + # for b in range(240, 500, 1): + # _, bw = test3(S=1, B=b, T=256, ILP=2) + # # 492 = 82*6 + # for b in range(240, 1000, 1): + # _, bw = test3(S=1, B=b, T=128, ILP=2) + # # 984 = 82*12 + + + # for b in [128,256]: + # test(S=1, B=b, T=1024, ILP=2) + # for b in [128,256]: + # test(S=0, B=b, T=512, ILP=2) + # for b in [128,256]: + # test(S=0, B=b, T=1024, ILP=2) + # for b in [128,256]: + # test(S=1, B=b, T=512, ILP=1) + # for b in [128,256]: + # test(S=1, B=b, T=1024, ILP=1) + # for b in [128,256]: + # test(S=0, B=b, T=512, ILP=1) + # for b in [128,256]: + # test(S=0, B=b, T=1024, ILP=1) + # test(S=1, B=128, T=512, ILP=4) + # test(S=1, B=64, T=512, ILP=2) + # test(S=1, B=80, T=512, ILP=2) + # test(S=1, B=100, T=512, ILP=2) + # test(S=1, B=110, T=512, ILP=2) + # test(S=1, B=115, T=512, ILP=2) + # test(S=1, B=120, T=512, ILP=2) + # test(S=1, B=130, T=512, ILP=2) + # test(S=1, B=140, T=512, ILP=2) + # test2(S=1, B=128, T=512, ILP=2) + # test(S=1, B=128, T=256, ILP=4) + # test(S=1, B=128, T=128, ILP=8) + # test(S=1, B=128, T=64, ILP=16) + + + +@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") +class TestBenchmarkCUDA(unittest.TestCase): + def setUp(self): + jt.flags.use_cuda = 1 + def tearDown(self): + jt.flags.use_cuda = 0 + + def test_main(self): + return + get_mem_band() + check_simple_add_band() + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_bf16.py b/python/jittor/test/test_bf16.py new file mode 100644 index 00000000..c2c51fec --- /dev/null +++ b/python/jittor/test/test_bf16.py @@ -0,0 +1,384 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import os + +def transpose0231(x): + s0, s1, s2, s3 = x.shape + asize = 16 + bsize = 16 + ILP = 2 + return jt.code([s0, s2, s3, s1], x.dtype, [x], + cuda_header="#include \n#include ", + cuda_src=f""" + __global__ void kernel(in0_type* __restrict__ x, in0_type* __restrict__ y, int s0, int s1, int s2, int s3) {{ + __shared__ in0_type t[{asize*ILP}*{bsize*ILP+1}]; + int t3 = threadIdx.x % {bsize}; + int t1 = threadIdx.x / {bsize}; + int b3 = blockIdx.x; + int b2 = blockIdx.y; + int b0 = blockIdx.z; + int x3 = 1; + int x2 = s3; + int x1 = s2*x2; + int x0 = s1*x1; + int y3 = 1; + int y2 = s1; + int y1 = s3*y2; + int y0 = s2*y1; + in0_type tmp[{ILP}]; + for (int i=0; i<(s1-1)/{asize*ILP}+1; i++) + {{ + int _b3 = b3 * {bsize*ILP} + t3*{ILP}; + if (_b3 < s3) {{ + #pragma unroll + for (int j=0; j<{ILP}; j++) {{ + vload( + tmp, + &x[b0*x0+(t1*{ILP}+j+i*{asize*ILP})*x1+b2*x2+_b3*x3] + ); + #pragma unroll + for (int k=0; k<{ILP}; k++) + t[(t1*{ILP}+j)*{bsize*ILP+1}+t3*{ILP}+k] = tmp[k]; + + }} + }} + __syncthreads(); + int t3_ = threadIdx.x % {asize}; + int t1_ = threadIdx.x / {asize}; + _b3 = b3 * {bsize*ILP} + t1_*{ILP}; + if (_b3 < s3) {{ + #pragma unroll + for (int j=0; j<{ILP}; j++) {{ + #pragma unroll + for (int k=0; k<{ILP}; k++) {{ + tmp[k] = + t[(t3*{ILP}+k)*{bsize*ILP+1}+t1_*{ILP}+j]; + }} + vload( + &y[b0*y0+b2*y1+(_b3+j)*y2+((t3*{ILP})+i*{asize*ILP})*y3], + tmp + ); + }} + }} + __syncthreads(); + }} + }} + int s0, s1, s2, s3; + in0->shape.unpack(s0, s1, s2, s3); + kernel<<<{{(s3-1)/{bsize*ILP}+1, s2, s0 }}, {bsize*asize}>>> + (in0_p, out0_p, s0, s1, s2, s3); + """) + +def transpose0231_2(x): + s0, s1, s2, s3 = x.shape + asize = 16 + bsize = 8 + ILP = 2 + return jt.code([s0, s2, s3, s1], x.dtype, [x], + cuda_header="#include \n#include ", + cuda_src=f""" + __global__ __launch_bounds__({asize*bsize}) void kernel(in0_type* __restrict__ x, in0_type* __restrict__ y, int s0, int s1, int s2, int s3) {{ + __shared__ in0_type t[{asize*ILP}*{bsize*ILP+1}]; + int t3 = threadIdx.x % {bsize}; + int t1 = threadIdx.x / {bsize}; + int b3 = blockIdx.x; + int b1 = blockIdx.y; + int b2 = 0; + int b0 = blockIdx.z; + int x3 = 1; + int x2 = s3; + int x1 = s2*x2; + int x0 = s1*x1; + int y3 = 1; + int y2 = s1; + int y1 = s3*y2; + int y0 = s2*y1; + in0_type tmp[{ILP}]; + {{ + int _b3 = b3 * {bsize*ILP} + t3*{ILP}; + if (_b3 < s3) {{ + #pragma unroll + for (int j=0; j<{ILP}; j++) {{ + if (t1*{ILP}+j+b1*{asize*ILP} >= s1) + continue; + vload( + tmp, + &x[b0*x0+(t1*{ILP}+j+b1*{asize*ILP})*x1+b2*x2+_b3*x3] + ); + #pragma unroll + for (int k=0; k<{ILP}; k++) + t[(t1*{ILP}+j)*{bsize*ILP+1}+t3*{ILP}+k] = tmp[k]; + + }} + }} + __syncthreads(); + int t3_ = threadIdx.x % {asize}; + int t1_ = threadIdx.x / {asize}; + _b3 = b3 * {bsize*ILP} + t1_*{ILP}; + int yy3 = (t3_*{ILP})+b1*{asize*ILP}; + if (_b3 < s3 && yy3 < s1) {{ + #pragma unroll + for (int j=0; j<{ILP}; j++) {{ + #pragma unroll + for (int k=0; k<{ILP}; k++) {{ + tmp[k] = + t[(t3_*{ILP}+k)*{bsize*ILP+1}+t1_*{ILP}+j]; + }} + vload( + &y[b0*y0+b2*y1+(_b3+j)*y2+yy3*y3], + tmp + ); + // printf("%d %d %d %d %d\\n", b0*y0+b2*y1+(_b3+j)*y2+yy3*y3, + // b0, b2, (_b3+j), yy3); + }} + }} + __syncthreads(); + }} + }} + int s0, s1, s2, s3; + in0->shape.unpack(s0, s1, s2, s3); + kernel<<<{{(s3-1)/{bsize*ILP}+1, (s1-1)/{asize*ILP}+1, s0 }}, {bsize*asize}>>> + (in0_p, out0_p, s0, s1, s2, s3); + """) + +def check_share(): + return + a = jt.rand((30, 32, 4, 2000)).float32() + jt.code(a.shape, a.dtype, [a], + cuda_header="#include \n#include ", + cuda_src=""" + __global__ void kernel(in0_type* __restrict__ a, in0_type* __restrict__ b) { + __shared__ float x[32*33]; + for (int i=0; i<3; i++) { + ((float2*)&x[i])[0] = ((float2*)&a[i])[0]; + ((float2*)&b[i])[0] = ((float2*)&x[i+1])[0]; + } + } + kernel<<<1024,16*16>>>(in0_p, out0_p); + """).sync() + jt.sync_all(True) + # print(a[0]+1) + print("pass test") + +class TestBF16(unittest.TestCase): + def test_array(self): + a = np.array([1,2,3], dtype="float") + b = jt.array(a).bfloat16() + np.testing.assert_allclose(a, b.float().numpy()) + + def test_add(self): + a = np.array([1,2,3], dtype="float32") + b = jt.bfloat16(a) + c = b+b + assert c.dtype == "bfloat16" + np.testing.assert_allclose(c.numpy(), a+a) + d = c.sum() + np.testing.assert_allclose(d.numpy(), [12]) + c = c+1 + print(c) + + def test_matmul(self): + a = jt.random((100,100)).bfloat16() + b = jt.random((100,100)).bfloat16() + c = jt.matmul(a, b) + c.sync() + print(c) + assert c.dtype == "bfloat16" + + def test_bmm(self): + a = jt.random((10,3,4)).bfloat16() + b = jt.random((10,4,5)).bfloat16() + c = jt.matmul(a, b) + c.sync() + + def test_matmul_grad(self): + a = jt.random((100,100)).bfloat16() + b = jt.random((100,100)).bfloat16() + c = jt.matmul(a, b) + c.sync() + da, db = jt.grad(c, [a,b]) + jt.sync_all() + assert da.dtype == "bfloat16" + assert db.dtype == "bfloat16" + + def test_conv(self): + a = jt.random((3,4,5,5)).bfloat16() + b = jt.random((4,4,3,3)).bfloat16() + c = jt.nn.conv(a, b) + c.sync() + + def test_max(self): + a = jt.random((100,)).bfloat16() + b = jt.random((100,)).bfloat16() + c = a.maximum(b) + c.sync() + + def test_reduce_dtype_infer(self): + return + # this test cannot pass now + with jt.flag_scope(amp_reg=1): + a = jt.random((3,4,5,5)).bfloat16() + b = a.sum() + b.sync() + assert b.dtype == "float32", b.dtype + with jt.flag_scope(amp_reg=2): + a = jt.random((3,4,5,5)).bfloat16() + b = a.sum() + b.sync() + assert b.dtype == "float32" + with jt.flag_scope(amp_reg=0): + a = jt.random((3,4,5,5)).bfloat16() + b = a.sum() + b.sync() + assert b.dtype == "float32" + with jt.flag_scope(amp_reg=2+4): + a = jt.random((3,4,5,5)).bfloat16() + b = a.sum() + b.sync() + assert b.dtype == "bfloat16", b.dtype + + def test_white_dtype_infer(self): + with jt.flag_scope(amp_reg=1): + a = jt.random((3,4,5,5)).bfloat16() + b = a**a + b.sync() + assert b.dtype == "float32" + with jt.flag_scope(amp_reg=2): + a = jt.random((3,4,5,5)).bfloat16() + b = a**a + b.sync() + assert b.dtype == "float32" + with jt.flag_scope(amp_reg=0): + a = jt.random((3,4,5,5)).bfloat16() + b = a**a + b.sync() + assert b.dtype == "float32" + with jt.flag_scope(amp_reg=2+8): + a = jt.random((3,4,5,5)).bfloat16() + b = a**a + b.sync() + assert b.dtype == "bfloat16", b.dtype + + def test_module_half(self): + a = jt.nn.Linear(10,10) + assert a.weight.dtype == "float32" + a.bfloat16() + assert a.weight.dtype == "bfloat16" + + def test_scalar(self): + a = jt.bfloat16([1,2,3]) + assert (a*1).dtype == "bfloat16" + assert (a*jt.bfloat16([1,2,3])).dtype == "bfloat16" + assert (a*jt.float32([1,2,3])).dtype == "float32" + assert (a*jt.float32([1,2,3]).sum()).dtype == "bfloat16" + assert jt.int([0,1,0]).ternary(a, jt.float32(1)).dtype == "bfloat16" + + def test_amp_level3(self): + with jt.flag_scope(amp_level = 3): + a = jt.bfloat16([1,2,3]) + assert (a.sum()).dtype == "bfloat16" + assert (a.mean()).dtype == "bfloat16" + assert (a.log()).dtype == "bfloat16" + assert (a.exp()).dtype == "bfloat16" + + def test_safe_clip(self): + import math + assert not jt.bfloat16(math.inf).isfinite() + assert jt.safe_clip(jt.bfloat16(math.inf)).isfinite() + +@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") +class TestBF16CUDA(TestBF16): + def setUp(self): + jt.flags.use_cuda = 1 + def tearDown(self): + jt.flags.use_cuda = 0 + + def test_add_correct(self): + na = np.random.rand(10000) + nb = np.random.rand(10000) + a = jt.array(na).bfloat16() + b = jt.array(nb).bfloat16() + c = a + b + nc = c.numpy() + np.testing.assert_allclose(nc, na+nb, atol=1e-2) + + def test_matmul_correct(self): + na = np.random.rand(64,64) + nb = np.random.rand(64,64) + a = jt.array(na).bfloat16() + b = jt.array(nb).bfloat16() + c = jt.matmul(a, b) + nc = c.numpy() + nc2 = np.matmul(na, nb) + np.testing.assert_allclose(nc, nc2, rtol=1e-2) + + + def test_softmax(self): + a = jt.rand((120, 2000, 2000)).bfloat16() + # a = jt.rand((1, 2000, 2000)).float32() + jt.sync_all() + with jt.profile_scope(10, 100): + a.log_softmax(-1).sync() + + def test_transpose(self): + check_share() + # return + a = jt.rand((30, 32, 4, 2000)).float32() + # a = jt.rand((1, 1024, 1, 2000)).float32() + diff = transpose0231(a).data != a.transpose((0,2,3,1)).data + print(np.where(diff)) + # return + jt.sync_all() + # with jt.profile_scope(100, 11000): + with jt.profile_scope(100, 11000): + # a.log_softmax(-1).sync() + transpose0231(a).sync() + + a.transpose((0,2,3,1)).sync() + # a.transpose((0,2,1,3)).sync() + a.fuse_transpose((0,2,1,3)).sync() + (a+1).sync() + jt.sync_all(True) + diff = transpose0231(a).data != a.transpose((0,2,3,1)).data + print(np.where(diff)) + np.testing.assert_allclose(transpose0231(a).data, a.transpose((0,2,3,1)).data) + + def test_transpose2(self): + # check_share() + # return + # a = jt.rand((30, 32, 4, 2000)).float32() + # a = jt.rand((1, 10000, 1, 2000)).float32() + a = jt.rand((1, 10000, 1, 2048)).float32() + print("transpose") + transpose0231_2(a).sync() + print("add") + (a+1).sync() + return + # a = jt.arange(32*16).reshape((1, 32, 1, 16)) + diff = transpose0231_2(a).data != a.transpose((0,2,3,1)).data + print(np.where(diff)) + # return + jt.sync_all() + # with jt.profile_scope(100, 11000): + with jt.profile_scope(100, 1100): + # a.log_softmax(-1).sync() + transpose0231_2(a).sync() + + a.transpose((0,2,3,1)).sync() + # a.transpose((0,2,1,3)).sync() + a.fuse_transpose((0,2,1,3)).sync() + (a+1).sync() + jt.sync_all(True) + diff = transpose0231_2(a).data != a.transpose((0,2,3,1)).data + print(np.where(diff)) + np.testing.assert_allclose(transpose0231_2(a).data, a.transpose((0,2,3,1)).data) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_bicubic.py b/python/jittor/test/test_bicubic.py new file mode 100644 index 00000000..3f3a4443 --- /dev/null +++ b/python/jittor/test/test_bicubic.py @@ -0,0 +1,53 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Haoyang Peng <2247838039@qq.com> +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np + +skip_this_test = False +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + from torch.nn import functional as F +except: + torch = None + skip_this_test = True + + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestBicubicInterpolate(unittest.TestCase): + # this is for testing bicubic interpolate + def test_bicubic(self): + for _ in range(20): + try: + tn = np.random.randn(1,1,5,5).astype('float32') + ja = jt.array(tn) + ta = torch.autograd.Variable(torch.from_numpy(tn),requires_grad=True) + # test upsample + ju = jt.nn.interpolate(ja,scale_factor=2,mode='bicubic') + tu = F.interpolate(ta,scale_factor=2,mode='bicubic') + assert np.allclose(ju.data,tu.detach().numpy(),rtol=1e-03,atol=1e-06) + gju = jt.grad(ju,ja) + gtu = torch.autograd.grad(tu,ta,torch.ones_like(tu),retain_graph=True)[0] + assert np.allclose(gju.data,gtu.detach().numpy(),rtol=1e-03,atol=1e-06) + # test align + je = jt.nn.interpolate(ja,scale_factor=2,mode='bicubic',align_corners=True) + te = F.interpolate(ta,scale_factor=2,mode='bicubic',align_corners=True) + assert np.allclose(je.data,te.detach().numpy(),rtol=1e-03,atol=1e-06) + gje = jt.grad(je,ja) + gte = torch.autograd.grad(te,ta,torch.ones_like(tu),retain_graph=True)[0] + assert np.allclose(gje.data,gte.detach().numpy(),rtol=1e-03,atol=1e-06) + except AssertionError: + print(ju,tu) + print(je,te) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_binary_op.py b/python/jittor/test/test_binary_op.py new file mode 100644 index 00000000..ee260ed0 --- /dev/null +++ b/python/jittor/test/test_binary_op.py @@ -0,0 +1,218 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from .test_core import expect_error +from .test_grad import ngrad +from .test_cuda import test_cuda + +def all_eq(x, y): + if len(x.shape) == 0: x = np.array([x]) + if len(y.shape) == 0: y = np.array([y]) + convert = lambda x: x.astype("uint8") if x.dtype=="bool" else x + x = convert(x) + y = convert(y) + if str(x.dtype).startswith("float"): + return str(y.dtype).startswith("float") and x.shape == y.shape and (x==y).all() + return x.dtype == y.dtype and x.shape == y.shape and np.testing.assert_allclose(x, y) + +def check(op, *args): + x = eval(f"np.{op}(*args)") + y = eval(f"jt.{op}(*args).data") + all_eq(x, y) + +class TestBinaryOp(unittest.TestCase): + def test_binary_op(self): + assert np.all(jt.binary(1,2,'maximum').data == 2) + assert np.all(jt.binary([[1,2]],[[3,4]],'add').data == [[4,6]]) + assert np.all(jt.less(1,2).data) + assert jt.less(1,2).data.dtype == "bool" + x = (jt.array(1) << jt.array(3)).data + assert (x == 8).all() + x = (jt.array(2) ** jt.array(3)).data + assert (x == 8).all() + a = np.array([1,2,3]) + b = np.array([7,10,13]) + check("logical_and", a, b) + check("logical_or", a, b) + check("logical_xor", a, b) + check("bitwise_and", a, b) + check("bitwise_or", a, b) + check("bitwise_xor", a, b) + + def test_i(self): + def check(op, a, b): + if isinstance(a, list): + a = np.array(a) + b = np.array(b) + if jt.flags.use_cuda and op == "@": + return + if op=="@": + a = np.float32(a) + b = np.float32(b) + ja = jt.array(a) + jb = jt.array(b) + exec(f"ja {op}= jb") + ja = ja.fetch_sync() + + if op == "@": + # numpy do not support @= + a = np.array(a) @ np.array(b) + else: + a = eval(f"a {op} b") + a = np.float32(a) + ja = np.float32(ja) + + all_eq(ja, a) + check("+", 5, 2) + check("-", 5, 2) + check("*", 5, 2) + check("/", 5, 2) + check("//", 5, 2) + # check("@", [[5]], [[2]]) + check("%", 5, 2) + check("**", 5, 2) + check("<<", 5, 2) + check(">>", 5, 2) + check("&", 5, 2) + check("^", 5, 2) + check("|", 5, 2) + + check("+", [5.0,6.0], [2.0,3.0]) + check("-", [5.0,6.0], [2.0,3.0]) + check("*", [5.0,6.0], [2.0,3.0]) + check("/", [5.0,6.0], [2.0,3.0]) + check("//", [5.0,6.0], [2.0,3.0]) + check("@", [[5,6],[7,8]], [[2,3],[4,5]]) + check("%", [5.0,6.0], [2.0,3.0]) + check("**", [5.0,6.0], [2.0,3.0]) + + def test_r(self): + def check(op, a, b): + a = np.array(a) + b = np.array(b) + if jt.flags.use_cuda and op == "@": + return + jb = jt.array(b) + jc = eval(f"a {op} jb").data + + + if op == "@": + # numpy do not support @= + a = np.array(a) @ np.array(b) + else: + a = eval(f"a {op} b") + a = np.array(a) + + all_eq(jc, a) + check("+", 5, 2) + check("-", 5, 2) + check("*", 5, 2) + check("/", 5, 2) + check("//", 5, 2) + # check("@", [[5]], [[2]]) + check("%", 5, 2) + check("**", 5, 2) + check("<<", 5, 2) + check(">>", 5, 2) + check("&", 5, 2) + check("^", 5, 2) + check("|", 5, 2) + + def test_grad(self): + ops = ["+", "-", "*", "/", "**"] + np.random.seed(3) + a = np.random.rand(10) + b = np.random.rand(10) + c = np.random.rand(10) + tol = 1e-2 if jt.flags.amp_reg & 2 else 1e-4 + for op in ops: + func = lambda x: eval(f"((x[0]{op}x[1])*x[2]).sum()") + x, grads = ngrad(func, [a,b,c], 1e-8) + ja = jt.array(a).name("ja") + jb = jt.array(b).name("jb") + jc = jt.array(c).name("jc") + jx = eval(f"(ja{op}jb)*jc") + jgrads = jt.grad(jx, [ja,jb,jc]) + for jd, nd in zip(jgrads, grads): + np.testing.assert_allclose(jd.data, nd, atol=tol, rtol=tol) + + def test_mod_float(self): + a = jt.random((10,)) + b = jt.random((10,)) + c = a % b + assert np.allclose(c.data, a.data % b.data) + a = jt.random((10,), 'float64') + b = jt.random((10,), 'float64') + c = a % b + assert np.allclose(c.data, a.data % b.data, a.data, b.data) + if jt.flags.amp_reg & 2: return + a = jt.random((10,)) * 1000 + b = (jt.random((10,)) * 10).int() + 1 + c = a % b + assert np.allclose(c.data, a.data % b.data), (c.data, a.data%b.data) + + def test_mod_grad(self): + a = jt.random((100,)) + b = jt.random((100,)) + c = a % b + da, db = jt.grad(c, [a, b]) + np.testing.assert_allclose(da.data, 1) + np.testing.assert_allclose(db.data, -np.floor(a.data/b.data)) + + def test_mod_negtive(self): + a = jt.random((100,)) - 0.5 + b = jt.random((100,)) - 0.5 + c = a % b + nc = a.data % b.data + np.testing.assert_allclose(c.data, nc.data, atol=1e-5, rtol=1e-5) + + def test_pow(self): + # win cuda 10.2 cannot pass + a = jt.random((100,)) + b = a**3 + b.sync() + + def test_binary_op_bool(self): + a = np.array([0,1,0,1]).astype(bool) + b = np.array([0,1,1,0]).astype(bool) + c = np.array([1,1,1,1]).astype(bool) + check("add", a, b) + all_eq(np.logical_xor(a, b), jt.subtract(a, b).data) + check("multiply", a, b) + check("logical_and", a, b) + check("logical_or", a, b) + check("logical_xor", a, b) + check("bitwise_and", a, b) + check("bitwise_or", a, b) + check("bitwise_xor", a, b) + check("divide", a, c) + check("floor_divide", a, c) + check("mod", a, c) + + +class TestBinaryOpCuda(TestBinaryOp, test_cuda(2)): + pass + +class TestBinaryOpCpuFp16(TestBinaryOp): + def setUp(self): + jt.flags.amp_reg = 2 | 4 | 8 | 16 + def tearDown(self): + jt.flags.amp_reg = 0 + +@unittest.skipIf(not jt.has_cuda, "no cuda found") +class TestBinaryOpCudaFp16(TestBinaryOp): + def setUp(self): + jt.flags.amp_reg = 2 | 4 | 8 | 16 + jt.flags.use_cuda = 1 + def tearDown(self): + jt.flags.amp_reg = 0 + jt.flags.use_cuda = 0 + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_bmm.py b/python/jittor/test/test_bmm.py new file mode 100644 index 00000000..40075868 --- /dev/null +++ b/python/jittor/test/test_bmm.py @@ -0,0 +1,45 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Meng-Hao Guo +# Dun Liang . +# +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +from jittor import nn +import unittest +import numpy as np + +class TestBMM(unittest.TestCase): + @unittest.skipIf(not jt.has_cuda, "No cuda found") + def test_bmm_cuda(self): + def check(batch, n, m, k): + def calc(use_cuda, a, b, mask): + jt.flags.use_cuda = use_cuda + a = jt.array(a) + b = jt.array(b) + mask = jt.array(mask) + c = nn.bmm(a, b) + da, db = jt.grad(c*mask, [a, b]) + return c.data, da.data, db.data + mask = np.random.rand(batch, n, k).astype("float32") + a = np.random.rand(batch, n, m).astype("float32") + b = np.random.rand(batch, m, k).astype("float32") + a1,a2,a3 = calc(0, a, b, mask) + b1,b2,b3 = calc(1, a, b, mask) + assert np.allclose(a1, b1) + assert np.allclose(a2, b2) + assert np.allclose(a3, b3) + check(10,3,4,5) + check(10,8,8,8) + check(10,8,1,8) + check(10,8,8,1) + check(10,1,8,8) + check(1,7,8,8) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_broadcast_to_op.py b/python/jittor/test/test_broadcast_to_op.py new file mode 100644 index 00000000..ce1b8e07 --- /dev/null +++ b/python/jittor/test/test_broadcast_to_op.py @@ -0,0 +1,144 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from .test_core import expect_error +from .test_cuda import test_cuda +import contextlib + +def gen_data(shape): + num = np.multiply.reduce(shape) + a = np.arange(0, num) + return a.reshape(shape) + +class TestBroadcastToOp(unittest.TestCase): + def setUp(self): + self.use_shape = False + + def test1(self): + def check(shape1, shape2): + a = gen_data(shape1) + b = gen_data(shape2) + aa,bb = np.broadcast_arrays(a, b) + if self.use_shape: + ja = jt.ops.broadcast(a, shape2).data + else: + ja = jt.ops.broadcast_var(a, b).data + assert ja.shape == aa.shape and (ja==aa).all(), f"{ja}, {aa}" + check([1], [3]) + check([3,1], [3]) + check([3,1,3], [1,3,1]) + check([2,3,4], [2,3,4,1,1,1]) + check([2,3], [2,3,1,1]) + check([2,1,3,1,4], [1,3,4]) + + expect_error(lambda: jt.ops.broadcast_var([1,2],[1,2,3])) + + def test_binary_op(self): + if self.use_shape: return + def check(shape1, shape2): + a = gen_data(shape1) + b = gen_data(shape2) + x = y = None + try: + x = a+b + except Exception as e: + pass + try: + y = jt.ops.add(a, b).data + except Exception as e: + pass + assert (x==y).all(), f"{x}\n{y}" + check([1], [3]) + check([3,1], [3]) + check([3,1,3], [1,3,1]) + check([2,3,4], [2,3,4,1,1,1]) + check([2,3], [2,3,1,1]) + check([2,1,3,1,4], [1,3,4]) + +class TestBroadcastToOpForward(unittest.TestCase): + def test_forward(self): + @contextlib.contextmanager + def check(bop_num): + jt.clean() + yield + graph = jt.dump_all_graphs() + bop = [ node for node in graph.nodes_info + if node.startswith("Op") and "broadcast_to" in node] + assert len(bop)==bop_num, (len(bop), bop_num) + + with check(1): + a = jt.array([1,2,3]) + b = a+1 + assert (b.data==[2,3,4]).all() + del a, b + + with check(0): + a = jt.array([1,2,3]) + b = a+a + assert (b.data==[2,4,6]).all() + del a, b + + def test_shape(shape1, shape2, bop_num): + with check(bop_num): + a = jt.random(shape1) + b = jt.random(shape2) + c = a+b + test_shape([3,3,3], [3,3,3], 0) + test_shape([3,3,3], [3,3,1], 1) + test_shape([3,3,3], [3,1,1], 1) + test_shape([3,3,3], [1,1,1], 1) + test_shape([3,3,3], [1,1,3], 1) + test_shape([3,3,3], [1,3,3], 1) + test_shape([3,3,1], [1,3,3], 2) + test_shape([3,1,3], [1,3,3], 2) + test_shape([3,3], [1,3,3], 1) + test_shape([3,3], [1,3,1], 2) + + +class TestBroadcastToOp2(TestBroadcastToOp): + def setUp(self): + self.use_shape = True + +@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") +class TestBroadcastToOpCuda(TestBroadcastToOp): + def setUp(self): + jt.flags.use_cuda = 2 + self.use_shape = False + def tearDown(self): + jt.flags.use_cuda = 0 + +@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") +class TestBroadcastToOp2Cuda(TestBroadcastToOp): + def setUp(self): + jt.flags.use_cuda = 2 + self.use_shape = True + def tearDown(self): + jt.flags.use_cuda = 0 + + +class TestBroadcastToOpMisc(unittest.TestCase): + def test_negtive_dim(self): + a = jt.array([1,2]) + assert (a.broadcast([2,2], [-1]).data == [[1,1],[2,2]]).all() + assert (a.broadcast([2,2], [-2]).data == [[1,2],[1,2]]).all() + + def test_negtive_dim2(self): + a = jt.array([1,2]) + b = jt.zeros((2,2)) + assert (a.broadcast(b, [-1]).data == [[1,1],[2,2]]).all() + assert (a.broadcast(b, [-2]).data == [[1,2],[1,2]]).all() + + def test_zero_dim(self): + a = jt.array(1.0) + b = a.broadcast([0]) + assert b.shape == [0] + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_broadcast_tuner.py b/python/jittor/test/test_broadcast_tuner.py new file mode 100644 index 00000000..693f571a --- /dev/null +++ b/python/jittor/test/test_broadcast_tuner.py @@ -0,0 +1,49 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import sys +import os +import jittor as jt +import unittest +import time +import numpy as np +from .test_reorder_tuner import simple_parser +from .test_log import find_log_with_re + +class TestBroadcastTuner(unittest.TestCase): + @classmethod + def setUpClass(self): + return + + def check(self, h, w, cs, rs, pa, rtp, dim): + a = jt.random([h,w]) + a.data + + + with jt.log_capture_scope( + log_v=0, log_vprefix="tuner_manager=100", + # this value is used for force compile + compile_options={"test_broadcast_tuner":1} + ) as logs: + amean=jt.mean(a, dims=[dim], keepdims=1) + a2mean=jt.mean(a*a, dims=[dim], keepdims=1) + norm_aa=(a-amean.broadcast_var(a))/(jt.sqrt(a2mean-amean*amean).broadcast_var(a)) + norm_aa.data + logs = find_log_with_re(logs, + "Run tuner broadcast: confidence\\((20)\\) candidates\\((.*)\\)$") + assert len(logs) == 1, logs + assert logs[0][0] == "20", "confidence of reorder should be 20" + candidates = simple_parser(logs[0][1]) + assert candidates == {"order0": [0,], "order1": [1,], "order2": [0,], "split1": [2048,], "use_movnt": [1,],}, candidates + + def test_broadcast_tuner(self): + self.check(8192,8192, 0, 0, 0, 5, 0) + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_cache.py b/python/jittor/test/test_cache.py new file mode 100644 index 00000000..b41def91 --- /dev/null +++ b/python/jittor/test/test_cache.py @@ -0,0 +1,97 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import os +os.environ['OPENBLAS_NUM_THREADS'] = '1' + +import unittest +import time +import jittor as jt +from jittor import LOG +import math +import numpy as np +from .test_core import expect_error +from .test_fused_op import retry + +def check_cache_code(fname): + check_code = True + error_line_num = -1 + with open(fname) as f: + lines = f.readlines() + for i in range(len(lines)): + if ('memory_checker->check_hit(' in lines[i]): + continue + code = lines[i] + address_pos = [] + for j in range(len(code)): + if code[j] == '[': + address_pos.append(j) + if code[j] == ']': + sp = address_pos[-1] - 1 + address_pos = address_pos[:-1] + if sp>=4 and code[sp-4:sp+1]=="shape": + continue + while ((sp >= 0) and ((code[sp] >= 'A' and code[sp] <= 'Z') or (code[sp] >= 'a' and code[sp] <= 'z') or + (code[sp] >= '0' and code[sp] <= '9') or code[sp] == '_' or code[sp] == '.' or (sp > 0 and code[sp] == '>' and code[sp - 1] == '-'))): + if (sp > 0 and code[sp] == '>' and code[sp - 1] == '-'): + sp -= 2 + else: + sp -= 1 + sp += 1 + check_var = code[sp:j + 1] + temp_i = i - 1 + have_check = False + while (temp_i >= 0 and 'memory_checker->check_hit(' in lines[temp_i]): + if check_var in lines[temp_i]: + have_check = True + break + temp_i -= 1 + if (not have_check): + check_code = False + error_line_num = i + break + if (not check_code): + break + assert check_code, "check cache not found in line " + str(error_line_num) + " of file " + fname + +class TestCache(unittest.TestCase): + def test_reduce(self): + @retry(10) + def check(n, m, reduce_dim, cache_report_, error_rate_threshold): + a = jt.random([n,m]) + a.sync() + with jt.profile_scope(compile_options = { + "check_cache": 1, "replace_strategy": 1, "page_size": 4 << 10, #2 << 20 + "vtop": 0, + "tlb_size": 64, "tlb_ways": 4, "tlb_line_size": 1, + "L1_size": 32 << 10, "L1_ways": 8, "L1_line_size": 64, + "L2_size": 256 << 10, "L2_ways": 8, "L2_line_size": 64, + "L3_size": 15 << 20, "L3_ways": 20, "L3_line_size": 64 + }, enable_tuner=0) as report: + c = a.sum(reduce_dim) + c.sync() + + check_cache_code(report[1][1]) + cache_report = report[-1][-5:] + for i in range(len(cache_report)): + cache_report[i] = int(cache_report[i]) + for i in range(len(cache_report)): + assert abs(cache_report[i] - cache_report_[i]) <= int(cache_report_[i] * error_rate_threshold), "cache report error: " + report[-2][-(len(cache_report) - i)] + " error, " + str(cache_report[i]) + "!=" + str(cache_report_[i]) + error_threshold = 0.02 + check(100, 10000, 0, [3010004, 989, 125729, 63129, 63129], error_threshold) + check(100, 10000, 1, [3000104, 981, 62510, 62510, 62510], error_threshold) + check(10, 98765, 0, [3061719, 2034, 129645, 129645, 67905], error_threshold) + check(10, 98765, 1, [2962964, 969, 61733, 61733, 61733], error_threshold) + check(7779, 97, 0, [2263790, 740, 47170, 47170, 47170], error_threshold) + check(7779, 97, 1, [2271472, 748, 47650, 47650, 47650], error_threshold) + check(1024, 1024, 0, [3146756, 1029, 65603, 65603, 65603], error_threshold) + check(1024, 1024, 1, [3146756, 1028, 65603, 65603, 65603], error_threshold) + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_candidate.py b/python/jittor/test/test_candidate.py new file mode 100644 index 00000000..4021016a --- /dev/null +++ b/python/jittor/test/test_candidate.py @@ -0,0 +1,69 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np + +def check_candidate(x, fail_func): + ans = [] + n = x.shape[0] + for i in range(n): + ok = True + for j in range(len(ans)): + if (fail_func(x[ans[j], :], x[i, :])): + ok = False + break + if (ok): + ans.append(i) + return np.array(ans, dtype=int) + +def check(shape, fail_cond, fail_func): + a = jt.random(shape) + selected = jt.candidate(a, fail_cond) + a_ = a.data + selected_out = selected.data + selected_ans = check_candidate(a_, fail_func) + assert selected_out.tolist() == selected_ans.tolist(), (selected_out, selected_ans) + +def check1(selected, comer): + return selected[0]>comer[0] or selected[1]>comer[1] or selected[2]>comer[2] + +def check2(selected, comer): + return selected[0]>comer[0] and selected[1]>comer[1] and selected[2]>comer[2] + +def check3(selected, comer): + threshold = 0.01 + s_1 = selected[2]*selected[3] + s_2 = comer[2]*comer[3] + s_inter_h = max(0,min(selected[2]+selected[0],comer[2]+comer[0])-max(selected[0],comer[0])) + s_inter_w = max(0,min(selected[3]+selected[1],comer[3]+comer[1])-max(selected[1],comer[1])) + s_inter = s_inter_h*s_inter_w + iou = s_inter / (s_1 + s_2 - s_inter) + return iou < threshold + +class TestCandidateOp(unittest.TestCase): + def test(self): + # increse sequence + check([100000,3], '(@x(j,0)>@x(i,0))or(@x(j,1)>@x(i,1))or(@x(j,2)>@x(i,2))', check1) + # no all increse sequence + check([100000,3], '(@x(j,0)>@x(i,0))and(@x(j,1)>@x(i,1))and(@x(j,2)>@x(i,2))', check2) + # nms + # [x0, y0, h, w] + threshold = '0.01' + s_1 = '@x(j,2)*@x(j,3)' + s_2 = '@x(i,2)*@x(i,3)' + s_inter_h = 'std::max((Tx)0,std::min(@x(j,2)+@x(j,0),@x(i,2)+@x(i,0))-std::max(@x(j,0),@x(i,0)))' + s_inter_w = 'std::max((Tx)0,std::min(@x(j,3)+@x(j,1),@x(i,3)+@x(i,1))-std::max(@x(j,1),@x(i,1)))' + s_inter = s_inter_h+'*'+s_inter_w + iou = s_inter + '/(' + s_1 +'+' + s_2 + '-' + s_inter + ')' + check([3000,4], iou+'<'+threshold, check3) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_clone.py b/python/jittor/test/test_clone.py new file mode 100644 index 00000000..557e85a5 --- /dev/null +++ b/python/jittor/test/test_clone.py @@ -0,0 +1,38 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np + +class TestClone(unittest.TestCase): + def test_mid_stop_grad(self): + jt.clean() + b = a = jt.array(1.0) + for i in range(10): + b = b.clone() + if i==5: c=b + b.sync() + assert jt.number_of_lived_vars()==11 + c.name("c") + c.stop_grad() + for n in jt.dump_all_graphs().nodes_info: + print(n) + assert jt.number_of_lived_vars()==3, jt.number_of_lived_vars() + + def test2(self): + a = jt.array([1,2]) + print(a.detach()) + + @jt.flag_scope(lazy_execution=0) + def test3(self): + a = jt.array([1,2]) + print(a.detach()) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_code_op.py b/python/jittor/test/test_code_op.py new file mode 100644 index 00000000..5420400c --- /dev/null +++ b/python/jittor/test/test_code_op.py @@ -0,0 +1,386 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from jittor import Function + +class TestCodeOp(unittest.TestCase): + def test(self): + a = jt.random([10]) + b = jt.code(a.shape, a.dtype, [a], + cpu_src=''' + for (int i=0; i + @alias(a, in0) + @alias(b, out) + """, + cpu_src=""" + for (int i=0; i + using namespace std; + """, + cpu_src=""" + @alias(a, in0) + @alias(b, out0) + @alias(c, out1) + @b(0) = @c(0) = @a(0); + for (int i=0; i0) + @b(num_b++) = @a(i); + else + @c(num_c++) = @a(i); + } + b->set_shape({num_b}); + c->set_shape({num_c}); + """ + ) + assert (b.data == [5,3,1]).all() + assert (c.data == [-4,-2]).all() + + def test_comment(self): + a = jt.array([3,2,1]) + b = jt.code(a.shape, a.dtype, [a], + cpu_header=''' + #include + // asd + /* asd + */ + ''', + cpu_src=""" + // test comment + /* + multi line + */ + @alias(a, in0) + for (int i=0; i>>(@ARGS); + ''', + cuda_grad_src = [''' + __global__ static void kernel2(@ARGS_DEF) { + @PRECALC + int i = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (; i>>(@ARGS); + ''', ''' + __global__ static void kernel3(@ARGS_DEF) { + @PRECALC + int i = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (; i>>(@ARGS); + ''']) + da, db = jt.grad(c, [a, b]) + assert np.allclose(c.data, a.data*b.data), (c.data, a.data*b.data) + assert np.allclose(da.data, b.data) + assert np.allclose(db.data, a.data) + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda=1) + def test_cuda2(self): + a = jt.random((100,100)) + b = jt.random((100,100)) + c = jt.code(a.shape, a.dtype, [a,b], + cuda_src=''' + __global__ static void kernel1(@ARGS_DEF) { + @PRECALC + for (int i=blockIdx.x; i>>(@ARGS); + ''', + cuda_grad_src = [''' + __global__ static void kernel(@ARGS_DEF) { + @PRECALC + for (int i=blockIdx.x; i>>(@ARGS); + ''', ''' + __global__ static void kernel(@ARGS_DEF) { + @PRECALC + @pout(0,0); + for (int i=blockIdx.x; i>>(@ARGS); + ''']) + da, db = jt.grad(c, [a, b]) + assert np.allclose(c.data, a.data*b.data), (c.data, a.data*b.data) + assert np.allclose(da.data, b.data) + assert np.allclose(db.data, a.data) + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda=1) + def test_cuda2_use_func(self): + class Func(Function): + def execute(self, a, b): + self.save_vars = a, b + return jt.code(a.shape, a.dtype, [a,b], + cuda_src=''' + __global__ static void kernel1(@ARGS_DEF) { + @PRECALC + for (int i=blockIdx.x; i>>(@ARGS); + ''') + + def grad(self, grad): + a, b = self.save_vars + return jt.code([a.shape, b.shape], [a.dtype, b.dtype], [a, b, grad], + cuda_src=''' + __global__ static void kernel2(@ARGS_DEF) { + @PRECALC + for (int i=blockIdx.x; i>>(@ARGS); + ''') + + a = jt.random((100,100)) + b = jt.random((100,100)) + + func = Func() + c = func(a,b) + da, db = jt.grad(c, [a, b]) + assert np.allclose(c.data, a.data*b.data), (c.data, a.data*b.data) + assert np.allclose(da.data, b.data) + assert np.allclose(db.data, a.data) + + def test_simple_var(self): + a = jt.code([1], "float32", inputs=[], + data = {"x":123}, + cpu_src=''' + @out0(0) = data["x"]; + ''').sync() + assert a.item() == 123 + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_compile_options.py b/python/jittor/test/test_compile_options.py new file mode 100644 index 00000000..17aba0bd --- /dev/null +++ b/python/jittor/test/test_compile_options.py @@ -0,0 +1,31 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import os +from .test_log import find_log_with_re +from .test_fused_op import retry + +class TestCompileOptions(unittest.TestCase): + def test(self): + a = jt.array([1,2,3]) + a.sync() + assert a.compile_options=={} + a.compile_options = {"compile_shapes":1} + assert a.compile_options=={"compile_shapes":1} + b = a+a + assert b.compile_options=={} + with jt.flag_scope(compile_options={"compile_shapes":1}): + c = a+b + assert c.compile_options=={"compile_shapes":1} + with jt.profile_scope() as report: + c.sync() + assert len(report)==2 and "compile_shapes:1" in report[1][0] + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_complex.py b/python/jittor/test/test_complex.py new file mode 100644 index 00000000..19686009 --- /dev/null +++ b/python/jittor/test/test_complex.py @@ -0,0 +1,200 @@ +import jittor as jt +from jittor.nn import ComplexNumber +import unittest +import numpy as np + +__skip_torch_test = False +try: + import torch +except: + __skip_torch_test = True + +class TestResultAndGrad: + def check_results(self, rlist1, rlist2): + assert len(rlist1) == len(rlist2) + for r1, r2 in zip(rlist1, rlist2): + assert r1.shape == r2.shape + assert np.allclose(r1, r2, rtol=1e-3, atol=1e-3) + + def grad_jittor(self, inputs, losses): + grads = [] + for i in inputs: + for loss in losses: + if isinstance(i, ComplexNumber): + g = jt.grad(loss, i.value, retain_graph=True) + grads.append(g[..., 0].numpy() + 1j * g[..., 1].numpy()) + else: + g = jt.grad(loss, i, retain_graph=True) + grads.append(g.numpy()) + return grads + + def grad_torch(self, inputs, losses): + grads = [] + for i in inputs: + for loss in losses: + g = torch.autograd.grad(loss, i, retain_graph=True)[0] + grads.append(g.detach().cpu().numpy()) + return grads + + def run_jittor_op(self, op, input_list, weights=None): + def _np_to_jittor(x:np.ndarray): + if x.dtype == np.complex64 or x.dtype == np.complex128: + nx = np.stack([np.real(x), np.imag(x)], axis=-1) + return ComplexNumber(jt.array(nx, dtype=jt.float32), is_concat_value=True) + elif x.dtype == np.float32 or x.dtype == np.float64: + return jt.array(x, dtype=jt.float32) + else: + assert False + def _jittor_to_np(x): + if isinstance(x, jt.Var): + return x.numpy() + elif isinstance(x, ComplexNumber): + return x.real.numpy() + 1j * x.imag.numpy() + assert False + + ninput_list = [_np_to_jittor(x) for x in input_list] + output_list = op(*ninput_list) + if isinstance(output_list, (jt.Var, ComplexNumber)): + output_list = [output_list] + losses = [] + if weights is None: + weights = [] + for o in output_list: + no = o.value if isinstance(o, ComplexNumber) else o + w = np.random.randn(*no.shape) + weights.append(w) + losses.append(jt.sum(no * jt.array(w))) + else: + assert len(output_list) == len(weights) + for o, w in zip(output_list, weights): + no = o.value if isinstance(o, ComplexNumber) else o + assert w.shape == no.shape + losses.append(jt.sum(no * jt.array(w))) + output_list = [_jittor_to_np(x) for x in output_list] + return ninput_list, output_list, losses, weights + + def run_torch_op(self, op, input_list, weights=None): + def _np_to_torch(x:np.ndarray): + return torch.from_numpy(x).requires_grad_(True) + def _torch_to_np(x:torch.Tensor) -> np.ndarray: + return x.detach().cpu().numpy() + ninput_list = [_np_to_torch(x) for x in input_list] + output_list = op(*ninput_list) + if isinstance(output_list, torch.Tensor): + output_list = [output_list] + losses = [] + if weights is None: + weights = [] + for o in output_list: + no = torch.stack([torch.real(o), torch.imag(o)], dim=-1) if o.is_complex() else o + w = np.random.randn(*no.shape) + weights.append(w) + losses.append(torch.sum(no * torch.from_numpy(w))) + else: + assert len(output_list) == len(weights) + for o, w in zip(output_list, weights): + no = torch.stack([torch.real(o), torch.imag(o)], dim=-1) if o.is_complex() else o + assert w.shape == no.shape + losses.append(torch.sum(no * torch.from_numpy(w))) + output_list = [_torch_to_np(x) for x in output_list] + return ninput_list, output_list, losses, weights + + def check_op_with_torch(self, jittor_op, torch_op, input_list, check_grad=True): + weights = None + jittor_input, jittor_output, jittor_losses, weights = self.run_jittor_op(jittor_op, input_list, weights) + torch_input, torch_output, torch_losses, weights = self.run_torch_op(torch_op, input_list, weights) + self.check_results(jittor_output, torch_output) + + if check_grad: + jittor_grads = self.grad_jittor(jittor_input, jittor_losses) + torch_grads = self.grad_torch(torch_input, torch_losses) + self.check_results(jittor_grads, torch_grads) + + def check_op_with_numpy(self, jittor_op, numpy_op, input_list): + _, jittor_output, _, _ = self.run_jittor_op(jittor_op, input_list, None) + numpy_output = numpy_op(*input_list) + if isinstance(numpy_output, np.ndarray): + numpy_output = [numpy_output] + + self.check_results(jittor_output, numpy_output) + +@unittest.skipIf(__skip_torch_test, "No Torch found") +class TestComplexLinalg(unittest.TestCase, TestResultAndGrad): + def random_complex_matrix(self, shape): + r = np.random.randn(*shape) + i = np.random.randn(*shape) + return r + 1j * i + + def test_complex_matmul(self): + s1 = (50, 200) + s2 = (200, 50) + m1 = self.random_complex_matrix(s1) + m2 = self.random_complex_matrix(s2) + + inputs = [m1, m2] + self.check_op_with_torch(jt.matmul, torch.matmul, inputs) + + def test_complex_matmul_batch(self): + s1 = (10, 50, 30) + s2 = (10, 30, 40) + m1 = self.random_complex_matrix(s1) + m2 = self.random_complex_matrix(s2) + + inputs = [m1, m2] + self.check_op_with_torch(jt.matmul, torch.matmul, inputs) + + def test_complex_inv(self): + s1 = (200, 200) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_torch(jt.linalg.inv, torch.linalg.inv, inputs) + + def test_complex_inv_batch(self): + s1 = (10, 50, 50) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_torch(jt.linalg.inv, torch.linalg.inv, inputs) + + def test_complex_eig(self): + # Unstable + s1 = (20, 20) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_numpy(jt.linalg.eig, np.linalg.eig, inputs) + + def test_complex_eig_batch(self): + # Unstable + s1 = (5, 10, 10) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_numpy(jt.linalg.eig, np.linalg.eig, inputs) + + def test_complex_qr(self): + s1 = (50, 50) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_torch(jt.linalg.qr, torch.linalg.qr, inputs) + + def test_complex_qr_batch(self): + s1 = (10, 20, 20) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_torch(jt.linalg.qr, torch.linalg.qr, inputs) + + def test_complex_svd(self): + # Unstable + s1 = (50, 50) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_numpy(jt.linalg.svd, np.linalg.svd, inputs) + + def test_complex_svd_batch(self): + # Unstable + s1 = (10, 20, 20) + m1 = self.random_complex_matrix(s1) + inputs = [m1] + self.check_op_with_numpy(jt.linalg.svd, np.linalg.svd, inputs) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_concat_op.py b/python/jittor/test/test_concat_op.py new file mode 100644 index 00000000..4386e481 --- /dev/null +++ b/python/jittor/test/test_concat_op.py @@ -0,0 +1,173 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np + +def concat2(arr, dim): + '''Concat Operator can concat a list of jt Var at a specfic dimension. + + * [in] x: input var list for concat + + * [in] dim: concat which dim + + * [out] out: concat result + +Example:: + + jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1) + # return [[1],[2],[2],[2]] + ''' + # TODO: low performance when concat lots of vars + total_dim = 0 + if dim < 0: dim += len(arr[0].shape) + for a in arr: + total_dim += a.shape[dim] + cdim = 0 + shape = list(a.shape) + shape[dim] = total_dim + s = jt.empty(shape, a.dtype) + slices = [slice(None)]*len(a.shape) + for a in arr: + slices[dim] = slice(cdim, cdim+a.shape[dim]) + # print(slices, type(a)) + s = s.setitem(tuple(slices), a) + # s = jt.setitem(s, tuple(slices), a) + cdim += a.shape[dim] + return s + +def numpy_concat(arr, dim): + arr = [ a.numpy() for a in arr ] + return np.concatenate(arr, dim) + +class TestConcatOp(unittest.TestCase): + def test_concat_op(self): + def check(tmp, dim=0): + res1 = numpy_concat(tmp, dim=dim) + res2 = jt.concat(tmp, dim=dim) + assert (res2!=res1).data.sum()==0, "concat fail..." + check([jt.array([[1],[2]]), jt.array([[2],[2]])]) + check([jt.array(np.array(range(24))).reshape((1,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))]) + check([jt.array(np.array(range(120))).reshape((5,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))]) + check([jt.array(np.array(range(5))).reshape((5,1)), jt.array(np.array(range(1))).reshape((1,1))]) + print('concat success...') + + + @unittest.skipIf(not jt.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda = 1) + def test_concat_perf(self): + def check(dim, size, backward=False): + n = 64 + a = jt.random((n,n,n,n)) + a.sync() + m = n // size + arr = [] + for i in range(m): + arr.append(a[(slice(None),)*dim + (slice(i*size,i*size+size),)]) + b = jt.concat(arr, dim) + if backward: + loss = b * a + b = jt.grad(loss, a) + with jt.profile_scope(1, 0) as rep: + b.sync() + # print(rep) + i = rep[0].index("TotalTime") + stime = 0 + for r in rep[1:]: + stime += float(r[i]) + bw = 4*64**4*2*2 / stime + # sizeof(float) * numel * (split and concat) * (read and write) + print(f"{dim} {size} {stime/1e6}ms, {bw}GB/s") + return bw + ndim = 4 + splits = [1, 2, 4, 8, 16, 32, 64] + m = len(splits) + result = np.zeros((4, m)) + result_back = np.zeros((4, m)) + for i in range(ndim): + for j in range(m): + result[i,j] = check(i, splits[j]) + result_back[i,j] = check(i, splits[j], True) + print(result.T) + print(result_back.T) + ''' +[[ 17.02802497 17.12933081 17.10814418 15.49217942] + [ 33.10922467 33.01865886 33.08940182 30.24637466] + [ 62.27219795 62.06702029 61.90039457 58.68727009] + [112.31933307 111.89659519 111.02357161 108.98520165] + [187.24806534 190.68837367 186.73965711 186.32242015] + [280.28594579 278.94498734 284.42015302 284.98722929] + [387.03887468 386.14916854 386.47551229 385.28621521]] + +[[ 5.04141217 4.55677858 4.55677363 3.79321142] + [ 9.05243799 8.99777599 8.96021333 7.49345194] + [ 17.45032635 17.36882645 17.14316909 14.98928307] + [ 35.60450372 35.55333375 35.32826879 32.00750909] + [ 61.72854251 62.285231 61.64460882 58.17541776] + [ 97.44981525 96.79104909 95.38118155 95.09154931] + [135.11495888 134.60444658 135.41807381 135.38139881]] + + ''' + + @unittest.skipIf(not jt.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda = 1) + def test_concat2_perf(self): + def check(dim, size, backward=False): + n = 64 + a = jt.random((n,n,n,n)) + a.sync() + m = n // size + arr = [] + for i in range(m): + arr.append(a.getitem((slice(None),)*dim + (slice(i*size,i*size+size),))) + b = concat2(arr, dim) + if backward: + loss = b * a + b = jt.grad(loss, a) + with jt.profile_scope(1, 0) as rep: + b.sync() + # print(rep) + i = rep[0].index("TotalTime") + stime = 0 + for r in rep[1:]: + stime += float(r[i]) + bw = 4*64**4*2*2 / stime + # sizeof(float) * numel * (split and concat) * (read and write) + print(f"{dim} {size} {stime/1e6}ms, {bw}GB/s") + return bw + ndim = 4 + splits = [1, 2, 4, 8, 16, 32, 64] + m = len(splits) + result = np.zeros((4, m)) + result_back = np.zeros((4, m)) + for i in range(ndim): + for j in range(m): + result[i,j] = check(i, splits[j]) + result_back[i,j] = check(i, splits[j], True) + print(result.T) + print(result_back.T) + ''' +[[ 15.59142118 15.8001291 15.77589713 11.79319714] + [ 31.33130734 31.2476813 31.20394782 23.19700034] + [ 57.90763098 57.71203221 58.02228419 45.60297828] + [104.20428796 104.08291412 104.18568373 91.648383 ] + [175.21896606 175.44422637 176.57915576 168.33344684] + [264.35929995 267.63202466 262.92687504 268.41854563] + [352.36998687 355.89200025 360.95753527 361.34916742]] +[[ 3.39802237 3.42782551 3.43126375 2.85884566] + [ 7.12993628 7.11445323 7.11482319 5.90134142] + [ 15.13540229 15.11031669 15.12954432 12.76302703] + [ 28.08930928 28.09445985 28.01005224 25.43536254] + [ 49.58246623 49.70843778 49.49253912 48.07459389] + [ 80.3745414 80.85044884 79.74203591 80.97114412] + [117.14450249 119.22320442 119.2380328 119.63622556]] + + ''' + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_console.py b/python/jittor/test/test_console.py new file mode 100644 index 00000000..8c6b1dfe --- /dev/null +++ b/python/jittor/test/test_console.py @@ -0,0 +1,22 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from jittor_utils import run_cmd +import sys + +class TestConsole(unittest.TestCase): + def test_console(self): + run_cmd(f"{sys.executable} -m jittor_utils.config --cxx-example > tmp.cc", jt.flags.cache_path) + s = run_cmd(f"{jt.flags.cc_path} tmp.cc $({sys.executable} -m jittor_utils.config --include-flags --libs-flags --cxx-flags) -o tmp.out && ./tmp.out", jt.flags.cache_path) + print(s) + assert "jt.Var" in s + assert "pred.shape 2 1000" in s + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_contrib.py b/python/jittor/test/test_contrib.py new file mode 100644 index 00000000..b19fc92a --- /dev/null +++ b/python/jittor/test/test_contrib.py @@ -0,0 +1,62 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from .test_core import expect_error + +class TestContrib(unittest.TestCase): + def test_concat(self): + def check(shape, dim, n): + num = np.prod(shape) + arr1 = [] + arr2 = [] + for i in range(n): + a = (np.array(range(num)) + i*num).reshape(shape) + arr1.append(a) + arr2.append(jt.array(a)) + x = np.concatenate(tuple(arr1), dim) + y = jt.concat(arr2, dim) + assert (x==y.data).all(), (x, y.data, arr1, arr2) + check([2,3,4], 0, 2) + check([2,3,4], 1, 3) + check([2,3,4], 2, 4) + check([2,3,4,5], 0, 4) + check([2,3,4,5], 2, 4) + check([2,3,4,5], 3, 4) + check([1], 0, 20) + + def test_slice(self): + def check(shape, slices): + x = jt.random(shape) + a = x[slices].data + b = x.data[slices] + assert (a==b).all(), (a, b) + y = x.numpy() + v = jt.random(a.shape) + x[slices] = v + y[slices] = v.data + assert (x.data==y).all() + # TODO: when slice same row/col many times and assign value, numpy will retain the last value but we assign their sum. eg: check([3,3,3,3], ([[0,1,1]],slice(None),[[1],[2],[0]],1)) + check([3], (1)) + check([3,3,3,3], ([[0],[1]],slice(None),[1,2],1)) + check([3,3,3,3], (slice(None),slice(None),slice(None),slice(None))) + check([3,3,3,3], ([0,1],[0,1],[0,1],[0,1])) + check([3,3,3,3], ([0,1],-2,slice(None),[0,1])) + check([3,3,3,3], ([0,1],slice(1,2,2),[1,2],1)) + check([3,3,3,3], ([0,1],slice(None),[1,2],1)) + check([10,10,10,10], (slice(1,None,2),slice(-1,None,2),[1,2],-4)) + check([20], 0) + check([20], 10) + check([20], -10) + check([10,10,10,10], 1) + check([10,10,10,10], (1,slice(None),2)) + check([10,10,10,10], (-2,slice(None),2,slice(1,9,2))) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_conv_transpose.py b/python/jittor/test/test_conv_transpose.py new file mode 100644 index 00000000..c33a6272 --- /dev/null +++ b/python/jittor/test/test_conv_transpose.py @@ -0,0 +1,112 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import os +import numpy as np + +from jittor.test.test_log import find_log_with_re +skip_this_test = False + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch +except: + skip_this_test = True + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestConvTranspose(unittest.TestCase): + + @unittest.skipIf(not jt.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda=1) + def test_cuda(self): + self.test() + + def test(self): + def check(data_shape, weights_shape, stride=1, dilation=1, groups=1): + N,C,H,W = data_shape + i,o,h,w = weights_shape + img = np.random.rand(N,C,H,W).astype("float32") + weights = np.random.rand(i,o//groups,h,w).astype("float32") + m1 = jt.nn.ConvTranspose(i,o,h, stride=stride, dilation=dilation, bias=False, groups=groups) + m2 = torch.nn.ConvTranspose2d(i,o,h, stride=stride, dilation=dilation, bias=False, groups=groups) + m1.weight.data = weights + m2.weight.data = torch.Tensor(weights) + x = jt.array(img) + out1 = m1(x) + mask = jt.random(out1.shape) + out1 = out1*mask + tx = torch.Tensor(img) + tx.requires_grad = True + out2 = m2(tx) * torch.Tensor(mask.data) + with jt.log_capture_scope(log_silent=1, + log_vprefix="var_re=0,conv=0,op.cc=100") as logs: + assert np.allclose(out1.data, out2.data) + dx, dw = jt.grad(out1, [x, m1.weight]) + jt.sync([dx, dw]) + out2.sum().backward() + assert np.allclose(dw.data, m2.weight.grad.numpy(), 1e-3) + assert np.allclose(dx.data, tx.grad.numpy()) + assert len(find_log_with_re(logs, "conv")) == 3 + check((4, 5, 10, 10), (5, 6, 3, 3)) + check((4, 5, 10, 10), (5, 6, 3, 3), 2) + check((4, 5, 100, 100), (5, 6, 4, 4), 2) + check((4, 5, 100, 100), (5, 6, 4, 4), 3) + check((4, 5, 100, 100), (5, 6, 5, 5), 1, 2) + check((4, 5, 100, 100), (5, 6, 5, 5), 2, 2) + check((4, 5, 100, 100), (5, 6, 5, 5), 2, 3) + check((4, 6, 100, 100), (6, 6, 5, 5), 2, 3, 2) + + def test_function(self): + def check(data_shape, weights_shape, stride=1, dilation=1): + N,C,H,W = data_shape + i,o,h,w = weights_shape + img = np.random.rand(N,C,H,W).astype("float32") + weights = np.random.rand(i,o,h,w).astype("float32") + m1 = jt.nn.ConvTranspose(i,o,h, stride=stride, dilation=dilation, bias=False) + m2 = torch.nn.ConvTranspose2d(i,o,h, stride=stride, dilation=dilation, bias=False) + m1.weight.data = weights + m2.weight.data = torch.Tensor(weights) + x = jt.array(img) + # out1 = m1(x) + out1 = jt.nn.conv_transpose2d(x, m1.weight, stride=stride, dilation=dilation, bias=False) + mask = jt.random(out1.shape) + out1 = out1*mask + tx = torch.Tensor(img) + tx.requires_grad = True + out2 = m2(tx) * torch.Tensor(mask.data) + with jt.log_capture_scope(log_silent=1, + log_vprefix="var_re=0,conv=0,op.cc=100") as logs: + assert np.allclose(out1.data, out2.data) + dx, dw = jt.grad(out1, [x, m1.weight]) + jt.sync([dx, dw]) + out2.sum().backward() + assert np.allclose(dw.data, m2.weight.grad.numpy(), 1e-3) + assert np.allclose(dx.data, tx.grad.numpy()) + assert len(find_log_with_re(logs, "conv")) == 3 + check((4, 5, 10, 10), (5, 6, 3, 3)) + check((4, 5, 10, 10), (5, 6, 3, 3), 2) + check((4, 5, 100, 100), (5, 6, 4, 4), 2) + check((4, 5, 100, 100), (5, 6, 4, 4), 3) + check((4, 5, 100, 100), (5, 6, 5, 5), 1, 2) + check((4, 5, 100, 100), (5, 6, 5, 5), 2, 2) + check((4, 5, 100, 100), (5, 6, 5, 5), 2, 3) + + def test_conv1d(self): + conv1d = jt.nn.Conv1d(10,20,5) + a = jt.rand((3,10,15)) + b = conv1d(a) + b.sync() + assert b.shape == [3,20,11] + b = jt.nn.Conv1d(10,20,5, padding=2)(a) + assert b.shape == [3,20,15] + assert sorted(list(conv1d.state_dict().keys())) == ['bias', 'weight'], conv1d.state_dict().keys() + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_conv_tuner.py b/python/jittor/test/test_conv_tuner.py new file mode 100644 index 00000000..643f2bac --- /dev/null +++ b/python/jittor/test/test_conv_tuner.py @@ -0,0 +1,172 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import os +import numpy as np +from jittor import compile_extern +# TODO: compare with pytorch + +from jittor.test.test_log import find_log_with_re +if jt.has_cuda: + from jittor.compile_extern import cublas_ops, cudnn_ops +else: + cublas_ops = cudnn_ops = None + +def conv_nchw(x, in_planes, out_planes, kernel_size, padding, stride = 1, dilation=1, init_method=None, w_ = None): + Kw = kernel_size + Kh = kernel_size + _C = in_planes + Kc = out_planes + N,C,H,W = x.shape + + assert C==_C + if w_ is None: + assert 0 + else: + w = w_ + oh = (H-Kh*dilation+dilation-1+padding*2)//stride+1 + ow = (W-Kw*dilation+dilation-1+padding*2)//stride+1 + xx = x.reindex([N,Kc,C,oh,ow,Kh,Kw], [ + 'i0', # Nid + 'i2', # Cid + f'i3*{stride}-{padding}+i5*{dilation}', # Hid+Khid + f'i4*{stride}-{padding}+i6*{dilation}', # Wid+KWid + ]) + ww = w.broadcast(xx.shape, [0,3,4]) + yy = xx*ww + y = yy.sum([2,5,6]) # C, Kh, Kw + return y + +def conv_nhwc(x, in_planes, out_planes, kernel_size, padding, stride = 1, dilation=1, init_method=None, w_ = None): + Kw = kernel_size + Kh = kernel_size + _C = in_planes + Kc = out_planes + N,H,W,C = x.shape + + assert C==_C + if w_ is None: + assert 0 + else: + w = w_ + oh = (H-Kh*dilation+dilation-1+padding*2)//stride+1 + ow = (W-Kw*dilation+dilation-1+padding*2)//stride+1 + xx = x.reindex([N,Kc,C,oh,ow,Kh,Kw], [ + 'i0', # Nid + f'i3*{stride}-{padding}+i5*{dilation}', # Hid+Khid + f'i4*{stride}-{padding}+i6*{dilation}', # Wid+KWid + 'i2', # Cid + ]) + ww = w.broadcast(xx.shape, [0,3,4]) + yy = xx*ww + y = yy.sum([2,5,6]) # C, Kh, Kw + return y + +def test_nhwc(x, w, stride, padding, dilation): + out_planes, in_planes, kernel_size, _ = w.shape + return conv_nhwc(x, in_planes, out_planes, kernel_size, padding, stride=stride, dilation=dilation, w_=w) + +def test_nchw(x, w, stride, padding, dilation): + out_planes, in_planes, kernel_size, _ = w.shape + return conv_nchw(x, in_planes, out_planes, kernel_size, padding, stride=stride, dilation=dilation, w_=w) + +def check_forward(xshape, wshape, stride, padding, dilation, use_cuda, nhwc): + if nhwc: + test_func = test_nhwc + else: + test_func = test_nchw + if use_cuda == 1: + op_name = "cudnn_conv" + else: + op_name = "mkl_conv" + with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1, + log_v=0, log_vprefix="op.cc=100,conv_tuner=1000", compile_options={"test":266} + ) as raw_log: + x = jt.random(xshape) + w = jt.random(wshape) + y = test_func(x, w, stride, padding, dilation) + y.sync() + with jt.flag_scope(use_cuda=0, enable_tuner=0, + compile_options={"test":255}): + cy = test_func(x, w, stride, padding, dilation) + cy.sync() + logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + op_name + ".*)") + assert len(logs)==1 and "oihw" in logs[0][0], logs + assert np.allclose(y.data, cy.data) + +def check_backward(xshape, wshape, stride, padding, dilation, use_cuda, nhwc): + if nhwc: + test_func = test_nhwc + else: + test_func = test_nchw + if use_cuda == 1: + op_name = "cudnn_conv" + else: + op_name = "mkl_conv" + + with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1, + log_v=1, log_vprefix="op.cc=1000,exe=1000,conv_t=1000", compile_options={"test":244} + ) as raw_log: + x = jt.random(xshape) + w = jt.random(wshape) + y = test_func(x, w, stride, padding, dilation) + loss = y.mean() + dx, dw = jt.grad(loss, [x, w]) + jt.sync([y, loss, dx, dw]) + with jt.flag_scope(use_cuda=0, enable_tuner=0, compile_options={"test":233}): + cy = test_func(x, w, stride, padding, dilation) + closs = cy.mean() + cdx, cdw = jt.grad(closs, [x, w]) + jt.sync([cy, closs, cdx, cdw]) + logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + op_name + ".*)") + assert len(logs)==3 and "oihw" in logs[0][0], (logs) + assert np.allclose(y.data, cy.data, 1e-3) + assert np.allclose(dw.data, cdw.data, 1e-3), (dw.data, cdw.data) + assert np.allclose(dx.data, cdx.data, 1e-3), (dx.data, cdx.data, np.abs(cdx.data).max(), np.abs(dx.data - cdx.data).max()) + +class TestConvTuner(unittest.TestCase): + def test_forward(self): + for dilation in [1,2,3]: + check_forward([10,100,100,3], [5,3,3,3], 2, 0, dilation, 0, True) + check_forward([10,40,50,4], [5,4,5,5], 1, 1, dilation, 0, True) + check_forward([10,40,50,4], [5,4,4,4], 3, 1, dilation, 0, True) + + check_forward([10,3,100,100], [5,3,3,3], 2, 0, dilation, 0, False) + check_forward([10,4,40,50], [5,4,5,5], 1, 1, dilation, 0, False) + check_forward([10,4,40,50], [5,4,4,4], 3, 1, dilation, 0, False) + + def test_backward(self): + for dilation in [1,2,3]: + check_backward([10,3,100,100], [5,3,3,3], 2, 0, dilation, 0, False) + check_backward([10,4,40,50], [5,4,5,5], 1, 1, dilation, 0, False) + check_backward([10,4,40,50], [5,4,4,4], 3, 1, dilation, 0, False) + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + def test_forward_cuda(self): + for dilation in [1,2,3]: + check_forward([10,100,100,3], [5,3,3,3], 2, 0, dilation, 1, True) + check_forward([10,40,50,4], [5,4,5,5], 1, 1, dilation, 1, True) + check_forward([10,40,50,4], [5,4,4,4], 3, 1, dilation, 1, True) + + check_forward([10,3,100,100], [5,3,3,3], 2, 0, dilation, 1, False) + check_forward([10,4,40,50], [5,4,5,5], 1, 1, dilation, 1, False) + check_forward([10,4,40,50], [5,4,4,4], 3, 1, dilation, 1, False) + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + def test_backward_cuda(self): + for dilation in [1,2,3]: + check_backward([10,3,100,100], [5,3,3,3], 2, 0, dilation, 1, False) + check_backward([10,4,40,50], [5,4,5,5], 1, 1, dilation, 1, False) + check_backward([10,4,40,50], [5,4,4,4], 3, 1, dilation, 1, False) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_core.py b/python/jittor/test/test_core.py new file mode 100644 index 00000000..30f9411a --- /dev/null +++ b/python/jittor/test/test_core.py @@ -0,0 +1,314 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import os + +def expect_error(func): + try: + func() + except Exception as e: + return + raise Exception("Expect an error, but nothing catched.") + +class TestCore(unittest.TestCase): + + def test_number_of_hold_vars(self): + assert jt.random([1,2,3]).peek() == "float32[1,2,3,]" + assert jt.core.number_of_hold_vars() == 0 + x = jt.random([1,2,3]) + assert jt.core.number_of_hold_vars() == 1 + del x + assert jt.core.number_of_hold_vars() == 0 + + def test_fetch_sync(self): + dtypes = ["float32", "float64"] + for dtype in dtypes: + x = jt.random([1,2,3], dtype) + res = x.data + assert res.dtype == dtype and res.shape == (1,2,3) + + def test_set_seed(self): + a = jt.random([1,2,3]).data + b = jt.random([1,2,3]).data + assert str(a) != str(b) + jt.set_seed(1) + a = jt.random([1,2,3]).data + jt.set_seed(1) + b = jt.random([1,2,3]).data + assert str(a) == str(b) + + def test_array_op(self): + data = [ + np.array([1,2,3]), + np.int32([1,2,3]), + np.int64([1,2,3]), + np.float32([1,2,3]), + np.float64([1,2,3]), + ] + for a in data: + assert sum(jt.array(a).data) == 6 + assert np.all(jt.array(np.int32([1,2,3])[::-1]).data == [3,2,1]) + assert jt.array(1).data.shape == (1,) + + def test_matmul_op(self): + a = np.array([[1, 0], [0, 1]]).astype("float32") + b = np.array([[4, 1], [2, 2]]).astype("float32") + c = np.matmul(a, b) + jtc = jt.matmul(jt.array(a), jt.array(b)).data + assert np.allclose(jtc, c) + + a = np.random.random((128,3,10,20)) + b = np.random.random((20,30)) + c = np.matmul(a, b) + jtc = jt.matmul(jt.array(a), jt.array(b)).data + assert np.allclose(jtc, c) + + a = np.random.random((128,3,10,20)) + b = np.random.random((128,3,20,30)) + c = np.matmul(a, b) + jtc = jt.matmul(jt.array(a), jt.array(b)).data + assert np.allclose(jtc, c), np.abs(jtc-c).max() + + def test_var_holder(self): + jt.clean() + self.assertEqual(jt.number_of_lived_vars(), 0) + expect_error(lambda: jt.matmul(1,1)) + expect_error(lambda: jt.matmul([1],[1])) + expect_error(lambda: jt.matmul([[1]],[1])) + self.assertEqual(jt.number_of_lived_vars(), 0) + a = jt.matmul(jt.float32([[3]]), jt.float32([[4]])).data + assert a.shape == (1,1) and a[0,0] == 12 + a = np.array([[1, 0], [0, 1]]).astype("float32") + b = np.array([[4, 1], [2, 2]]).astype("float32") + c = np.matmul(a, b) + jtc = jt.matmul(jt.array(a), jt.array(b)).data + assert np.all(jtc == c) + + def test_save_load_sub_module(self): + class Net(jt.Module): + def __init__(self): + self.conv1 = jt.nn.Conv(3,3,3) + net = Net() + assert list(net.state_dict().keys()) == ['conv1.weight', 'conv1.bias'] + assert list(net.conv1.state_dict().keys()) == ['weight', 'bias'] + pkl_name = os.path.join(jt.flags.cache_path, "sub.pkl") + net.conv1.save(pkl_name) + net.conv1.load(pkl_name) + + def test_module(self): + a = jt.Module() + a.__setattr__("x", 1) + assert a.__getattr__("x") == 1 + a.y = 2 + assert a.y == 2 + + def test_modules(self): + a = jt.Module() + a.x = jt.Module() + a.y = jt.Module() + a.a = jt.array([1,2,3]) + a.b = jt.array([1,2,3]) + assert list(a._modules.keys()) == ["x", "y"] + assert a._modules['x'] is a.x + assert a._modules['y'] is a.y + assert list(a._parameters.keys()) == ['a', 'b'] + assert a._parameters['a'] is a.a + assert a._parameters['b'] is a.b + + def test_copy_memopt(self): + # exe: post run + # remove pending done + # add hold pending done + # pending release mem done + a = jt.rand(10) + b = a.copy().copy().copy() + a.name("aa") + b.name("bb") + + cnt = 0 + graphs = jt.dump_all_graphs() + for x in graphs.nodes_info: + if "Var" not in x: continue + print(x) + if ",aa," in x: + assert ":2:i" in x, x + elif ",bb," in x: + assert ":1:i" in x + else: + assert ":1:i" in x + + b.sync() + cnt = 0 + graphs = jt.dump_all_graphs() + for x in graphs.nodes_info: + # print(x) + if "Var" in x and ",0)" in x: + cnt += 1 + assert cnt == 2 + + def test_fuse_memopt(self): + def check(): + a = jt.rand(10) + b = (a.copy().name("copy_out1") + 1).sqr() + a.copy().name("copy_out2") + b.sync() + for n in jt.dump_all_graphs().nodes_info: + if "Var" not in n: continue + # print(n) + + if "copy_out1" in n: + # copy out1 is not free + assert ",0)" not in n + if "copy_out2" in n: + # copy out2 is freeed + assert ",0)" in n + da = jt.grad(b, a) + da.sync() + check() + jt.gc() + assert jt.liveness_info()['lived_vars'] == 0 + + def test_out_hint1(self): + a = jt.rand(10) + b = jt.rand(10) + c = jt.ternary_out_hint((a0.0).name("b"+str(i)), a, 0.0) + a = jt.matmul(a.name("m1"),jt.rand(10,10).name("m2")).name("m3-"+str(i)) + da = jt.grad(a, x, True) + # jt.clean_graph() + da.sync() + cnt1 = 0 + cnt2 = 0 + for n in jt.dump_all_graphs().nodes_info: + if "Var" in n and ",0)" not in n: + cnt1 +=1 + if "bool" in n: + cnt2 += 1 + print(cnt1, cnt2) + assert cnt2 == 10 + assert cnt1 <= 33, cnt1 + + def test_node_order(self): + a = jt.nn.Sequential() + for i in range(10): + a.append(jt.nn.Linear(10,10, bias=False)) + sgd = jt.optim.SGD(a.parameters(), 0.1) + jt.sync_all() + with jt.log_capture_scope(log_silent=1, + log_vprefix="exe=100") as logs: + x = jt.rand(3,10) + y = a(x) + sgd.step(y*y) + jt.sync_all() + orders = [] + for l in logs: + msg = l["msg"] + if "Finished" in msg: + # print(msg) + if "weight" in msg: + assert msg.count("Var") >= 2 + order = int(msg.split('fused ')[1].split("/")[0]) + # print(order) + orders.append(order) + assert len(orders) == 10, orders + for i in range(10): + assert orders[i] <= 14+i*3 + + def test_bc_bug(self): + a = jt.zeros((1,1)) + b = a * 0.5 + b.sync() + da = jt.grad(b, a) + da.sync() + + def test_attr_dict(self): + a = jt.array([1,2,3]) + a.hahaha = 1 + assert a.hahaha == 1 + + def test_swap(self): + # TODO: add skip + return + jt.gc() + jt.display_memory_info() + with jt.flag_scope(cpu_mem_limit=50*1024*1024): + np_arrays = [] + jt_arrays = [] + for i in range(100): + jt_arrays.append(jt.randn(1*1024*1024//4)) + np_arrays.append(jt_arrays[-1].numpy()) + for j in range(i): + is_swapped = jt_arrays[j].location() == "disk" + assert is_swapped == (j<=i-50), (i,j,jt_arrays[j].debug_msg(), is_swapped, j<=i-50, jt.display_memory_info()) + for i in range(len(np_arrays)): + np.testing.assert_allclose(jt_arrays[i].numpy(), np_arrays[i]) + for i in range(len(np_arrays)): + np.testing.assert_allclose(jt_arrays[i].numpy(), np_arrays[i]) + + def test_swap_cuda(self): + # TODO: add skip + return + jt.gc() + jt.display_memory_info() + if jt.has_cuda: + np_arrays = [] + jt_arrays = [] + jt.gc() + jt.display_memory_info() + cpu_mem_limit=50*1024*1024 + device_mem_limit=20*1024*1024 + with jt.flag_scope(use_cuda=1, + cpu_mem_limit=cpu_mem_limit, + device_mem_limit=device_mem_limit): + for i in range(100): + jt_arrays.append(jt.random([1*1024*1024//4])) + np_arrays.append(jt_arrays[-1].numpy()) + jt_arrays[-1].sync().name(str(i)) + meminfo = jt.get_mem_info() + assert meminfo.total_cpu_used < cpu_mem_limit + 2*1024*1024 + assert meminfo.total_cuda_used < device_mem_limit + 2*1024*1024 + for i in range(len(np_arrays)): + np.testing.assert_allclose(jt_arrays[i].numpy(), np_arrays[i]) + meminfo = jt.get_mem_info() + assert meminfo.total_cpu_used < cpu_mem_limit + 2*1024*1024 + assert meminfo.total_cuda_used < device_mem_limit + 2*1024*1024 + for i in range(len(np_arrays)): + np.testing.assert_allclose(jt_arrays[i].numpy(), np_arrays[i]) + meminfo = jt.get_mem_info() + assert meminfo.total_cpu_used < cpu_mem_limit + 2*1024*1024 + assert meminfo.total_cuda_used < device_mem_limit + 2*1024*1024 + + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_cub_cumsum.py b/python/jittor/test/test_cub_cumsum.py new file mode 100644 index 00000000..482b1e45 --- /dev/null +++ b/python/jittor/test/test_cub_cumsum.py @@ -0,0 +1,106 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from jittor import compile_extern +if jt.has_cuda: + from jittor.compile_extern import cublas_ops, cudnn_ops, cub_ops +else: + cublas_ops = cudnn_ops = cub_ops = None + + +def test_forward(shape, dim=None): + x = jt.random(shape) + y = jt.numpy_cumsum(x, dim=dim) + y_ = jt.cub_cumsum(x, dim=dim) + assert(np.allclose(y.data, y_.data)) + +def test_backward(shape, dim=None): + x = jt.random(shape) + z = jt.random(shape) + + y = jt.numpy_cumsum(x, dim=dim) + loss = (y * z).sum() + grad = jt.grad(loss, x) + + y_ = jt.cub_cumsum(x, dim=dim) + loss_ = (y_ * z).sum() + grad_ = jt.grad(loss_, x) + assert(np.allclose(grad.data, grad_.data)) + +class TestCubCumsumOp(unittest.TestCase): + def setUp(self): + self.is_reversed = False + + @unittest.skipIf(cub_ops==None, "Not use cub, Skip") + @jt.flag_scope(use_cuda=1) + def test_1d(self): + test_forward([20]) + test_forward([6007]) + test_forward([6007], 0) + test_forward([6007], -1) + + @unittest.skipIf(cub_ops==None, "Not use cub, Skip") + @jt.flag_scope(use_cuda=1) + def test_1d_backward(self): + test_backward([20]) + test_backward([6007]) + test_backward([6007], 0) + test_backward([6007], -1) + + @unittest.skipIf(cub_ops==None, "Not use cub, Skip") + @jt.flag_scope(use_cuda=1) + def test_2d(self): + test_forward([5,5]) + test_forward([2000, 6007]) + test_forward([2000, 6007], 1) + test_forward([2000, 6007], -1) + + @unittest.skipIf(cub_ops==None, "Not use cub, Skip") + @jt.flag_scope(use_cuda=1) + def test_2d_backward(self): + test_backward([5,5]) + test_backward([2000, 6007]) + test_backward([2000, 6007], 1) + test_backward([2000, 6007], -1) + + @unittest.skipIf(cub_ops==None, "Not use cub, Skip") + @jt.flag_scope(use_cuda=1) + def test_nd(self): + test_forward([5,6,7,8], 0) + test_forward([5,6,7,8], 1) + test_forward([5,6,7,8], 2) + test_forward([5,6,7,8], 3) + test_forward([5,6,7,8], -1) + test_forward([16,14,14,4096], 0) + test_forward([16,14,14,4096], 1) + test_forward([16,14,14,4096], 2) + test_forward([16,14,14,4096], 3) + test_forward([16,14,14,4096], -1) + test_forward([16,14,14,4095], 3) + + @unittest.skipIf(cub_ops==None, "Not use cub, Skip") + @jt.flag_scope(use_cuda=1) + def test_nd_backward(self): + test_backward([5,6,7,8], 0) + test_backward([5,6,7,8], 1) + test_backward([5,6,7,8], 2) + test_backward([5,6,7,8], 3) + test_backward([5,6,7,8], -1) + test_backward([16,14,14,4096], 0) + test_backward([16,14,14,4096], 1) + test_backward([16,14,14,4096], 2) + test_backward([16,14,14,4096], 3) + test_backward([16,14,14,4096], -1) + test_backward([16,14,14,4095], 3) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_cublas_test_op.py b/python/jittor/test/test_cublas_test_op.py new file mode 100644 index 00000000..01e34d06 --- /dev/null +++ b/python/jittor/test/test_cublas_test_op.py @@ -0,0 +1,38 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import os +from jittor import compile_extern +if jt.has_cuda: + from jittor.compile_extern import cublas_ops, cudnn_ops, cub_ops +else: + cublas_ops = cudnn_ops = cub_ops = None + +@unittest.skipIf(cublas_ops==None, "Not use cublas, Skip") +class TestCublasTestOp(unittest.TestCase): + def test(self): + assert cublas_ops.cublas_test(2).data==123 + assert cublas_ops.cublas_test(5).data==123 + assert cublas_ops.cublas_test(10).data==123 + assert cublas_ops.cublas_test(20).data==123 + +@unittest.skipIf(cudnn_ops==None, "Not use cudnn, Skip") +class TestCudnnTestOp(unittest.TestCase): + def test(self): + assert cudnn_ops.cudnn_test("").data == 123 + assert cudnn_ops.cudnn_test("-c2048 -h7 -w7 -k512 -r1 -s1 -pad_h0 -pad_w0 -u1 -v1").data == 123 + +@unittest.skipIf(cub_ops==None, "Not use cub, Skip") +class TestCubTestOp(unittest.TestCase): + @jt.flag_scope(use_cuda=1) + def test(self): + assert cub_ops.cub_test("xx").data == 123 + assert cub_ops.cub_test("xx --n=100000").data == 123 + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_cuda.py b/python/jittor/test/test_cuda.py new file mode 100644 index 00000000..56397bd8 --- /dev/null +++ b/python/jittor/test/test_cuda.py @@ -0,0 +1,117 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +from .test_core import expect_error + +def test_cuda(use_cuda=1): + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + class TestCudaBase(unittest.TestCase): + def setUp(self): + jt.flags.use_cuda = use_cuda + def tearDown(self): + jt.flags.use_cuda = 0 + return TestCudaBase + +@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") +class TestCuda(unittest.TestCase): + @jt.flag_scope(use_cuda=1) + def test_cuda_flags(self): + a = jt.random((10, 10)) + a.sync() + + @jt.flag_scope(use_cuda=2) + def test_no_cuda_op(self): + no_cuda_op = jt.compile_custom_op(""" + struct NoCudaOp : Op { + Var* output; + NoCudaOp(NanoVector shape, string dtype="float"); + + const char* name() const override { return "my_cuda"; } + DECLARE_jit_run; + }; + """, """ + #ifndef JIT + NoCudaOp::NoCudaOp(NanoVector shape, string dtype) { + flags.set(NodeFlags::_cpu); + output = create_output(shape, dtype); + } + + void NoCudaOp::jit_prepare(JK& jk) { + add_jit_define(jk, "T", output->dtype()); + } + + #else // JIT + void NoCudaOp::jit_run() {} + #endif // JIT + """, + "no_cuda") + # force use cuda + a = no_cuda_op([3,4,5], 'float') + expect_error(lambda: a()) + + @jt.flag_scope(use_cuda=1) + def test_cuda_custom_op(self): + my_op = jt.compile_custom_op(""" + struct MyCudaOp : Op { + Var* output; + MyCudaOp(NanoVector shape, string dtype="float"); + + const char* name() const override { return "my_cuda"; } + DECLARE_jit_run; + }; + """, """ + #ifndef JIT + MyCudaOp::MyCudaOp(NanoVector shape, string dtype) { + flags.set(NodeFlags::_cuda); + output = create_output(shape, dtype); + } + + void MyCudaOp::jit_prepare(JK& jk) { + add_jit_define(jk, "T", output->dtype()); + } + + #else // JIT + #ifdef JIT_cuda + + __global__ void kernel(index_t n, T *x) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (int i = index; i < n; i += stride) + x[i] = (T)-i; + } + + void MyCudaOp::jit_run() { + index_t num = output->num; + auto* __restrict__ x = output->ptr(); + int blockSize = 256; + int numBlocks = (num + blockSize - 1) / blockSize; + kernel<<>>(num, x); + } + #endif // JIT_cuda + #endif // JIT + """, + "my_cuda") + a = my_op([3,4,5], 'float') + na = a.data + assert a.shape == [3,4,5] and a.dtype == 'float' + assert (-na.flatten() == range(3*4*5)).all(), na + + def test_cuda_fused_op(self): + a = jt.array([1,2,3]) + a.sync() + with jt.flag_scope(use_cuda=1): + ((a+a)*2).data + + +@unittest.skipIf(jt.compiler.has_cuda, "Only test without CUDA") +class TestNoCuda(unittest.TestCase): + def test_cuda_flags(self): + expect_error(lambda: setattr(jt.flags, "use_cuda",1)) + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_cudnn_op.py b/python/jittor/test/test_cudnn_op.py new file mode 100644 index 00000000..853833f1 --- /dev/null +++ b/python/jittor/test/test_cudnn_op.py @@ -0,0 +1,194 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import os +import numpy as np +from jittor import compile_extern + +from jittor.test.test_log import find_log_with_re +if jt.has_cuda: + from jittor.compile_extern import cublas_ops, cudnn_ops +else: + cublas_ops = cudnn_ops = None + +def conv_oihw(x, w, stride=1, padding=0, dilation=1): + assert type(stride)==int and type(padding)==int + N,H,W,C = x.shape + # Kh,Kw,C2,c = w.shape + c,C2,Kh,Kw = w.shape + oh, ow = (H-Kh*dilation+dilation-1+padding*2)//stride+1, (W-Kw*dilation+dilation-1+padding*2)//stride+1 + assert C2==C or C2==1, (C2, C) + x = x.reindex([N,oh,ow,c,C2,Kh,Kw], [ + 'i0', # Nid = Nid + f'i1*{stride}+i5*{dilation}-{padding}', # Hid = ohid*stride+Khid + f'i2*{stride}+i6*{dilation}-{padding}', # Wid = owid*stride+Kwid + 'i3' if C2==1 and C>1 else 'i4', # depthwise or normal + ]) + y = (x*w).sum([4,5,6]) # Kh, Kw, C + return y + +def conv(x, w, stride, padding): + out_planes, in_planes, kernel_size, _ = w.shape + Kw = kernel_size + Kh = kernel_size + _C = in_planes + Kc = out_planes + N,C,H,W = x.shape + assert C==_C + xx = x.reindex([N,Kc,C,(H+padding*2-kernel_size)//stride+1,(W+padding*2-kernel_size)//stride+1,Kh,Kw], [ + 'i0', # Nid + 'i2', # Cid + f'i3*{stride}-{padding}+i5', # Hid+Khid + f'i4*{stride}-{padding}+i6', # Wid+KWid + ]) + ww = w.broadcast(xx.shape, [0,3,4]) + yy = xx*ww + y = yy.sum([2,5,6]) # Kc, Kh, Kw + return y + +@unittest.skipIf(cudnn_ops==None, "Not use cudnn, Skip") +class TestCudnnConvOp(unittest.TestCase): + def test(self): + def check(xshape, wshape, stride=1, padding=0, dilation=1): + with jt.log_capture_scope(use_cuda=1, enable_tuner=1, + log_v=0, log_vprefix="op.cc=100" + ) as raw_log: + x = jt.random(xshape) + w = jt.random(wshape) + y = conv_oihw(x, w, stride, padding, dilation) + y.sync() + with jt.flag_scope(use_cuda=0, enable_tuner=1): + cy = conv_oihw(x, w, stride, padding, dilation) + cy.sync() + logs = find_log_with_re(raw_log, "(Jit op key (not )?found: cudnn_conv.*)") + assert len(logs)==1 and "oihw" in logs[0][0], logs + assert np.allclose(y.data, cy.data), np.abs(y.data-cy.data).max() + check([10,100,100,3], [5,3,3,3], stride=2, padding=0, dilation=1) + check([10,40,50,4], [5,4,5,5], stride=1, padding=1, dilation=1) + check([10,40,50,4], [5,4,4,4], stride=3, padding=1, dilation=1) + + def test_backward_nhwc(self): + # TODO: cudnn backward do not support nhwc + return + def check(xshape, wshape, stride=1, padding=0, dilation=1): + with jt.log_capture_scope(use_cuda=1, enable_tuner=1, + log_v=0, log_vprefix="op.cc=100" + ) as raw_log: + x = jt.random(xshape) + w = jt.random(wshape) + y = conv_oihw(x, w, stride, padding, dilation) + mask = jt.random(y.shape) + loss = mask * y + dx, dw = jt.grad(loss, [x, w]) + jt.sync([y, loss, dx, dw]) + + with jt.flag_scope(use_cuda=0, enable_tuner=0): + cy = conv_oihw(x, w, stride, padding, dilation) + closs = mask * cy + cdx, cdw = jt.grad(closs, [x, w]) + jt.sync([cy, closs, cdx, cdw]) + logs = find_log_with_re(raw_log, "(Jit op key (not )?found: cudnn_conv.*)") + assert len(logs)==3 and "oihw" in logs[0][0], logs + assert np.allclose(y.data, cy.data) + assert np.allclose(dx.data, cdx.data) + assert np.allclose(dw.data, cdw.data) + check([10,100,100,3], [5,3,3,3], stride=2, padding=0, dilation=1) + check([10,40,50,4], [5,4,5,5], stride=1, padding=1, dilation=1) + check([10,40,50,4], [5,4,4,4], stride=3, padding=1, dilation=1) + + def test_backward(self): + def check(xshape, wshape, stride=1, padding=0, dilation=1): + with jt.log_capture_scope(use_cuda=1, enable_tuner=1, + log_v=1, log_vprefix="op.cc=100,exe=1000" + ) as raw_log: + x = jt.random(xshape) + w = jt.random(wshape) + y = conv(x, w, stride, padding) + mask = jt.random(y.shape) + loss = mask * y + dx, dw = jt.grad(loss, [x, w]) + jt.sync([y, loss, dx, dw]) + + # fails when enable_tuner=1, something wrong with mkl_conv_backward_x maybe. + with jt.flag_scope(use_cuda=0, enable_tuner=0): + cy = conv(x, w, stride, padding) + closs = mask * cy + cdx, cdw = jt.grad(closs, [x, w]) + jt.sync([cy, closs, cdx, cdw]) + logs = find_log_with_re(raw_log, "(Jit op key (not )?found: cudnn_conv.*)") + assert len(logs)==3 and "oihw" in logs[0][0], logs + assert np.allclose(y.data, cy.data) + np.testing.assert_allclose(dx.data, cdx.data, atol=1e-2, rtol=1e-3) + np.testing.assert_allclose(dw.data, cdw.data, atol=1e-2, rtol=1e-3) + if os.name == 'nt': return + check([10,3,100,100], [5,3,3,3], stride=2, padding=0, dilation=1) + check([10,4,40,50], [5,4,5,5], stride=1, padding=1, dilation=1) + check([10,4,40,50], [5,4,4,4], stride=3, padding=1, dilation=1) + + def test_conv3d(self): + def check(xshape, wshape, stride=(1,1,1), padding=(0,0,0), dilation=(1,1,1), group=1): + with jt.flag_scope(use_cuda=1): + x = jt.random(xshape) + w = jt.random(wshape) + # y = jt.cudnn.ops.cudnn_conv3d(x, w, *stride, *padding, *dilation, group) + y = jt.nn.conv3d(x, w, None, stride, padding, dilation, group) + masky = jt.rand_like(y) + dx, dw = jt.grad(masky*y, [x, w]) + jt.sync_all() + + y2 = jt.nn.conv3d(x, w, None, stride, padding, dilation, group) + dx2, dw2 = jt.grad(masky*y2, [x, w]) + np.testing.assert_allclose(y.data, y2.data, rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(dx.data, dx2.data, rtol=1e-3, atol=1e-3) + np.testing.assert_allclose(dw.data, dw2.data, rtol=1e-3, atol=1e-3) + + check((2,4,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1)) + check((2,4,10,10,10), (5,4,3,3,3), (2,2,2), (1,1,1)) + check((2,4,10,10,10), (5,4,3,3,3), (2,2,2), (0,0,0)) + # TODO: check why windows failed in this test + if os.name == "nt": return + check((2,4,10,10,10), (5,4,3,3,3), (1,2,3), (0,0,0)) + check((2,4,10,10,10), (5,4,3,4,5), (1,1,1), (1,1,1)) + check((2,4,10,10,10), (5,4,3,4,5), (1,2,3), (0,0,0)) + check((2,4,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1), dilation=(1,2,3)) + + def test_conv_transpose3d(self): + jt.set_global_seed(10) + def check(xshape, wshape, stride=(1,1,1), padding=(0,0,0), dilation=(1,1,1), group=1): + with jt.flag_scope(use_cuda=1): + x = jt.random(xshape) + w = jt.random(wshape) + jt.sync_all() + + y2 = jt.nn.conv_transpose3d(x, w, None, stride, padding, 0, group, dilation) + jt.sync_all() + + with jt.flag_scope(use_cuda=1): + # y = jt.cudnn.ops.cudnn_conv3d_backward_x(w, x, *y2.shape[2:], *stride, *padding, *dilation, group) + y = jt.nn.conv_transpose3d(x, w, None, stride, padding, 0, group, dilation) + masky = jt.rand_like(y) + dx, dw = jt.grad(masky*y, [x, w]) + jt.sync_all() + + dx2, dw2 = jt.grad(masky*y2, [x, w]) + jt.sync_all() + np.testing.assert_allclose(y.numpy(), y2.numpy(), rtol=1e-3, atol=1e-4) + np.testing.assert_allclose(dx.numpy(), dx2.numpy(), rtol=1e-3, atol=1e-4) + np.testing.assert_allclose(dw.numpy(), dw2.numpy(), rtol=1e-3, atol=1e-3) + + check((2,5,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1)) + check((2,5,10,10,10), (5,4,3,3,3), (2,2,2), (1,1,1)) + check((2,5,10,10,10), (5,4,3,3,3), (2,2,2), (0,0,0)) + if os.name == 'nt': return + check((2,5,10,10,10), (5,4,3,3,3), (1,2,3), (0,0,0)) + check((2,5,10,10,10), (5,4,3,4,5), (1,1,1), (1,1,1)) + check((2,5,10,10,10), (5,4,3,4,5), (1,2,3), (0,0,0)) + check((2,5,10,10,10), (5,4,3,3,3), (1,1,1), (1,1,1), dilation=(1,2,3)) + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_cumprod_op.py b/python/jittor/test/test_cumprod_op.py new file mode 100644 index 00000000..04ed896b --- /dev/null +++ b/python/jittor/test/test_cumprod_op.py @@ -0,0 +1,46 @@ +# *************************************************************** +# Copyright (c) 2020 Jittor. Authors: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# All Rights Reserved. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np + +skip_this_test = False +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + from torch.autograd import Variable +except: + torch = None + skip_this_test = True + + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestCumprod(unittest.TestCase): + def test_cumprod_cpu(self): + for i in range(1,6): + for j in range(i): + print("test", i, j) + x = np.random.rand(*((10,)*i)) + x_jt = jt.array(x) + y_jt = jt.cumprod(x_jt, j).sqr() + g_jt = jt.grad(y_jt.sum(), x_jt) + x_tc = Variable(torch.from_numpy(x), requires_grad=True) + y_tc = torch.cumprod(x_tc, j)**2 + y_tc.sum().backward() + g_tc = x_tc.grad + assert np.allclose(y_jt.numpy(), y_tc.data) + np.testing.assert_allclose(g_jt.numpy(), g_tc.data, atol=1e-5) + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda=1) + def test_cumprod_gpu(self): + self.test_cumprod_cpu() + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_custom_op.py b/python/jittor/test/test_custom_op.py new file mode 100644 index 00000000..26ed1580 --- /dev/null +++ b/python/jittor/test/test_custom_op.py @@ -0,0 +1,110 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import os +import jittor as jt +from .test_core import expect_error + +header =""" +#pragma once +#include "op.h" + +namespace jittor { + +struct CustomOp : Op { + Var* output; + CustomOp(NanoVector shape, NanoString dtype=ns_float32); + + const char* name() const override { return "custom"; } + DECLARE_jit_run; +}; + +} // jittor +""" + +src = """ +#include "var.h" +#include "custom_op.h" + +namespace jittor { +#ifndef JIT +CustomOp::CustomOp(NanoVector shape, NanoString dtype) { + output = create_output(shape, dtype); +} + +void CustomOp::jit_prepare(JK& jk) { + add_jit_define(jk, "T", output->dtype()); +} + +#else // JIT +#ifdef JIT_cpu +void CustomOp::jit_run() { + index_t num = output->num; + auto* __restrict__ x = output->ptr(); + for (index_t i=0; idtype()); + } + + #else // JIT + void MyOp::jit_run() { + index_t num = output->num; + auto* __restrict__ x = output->ptr(); + for (index_t i=0; i +# Dun Liang . +# All Rights Reserved. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from jittor import compile_extern +from .test_log import find_log_with_re +import copy +if jt.has_cuda: + from jittor.compile_extern import cutt_ops +else: + cutt_ops = None + +class TestCutt(unittest.TestCase): + @unittest.skipIf(cutt_ops==None, "Not use cutt, Skip") + @jt.flag_scope(use_cuda=1) + def test(self): + t = cutt_ops.cutt_test("213") + assert t.data == 123 +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_cutt_transpose_op.py b/python/jittor/test/test_cutt_transpose_op.py new file mode 100644 index 00000000..8558176d --- /dev/null +++ b/python/jittor/test/test_cutt_transpose_op.py @@ -0,0 +1,103 @@ +# *************************************************************** +# Copyright (c) 2019 Dun Liang . All Rights Reserved. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from .test_core import expect_error +from .test_grad import ngrad +from itertools import permutations +from jittor import compile_extern +from .test_log import find_log_with_re +if jt.has_cuda: + from jittor.compile_extern import cutt_ops +else: + cutt_ops = None + +def gen_data(shape): + num = np.multiply.reduce(shape) + a = np.arange(0, num) + return a.reshape(shape) + +class TestCuttTransposeOp(unittest.TestCase): + @unittest.skipIf(cutt_ops==None, "Not use cutt, Skip") + @jt.flag_scope(use_cuda=1) + def test_with_np(self): + def check(a): + perms = list(permutations(range(a.ndim))) + [None] + for perm in perms: + with jt.log_capture_scope( + log_silent=1, + log_v=0, log_vprefix="cutt=100" + ) as raw_log: + if perm: + x = np.transpose(a, perm) + y = jt.transpose(a, perm).data + else: + x = np.transpose(a) + y = jt.transpose(a).data + self.assertEqual(x.shape, y.shape) + logs = find_log_with_re(raw_log, "(Run cutt_transpose with key.*)") + if perm is None: + continue + last = -1 + in_order = True + for i in range(len(perm)): + if a.shape[perm[i]] == 1: + continue + if last != -1 and last > perm[i]: + in_order = False + break + last = perm[i] + # if not in_order: + # assert len(logs)==1 + assert (x==y).all(), f"\n{x}\n{y}\n{perm}\n{a.shape}" + + ia = [gen_data([5, 7]), gen_data([2,2,2]), gen_data([2,3,4,5]), gen_data([5,3]), gen_data([3,1,5,3,1])] + for a in ia: check(a) + + @unittest.skipIf(cutt_ops==None, "Not use cutt, Skip") + @jt.flag_scope(use_cuda=1) + def test_grad(self): + def check(a): + perms = list(permutations(range(a.ndim))) + [None] + for perm in perms: + x = jt.array(a).float() + if perm: + y = jt.transpose(x, perm) + else: + y = jt.transpose(x) + dx = jt.grad(y*y, x).data + self.assertEqual(dx.shape, a.shape) + assert (dx==a*2).all(), f"\n{dx}\n{a}\n{perm}" + ia = [gen_data([2,2,2]), gen_data([2,3,4,5]), gen_data([5,3]), gen_data([3,1,5,3,1])] + for a in ia: check(a) + + @unittest.skipIf(cutt_ops==None, "Not use cutt, Skip") + @jt.flag_scope(use_cuda=1) + def test_matmul_grad(self): + np.random.seed(0) + for i in range(10): + a = np.random.rand(2,3).astype("float32") + b = np.random.rand(3,4).astype("float32") + out, (da, db) = ngrad(lambda vars: np.matmul(vars[0],vars[1]).sum(), [a,b], 1e-1) + ja = jt.array(a) + jb = jt.array(b) + jc = ja.matmul(jb) + jda, jdb = jt.grad(jc, [ja,jb]) + assert ((da-jda.data)<1e-5).all(), (da, jda.data, da-jda.data) + assert ((db-jdb.data)<1e-5).all(), (db-jdb.data) + + @unittest.skipIf(cutt_ops==None, "Not use cutt, Skip") + @jt.flag_scope(use_cuda=1) + def test_matmul_grad(self): + a = jt.zeros((0, 10)) + b = a.transpose(1, 0) + c = b.data + assert c.shape[0] == 10 + assert c.shape[1] == 0 + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_dataset.py b/python/jittor/test/test_dataset.py new file mode 100644 index 00000000..a9029062 --- /dev/null +++ b/python/jittor/test/test_dataset.py @@ -0,0 +1,348 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +from jittor.dataset.dataset import ImageFolder, Dataset +import jittor.transform as transform + +import jittor as jt +import unittest +import os +import numpy as np +import random + +pass_this_test = False +msg = "" +mid = 0 +if hasattr(os, "uname") and os.uname()[1] == "jittor-ce": + mid = 1 +try: + traindir = ["/data1/cjld/imagenet/train/","/home/cjld/imagenet/train/"][mid] + assert os.path.isdir(traindir) +except Exception as e: + pass_this_test = True + msg = str(e) + +@unittest.skipIf(pass_this_test, f"can not run imagenet dataset test: {msg}") +class TestDataset(unittest.TestCase): + def test_multi_workers(self): + check_num_batch = 10 + tc_data = [] + + def get_dataset(): + dataset = ImageFolder(traindir).set_attrs(batch_size=256, shuffle=False) + dataset.set_attrs(transform = transform.Compose([ + transform.Resize(224), + transform.ImageNormalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]), num_workers=0) + return dataset + + dataset = get_dataset() + + for i, data in enumerate(dataset): + print("get batch", i) + tc_data.append(data) + if i==check_num_batch: break + + def check(num_workers, epoch=1): + dataset = get_dataset().set_attrs(num_workers=num_workers) + + random.seed(0) + + for _ in range(epoch): + for i, (images, labels) in enumerate(dataset): + print("compare", i) + assert np.allclose(images.data, tc_data[i][0].data), \ + (images.sum(), tc_data[i][0].sum(), images.shape, + tc_data[i][0].shape) + assert np.allclose(labels.data, tc_data[i][1].data) + if i==check_num_batch: break + # dataset.terminate() + check(1) + check(2) + check(4,2) + + def test_collate_batch(self): + from jittor.dataset.utils import collate_batch + batch = collate_batch([(1,1),(1,2),(1,3)]) + assert isinstance(batch[0], np.ndarray) + assert isinstance(batch[1], np.ndarray) + + +class YourDataset(Dataset): + def __init__(self): + super().__init__() + self.set_attrs(total_len=10240) + + def __getitem__(self, k): + self.tmp = None + x = jt.array(k) + y = x + for i in range(10): + for j in range(i+2): + y = y + j - j + y.stop_fuse() + return x, y + + +class YourDataset2(Dataset): + def __init__(self): + super().__init__() + self.set_attrs(total_len=16) + + def __getitem__(self, k): + return np.random.rand(2) + + +class YourDataset3(Dataset): + def __init__(self): + super().__init__() + self.set_attrs(total_len=16) + + def __getitem__(self, k): + return random.randint(0,1000) + + +class YourDataset4(Dataset): + def __init__(self): + super().__init__() + self.set_attrs(total_len=160) + + def __getitem__(self, k): + return jt.rand(2) + + +class YourDataset5(Dataset): + def __init__(self): + super().__init__() + self.set_attrs(total_len=160) + + def __getitem__(self, k): + return { "a":np.array([1,2,3]) } + +class TestDataset2(unittest.TestCase): + def test_dataset_use_jittor(self): + dataset = YourDataset().set_attrs(batch_size=256, shuffle=True, num_workers=4) + dataset.tmp = jt.array([1,2,3,4,5]) + dataset.tmp.sync() + for x, y in dataset: + # dataset.display_worker_status() + pass + + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda=1) + def test_dataset_use_jittor_cuda(self): + self.test_dataset_use_jittor() + +class TestDatasetSeed(unittest.TestCase): + def test_np(self): + + dataset = YourDataset2().set_attrs(batch_size=1, shuffle=True, num_workers=4) + for _ in range(10): + dd = [] + for d in dataset: + dd.append(d.numpy()) + for i in range(len(d)): + for j in range(i+1, len(d)): + assert not np.allclose(dd[i], dd[j]) + + def test_py_native(self): + import random + + jt.set_global_seed(0) + dataset = YourDataset3().set_attrs(batch_size=1, shuffle=True, num_workers=4) + for _ in range(10): + dd = [] + for d in dataset: + dd.append(d.numpy()) + for i in range(len(d)): + for j in range(i+1, len(d)): + assert not np.allclose(dd[i], dd[j]) + + def test_jtrand(self): + import random + + jt.set_global_seed(0) + dataset = YourDataset4().set_attrs(batch_size=1, shuffle=True, num_workers=4) + for _ in range(10): + dd = [] + for d in dataset: + dd.append(d.numpy()) + for i in range(len(d)): + for j in range(i+1, len(d)): + assert not np.allclose(dd[i], dd[j]) + + def test_dict(self): + import random + + jt.set_global_seed(0) + dataset = YourDataset5().set_attrs(batch_size=1, shuffle=True, num_workers=4) + for _ in range(10): + dd = [] + for d in dataset: + # breakpoint() + assert isinstance(d, dict) + assert isinstance(d['a'], jt.Var) + np.testing.assert_allclose(d['a'].numpy(), [[1,2,3]]) + + def test_cifar(self): + from jittor.dataset.cifar import CIFAR10 + a = CIFAR10() + a.set_attrs(batch_size=16) + for imgs, labels in a: + print(imgs.shape, labels.shape) + assert imgs.shape == [16,32,32,3,] + assert labels.shape == [16,] + break + + def test_tensor_dataset(self): + import jittor as jt + from jittor.dataset import TensorDataset + + x = jt.array([1,2,3]) + y = jt.array([4,5,6]) + z = jt.array([7,8,9]) + + dataset = TensorDataset(x, y, z) + # dataset.set_attrs(batch_size=2) + dataset.set_attrs(batch_size=1) + + for i,(a,b,c) in enumerate(dataset): + # print(a,b,c) + # print(a.shape) + assert a.shape == [1] + assert x[i] == a + assert y[i] == b + assert z[i] == c + + def test_children_died(self): + if os.name == 'nt': + # TODO: windows cannot pass this test now + # don't know how to detect child died in windows + # some clue: https://ikriv.com/blog/?p=1431 + return + src = """ +import jittor as jt +from jittor.dataset import Dataset +import numpy as np + +class YourDataset(Dataset): + def __init__(self): + super().__init__() + self.set_attrs(total_len=160) + + def __getitem__(self, k): + if k>100: + while 1: + pass + return { "a":np.array([1,2,3]) } +if __name__ == "__main__": + dataset = YourDataset() + dataset.set_attrs(num_workers=2) + + for d in dataset: + dataset.workers[0].p.kill() + pass +""" + fname = os.path.join(jt.flags.cache_path, "children_dead_test.py") + with open(fname, 'w') as f: + f.write(src) + import subprocess as sp + import sys + cmd = sys.executable + " " + fname + print(cmd) + r = sp.run(cmd, shell=True, stdout=sp.PIPE, stderr=sp.PIPE) + s = r.stderr.decode() + print(s) + assert r.returncode != 0 + assert "SIGCHLD" in s + assert "quick exit" in s + + + @unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found") + def test_dataset_shuffle_mpi(self): + src = """ +import jittor as jt +from jittor.dataset import Dataset +import numpy as np + +class YourDataset(Dataset): + def __init__(self): + super().__init__() + self.set_attrs(total_len=160, shuffle=True) + + def __getitem__(self, k): + return k + +dataset = YourDataset() +dataset.set_attrs(num_workers=2) + +for d in dataset: + for a in d: + print("CHECK: ", a.item()) +""" + fname = os.path.join(jt.flags.cache_path, "test_dataset_shuffle_mpi.py") + with open(fname, 'w') as f: + f.write(src) + import subprocess as sp + import sys + cmd = sys.executable + " " + fname + mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun") + cmd = mpirun_path + " -np 2 " + cmd + print(cmd) + r = sp.run(cmd, shell=True, stdout=sp.PIPE, stderr=sp.PIPE) + s = r.stdout.decode() + # print(s) + st = set([ l for l in s.splitlines() if l.startswith("CHECK:") ]) + assert r.returncode == 0 + # print(st) + assert len(st) == 160, len(st) + + def test_children_died2(self): + src = """ +import jittor as jt +from jittor.dataset import Dataset +import numpy as np + +class YourDataset(Dataset): + def __init__(self): + super().__init__() + self.set_attrs(total_len=160) + + def __getitem__(self, k): + if k>100: + while 1: + pass + return { "a":np.array([1,2,3]) } + +if __name__ == "__main__": + dataset = YourDataset() + dataset.set_attrs(num_workers=2) + + for d in dataset: + break + dataset.terminate() +""" + fname = os.path.join(jt.flags.cache_path, "children_dead_test.py") + with open(fname, 'w') as f: + f.write(src) + import subprocess as sp + import sys + cmd = sys.executable + " " + fname + print(cmd) + r = sp.run(cmd, shell=True, stdout=sp.PIPE, stderr=sp.PIPE) + s = r.stderr.decode() + print(s) + assert r.returncode == 0 + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_default_var.py b/python/jittor/test/test_default_var.py new file mode 100644 index 00000000..d6db7e63 --- /dev/null +++ b/python/jittor/test/test_default_var.py @@ -0,0 +1,49 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import sys +import os +import jittor as jt +import unittest +import time +import numpy as np +from .test_reorder_tuner import simple_parser +from .test_log import find_log_with_re + +class TestDefaultVar(unittest.TestCase): + @classmethod + def setUpClass(self): + return + + @jt.flag_scope(auto_convert_64_to_32=0) + def test_default_var(self): + a=jt.array((2,3,3), np.float32) + b=a*2.0 + assert str(b.dtype) == "float32" + b=a*2 + assert str(b.dtype) == "float32" + a=jt.array((2,3,3), np.int32) + b=a*2.0 + assert str(b.dtype) == "float32" + b=a*2 + assert str(b.dtype) == "int32" + + a=jt.array((2,3,3), np.float64) + b=a*2.0 + assert str(b.dtype) == "float64" + b=a*2 + assert str(b.dtype) == "float64" + a=jt.array((2,3,3), np.int64) + b=a*2.0 + assert str(b.dtype) == "float64" + b=a*2 + assert str(b.dtype) == "int64" + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_densenet.py b/python/jittor/test/test_densenet.py new file mode 100644 index 00000000..2d32bc3a --- /dev/null +++ b/python/jittor/test/test_densenet.py @@ -0,0 +1,103 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +from jittor import nn, Module +from jittor.models import densenet +import numpy as np +import sys, os +import random +import math +import unittest +from jittor.test.test_reorder_tuner import simple_parser +from jittor.test.test_log import find_log_with_re +from jittor.dataset.mnist import MNIST +import jittor.transform as trans +import time + +skip_this_test = True + +class MnistNet(Module): + def __init__(self): + self.model = densenet.densenet169() + self.layer = nn.Linear(1000,10) + def execute(self, x): + x = self.model(x) + x = self.layer(x) + return x + +@unittest.skipIf(skip_this_test, "skip_this_test") +class TestDensenet(unittest.TestCase): + @classmethod + def setUpClass(self): + # hyper-parameters + self.batch_size = 100 + self.weight_decay = 0.0001 + self.momentum = 0.9 + self.learning_rate = 0.1 + # mnist dataset + self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \ + .set_attrs(batch_size=self.batch_size, shuffle=True) + self.train_loader.num_workers = 4 + + # setup random seed + def setup_seed(self, seed): + np.random.seed(seed) + random.seed(seed) + jt.seed(seed) + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1, use_stat_allocator=1) + def test_densenet(self): + self.setup_seed(1) + loss_list=[] + acc_list=[] + mnist_net = MnistNet() + global prev + prev = time.time() + SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay) + # SGD = jt.optim.Adam(mnist_net.parameters(), lr=0.0001) + + for batch_idx, (data, target) in enumerate(self.train_loader): + output = mnist_net(data) + loss = nn.cross_entropy_loss(output, target) + SGD.step(loss) + def callback(batch_idx, loss, output, target): + # print train info + global prev + pred = np.argmax(output, axis=1) + acc = np.mean(target==pred) + loss_list.append(loss[0]) + acc_list.append(acc) + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}' + .format(0, batch_idx, 600,1. * batch_idx / 6.0, loss[0], acc, time.time()-prev)) + # prev = time.time() + jt.fetch(batch_idx, loss, output, target, callback) + # Train Epoch: 0 [0/600 (0%)] Loss: 2.402650 Acc: 0.060000 + # Train Epoch: 0 [1/600 (0%)] Loss: 2.770145 Acc: 0.100000 + # Train Epoch: 0 [2/600 (0%)] Loss: 3.528072 Acc: 0.100000 + # Train Epoch: 0 [3/600 (0%)] Loss: 2.992042 Acc: 0.100000 + # Train Epoch: 0 [4/600 (1%)] Loss: 4.672772 Acc: 0.060000 + # Train Epoch: 0 [5/600 (1%)] Loss: 5.003410 Acc: 0.080000 + # Train Epoch: 0 [6/600 (1%)] Loss: 5.417546 Acc: 0.100000 + # Train Epoch: 0 [7/600 (1%)] Loss: 5.137665 Acc: 0.100000 + # Train Epoch: 0 [8/600 (1%)] Loss: 5.241075 Acc: 0.070000 + # Train Epoch: 0 [9/600 (2%)] Loss: 4.515363 Acc: 0.100000 + # Train Epoch: 0 [10/600 (2%)] Loss: 3.357187 Acc: 0.170000 + # Train Epoch: 0 [20/600 (3%)] Loss: 2.265879 Acc: 0.100000 + # Train Epoch: 0 [30/600 (5%)] Loss: 2.107000 Acc: 0.250000 + # Train Epoch: 0 [40/600 (7%)] Loss: 1.918214 Acc: 0.290000 + # Train Epoch: 0 [50/600 (8%)] Loss: 1.645694 Acc: 0.400000 + + jt.sync_all(True) + assert np.mean(loss_list[-50:])<0.3 + assert np.mean(acc_list[-50:])>0.9 + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_depthwise_conv.py b/python/jittor/test/test_depthwise_conv.py new file mode 100644 index 00000000..59a6dcef --- /dev/null +++ b/python/jittor/test/test_depthwise_conv.py @@ -0,0 +1,89 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import jittor.models as jtmodels + +def load_parameters(m1, m2): + m1.save('/tmp/temp.pk') + m2.load('/tmp/temp.pk') + +def compare_parameters(m1, m2): + ps1 = m1.parameters() + ps2 = m2.parameters() + for i in range(len(ps1)): + x = ps1[i].data + 1e-8 + y = ps2[i].data + 1e-8 + relative_error = abs(x - y) / abs(y) + diff = relative_error.mean() + assert diff < 1e-4, (diff, 'backward', ps2[i].name(), ps1[i].mean(), ps1[i].std(), ps2[i].mean(), ps2[i].std()) + +class TestDepthwiseConv(unittest.TestCase): + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1) + def test_data(self): + test_img = np.random.random((64,3,224,224)).astype('float32') + jittor_test_img = jt.array(test_img) + lr = 100 + + jittor_model = jtmodels.__dict__['mobilenet_v2']() + jittor_model2 = jtmodels.__dict__['mobilenet_v2']() + # Set eval to avoid dropout layer & bn errors + jittor_model.train() + jittor_model.classifier[0].eval() + for m in jittor_model.modules(): + if isinstance(m, jt.nn.BatchNorm): + m.eval() + + jittor_model2.train() + jittor_model2.classifier[0].eval() + for m in jittor_model2.modules(): + if isinstance(m, jt.nn.BatchNorm): + m.eval() + + load_parameters(jittor_model2, jittor_model) + for m in jittor_model.modules(): + if isinstance(m, jt.nn.Conv): + m.is_depthwise_conv = False + cnt = 0 + for m in jittor_model2.modules(): + if isinstance(m, jt.nn.Conv): + if (m.is_depthwise_conv): + cnt += 1 + assert cnt == 17, (cnt, '!=', 17) + jt_optimizer = jt.nn.SGD(jittor_model.parameters(), lr = lr) + jt_optimizer2 = jt.nn.SGD(jittor_model2.parameters(), lr = lr) + + jittor_result = jittor_model(jittor_test_img) + mask = jt.random(jittor_result.shape, jittor_result.dtype) + loss = jittor_result * mask + jt_optimizer.step(loss) + jt.sync_all(True) + + jittor_result2 = jittor_model2(jittor_test_img) + loss = jittor_result2 * mask + + x = jittor_result2.data + 1e-8 + y = jittor_result.data + 1e-8 + relative_error = abs(x - y) / abs(y) + diff = relative_error.mean() + assert diff < 1e-4, (diff, 'forword') + + jt_optimizer2.step(loss) + jt.sync_all(True) + compare_parameters(jittor_model, jittor_model2) + + + jt.clean() + jt.gc() + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_digamma.py b/python/jittor/test/test_digamma.py new file mode 100644 index 00000000..308c5f8d --- /dev/null +++ b/python/jittor/test/test_digamma.py @@ -0,0 +1,43 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Haoyang Peng <2247838039@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +import numpy as np +import unittest + +try: + import torch + from torch.autograd import Variable + has_autograd = True +except: + has_autograd = False + +@unittest.skipIf(not has_autograd, "No autograd found.") +class TestDigamma(unittest.TestCase): + def test_digamma(self): + for i in range(30): + nx = np.random.uniform(0, 1, (32, 32)) + x = jt.array(nx) + tx = torch.autograd.Variable(torch.tensor(nx, dtype=torch.float32), requires_grad=True) + dx = jt.digamma.apply(x) + tdx = torch.digamma(tx) + np.testing.assert_allclose(dx.data, tdx.detach().numpy(), rtol=1e-4, atol=1e-6) + jgdx = jt.grad(dx, x) + tgdx = torch.autograd.grad(tdx, tx, torch.ones_like(tx))[0] + np.testing.assert_allclose(jgdx.data, tgdx.detach().numpy(), rtol=1e-4, atol=1e-6) + +@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") +class TestCudaDigamma(TestDigamma): + def setUp(self): + jt.flags.use_cuda = 1 + def tearDown(self): + jt.flags.use_cuda = 0 + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_distributions.py b/python/jittor/test/test_distributions.py new file mode 100644 index 00000000..0143c5a7 --- /dev/null +++ b/python/jittor/test/test_distributions.py @@ -0,0 +1,155 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Wenyang Zhou <576825820@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import jittor.distributions as jd + +skip_this_test = False +try: + jt.dirty_fix_pytorch_runtime_error() + import torch +except: + torch = None + skip_this_test = True + + +class TestOneHot(unittest.TestCase): + def test_presum(self): + a = jt.array([[1,2,3,4]]) + b = jd.simple_presum(a) + assert (b.data == [[0,1,3,6,10]]).all() + + @unittest.skipIf(skip_this_test, "No Torch Found") + def test_one_hot(self): + a = jd.OneHotCategorical(jt.array([0.25, 0.25, 0.25, 0.25])) + x = a.sample().numpy() + for i in range(1000): + x += a.sample().numpy() + assert (x > 200).all() + y = a.sample([2,3]) + y.sync() + assert y.shape == [2,3,4] + probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10)) + probs,probs2 = probs / probs.sum(),probs2 / probs2.sum() + + jc, jc2 = jd.OneHotCategorical(jt.array(probs)),jd.OneHotCategorical(jt.array(probs2)) + tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2)) + assert np.allclose(jc.entropy().data,tc.entropy().numpy()) + x = np.zeros((4,10)) + for _ in range(4): + nx = np.random.randint(0,9) + x[_,nx] = 1 + np.testing.assert_allclose(jc.log_prob(jt.array(x)),tc.log_prob(torch.tensor(x)), atol=1e-5) + assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2)) + + def test_cate(self): + a = jd.Categorical(jt.array([0.25, 0.25, 0.25, 0.25])) + x =np.array([0,0,0,0]) + for i in range(1000): + x[a.sample().item()]+=1 + assert (x > 200).all() + y = a.sample([2,3]) + y.sync() + assert y.shape == [2,3] + + @unittest.skipIf(skip_this_test, "No Torch Found") + def test_normal(self): + for _ in range(4): + mu = np.random.uniform(-1,1) + sigma = np.random.uniform(0,2) + jn = jd.Normal(mu,sigma) + tn = torch.distributions.Normal(mu,sigma) + assert np.allclose(jn.entropy().data,tn.entropy().numpy()) + x = np.random.uniform(-1,1) + np.testing.assert_allclose(jn.log_prob(x),tn.log_prob(torch.tensor(x))) + mu2 = np.random.uniform(-1,1) + sigma2 = np.random.uniform(0,2) + jn2 = jd.Normal(mu2,sigma2) + tn2 = torch.distributions.Normal(mu2,sigma2) + assert np.allclose(jd.kl_divergence(jn,jn2).data,torch.distributions.kl_divergence(tn,tn2).numpy()) + + @unittest.skipIf(skip_this_test, "No Torch Found") + def test_categorical1(self): + for _ in range(4): + probs,probs2 = np.random.uniform(0,1,(10)), np.random.uniform(0,1,(10)) + probs,probs2 = probs / probs.sum(),probs2 / probs2.sum() + jc, jc2 = jd.Categorical(jt.array(probs)),jd.Categorical(jt.array(probs2)) + tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2)) + assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy()) + x = np.random.randint(0,10,(4)) + np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5) + assert np.allclose(jd.kl_divergence(jc,jc2),torch.distributions.kl_divergence(tc,tc2)) + + @unittest.skipIf(skip_this_test, "No Torch Found") + def test_categorical2(self): + def check(prob_shape, sample_shape): + for _ in range(4): + probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape) + + jc, jc2 = jd.Categorical(jt.array(probs)),jd.Categorical(jt.array(probs2)) + tc, tc2 = torch.distributions.Categorical(torch.tensor(probs)),torch.distributions.Categorical(torch.tensor(probs2)) + assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy()) + x1 = jc.sample(sample_shape) + x2 = tc.sample(sample_shape) + assert tuple(x1.shape) == tuple(x2.shape) + x = np.random.randint(0,prob_shape[-1], tuple(x1.shape)) + np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5) + np.testing.assert_allclose(jd.kl_divergence(jc,jc2), torch.distributions.kl_divergence(tc,tc2), atol=1e-5) + check((10,), (4,)) + check((2,3), (4,)) + check((3,4,5,6), (2,)) + + @unittest.skipIf(skip_this_test, "No Torch Found") + def test_one_hot_categorical2(self): + def check(prob_shape, sample_shape): + for _ in range(4): + probs,probs2 = np.random.uniform(0,1,prob_shape), np.random.uniform(0,1, prob_shape) + + jc, jc2 = jd.OneHotCategorical(jt.array(probs)),jd.OneHotCategorical(jt.array(probs2)) + tc, tc2 = torch.distributions.OneHotCategorical(torch.tensor(probs)),torch.distributions.OneHotCategorical(torch.tensor(probs2)) + assert np.allclose(jc.entropy().data, tc.entropy().numpy()), (jc.entropy().data, tc.entropy().numpy()) + x1 = jc.sample(sample_shape) + x2 = tc.sample(sample_shape) + assert tuple(x1.shape) == tuple(x2.shape) + x = np.random.randint(0,prob_shape[-1], tuple(x1.shape)) + np.testing.assert_allclose(jc.log_prob(x), tc.log_prob(torch.tensor(x)), atol=1e-5) + np.testing.assert_allclose(jd.kl_divergence(jc,jc2), torch.distributions.kl_divergence(tc,tc2), atol=1e-5) + check((10,), (4,)) + check((2,3), (4,)) + check((3,4,5,6), (2,)) + + @unittest.skipIf(skip_this_test, "No Torch Found") + def test_uniform(self): + for _ in range(4): + low, low2 = np.random.randint(-1,2), np.random.randint(-1,2) + leng, leng2 = np.random.uniform(0,2), np.random.uniform(0,2) + high, high2 = low + leng, low2 + leng2 + ju, ju2 = jd.Uniform(low,high),jd.Uniform(low2,high2) + tu, tu2 = torch.distributions.Uniform(low,high),torch.distributions.Uniform(low2,high2) + assert np.allclose(ju.entropy().data,tu.entropy().numpy()) + x = np.random.uniform(low,high) + assert np.allclose(ju.log_prob(x),tu.log_prob(torch.tensor(x))) + assert np.allclose(jd.kl_divergence(ju,ju2),torch.distributions.kl_divergence(tu,tu2)) + + @unittest.skipIf(skip_this_test, "No Torch Found") + def test_geometric(self): + for _ in range(4): + prob, prob2 = np.random.uniform(0,1), np.random.uniform(0,1) + jg, jg2 = jd.Geometric(prob),jd.Geometric(prob2) + tg, tg2 = torch.distributions.Geometric(prob),torch.distributions.Geometric(prob2) + np.testing.assert_allclose(jg.entropy().data,tg.entropy().numpy(), atol=1e-4) + x = np.random.randint(1,10) + np.testing.assert_allclose(jg.log_prob(x),tg.log_prob(torch.tensor(x)), atol=1e-4) + # print(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2)) + np.testing.assert_allclose(jd.kl_divergence(jg,jg2),torch.distributions.kl_divergence(tg,tg2), atol=1e-4) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_einops.py b/python/jittor/test/test_einops.py new file mode 100644 index 00000000..fb9facdf --- /dev/null +++ b/python/jittor/test/test_einops.py @@ -0,0 +1,616 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: DongYang Li . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +from collections import namedtuple +import tempfile +import pickle +import itertools +from jittor.einops.einops import (rearrange, reduce, _enumerate_directions, _reductions) +from jittor.einops import EinopsError +import jittor as jt +import numpy +import unittest + +# tests/__init__.py +import os +from jittor.einops import _backends +import warnings + +flag_to_bool = { + '': False, + '0': False, + '1': True, +} + + +def collect_test_backends(symbolic=False, layers=False): + """ + :param symbolic: symbolic or imperative frameworks? + :param layers: layers or operations? + :return: list of backends satisfying set conditions + """ + if not symbolic: + if not layers: + backend_types = [ + _backends.NumpyBackend, + _backends.JittorBackend, + ] + else: + backend_types = [ + _backends.JittorBackend, + ] + else: + backend_types = [] + result = [] + for backend_type in backend_types: + try: + result.append(backend_type()) + except ImportError: + # problem with backend installation fails a specific test function, + # but will be skipped in all other test cases + warnings.warn('backend could not be initialized for tests: {}'.format(backend_type)) + return result + + +# test/test_ops.py + +imp_op_backends = collect_test_backends(symbolic=False, layers=False) + +# test/test_layer.py + + +class TestSlice(unittest.TestCase): + + def test_anonymous_axes(self): + x = numpy.arange(1 * 2 * 4 * 6).reshape([1, 2, 4, 6]) + for pattern, axis_dimensions in test_cases_repeat_anonymous: + check_reversion(x, pattern, **axis_dimensions) + + def test_repeat_imperatives(self): + x = numpy.arange(2 * 3 * 5).reshape([2, 3, 5]) + for backend in imp_op_backends: + print('Repeat tests for ', backend.framework_name) + + for pattern, axis_dimensions in repeat_test_cases: + expected = reduce(x, pattern, reduction='repeat', **axis_dimensions) + converted = backend.from_numpy(x) + repeated = reduce(converted, pattern, reduction='repeat', **axis_dimensions) + result = backend.to_numpy(repeated) + assert numpy.array_equal(result, expected) + + def test_repeat_numpy(self): + # check repeat vs reduce. Repeat works ok if reverse reduction with min and max work well + x = numpy.arange(2 * 3 * 5).reshape([2, 3, 5]) + x1 = reduce(x, 'a b c -> copy a b c ', reduction='repeat', copy=1) + assert numpy.array_equal(x[None], x1) + for pattern, axis_dimensions in repeat_test_cases: + check_reversion(x, pattern, **axis_dimensions) + + def test_tiling_imperatives(self): + for backend in imp_op_backends: + print('Tiling tests for ', backend.framework_name) + input = numpy.arange(2 * 3 * 5, dtype='int64').reshape([2, 1, 3, 1, 5]) + test_cases = [ + (1, 1, 1, 1, 1), + (1, 2, 1, 3, 1), + (3, 1, 1, 4, 1), + ] + for repeats in test_cases: + expected = numpy.tile(input, repeats) + converted = backend.from_numpy(input) + repeated = backend.tile(converted, repeats) + result = backend.to_numpy(repeated) + assert numpy.array_equal(result, expected) + + def test_gradients_imperatives(self): + # lazy - just checking reductions + for reduction in _reductions: + x = numpy.arange(1, 1 + 2 * 3 * 4).reshape([2, 3, 4]).astype('float32') + results = {} + for backend in imp_op_backends: + y0 = backend.from_numpy(x) + if not 'jittor' in backend.framework_name and not hasattr(y0, 'grad'): + continue + y1 = reduce(y0, 'a b c -> c a', reduction=reduction) + y2 = reduce(y1, 'c a -> a c', reduction=reduction) + y3 = reduce(y2, 'a (c1 c2) -> a', reduction=reduction, c1=2) + y4 = reduce(y3, '... -> ', reduction=reduction) + if 'jittor' in backend.framework_name: + grad = backend.jittor.grad(y4, y0) + else: + y4.backward() + grad = y0.grad + results[backend.framework_name] = backend.to_numpy(grad) + + print('comparing gradients for', results.keys()) + for name1, grad1 in results.items(): + for name2, grad2 in results.items(): + assert numpy.allclose(grad1, grad2), [name1, name2, 'provided different gradients'] + + def test_concatenations_and_stacking(self): + for backend in imp_op_backends: + print('testing shapes for ', backend.framework_name) + for n_arrays in [1, 2, 5]: + shapes = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6] + for shape in shapes: + if (backend.framework_name == 'jittor')\ + and len(shape) == 0: + # jittor stores scalar in 1d array + continue + arrays1 = [numpy.arange(i, i + numpy.prod(shape)).reshape(shape) for i in range(n_arrays)] + arrays2 = [backend.from_numpy(array) for array in arrays1] + result0 = numpy.asarray(arrays1) + result1 = rearrange(arrays1, '...->...') + result2 = rearrange(arrays2, '...->...') + assert numpy.array_equal(result0, result1) + assert numpy.array_equal(result1, backend.to_numpy(result2)) + + result1 = rearrange(arrays1, 'b ... -> ... b') + result2 = rearrange(arrays2, 'b ... -> ... b') + assert numpy.array_equal(result1, backend.to_numpy(result2)) + + def test_enumerating_directions(self): + for backend in imp_op_backends: + print('testing directions for', backend.framework_name) + for shape in [[], [1], [1, 1, 1], [2, 3, 5, 7]]: + if (backend.framework_name == 'jittor')\ + and len(shape) == 0: + # jittor stores scalar in 1d array + continue + x = numpy.arange(numpy.prod(shape)).reshape(shape) + axes1 = _enumerate_directions(x) + axes2 = _enumerate_directions(backend.from_numpy(x)) + assert len(axes1) == len(axes2) == len(shape) + for ax1, ax2 in zip(axes1, axes2): + ax2 = backend.to_numpy(ax2) + assert ax1.shape == ax2.shape + assert numpy.allclose(ax1, ax2) + + def test_reduction_with_callable_imperatives(self): + x_numpy = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]).astype('float32') + x_numpy /= x_numpy.max() + + def logsumexp_jittor(x, tuple_of_axes): + import jittor as jt + return jt.nn.logsumexp(x, tuple_of_axes) + + def logsumexp_numpy(x, tuple_of_axes): + # very naive logsumexp to compare to + minused = x.max(tuple_of_axes) + y = x - x.max(tuple_of_axes, keepdims=True) + y = numpy.exp(y) + y = numpy.sum(y, axis=tuple_of_axes) + return numpy.log(y) + minused + + from jittor.einops._backends import JittorBackend, NumpyBackend + backend2callback = { + JittorBackend.framework_name: logsumexp_jittor, + NumpyBackend.framework_name: logsumexp_numpy, + } + + for backend in imp_op_backends: + if backend.framework_name not in backend2callback: + continue + + backend_callback = backend2callback[backend.framework_name] + + x_backend = backend.from_numpy(x_numpy) + for pattern1, pattern2 in equivalent_reduction_patterns: + print('Test reduction with callable for ', backend.framework_name, pattern1, pattern2) + output_numpy = reduce(x_numpy, pattern1, reduction=logsumexp_numpy) + output_backend = reduce(x_backend, pattern1, reduction=backend_callback) + assert numpy.allclose( + output_numpy, + backend.to_numpy(output_backend), + ) + + def test_reduction_stress_imperatives(self): + for backend in imp_op_backends: + print('Stress-testing reduction for ', backend.framework_name) + for reduction in _reductions + ('rearrange',): + dtype = 'int64' + coincide = numpy.array_equal + if reduction in ['mean', 'prod']: + dtype = 'float64' + coincide = numpy.allclose + for n_axes in range(11): + shape = numpy.random.randint(2, 4, size=n_axes) + permutation = numpy.random.permutation(n_axes) + skipped = 0 if reduction == 'rearrange' else numpy.random.randint(n_axes + 1) + left = ' '.join('x' + str(i) for i in range(n_axes)) + right = ' '.join('x' + str(i) for i in permutation[skipped:]) + pattern = left + '->' + right + x = numpy.arange(1, 1 + numpy.prod(shape), dtype=dtype).reshape(shape) + if reduction == 'prod': + x /= x.mean() # to avoid overflows + result1 = reduce(x, pattern, reduction=reduction) + result2 = x.transpose(permutation) + if skipped > 0: + result2 = getattr(result2, reduction)(axis=tuple(range(skipped))) + assert coincide(result1, result2) + check_op_against_numpy(backend, x, pattern, reduction=reduction, axes_lengths={}, is_symbolic=False) + + def test_reduction_imperatives(self): + for backend in imp_op_backends: + print('Reduction tests for ', backend.framework_name) + for reduction in _reductions: + # slight redundancy for simpler order - numpy version is evaluated multiple times + input = numpy.arange(2 * 3 * 4 * 5 * 6, dtype='int64').reshape([2, 3, 4, 5, 6]) + if reduction in ['mean', 'prod']: + input = input / input.astype('float64').mean() + test_cases = [ + ['a b c d e -> ', {}, + getattr(input, reduction)()], + ['a ... -> ', {}, + getattr(input, reduction)()], + ['(a1 a2) ... (e1 e2) -> ', dict(a1=1, e2=2), + getattr(input, reduction)()], + ['a b c d e -> (e c) a', {}, + getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2])], + ['a ... c d e -> (e c) a', {}, + getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2])], + ['a b c d e ... -> (e c) a', {}, + getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2])], + ['a b c d e -> (e c a)', {}, + getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1])], + ['(a a2) ... -> (a2 a) ...', dict(a2=1), + input], + ] + for pattern, axes_lengths, expected_result in test_cases: + result = reduce(backend.from_numpy(input.copy()), pattern, reduction=reduction, **axes_lengths) + result = backend.to_numpy(result) + assert numpy.allclose(result, expected_result) + + def test_rearrange_permutations_numpy(self): + # tests random permutation of axes against two independent numpy ways + for n_axes in range(1, 10): + input = numpy.arange(2 ** n_axes).reshape([2] * n_axes) + permutation = numpy.random.permutation(n_axes) + left_expression = ' '.join('i' + str(axis) for axis in range(n_axes)) + right_expression = ' '.join('i' + str(axis) for axis in permutation) + expression = left_expression + ' -> ' + right_expression + result = rearrange(input, expression) + + for pick in numpy.random.randint(0, 2, [10, n_axes]): + assert input[tuple(pick)] == result[tuple(pick[permutation])] + + for n_axes in range(1, 10): + input = numpy.arange(2 ** n_axes).reshape([2] * n_axes) + permutation = numpy.random.permutation(n_axes) + left_expression = ' '.join('i' + str(axis) for axis in range(n_axes)[::-1]) + right_expression = ' '.join('i' + str(axis) for axis in permutation[::-1]) + expression = left_expression + ' -> ' + right_expression + result = rearrange(input, expression) + assert result.shape == input.shape + expected_result = numpy.zeros_like(input) + for original_axis, result_axis in enumerate(permutation): + expected_result |= ((input >> original_axis) & 1) << result_axis + + assert numpy.array_equal(result, expected_result) + + def test_rearrange_consistency_numpy(self): + shape = [1, 2, 3, 5, 7, 11] + x = numpy.arange(numpy.prod(shape)).reshape(shape) + for pattern in [ + 'a b c d e f -> a b c d e f', + 'b a c d e f -> a b d e f c', + 'a b c d e f -> f e d c b a', + 'a b c d e f -> (f e) d (c b a)', + 'a b c d e f -> (f e d c b a)', + ]: + result = rearrange(x, pattern) + assert len(numpy.setdiff1d(x, result)) == 0 + assert result.dtype == x.dtype + + result = rearrange(x, 'a b c d e f -> a (b) (c d e) f') + assert numpy.array_equal(x.flatten(), result.flatten()) + + result = rearrange(x, 'a aa aa1 a1a1 aaaa a11 -> a aa aa1 a1a1 aaaa a11') + assert numpy.array_equal(x, result) + + result1 = rearrange(x, 'a b c d e f -> f e d c b a') + result2 = rearrange(x, 'f e d c b a -> a b c d e f') + assert numpy.array_equal(result1, result2) + + result = rearrange(rearrange(x, 'a b c d e f -> (f d) c (e b) a'), '(f d) c (e b) a -> a b c d e f', b=2, d=5) + assert numpy.array_equal(x, result) + + sizes = dict(zip('abcdef', shape)) + temp = rearrange(x, 'a b c d e f -> (f d) c (e b) a', **sizes) + result = rearrange(temp, '(f d) c (e b) a -> a b c d e f', **sizes) + assert numpy.array_equal(x, result) + + x2 = numpy.arange(2 * 3 * 4).reshape([2, 3, 4]) + result = rearrange(x2, 'a b c -> b c a') + assert x2[1, 2, 3] == result[2, 3, 1] + assert x2[0, 1, 2] == result[1, 2, 0] + + def test_ellipsis_ops_imperative(self): + """ Checking various patterns against numpy """ + x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]) + for is_symbolic in [True, False]: + for backend in collect_test_backends(symbolic=is_symbolic, layers=False): + for pattern in identity_patterns + list(itertools.chain(*equivalent_rearrange_patterns)): + check_op_against_numpy(backend, x, pattern, axes_lengths={}, + reduction='rearrange', is_symbolic=is_symbolic) + + for reduction in ['min', 'max', 'sum']: + for pattern in itertools.chain(*equivalent_reduction_patterns): + check_op_against_numpy(backend, x, pattern, axes_lengths={}, + reduction=reduction, is_symbolic=is_symbolic) + + def test_ellipsis_ops_numpy(self): + x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]) + for pattern in identity_patterns: + assert numpy.array_equal(x, rearrange(x, pattern)), pattern + + for pattern1, pattern2 in equivalent_rearrange_patterns: + assert numpy.array_equal(rearrange(x, pattern1), rearrange(x, pattern2)) + + for reduction in ['min', 'max', 'sum']: + for pattern1, pattern2 in equivalent_reduction_patterns: + assert numpy.array_equal(reduce(x, pattern1, reduction=reduction), + reduce(x, pattern2, reduction=reduction)) + + # now just check coincidence with numpy + all_rearrange_patterns = [*identity_patterns] + for pattern_pairs in equivalent_rearrange_patterns: + all_rearrange_patterns.extend(pattern_pairs) + + def test_collapsed_ellipsis_errors_out(self): + x = numpy.zeros([1, 1, 1, 1, 1]) + rearrange(x, 'a b c d ... -> a b c ... d') + error = 0 + try: + rearrange(x, 'a b c d (...) -> a b c ... d') + except Exception as e: + error = 1 + assert error == 1 + + rearrange(x, '... -> (...)') + error = 0 + try: + rearrange(x, '(...) -> (...)') + except Exception as e: + error = 1 + assert error == 1 + + def test_rearrange_imperative(self): + for backend in collect_test_backends(symbolic=False, layers=True): + print('Test layer for ', backend.framework_name) + + for pattern, axes_lengths, input_shape, wrong_shapes in rearrangement_patterns: + x = numpy.arange(numpy.prod(input_shape), dtype='float32').reshape(input_shape) + result_numpy = rearrange(x, pattern, **axes_lengths) + layer = backend.layers().Rearrange(pattern, **axes_lengths) + for shape in wrong_shapes: + try: + layer(backend.from_numpy(numpy.zeros(shape, dtype='float32'))) + except: + pass + else: + raise AssertionError('Failure expected') + + # simple pickling / unpickling + layer2 = pickle.loads(pickle.dumps(layer)) + result1 = backend.to_numpy(layer(backend.from_numpy(x))) + result2 = backend.to_numpy(layer2(backend.from_numpy(x))) + assert numpy.allclose(result_numpy, result1) + assert numpy.allclose(result1, result2) + + just_sum = backend.layers().Reduce('...->', reduction='sum') + + + variable = backend.from_numpy(x) + result = just_sum(layer(variable)) + + if 'jittor' in backend.framework_name: + grad = backend.jittor.grad(result, variable) + else: + result.backward() + grad = variable.grad + + assert numpy.allclose(backend.to_numpy(grad), 1) + + def test_reduce_imperative(self): + for backend in collect_test_backends(symbolic=False, layers=True): + print('Test layer for ', backend.framework_name) + for reduction in _reductions: + for pattern, axes_lengths, input_shape, wrong_shapes in reduction_patterns: + print(backend, reduction, pattern, axes_lengths, input_shape, wrong_shapes) + x = numpy.arange(1, 1 + numpy.prod(input_shape), dtype='float32').reshape(input_shape) + x /= x.mean() + result_numpy = reduce(x, pattern, reduction, **axes_lengths) + layer = backend.layers().Reduce(pattern, reduction, **axes_lengths) + for shape in wrong_shapes: + try: + layer(backend.from_numpy(numpy.zeros(shape, dtype='float32'))) + except: + pass + else: + raise AssertionError('Failure expected') + + # simple pickling / unpickling + layer2 = pickle.loads(pickle.dumps(layer)) + result1 = backend.to_numpy(layer(backend.from_numpy(x))) + result2 = backend.to_numpy(layer2(backend.from_numpy(x))) + assert numpy.allclose(result_numpy, result1) + assert numpy.allclose(result1, result2) + + just_sum = backend.layers().Reduce('...->', reduction='sum') + + + variable = backend.from_numpy(x) + result = just_sum(layer(variable)) + + if 'jittor' in backend.framework_name: + grad = backend.jittor.grad(result, variable) + grad = backend.to_numpy(grad) + else: + result.backward() + grad = backend.to_numpy(variable.grad) + if reduction == 'sum': + assert numpy.allclose(grad, 1) + if reduction == 'mean': + assert numpy.allclose(grad, grad.min()) + if reduction in ['max', 'min']: + assert numpy.all(numpy.in1d(grad, [0, 1])) + assert numpy.sum(grad) > 0.5 + + def test_jittor_layer(self): + has_jittor = any(backend.framework_name == 'jittor' for backend in collect_test_backends(symbolic=False, layers=True)) + if has_jittor: + # checked that jittor present + import jittor + + rtol = 1e-05 + atol = 1e-08 + def allclose(input, other): return jittor.all(jittor.abs(input-other) <= atol+rtol*jittor.abs(other)) + model1 = create_jittor_model(use_reduce=True) + model2 = create_jittor_model(use_reduce=False) + input = jittor.randn([10, 3, 32, 32]) + # random models have different predictions + assert not allclose(model1(input), model2(input)) + model2.load_state_dict(pickle.loads(pickle.dumps(model1.state_dict()))) + assert allclose(model1(input), model2(input)) + + +testcase = namedtuple('testcase', ['pattern', 'axes_lengths', 'input_shape', 'wrong_shapes']) + +rearrangement_patterns = [ + testcase('b c h w -> b (c h w)', dict(c=20), (10, 20, 30, 40), + [(), (10,), (10, 10, 10), (10, 21, 30, 40), [1, 20, 1, 1, 1]]), + testcase('b c (h1 h2) (w1 w2) -> b (c h2 w2) h1 w1', dict(h2=2, w2=2), (10, 20, 30, 40), + [(), (1, 1, 1, 1), (1, 10, 3), ()]), + testcase('b ... c -> c b ...', dict(b=10), (10, 20, 30), + [(), (10,), (5, 10)]), +] + +reduction_patterns = rearrangement_patterns + [ + testcase('b c h w -> b ()', dict(b=10), (10, 20, 30, 40), + [(10,), (10, 20, 30)]), + testcase('b c (h1 h2) (w1 w2) -> b c h1 w1', dict(h1=15, h2=2, w2=2), (10, 20, 30, 40), + [(10, 20, 31, 40)]), + testcase('b ... c -> b', dict(b=10), (10, 20, 30, 40), + [(10,), (11, 10)]), +] + +equivalent_reduction_patterns = [ + ('a b c d e -> ', ' ... -> '), + ('a b c d e -> (e a)', 'a ... e -> (e a)'), + ('a b c d e -> d (a e)', ' a b c d e ... -> d (a e) '), + ('a b c d e -> (a b)', ' ... c d e -> (...) '), +] + +equivalent_rearrange_patterns = [ + ('a b c d e -> (a b) c d e', 'a b ... -> (a b) ... '), + ('a b c d e -> a b (c d) e', '... c d e -> ... (c d) e'), + ('a b c d e -> a b c d e', '... -> ... '), + ('a b c d e -> (a b c d e)', '... -> (...)'), + ('a b c d e -> b (c d e) a', 'a b ... -> b (...) a'), + ('a b c d e -> b (a c d) e', 'a b ... e -> b (a ...) e'), +] + +identity_patterns = [ + '...->...', + 'a b c d e-> a b c d e', + 'a b c d e ...-> ... a b c d e', + 'a b c d e ...-> a ... b c d e', + '... a b c d e -> ... a b c d e', + 'a ... e-> a ... e', + 'a ... -> a ... ', + 'a ... c d e -> a (...) c d e', +] + +test_cases_repeat_anonymous = [ + # all assume that input has shape [1, 2, 4, 6] + ('a b c d -> c a d b', dict()), + ('a b c d -> (c 2 d a b)', dict(a=1, c=4, d=6)), + ('1 b c d -> (d copy 1) 3 b c ', dict(copy=3)), + ('1 ... -> 3 ... ', dict()), + ('() ... d -> 1 (copy1 d copy2) ... ', dict(copy1=2, copy2=3)), + ('1 b c d -> (1 1) (1 b) 2 c 3 d (1 1)', dict()), + +] + +repeat_test_cases = [ + # all assume that input has shape [2, 3, 5] + ('a b c -> c a b', dict()), + ('a b c -> (c copy a b)', dict(copy=2, a=2, b=3, c=5)), + ('a b c -> (a copy) b c ', dict(copy=1)), + ('a b c -> (c a) (copy1 b copy2)', dict(a=2, copy1=1, copy2=2)), + ('a ... -> a ... copy', dict(copy=4)), + ('... c -> ... (copy1 c copy2)', dict(copy1=1, copy2=2)), + ('... -> ... ', dict()), + (' ... -> copy1 ... copy2 ', dict(copy1=2, copy2=3)), + ('a b c -> copy1 a copy2 b c () ', dict(copy1=2, copy2=1)), +] + + +def check_reversion(x, repeat_pattern, **sizes): + """Checks repeat pattern by running reduction """ + left, right = repeat_pattern.split('->') + reduce_pattern = right + '->' + left + repeated = reduce(x, repeat_pattern, reduction='repeat', **sizes) + reduced_min = reduce(repeated, reduce_pattern, reduction='min', **sizes) + reduced_max = reduce(repeated, reduce_pattern, reduction='max', **sizes) + assert numpy.array_equal(x, reduced_min) + assert numpy.array_equal(x, reduced_max) + + +def check_op_against_numpy(backend, numpy_input, pattern, axes_lengths, reduction='rearrange', is_symbolic=False): + """ + Helper to test result of operation (rearrange or transpose) against numpy + if reduction == 'rearrange', rearrange op is tested, otherwise reduce + """ + if len(numpy_input.shape) == 0: + return + + def operation(x): + if reduction == 'rearrange': + return rearrange(x, pattern, **axes_lengths) + else: + return reduce(x, pattern, reduction, **axes_lengths) + + numpy_result = operation(numpy_input) + check_equal = numpy.array_equal + p_none_dimension = 0.5 + if 'jittor' in backend.framework_name: + check_equal = numpy.allclose + p_none_dimension = 0 + if is_symbolic: + symbol_shape = [d if numpy.random.random() >= p_none_dimension else None for d in numpy_input.shape] + symbol = backend.create_symbol(shape=symbol_shape) + result_symbol = operation(symbol) + backend_result = backend.eval_symbol(result_symbol, [(symbol, numpy_input)]) + else: + backend_result = operation(backend.from_numpy(numpy_input)) + backend_result = backend.to_numpy(backend_result) + + check_equal(numpy_result, backend_result) + + +def create_jittor_model(use_reduce=False): + from jittor.nn import Sequential, Conv2d, MaxPool2d, Linear, ReLU + from jittor.einops.layers.jittor import Rearrange, Reduce, EinMix + return Sequential( + Conv2d(3, 6, kernel_size=(5, 5)), + Reduce('b c (h h2) (w w2) -> b c h w', 'max', h2=2, w2=2) if use_reduce else MaxPool2d(kernel_size=2), + Conv2d(6, 16, kernel_size=(5, 5)), + Reduce('b c (h h2) (w w2) -> b c h w', 'max', h2=2, w2=2), + Rearrange('b c h w -> b (c h w)'), + Linear(16 * 5 * 5, 120), + ReLU(), + Linear(120, 84), + ReLU(), + EinMix('b c1 -> (b c2)', weight_shape='c1 c2', bias_shape='c2', c1=84, c2=84), + EinMix('(b c2) -> b c3', weight_shape='c2 c3', bias_shape='c3', c2=84, c3=84), + Linear(84, 10), + ) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_einsum.py b/python/jittor/test/test_einsum.py new file mode 100644 index 00000000..f678d435 --- /dev/null +++ b/python/jittor/test/test_einsum.py @@ -0,0 +1,104 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Haoyang Peng <2247838039@qq.com> +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +import numpy as np +import unittest + +try: + import torch + from torch.autograd import Variable + import autograd.numpy as anp + from autograd import jacobian + + has_autograd = True +except: + has_autograd = False + +cupy = None +try: + import cupy +except: + pass + +@unittest.skipIf(not has_autograd, "No autograd found.") +class TestEinsum(unittest.TestCase): + def test_einsum_ijjk(self): + for i in range(30): + string = "ij,jk->ik" + tn, tm = np.random.randn(3, 3).astype('float32'), np.random.randn(3, 3).astype('float32') + x = jt.array(tn) + y = jt.array(tm) + t_x = torch.from_numpy(tn) + t_y = torch.from_numpy(tm) + t_x = Variable(t_x, requires_grad=True) + t_y = Variable(t_y, requires_grad=True) + jq = jt.linalg.einsum(string, x, y) + tq = torch.einsum(string, t_x, t_y) + np.testing.assert_allclose(jq.data, tq.detach().numpy(), rtol=1e-4, atol=1e-6) + gq = jt.grad(jq, x).data + gr = jt.grad(jq, y).data + tgq = torch.autograd.grad(tq, t_x, torch.ones_like(tq), retain_graph=True) + tgr = torch.autograd.grad(tq, t_y, torch.ones_like(tq)) + np.testing.assert_allclose(gq, tgq[0].numpy(), rtol=1e-4, atol=1e-6) + np.testing.assert_allclose(gr, tgr[0].numpy(), rtol=1e-4, atol=1e-6) + + def test_einsum_ii(self): + for i in range(30): + string = "ij->i" + tn, tm = np.random.randn(3, 3).astype('float32'), np.random.randn(3, 3).astype('float32') + x = jt.array(tn) + # x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"]) + t_x = torch.from_numpy(tn) + t_x = Variable(t_x, requires_grad=True) + jq = jt.linalg.einsum(string, x) + tq = torch.einsum(string, t_x) + np.testing.assert_allclose(jq.data, tq.detach().numpy(), rtol=1e-4, atol=1e-6) + gq = jt.grad(jq, x).data + tgq = torch.autograd.grad(tq, t_x, torch.ones_like(tq)) + np.testing.assert_allclose(gq, tgq[0].numpy(), rtol=1e-4, atol=1e-6) + + def test_einsum_multi(self): + for i in range(30): + string = "ij,ijk,jk->ik" + tn, tm, tk = np.random.randn(3, 4).astype('float32'), np.random.randn(3, 4, 5).astype('float32'), np.random.randn(4, 5).astype('float32') + x = jt.array(tn) + y = jt.array(tm) + z = jt.array(tk) + # x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"]) + t_x = torch.from_numpy(tn) + t_y = torch.from_numpy(tm) + t_z = torch.from_numpy(tk) + t_x = Variable(t_x, requires_grad=True) + t_y = Variable(t_y, requires_grad=True) + t_z = Variable(t_z, requires_grad=True) + jq = jt.linalg.einsum(string, x, y, z) + tq = torch.einsum(string, t_x, t_y, t_z) + np.testing.assert_allclose(jq.data, tq.detach().numpy(), rtol=1e-4, atol=1e-6) + gq = jt.grad(jq, x).data + gr = jt.grad(jq, y).data + gz = jt.grad(jq, z).data + tgq = torch.autograd.grad(tq, t_x, torch.ones_like(tq), retain_graph=True) + tgr = torch.autograd.grad(tq, t_y, torch.ones_like(tq), retain_graph=True) + tgz = torch.autograd.grad(tq, t_z, torch.ones_like(tq), retain_graph=True) + np.testing.assert_allclose(gq, tgq[0].numpy(), rtol=1e-4, atol=1e-6) + np.testing.assert_allclose(gr, tgr[0].numpy(), rtol=1e-4, atol=1e-6) + np.testing.assert_allclose(gz, tgz[0].numpy(), rtol=1e-4, atol=1e-6) + + +@unittest.skipIf(not jt.compiler.has_cuda or cupy is None, "No CUDA found") +class TestCudaEinsum(TestEinsum): + def setUp(self): + jt.flags.use_cuda = 1 + def tearDown(self): + jt.flags.use_cuda = 0 + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_emnist.py b/python/jittor/test/test_emnist.py new file mode 100644 index 00000000..f2b18eef --- /dev/null +++ b/python/jittor/test/test_emnist.py @@ -0,0 +1,31 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +from jittor.dataset.mnist import EMNIST, MNIST +import jittor.transform as transform + +@unittest.skipIf(True, f"skip emnist test") +class TestEMNIST(unittest.TestCase): + def test_emnist(self): + import pylab as pl + # emnist_dataset = EMNIST() + emnist_dataset = EMNIST() + for imgs, labels in emnist_dataset: + print(imgs.shape, labels.shape) + print(labels.max(), labels.min()) + # imgs = imgs.transpose(0,1,3,2).transpose(1,2,0,3)[0].reshape(28, -1) + imgs = imgs.transpose(1,2,0,3)[0].reshape(28, -1) + print(labels) + pl.imshow(imgs), pl.show() + break + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_error_msg.py b/python/jittor/test/test_error_msg.py new file mode 100644 index 00000000..3b10a103 --- /dev/null +++ b/python/jittor/test/test_error_msg.py @@ -0,0 +1,71 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np + +class TestErrorMsg(unittest.TestCase): + + def test_error_msg(self): + a = jt.array([3,2,1]) + b = jt.code(a.shape, a.dtype, [a], + cpu_header=""" + #include + @alias(a, in0) + @alias(b, out) + """, + cpu_src=""" + for (int i=0; i + @alias(a, in0) + @alias(b, out) + """, + cpu_src=""" + for (int i=0; i. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +from jittor import init, Module +import numpy as np + +def matmul(a, b): + (n, m), k = a.shape, b.shape[-1] + a = a.broadcast([n,m,k], dims=[2]) + b = b.broadcast([n,m,k], dims=[0]) + return (a*b).sum(dim=1) + +class Linear(Module): + def __init__(self, in_features, out_features, bias=True): + self.w = (jt.random((in_features, out_features))-0.5) / in_features**0.5 + self.b = jt.random((out_features,))-0.5 if bias else None + def execute(self, x): + x = matmul(x, self.w) + if self.b is not None: + return x+self.b + return x + +def relu(x): + return jt.maximum(x, 0.0) +Relu = jt.make_module(relu) + +class Model(Module): + def __init__(self, input_size): + self.linear1 = Linear(input_size, 10) + self.relu1 = Relu() + self.linear2 = Linear(10, 1) + def execute(self, x): + x = self.linear1(x) + x = self.relu1(x) + return self.linear2(x) + +class TestExample(unittest.TestCase): + def test1(self): + np.random.seed(0) + jt.set_seed(3) + n = 1000 + batch_size = 50 + lr = 0.05 + + def get_data(n): + for i in range(n): + x = np.random.rand(batch_size, 1) + y = x*x + yield jt.float32(x), jt.float32(y) + + model = Model(input_size=1) + ps = model.parameters() + for p in reversed(ps): p.sync(0,0) + + for i,(x,y) in enumerate(get_data(n)): + pred_y = model(x).name("pred_y") + loss = ((pred_y - y).sqr()).name("loss") + loss_mean = loss.mean() + + gs = jt.grad(loss_mean, ps) + for p, g in zip(ps, gs): + p -= g * lr + + if i>2: + assert prev == jt.liveness_info(), f"memory leak {prev} {jt.liveness_info()}" + prev = jt.liveness_info() + print(f"step {i}, loss = {loss_mean.data.sum()} {jt.liveness_info()}") + + possible_results = [ + 0.0009948202641680837, + 0.001381353591568768, + 0.00110957445576787, + ] + loss_mean = loss_mean.data + assert any(abs(loss_mean - r) < 1e-6 for r in possible_results) + + jt.clean() + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_example_accumulate_grad.py b/python/jittor/test/test_example_accumulate_grad.py new file mode 100644 index 00000000..e4685382 --- /dev/null +++ b/python/jittor/test/test_example_accumulate_grad.py @@ -0,0 +1,91 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +from jittor import init, Module +import numpy as np +from jittor.optim import Optimizer +f32 = jt.float32 + +def matmul(a, b): + (n, m), k = a.shape, b.shape[-1] + a = a.broadcast([n,m,k], dims=[2]) + b = b.broadcast([n,m,k], dims=[0]) + return (a*b).sum(dim=1) + +class Linear(Module): + def __init__(self, in_features, out_features, bias=True): + self.w = (jt.random((in_features, out_features))-0.5) / in_features**0.5 + self.b = jt.random((out_features,))-0.5 if bias else None + def execute(self, x): + x = matmul(x, self.w) + if self.b is not None: + return x+self.b + return x + +def relu(x): + return jt.maximum(x, 0.0) +Relu = jt.make_module(relu) + +class Model(Module): + def __init__(self, input_size): + self.linear1 = Linear(input_size, 10) + self.relu1 = Relu() + self.linear2 = Linear(10, 1) + def execute(self, x): + x = self.linear1(x) + x = self.relu1(x) + return self.linear2(x) + +class TestExample(unittest.TestCase): + def test1(self): + np.random.seed(0) + jt.set_seed(3) + n = 1000 + batch_size = 50 + base_lr = 0.05 + # tune accumulation_steps for step and batch_size + accumulation_steps = 10 + n *= accumulation_steps + batch_size //= accumulation_steps + # we need to stop grad of global value to prevent memory leak + lr = f32(base_lr).name("lr").stop_grad() + + def get_data(n): + for i in range(n): + x = np.random.rand(batch_size, 1) + y = x*x + yield jt.float32(x), jt.float32(y) + + model = Model(input_size=1) + ps = model.parameters() + for p in reversed(ps): p.sync(0,0) + opt = Optimizer(ps, lr) + all_loss = 0 + + for i,(x,y) in enumerate(get_data(n)): + pred_y = model(x).name("pred_y") + loss = ((pred_y - y)**f32(2)).name("loss") + loss_mean = loss.mean() / accumulation_steps + all_loss += loss_mean.item() + + opt.backward(loss_mean) + if (i+1) % accumulation_steps == 0: + opt.step() + + if i>50: + assert prev == jt.liveness_info(), f"memory leak {prev} {jt.liveness_info()}" + prev = jt.liveness_info() + print(f"step {i}, loss = {loss_mean.data.sum()} {jt.liveness_info()}") + + print(all_loss) + possible_results = [19.8639366890402, 8.207454475712439] + assert any(abs(all_loss - r) < 1e-3 for r in possible_results) + jt.clean() + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_fetcher.py b/python/jittor/test/test_fetcher.py new file mode 100644 index 00000000..0059d2ef --- /dev/null +++ b/python/jittor/test/test_fetcher.py @@ -0,0 +1,35 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from jittor import compile_extern + +class TestFetcher(unittest.TestCase): + def test_fetch(self): + a = jt.array([1,2,3]) + a = a*2 + v = [] + jt.fetch(a, lambda a: v.append(a)) + jt.fetch(1, 2, 3, a, + lambda x, y, z, a: self.assertTrue(x==1 and y==2 and z==3 and isinstance(a, np.ndarray)) + ) + jt.sync_all(True) + assert len(v)==1 and (v[0]==[2,4,6]).all() + +@unittest.skipIf(not jt.has_cuda, "Cuda not found") +class TestFetcherCuda(TestFetcher): + @classmethod + def setUpClass(self): + jt.flags.use_cuda = 1 + + @classmethod + def tearDownClass(self): + jt.flags.use_cuda = 0 + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_fft_op.py b/python/jittor/test/test_fft_op.py new file mode 100644 index 00000000..0e4fb5fa --- /dev/null +++ b/python/jittor/test/test_fft_op.py @@ -0,0 +1,262 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +import unittest +from .test_log import find_log_with_re +import torch # torch >= 1.9.0 needed +import numpy as np +from jittor import nn + +#requires torch>=1.10.1 +class TestFFTOp(unittest.TestCase): + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1) + def test_fft_forward(self): + img = np.random.rand(256, 300) + img2 = np.random.rand(256, 300) + X = np.stack([img, img2], 0) + + # torch + x = torch.Tensor(X) + y = torch.fft.fft2(x) + y_torch_real = y.numpy().real + y_torch_imag = y.numpy().imag + + #jittor + x = jt.array(X,dtype=jt.float32) + x = jt.stack([x, jt.zeros_like(x)], 3) + y = nn._fft2(x) + y_jt_real = y[:, :, :, 0].data + y_jt_imag = y[:, :, :, 1].data + assert(np.allclose(y_torch_real, y_jt_real, atol=1)) + assert(np.allclose(y_torch_imag, y_jt_imag, atol=1)) + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1) + def test_ifft_forward(self): + img = np.random.rand(256, 300) + img2 = np.random.rand(256, 300) + X = np.stack([img, img2], 0) + + # torch + x = torch.Tensor(X) + y = torch.fft.fft2(x) + y_torch_real = y.numpy().real + y_torch_imag = y.numpy().imag + y_ori = torch.fft.ifft2(y) + y_ori_torch_real = y_ori.real.numpy() + assert(np.allclose(y_ori_torch_real, X, atol=1)) + + #jittor + x = jt.array(X,dtype=jt.float32) + x = jt.stack([x, jt.zeros_like(x)], 3) + y = nn._fft2(x) + y_ori = nn._fft2(y, True) + y_jt_real = y[:, :, :, 0].data + y_jt_imag = y[:, :, :, 1].data + y_ori_jt_real = y_ori[:, :, :, 0].data + assert(np.allclose(y_torch_real, y_jt_real, atol=1)) + assert(np.allclose(y_torch_imag, y_jt_imag, atol=1)) + assert(np.allclose(y_ori_jt_real, X, atol=1)) + assert(np.allclose(y_ori_jt_real, y_ori_torch_real, atol=1)) + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1) + def test_fft_backward(self): + img = np.random.rand(256, 300) + img2 = np.random.rand(256, 300) + X = np.stack([img, img2], 0) + T1 = np.random.rand(1,256,300) + T2 = np.random.rand(1,256,300) + + # torch + x = torch.Tensor(X) + x.requires_grad = True + t1 = torch.Tensor(T1) + t2 = torch.Tensor(T2) + y_mid = torch.fft.fft2(x) + y = torch.fft.fft2(y_mid) + real = y.real + imag = y.imag + loss = (real * t1).sum() + (imag * t2).sum() + loss.backward() + grad_x_torch = x.grad.detach().numpy() + + #jittor + x = jt.array(X,dtype=jt.float32) + t1 = jt.array(T1,dtype=jt.float32) + t2 = jt.array(T2,dtype=jt.float32) + x = jt.stack([x, jt.zeros_like(x)], 3) + y_mid = nn._fft2(x) + y = nn._fft2(y_mid) + real = y[:, :, :, 0] + imag = y[:, :, :, 1] + loss = (real * t1).sum() + (imag * t2).sum() + grad_x_jt = jt.grad(loss, x).data[:, :, :, 0] + assert(np.allclose(grad_x_jt, grad_x_torch, atol=1)) + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1) + def test_ifft_backward(self): + img = np.random.rand(256, 300) + img2 = np.random.rand(256, 300) + X = np.stack([img, img2], 0) + T1 = np.random.rand(1,256,300) + T2 = np.random.rand(1,256,300) + + # torch + x = torch.Tensor(X) + x.requires_grad = True + t1 = torch.Tensor(T1) + t2 = torch.Tensor(T2) + y_mid = torch.fft.ifft2(x) + y = torch.fft.ifft2(y_mid) + real = y.real + imag = y.imag + loss = (real * t1).sum() + (imag * t2).sum() + loss.backward() + grad_x_torch = x.grad.detach().numpy() + + #jittor + x = jt.array(X,dtype=jt.float32) + t1 = jt.array(T1,dtype=jt.float32) + t2 = jt.array(T2,dtype=jt.float32) + x = jt.stack([x, jt.zeros_like(x)], 3) + y_mid = nn._fft2(x, True) + y = nn._fft2(y_mid, True) + real = y[:, :, :, 0] + imag = y[:, :, :, 1] + loss = (real * t1).sum() + (imag * t2).sum() + grad_x_jt = jt.grad(loss, x).data[:, :, :, 0] + assert(np.allclose(grad_x_jt, grad_x_torch)) + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1) + def test_fft_float64_forward(self): + img = np.random.rand(256, 300) + img2 = np.random.rand(256, 300) + X = np.stack([img, img2], 0) + + # torch + x = torch.DoubleTensor(X) + y = torch.fft.fft2(x) + y_torch_real = y.numpy().real + y_torch_imag = y.numpy().imag + + #jittor + x = jt.array(X).float64() + x = jt.stack([x, jt.zeros_like(x)], 3) + y = nn._fft2(x) + y_jt_real = y[:, :, :, 0].data + y_jt_imag = y[:, :, :, 1].data + assert(np.allclose(y_torch_real, y_jt_real, atol=1)) + assert(np.allclose(y_torch_imag, y_jt_imag, atol=1)) + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1) + def test_ifft_float64_forward(self): + img = np.random.rand(256, 300) + img2 = np.random.rand(256, 300) + X = np.stack([img, img2], 0) + + # torch + x = torch.DoubleTensor(X) + y = torch.fft.fft2(x) + y_torch_real = y.numpy().real + y_torch_imag = y.numpy().imag + y_ori = torch.fft.ifft2(y) + y_ori_torch_real = y_ori.real.numpy() + assert(np.allclose(y_ori_torch_real, X, atol=1)) + + #jittor + x = jt.array(X).float64() + x = jt.stack([x, jt.zeros_like(x)], 3) + y = nn._fft2(x) + y_ori = nn._fft2(y, True) + y_jt_real = y[:, :, :, 0].data + y_jt_imag = y[:, :, :, 1].data + y_ori_jt_real = y_ori[:, :, :, 0].data + assert(np.allclose(y_torch_real, y_jt_real, atol=1)) + assert(np.allclose(y_torch_imag, y_jt_imag, atol=1)) + assert(np.allclose(y_ori_jt_real, X, atol=1)) + assert(np.allclose(y_ori_jt_real, y_ori_torch_real, atol=1)) + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1) + def test_fft_float64_backward(self): + img = np.random.rand(256, 300) + img2 = np.random.rand(256, 300) + X = np.stack([img, img2], 0) + T1 = np.random.rand(1,256,300) + T2 = np.random.rand(1,256,300) + + # torch + x = torch.DoubleTensor(X) + x.requires_grad = True + t1 = torch.DoubleTensor(T1) + t2 = torch.DoubleTensor(T2) + y_mid = torch.fft.fft2(x) + y = torch.fft.fft2(y_mid) + real = y.real + imag = y.imag + loss = (real * t1).sum() + (imag * t2).sum() + loss.backward() + grad_x_torch = x.grad.detach().numpy() + + #jittor + x = jt.array(X).float64() + t1 = jt.array(T1).float64() + t2 = jt.array(T2).float64() + x = jt.stack([x, jt.zeros_like(x)], 3) + y_mid = nn._fft2(x) + y = nn._fft2(y_mid) + real = y[:, :, :, 0] + imag = y[:, :, :, 1] + loss = (real * t1).sum() + (imag * t2).sum() + grad_x_jt = jt.grad(loss, x).data[:, :, :, 0] + assert(np.allclose(grad_x_jt, grad_x_torch, atol=1)) + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1) + def test_ifft_float64_backward(self): + img = np.random.rand(256, 300) + img2 = np.random.rand(256, 300) + X = np.stack([img, img2], 0) + T1 = np.random.rand(1,256,300) + T2 = np.random.rand(1,256,300) + + # torch + x = torch.DoubleTensor(X) + x.requires_grad = True + t1 = torch.DoubleTensor(T1) + t2 = torch.DoubleTensor(T2) + y_mid = torch.fft.ifft2(x) + y = torch.fft.ifft2(y_mid) + real = y.real + imag = y.imag + loss = (real * t1).sum() + (imag * t2).sum() + loss.backward() + grad_x_torch = x.grad.detach().numpy() + + #jittor + x = jt.array(X).float64() + t1 = jt.array(T1).float64() + t2 = jt.array(T2).float64() + x = jt.stack([x, jt.zeros_like(x)], 3) + y_mid = nn._fft2(x, True) + y = nn._fft2(y_mid, True) + real = y[:, :, :, 0] + imag = y[:, :, :, 1] + loss = (real * t1).sum() + (imag * t2).sum() + grad_x_jt = jt.grad(loss, x).data[:, :, :, 0] + assert(np.allclose(grad_x_jt, grad_x_torch)) + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_flags.py b/python/jittor/test/test_flags.py new file mode 100644 index 00000000..eff3cf95 --- /dev/null +++ b/python/jittor/test/test_flags.py @@ -0,0 +1,31 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +from .test_core import expect_error + +class TestFlags(unittest.TestCase): + def test_error(self): + def check(): jt.flags.asdasd=1 + expect_error(check) + + def test_get_set(self): + prev = jt.flags.log_v + jt.flags.log_v=1 + assert jt.flags.log_v == 1 + jt.flags.log_v=prev + assert jt.flags.log_v == prev + + def test_scope(self): + prev = jt.flags.log_v + with jt.flag_scope(log_v=1): + assert jt.flags.log_v == 1 + assert jt.flags.log_v == prev + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_fold.py b/python/jittor/test/test_fold.py new file mode 100644 index 00000000..34366b4b --- /dev/null +++ b/python/jittor/test/test_fold.py @@ -0,0 +1,48 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Haoyang Peng <2247838039@qq.com> +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np + +skip_this_test = False +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + from torch.nn import functional as F +except: + torch = None + skip_this_test = True + + +@unittest.skipIf(skip_this_test, "No Torch Found") +class TestFoldOp(unittest.TestCase): + def test_fold(self): + # test unfold first and the test fold. + for i in range(4,10): + tn = np.random.randn(1,3,i,i).astype('float32') + ja = jt.array(tn) + ta = torch.autograd.Variable(torch.from_numpy(tn),requires_grad=True) + juf = jt.nn.unfold(ja,kernel_size=2,stride=2,dilation=2,padding=2) + tuf = F.unfold(ta,kernel_size=2,stride=2,dilation=2,padding=2) + assert np.allclose(juf.data,tuf.detach().numpy()) + gjuf = jt.grad(juf,ja) + gtuf = torch.autograd.grad(tuf,ta,torch.ones_like(tuf),retain_graph=True)[0] + assert np.allclose(gjuf.data,gtuf.detach().numpy()) + # test fold + jf = jt.nn.fold(juf,output_size=(i,i),kernel_size=2,stride=2,dilation=2,padding=2) + tf = F.fold(tuf,output_size=(i,i),kernel_size=2,stride=2,dilation=2,padding=2) + assert np.allclose(jf.data,tf.detach().numpy()) + gjf = jt.grad(jf,juf) + gtf = torch.autograd.grad(tf,tuf,torch.ones_like(tf),retain_graph=True)[0] + assert np.allclose(gjf.data,gtf.detach().numpy()) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_fp16.py b/python/jittor/test/test_fp16.py new file mode 100644 index 00000000..6f69a2ce --- /dev/null +++ b/python/jittor/test/test_fp16.py @@ -0,0 +1,372 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import os + +def transpose0231(x): + s0, s1, s2, s3 = x.shape + asize = 16 + bsize = 16 + ILP = 2 + return jt.code([s0, s2, s3, s1], x.dtype, [x], + cuda_header="#include \n#include ", + cuda_src=f""" + __global__ void kernel(in0_type* __restrict__ x, in0_type* __restrict__ y, int s0, int s1, int s2, int s3) {{ + __shared__ in0_type t[{asize*ILP}*{bsize*ILP+1}]; + int t3 = threadIdx.x % {bsize}; + int t1 = threadIdx.x / {bsize}; + int b3 = blockIdx.x; + int b2 = blockIdx.y; + int b0 = blockIdx.z; + int x3 = 1; + int x2 = s3; + int x1 = s2*x2; + int x0 = s1*x1; + int y3 = 1; + int y2 = s1; + int y1 = s3*y2; + int y0 = s2*y1; + in0_type tmp[{ILP}]; + for (int i=0; i<(s1-1)/{asize*ILP}+1; i++) + {{ + int _b3 = b3 * {bsize*ILP} + t3*{ILP}; + if (_b3 < s3) {{ + #pragma unroll + for (int j=0; j<{ILP}; j++) {{ + vload( + tmp, + &x[b0*x0+(t1*{ILP}+j+i*{asize*ILP})*x1+b2*x2+_b3*x3] + ); + #pragma unroll + for (int k=0; k<{ILP}; k++) + t[(t1*{ILP}+j)*{bsize*ILP+1}+t3*{ILP}+k] = tmp[k]; + + }} + }} + __syncthreads(); + int t3_ = threadIdx.x % {asize}; + int t1_ = threadIdx.x / {asize}; + _b3 = b3 * {bsize*ILP} + t1_*{ILP}; + if (_b3 < s3) {{ + #pragma unroll + for (int j=0; j<{ILP}; j++) {{ + #pragma unroll + for (int k=0; k<{ILP}; k++) {{ + tmp[k] = + t[(t3*{ILP}+k)*{bsize*ILP+1}+t1_*{ILP}+j]; + }} + vload( + &y[b0*y0+b2*y1+(_b3+j)*y2+((t3*{ILP})+i*{asize*ILP})*y3], + tmp + ); + }} + }} + __syncthreads(); + }} + }} + int s0, s1, s2, s3; + in0->shape.unpack(s0, s1, s2, s3); + kernel<<<{{(s3-1)/{bsize*ILP}+1, s2, s0 }}, {bsize*asize}>>> + (in0_p, out0_p, s0, s1, s2, s3); + """) + +def transpose0231_2(x): + s0, s1, s2, s3 = x.shape + asize = 16 + bsize = 8 + ILP = 2 + return jt.code([s0, s2, s3, s1], x.dtype, [x], + cuda_header="#include \n#include ", + cuda_src=f""" + __global__ __launch_bounds__({asize*bsize}) void kernel(in0_type* __restrict__ x, in0_type* __restrict__ y, int s0, int s1, int s2, int s3) {{ + __shared__ in0_type t[{asize*ILP}*{bsize*ILP+1}]; + int t3 = threadIdx.x % {bsize}; + int t1 = threadIdx.x / {bsize}; + int b3 = blockIdx.x; + int b1 = blockIdx.y; + int b2 = 0; + int b0 = blockIdx.z; + int x3 = 1; + int x2 = s3; + int x1 = s2*x2; + int x0 = s1*x1; + int y3 = 1; + int y2 = s1; + int y1 = s3*y2; + int y0 = s2*y1; + in0_type tmp[{ILP}]; + {{ + int _b3 = b3 * {bsize*ILP} + t3*{ILP}; + if (_b3 < s3) {{ + #pragma unroll + for (int j=0; j<{ILP}; j++) {{ + if (t1*{ILP}+j+b1*{asize*ILP} >= s1) + continue; + vload( + tmp, + &x[b0*x0+(t1*{ILP}+j+b1*{asize*ILP})*x1+b2*x2+_b3*x3] + ); + #pragma unroll + for (int k=0; k<{ILP}; k++) + t[(t1*{ILP}+j)*{bsize*ILP+1}+t3*{ILP}+k] = tmp[k]; + + }} + }} + __syncthreads(); + int t3_ = threadIdx.x % {asize}; + int t1_ = threadIdx.x / {asize}; + _b3 = b3 * {bsize*ILP} + t1_*{ILP}; + int yy3 = (t3_*{ILP})+b1*{asize*ILP}; + if (_b3 < s3 && yy3 < s1) {{ + #pragma unroll + for (int j=0; j<{ILP}; j++) {{ + #pragma unroll + for (int k=0; k<{ILP}; k++) {{ + tmp[k] = + t[(t3_*{ILP}+k)*{bsize*ILP+1}+t1_*{ILP}+j]; + }} + vload( + &y[b0*y0+b2*y1+(_b3+j)*y2+yy3*y3], + tmp + ); + // printf("%d %d %d %d %d\\n", b0*y0+b2*y1+(_b3+j)*y2+yy3*y3, + // b0, b2, (_b3+j), yy3); + }} + }} + __syncthreads(); + }} + }} + int s0, s1, s2, s3; + in0->shape.unpack(s0, s1, s2, s3); + kernel<<<{{(s3-1)/{bsize*ILP}+1, (s1-1)/{asize*ILP}+1, s0 }}, {bsize*asize}>>> + (in0_p, out0_p, s0, s1, s2, s3); + """) + +def check_share(): + return + a = jt.rand((30, 32, 4, 2000)).float32() + jt.code(a.shape, a.dtype, [a], + cuda_header="#include \n#include ", + cuda_src=""" + __global__ void kernel(in0_type* __restrict__ a, in0_type* __restrict__ b) { + __shared__ float x[32*33]; + for (int i=0; i<3; i++) { + ((float2*)&x[i])[0] = ((float2*)&a[i])[0]; + ((float2*)&b[i])[0] = ((float2*)&x[i+1])[0]; + } + } + kernel<<<1024,16*16>>>(in0_p, out0_p); + """).sync() + jt.sync_all(True) + # print(a[0]+1) + print("pass test") + +class TestFP16(unittest.TestCase): + def test_array(self): + a = np.array([1,2,3], dtype="float16") + b = jt.array(a) + np.testing.assert_allclose(a, b.data) + + def test_add(self): + a = np.array([1,2,3], dtype="float16") + b = jt.array(a) + c = b+b + np.testing.assert_allclose(c.data, a+a) + d = c.sum() + np.testing.assert_allclose(d.data, [12]) + c = c+1 + print(c) + + def test_matmul(self): + a = jt.random((100,100)).float16() + b = jt.random((100,100)).float16() + c = jt.matmul(a, b) + c.sync() + + def test_bmm(self): + a = jt.random((10,3,4)).float16() + b = jt.random((10,4,5)).float16() + c = jt.matmul(a, b) + c.sync() + + def test_matmul_grad(self): + a = jt.random((100,100)).float16() + b = jt.random((100,100)).float16() + c = jt.matmul(a, b) + c.sync() + da, db = jt.grad(c, [a,b]) + jt.sync_all() + assert da.dtype == "float16" + assert db.dtype == "float16" + + def test_array_random_auto_cast(self): + a = jt.array([1.0,2.0]) + assert a.dtype == "float32" + with jt.flag_scope(amp_reg=2+16): + a = jt.array([1.0,2.0]) + assert a.dtype == "float16", a.dtype + + a = jt.random([10]) + assert a.dtype == "float32" + with jt.flag_scope(amp_reg=2+16): + a = jt.random([10]) + assert a.dtype == "float16", a.dtype + + def test_conv(self): + a = jt.random((3,4,5,5)).float16() + b = jt.random((4,4,3,3)).float16() + c = jt.nn.conv(a, b) + c.sync() + + def test_max(self): + a = jt.random((100,)).float16() + b = jt.random((100,)).float16() + c = a.maximum(b) + c.sync() + + def test_reduce_dtype_infer(self): + with jt.flag_scope(amp_reg=1): + a = jt.random((3,4,5,5)).float16() + b = a.sum() + b.sync() + assert b.dtype == "float32" + with jt.flag_scope(amp_reg=2): + a = jt.random((3,4,5,5)).float16() + b = a.sum() + b.sync() + assert b.dtype == "float32" + with jt.flag_scope(amp_reg=0): + a = jt.random((3,4,5,5)).float16() + b = a.sum() + b.sync() + assert b.dtype == "float32" + with jt.flag_scope(amp_reg=2+4): + a = jt.random((3,4,5,5)).float16() + b = a.sum() + b.sync() + assert b.dtype == "float16", b.dtype + + def test_white_dtype_infer(self): + with jt.flag_scope(amp_reg=1): + a = jt.random((3,4,5,5)).float16() + b = a**a + b.sync() + assert b.dtype == "float32" + with jt.flag_scope(amp_reg=2): + a = jt.random((3,4,5,5)).float16() + b = a**a + b.sync() + assert b.dtype == "float32" + with jt.flag_scope(amp_reg=0): + a = jt.random((3,4,5,5)).float16() + b = a**a + b.sync() + assert b.dtype == "float32" + with jt.flag_scope(amp_reg=2+8): + a = jt.random((3,4,5,5)).float16() + b = a**a + b.sync() + assert b.dtype == "float16", b.dtype + + def test_module_half(self): + a = jt.nn.Linear(10,10) + assert a.weight.dtype == "float32" + a.half() + assert a.weight.dtype == "float16" + + def test_scalar(self): + a = jt.float16([1,2,3]) + assert (a*1).dtype == "float16" + assert (a*jt.float16([1,2,3])).dtype == "float16" + assert (a*jt.float32([1,2,3])).dtype == "float32" + assert (a*jt.float32([1,2,3]).sum()).dtype == "float16" + assert jt.int([0,1,0]).ternary(a, jt.float32(1)).dtype == "float16" + + def test_amp_level3(self): + with jt.flag_scope(amp_level = 3): + a = jt.float16([1,2,3]) + assert (a.sum()).dtype == "float16" + assert (a.mean()).dtype == "float16" + assert (a.log()).dtype == "float16" + assert (a.exp()).dtype == "float16" + + def test_safe_clip(self): + import math + assert not jt.float16(math.inf).isfinite() + assert jt.safe_clip(jt.float16(math.inf)).isfinite() + +@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") +class TestFP16CUDA(TestFP16): + def setUp(self): + jt.flags.use_cuda = 1 + def tearDown(self): + jt.flags.use_cuda = 0 + + def test_softmax(self): + a = jt.rand((120, 2000, 2000)).float16() + # a = jt.rand((1, 2000, 2000)).float32() + jt.sync_all() + with jt.profile_scope(10, 100): + a.log_softmax(-1).sync() + + def test_transpose(self): + check_share() + # return + a = jt.rand((30, 32, 4, 2000)).float32() + # a = jt.rand((1, 1024, 1, 2000)).float32() + diff = transpose0231(a).data != a.transpose((0,2,3,1)).data + print(np.where(diff)) + # return + jt.sync_all() + # with jt.profile_scope(100, 11000): + with jt.profile_scope(100, 11000): + # a.log_softmax(-1).sync() + transpose0231(a).sync() + + a.transpose((0,2,3,1)).sync() + # a.transpose((0,2,1,3)).sync() + a.fuse_transpose((0,2,1,3)).sync() + (a+1).sync() + jt.sync_all(True) + diff = transpose0231(a).data != a.transpose((0,2,3,1)).data + print(np.where(diff)) + np.testing.assert_allclose(transpose0231(a).data, a.transpose((0,2,3,1)).data) + + def test_transpose2(self): + # check_share() + # return + # a = jt.rand((30, 32, 4, 2000)).float32() + # a = jt.rand((1, 10000, 1, 2000)).float32() + a = jt.rand((1, 10000, 1, 2048)).float32() + print("transpose") + transpose0231_2(a).sync() + print("add") + (a+1).sync() + return + # a = jt.arange(32*16).reshape((1, 32, 1, 16)) + diff = transpose0231_2(a).data != a.transpose((0,2,3,1)).data + print(np.where(diff)) + # return + jt.sync_all() + # with jt.profile_scope(100, 11000): + with jt.profile_scope(100, 1100): + # a.log_softmax(-1).sync() + transpose0231_2(a).sync() + + a.transpose((0,2,3,1)).sync() + # a.transpose((0,2,1,3)).sync() + a.fuse_transpose((0,2,1,3)).sync() + (a+1).sync() + jt.sync_all(True) + diff = transpose0231_2(a).data != a.transpose((0,2,3,1)).data + print(np.where(diff)) + np.testing.assert_allclose(transpose0231_2(a).data, a.transpose((0,2,3,1)).data) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_function.py b/python/jittor/test/test_function.py new file mode 100644 index 00000000..ae1baa07 --- /dev/null +++ b/python/jittor/test/test_function.py @@ -0,0 +1,308 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from collections.abc import Sequence, Mapping +from .test_core import expect_error +from jittor import Function + +class TestFunction(unittest.TestCase): + def test1(self): + class MyFunc(Function): + def execute(self, x): + return x+1 + + def grad(self, grad): + return grad-2 + a = jt.ones(1) + func = MyFunc() + b = func(a) + da = jt.grad(b, a) + assert da.data == -1 + + def test_apply(self): + class MyFunc(Function): + def execute(self, x): + return x+1 + + def grad(self, grad): + return grad-2 + a = jt.ones(1) + func = MyFunc.apply + b = func(a) + da = jt.grad(b, a) + assert da.data == -1 + + def test2(self): + class MyFunc(Function): + def execute(self, x): + self.x = x + return x+1 + + def grad(self, grad): + return (grad-2) * self.x + a = jt.ones(1) * 10 + func = MyFunc() + b = func(a) + da = jt.grad(b, a) + assert da.data == -10 + + def test_grad_not_match_error(self): + class MyFunc(Function): + def execute(self, x, y): + self.x = x + self.y = y + return x*y + + def grad(self, grad): + return (grad-2) * self.x + a = jt.array(3.0) + b = jt.array(4.0) + func = MyFunc() + c = func(a, b) + expect_error(lambda: jt.grad(c, [a, b])) + + def test_multi_grads(self): + class MyFunc(Function): + def execute(self, x, y): + self.x = x + self.y = y + return x*y + + def grad(self, grad): + return (grad-2) * self.y, (grad-2) * self.x + a = jt.array(3.0) + b = jt.array(4.0) + func = MyFunc() + c = func(a, b) + da, db = jt.grad(c, [a, b]) + assert da.data == -4 + assert db.data == -3 + + def test_multi_grads_none(self): + class MyFunc(Function): + def execute(self, x, y): + self.x = x + self.y = y + return x*y + + def grad(self, grad): + return (grad-2) * self.y, None + a = jt.array(3.0) + b = jt.array(4.0) + func = MyFunc() + c = func(a, b) + da, db = jt.grad(c, [a, b]) + assert da.data == -4 + assert db.data == 0 + + def test_multi_grads_multi_out(self): + class MyFunc(Function): + def execute(self, x, y): + self.x = x + self.y = y + return x*y, x/y + + def grad(self, grad0, grad1): + return grad0 * self.y, grad1 * self.x + a = jt.array(3.0) + b = jt.array(4.0) + func = MyFunc() + c,d = func(a, b) + da, db = jt.grad(c+d*3, [a, b]) + assert da.data == 4 + assert db.data == 9 + + def test_multi_grads_multi_out_stop_grad_0(self): + class MyFunc(Function): + def execute(self, x, y): + self.x = x + self.y = y + return x*y, x/y + + def grad(self, grad0, grad1): + return grad0 * self.y, grad1 * self.x + a = jt.array(3.0) + b = jt.array(4.0) + b.stop_grad() + func = MyFunc() + c,d = func(a, b) + da, db = jt.grad(c+d*3, [a, b]) + assert da.data == 4 + assert db.data == 0 + + def test_multi_grads_multi_out_stop_grad_1(self): + class MyFunc(Function): + def execute(self, x, y): + self.x = x + self.y = y + return x*y, x/y + + def grad(self, grad0, grad1): + assert grad1 is None + return grad0 * self.y, None + a = jt.array(3.0) + b = jt.array(4.0) + func = MyFunc() + c,d = func(a, b) + d.stop_grad() + da, db = jt.grad(c+d*3, [a, b]) + assert da.data == 4 + assert db.data == 0 + + def test_multi_grads_multi_out2(self): + class MyFunc(Function): + def execute(self, x, y): + self.x = x + self.y = y + return x*y, x/y + + def grad(self, grad0, grad1): + res = (grad0 * self.y, grad1 * self.x) + print(res) + return res + a = jt.array(3.0) + b = jt.array(4.0) + func = MyFunc() + c,d = func(a, b) + da, db = jt.grad(c+d*3, [a, b]) + assert da.data == 4, da.data + assert db.data == 9 + + def test_multi_grads_multi_out3(self): + class MyFunc(Function): + def execute(self, x, y): + self.x = x + self.y = y + return x*y, x/y + + def grad(self, grad0, grad1): + res = (grad0 * self.y, grad1 * self.x) + print(res) + return res + a = jt.array(3.0) + b = jt.array(4.0) + c,d = MyFunc()(a, b) + da, db = jt.grad(c+d*3, [a, b]) + assert da.data == 4, da.data + assert db.data == 9 + + def test_multi_grads_multi_out4(self): + class MyFunc(Function): + def execute(self, x, z, y): + self.x = x + self.y = y + return x*y, "test", x/y + + def grad(self, grad0, _, grad1): + assert _ is None + res = (grad0 * self.y, None, grad1 * self.x) + print(res) + return res + a = jt.array(3.0) + b = jt.array(4.0) + c,_,d = MyFunc()(a, "a", b) + da, db = jt.grad(c+d*3, [a, b]) + assert da.data == 4, da.data + assert db.data == 9 + + + def test_multi_grads_multi_out5(self): + class MyFunc(Function): + def execute(self, x, z, y): + self.x = x.name("x") + self.y = y.name("y") + return x*y, "test", x/y + + def grad(self, grad0, _, grad1): + assert _ is None + res = (grad0 * self.y, 1, grad1 * self.x) + print(res) + return res + a = jt.array(3.0).name('a') + b = jt.array(4.0).name('b') + c,_,d = MyFunc()(a, "a", b) + c.name('c'), d.name('d') + expect_error(lambda : jt.grad(c+d*3, [a, b])) + + def test_zmem_leak(self): + def test(): + self.test_multi_grads_multi_out5() + test() + jt.clean() + self.assertEqual(jt.liveness_info()["lived_vars"], 0) + + def test_zmem_leak2(self): + def test(): + class MyFunc(Function): + def execute(self, x, z, y): + self.x = x.name("x") + self.y = y.name("y") + return x*y, "test", x/y + + def grad(self, grad0, _, grad1): + assert _ is None + res = (grad0 * self.y, None, grad1 * self.x) + return res + a = jt.array(3.0).name('a') + b = jt.array(4.0).name('b') + c,_,d = MyFunc()(a, "a", b) + c.name('c'), d.name('d') + g = jt.grad(c+d*3, [a, b]) + test() + jt.clean() + jt.dump_all_graphs() + self.assertEqual(jt.liveness_info()["lived_vars"], 0) + + @unittest.skipIf(True, "skip memleak test") + def test_zmem_leak3(self): + def test(): + class MyFunc(Function): + def execute(self, x, z, y): + self.x = x + self.y = y + return x*y, "test", x/y + + def grad(self, grad0, _, grad1): + assert _ is None + res = (grad0 * self.y, None, grad1 * self.x) + return res + a = jt.array(3.0) + b = jt.array(4.0) + c,_,d = MyFunc()(a, "a", b) + g = jt.grad(c+d*3, [a, b]) + jt.sync(g) + import resource + t1 = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + for i in range(100000): + test() + if i % 10000 == 0: + jt.clean() + t2 = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + for i in range(1000000): + test() + if i % 10000 == 0: + jt.clean() + t3 = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + print(t1,t2,t3) + assert t3 < t2 + 10, (t1,t2,t3) + self.assertEqual(jt.liveness_info()["lived_vars"], 0) + + +class TestFunctionWithEagerExecution(TestFunction): + @classmethod + def setUpClass(self): + jt.flags.lazy_execution = 0 + @classmethod + def tearDownClass(self): + jt.flags.lazy_execution = 1 + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_fused_op.py b/python/jittor/test/test_fused_op.py new file mode 100644 index 00000000..119b423c --- /dev/null +++ b/python/jittor/test/test_fused_op.py @@ -0,0 +1,371 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import os +os.environ['OPENBLAS_NUM_THREADS'] = '1' + +import unittest +import time +import jittor as jt +from jittor import LOG +import numpy as np +from .test_core import expect_error +import contextlib + +performance_test = os.environ.get("performance_test", "") == "1" +skip_slow_test = not performance_test + +@contextlib.contextmanager +def performance_test_scope(warmup=0, rerun=0, **args): + """ profile scope + example: + with jt.profile_scope() as report: + ...... + print(report) + """ + assert not jt.flags.profiler_enable + if skip_slow_test: + jt.profiler.start(0, 0) + else: + jt.profiler.start(warmup, rerun) + report = [] + try: + with jt.flag_scope(**args): + yield report + finally: + jt.profiler.stop() + if skip_slow_test: + report.extend([[1e30]]*3) + else: + report.extend(jt.profiler.report()) + +def retry(num): + def outer(func): + def inner(*args): + for i in range(num): + if i == num-1: + func(*args) + break + try: + func(*args) + break + except: + pass + LOG.v(f"Retry {i}") + return inner + return outer + +def get_np_matmul_toughtput(size): + # import os + # os.environ['OPENBLAS_NUM_THREADS'] = '1' + # import numpy as np + # import time + a = np.random.randn(size, size).astype("float32") + b = np.random.randn(size, size).astype("float32") + c = np.random.randn(size, size).astype("float32") + warmup = 2 + rerun = 10+1 + for _ in range(warmup): np.matmul(a,b,c) + start_time = time.time() + for _ in range(rerun): np.matmul(a,b,c) + elapsed_time = time.time() - start_time + return (size*size*size*rerun) / elapsed_time + +class TestFusedOp(unittest.TestCase): + def test_add(self): + jt.clean() + def check(hv, lv, lo): + self.assertEqual(( + jt.number_of_hold_vars(), + jt.number_of_lived_vars(), + jt.number_of_lived_ops()), + (hv, lv, lo)) + for i in range(8): + check(0,0,0) + a = jt.array([1.0,1.0]).name('a').stop_fuse() + b = (a+jt.array([1.0,1.0]).name('t1').stop_fuse()).name('b') + c = (b+jt.array([1.0,1.0]).name('t2').stop_fuse()).name('c') + check(3,5,5) + graph = jt.dump_all_graphs() + # for n in graph.nodes_info: + # print(n) + np.testing.assert_allclose(c.data, [3,3]) + graph2 = jt.dump_all_graphs() + print("check", i) + for n in graph2.nodes_info: + print(n) + print(jt.liveness_info()) + check(3,5,2) + graph = jt.dump_all_graphs() + for node in graph.nodes_info: + if node.startswith("Op"): + if 'add->' in node: + assert ':s0' in node, node + else: + assert ':s1' in node, node + elif ',b,' in node: + # b has been fused + assert ':s0' in node, node + else: + assert ':s1' in node + if i&1: del a + if i&2: del b + if i&4: del c + + if i==0: check(3,5,2) + elif i==1: check(2,5,2) + elif i==2: check(2,5,2) + elif i==3: check(1,1,0) + elif i==4: check(2,3,1) + elif i==5: check(1,3,1) + elif i==6: check(1,1,0) + elif i==7: check(0,0,0) + + if not (i&1): a.sync() + if not (i&2): b.sync() + if not (i&4): c.sync() + + if i==0: check(3,5,2) + elif i==1: check(2,3,1) + elif i==2: check(2,5,2) + elif i==3: check(1,1,0) + elif i==4: check(2,3,1) + elif i==5: check(1,1,0) + elif i==6: check(1,1,0) + + if not (i&1): del a + if not (i&2): del b + if not (i&4): del c + check(0,0,0) + + def test_fuse_reduce_and_broadcast(self): + size = 10 + a = jt.random([size,size,1]) + b = jt.random([1,size,size]) + c = (a*b).sum(dim=1) + nc = (a.data*b.data).sum(1) + assert c.shape == [size,size] + assert (np.abs(nc-c.data)<1e-5).all() + + def test_fuse_reduce_and_broadcast2(self): + size = 10 + a = jt.random([1]) + b = a.broadcast([size]).sum() + assert (np.abs(b.data - a.data*size) < 1e-5).all() + a = jt.random([size,1]) + b = a.broadcast([size,size]).sum(1, keepdims=True) + assert (np.abs(b.data - a.data*size) < 1e-5).all() + + def test_fuse_reduce2(self): + size = 10 + a = jt.random([1]).broadcast([size]).name('a') + # a.data + b = a.sum().name('b') + c = a.min().name('c') + d = a.max().name('d') + jt.fetch_sync([b,c,d]) + + graph = jt.dump_all_graphs() + node_a = [ node for node in graph.nodes_info if ",a," in node ] + assert 's0' in node_a[0] + + v = a.data[0] + assert np.allclose(v*10,b.data) and v==c.data and v==d.data, (v, b.data, c.data, d.data) + + def test_profile_fused_op(self): + size = 1000 + r1 = [] + r2 = [] + for size in range(1024, 1025, 1): + with performance_test_scope(2, 10) as report: + a = jt.random([size,size,1]) + b = jt.random([1,size,size]) + c = (a*b).sum(1) + c.sync() + + assert len(report) == 3 + tp_np = get_np_matmul_toughtput(size) + tp_jt = float(report[1][-1]) + r1.append(tp_jt) + r2.append(tp_np) + na = a.data.reshape((size,size)) + nb = b.data.reshape((size,size)) + nc = np.matmul(na,nb) + assert (np.abs(nc-c.data)<1e-2).all(), np.abs(nc-c.data).max() + + # @unittest.skipIf(skip_slow_test, "Skip slow test") + def test_profile_fused_op_transpose(self): + for size in range(1024, 1025, 1): + with performance_test_scope(2, 10): + b = jt.random([size,1,size]) + a = jt.random([1,size,size]) + c = (a*b).sum(2) + c.data + + # @unittest.skipIf(skip_slow_test, "Skip slow test") + def test_profile_fused_op_split(self): + # match v4 + @retry(10) + def check(n, m, k, cs, rs, rtp): + a = jt.random([n,m,1]) + b = jt.random([1,m,k]) + a.data, b.data + with performance_test_scope( + 20, 20000000000//(n*m*k), + compile_options = { + "split0":16,"split1":6,"split2":16, + "order0":0, "order1":1,"order2":1, + "order3":0, "order4":0,"order5":0, + "restride":rs, "unroll":2, "vectorize":2, + "compile_shapes":cs + }) as report: + c = (a*b).sum(1) + c.data + + na = a.data.reshape((n,m)) + nb = b.data.reshape((m,k)) + nc = np.matmul(na,nb) + + assert (np.abs(nc-c.data)<1e-2).all(), np.abs(nc-c.data).max() + tp = float(report[-1][-1]) + assert tp > rtp * 10**9, (tp, rtp) + + check(65, 8, 19, 1, 0, 0) + check(64, 6, 16, 1, 0, 33) # TODO 36 + check(64, 6, 16, 0, 0, 21) + check(64, 60, 16, 1, 0, 44) + check(64, 60, 16, 0, 0, 30) + check(65, 60, 16, 0, 0, 30) + check(65, 61, 16, 0, 0, 27) + check(65, 65, 16, 0, 0, 26) + check(64, 60, 64, 1, 1, 27) + check(64, 60, 64, 1, 0, 42) + check(64, 60, 64, 0, 0, 30) # TODO: why slower? + + def test_array_reindex(self): + a = jt.array([1]) + b = a.reindex([3], ['i0-1']) + np.testing.assert_allclose(b.data, [0,1,0]) + + + @unittest.skipIf(skip_slow_test, "Skip slow test") + def test_profile_fused_op_restride(self): + # match v6 + + @retry(10) + def check(n, m, k, cs, rs, pa, rtp): + a = jt.random([n,m,1]) + b = jt.random([1,m,k]) + a.data, b.data + with performance_test_scope( + 0, 20000000000//(n*m*k), + compile_options = { + "order0":0, "order1":0,"order2":0, + "split0":64,"split1":60,"split2":64, + "split3":16,"split4":6,"split5":16, + "order3":0, "order4":1,"order5":1, + "order6":0, "order7":0,"order8":0, + "restride":rs,"vectorize":2,"unroll":2, + "compile_shapes":cs, "parallel":pa + }) as report: + c = (a*b).sum(1) + c.sync() + + na = a.data.reshape((n,m)) + nb = b.data.reshape((m,k)) + nc = np.matmul(na,nb) + assert (np.abs(nc-c.data)/nc<1e-5).all(), (np.abs(nc-c.data).max(), np.where(np.abs(nc-c.data)>1)) + tp = float(report[-1][-1]) + assert tp > rtp * 10**9, (tp, rtp) + + check(64*1, 60*1, 64*1, 0, 0, 0, 31) + check(64*1, 60*1, 64*1, 0, 1, 0, 25) + check(64*1, 60*1, 64*1, 1, 0, 0, 42) + check(64*55, 60*55, 64*55, 0, 0, 0, 20) + check(64*55, 60*55, 64*55, 1, 1, 0, 37) + check(64*55, 60*55, 64*55, 0, 1, 0, 36) + check(64*55+1, 60*55+1, 64*55, 0, 1, 0, 36) + check(64*55+1, 60*55+1, 64*55+1, 0, 1, 0, 36) + check(64*55+15, 60*55+15, 64*55+15, 0, 1, 0, 34) # TODO: 36 + check(64*16, 60*16, 64*16, 0, 1, 0, 35) + check(64*55, 60*55, 64*55, 0, 1, 0, 36) + + + @unittest.skipIf(skip_slow_test, "Skip slow test") + def test_profile_fused_op_split3(self): + # match v6 + + n, m, k = 64*100, 60*100, 64*100 + + a = jt.random([n,m,1]) + b = jt.random([1,m,k]) + # a = jt.ones([n,m,1]) + # b = jt.ones([1,m,k]) + a.data, b.data + with performance_test_scope( + 0, 400000000000//(n*m*k), + compile_options={ + "order0":0, "order1":0,"order2":0, + "split0":64*4,"split1":60*4,"split2":64*4, + + "order3":0, "order4":1,"order5":1, + "split3":64,"split4":60,"split5":64, + + "order6":0, "order7":1,"order8":1, + "split6":16,"split7":6,"split8":16, + + "order9":0, "order10":0,"order11":0, + + "restride":1, "unroll":2,"vectorize":2 + }): + c = (a*b).sum(1) + c.data + + na = a.data.reshape((n,m)) + nb = b.data.reshape((m,k)) + nc = np.matmul(na,nb) + + assert (np.abs(nc-c.data)/nc<1e-5).all(), (np.abs(nc-c.data).max(), np.where(np.abs(nc-c.data)>1)) + + @unittest.skipIf(skip_slow_test, "Skip slow test") + def test_profile_fused_op_parallel(self): + # match v6 + + @retry(10) + def check(n, m, k, cs, rs, pa, rtp): + a = jt.random([n,m,1]) + b = jt.random([1,m,k]) + a.data, b.data + with performance_test_scope( + 2, 30, + compile_options = { + "order0":0, "order1":0,"order2":0, + "split0":64,"split1":60,"split2":64, + "split3":16,"split4":6,"split5":16, + "order3":0, "order4":1,"order5":1, + "order6":0, "order7":0,"order8":0, + "restride":rs,"vectorize":2,"unroll":2, + "compile_shapes":cs, "parallel":pa + }) as report: + c = (a*b).sum(1) + c.data + + na = a.data.reshape((n,m)) + nb = b.data.reshape((m,k)) + nc = np.matmul(na,nb) + + assert (np.abs(nc-c.data)/nc<1e-5).all(), (np.abs(nc-c.data).max(), np.where(np.abs(nc-c.data)>1)) + tp = float(report[-1][-1]) + assert tp > rtp * 10**9, (tp, rtp) + + check(64*16, 60*16, 64*16, 1, 1, 0, 35) + check(64*16, 60*16, 64*16, 1, 1, 1, 60) + check(64*16, 60*16, 64*16, 0, 1, 0, 35) + check(64*16, 60*16, 64*16, 0, 1, 1, 60) + check(64*16+5, 60*16+5, 64*16+5, 0, 1, 1, 50) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_fuser.py b/python/jittor/test/test_fuser.py new file mode 100644 index 00000000..29b85822 --- /dev/null +++ b/python/jittor/test/test_fuser.py @@ -0,0 +1,49 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np + + +class TestFuser(unittest.TestCase): + def test_wrong_fuse(self): + a = jt.array([1]) + b = jt.random([10,]) + c = (a * b).sum() + (a + 1) + print(c) + + def test_wrong_fuse2(self): + a = jt.array([1]) + b = jt.random([10,]) + c = jt.random([100,]) + bb = a*b + cc = a*c + jt.sync([bb,cc]) + np.testing.assert_allclose(b.data, bb.data) + np.testing.assert_allclose(c.data, cc.data) + + def test_for_fuse(self): + arr = [] + x = 0 + for i in range(100): + arr.append(jt.array(1)) + x += arr[-1] + x.sync() + for i in range(100): + # print(arr[i].debug_msg()) + assert ",0)" not in arr[i].debug_msg() + + def test_array_bc(self): + # a = jt.array(1) + with jt.profile_scope() as rep: + b = jt.array(1).broadcast([10]) + b.sync() + assert len(rep) == 2 + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_gamma_distribution.py b/python/jittor/test/test_gamma_distribution.py new file mode 100644 index 00000000..87e936ce --- /dev/null +++ b/python/jittor/test/test_gamma_distribution.py @@ -0,0 +1,42 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Haoyang Peng <2247838039@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +import numpy as np +import unittest + +try: + import torch + from torch.autograd import Variable + has_autograd = True +except: + has_autograd = False + +@unittest.skipIf(not has_autograd or not jt.compiler.has_cuda, "No autograd or cuda found.") +class TestDigamma(unittest.TestCase): + def setUp(self): + jt.flags.use_cuda = 1 + def tearDown(self): + jt.flags.use_cuda = 0 + + def test_digamma(self): + for i in range(30): + concentration = np.random.uniform(1, 3) + rate = np.random.uniform(1, 2) + j_gamma = jt.distributions.GammaDistribution(concentration, rate) + t_gamma = torch.distributions.gamma.Gamma(torch.tensor([concentration]), torch.tensor([rate])) + samples = t_gamma.sample((30, i+5)) + j_samples = jt.array(samples.detach().numpy()) + np.testing.assert_allclose(j_gamma.log_prob(j_samples).data, t_gamma.log_prob(samples).detach().numpy(), rtol=1e-4, atol=1e-6) + samples = j_gamma.sample((30,i+5)) + t_samples = torch.tensor(samples.numpy()) + np.testing.assert_allclose(j_gamma.log_prob(samples).data, t_gamma.log_prob(t_samples).detach().numpy(), rtol=1e-4, atol=1e-6) + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_grad.py b/python/jittor/test/test_grad.py new file mode 100644 index 00000000..de328ef2 --- /dev/null +++ b/python/jittor/test/test_grad.py @@ -0,0 +1,184 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from .test_core import expect_error + +def equal_size(x, y): + return x.dtype == y.dtype and x.shape == y.shape + +def ngrad(func, vars, eps): + out = func(vars) + dout = [] + for i in range(len(vars)): + pvar = vars[i].astype("float64") + if type(pvar)==np.ndarray and pvar.size>1: + grad = [] + var_f = pvar.flatten() + for j in range(len(var_f)): + var = pvar.flatten() + var[j] += eps + vars[i] = var.reshape(pvar.shape) + out2 = func(vars) + grad.append((out2-out)/eps) + dout.append(np.array(grad).reshape(pvar.shape)) + else: + vars[i] = vars[i] + eps + out2 = func(vars) + dout.append((out2-out)/eps) + vars[i] = pvar + return out, dout + +class TestGrad(unittest.TestCase): + def test_grad(self): + x = jt.array([1.0, 2.0]) + y = jt.array([3.0, 4.0]) + z = x*y + dx, dy, dz = jt.grad(z, [x,y,z]) + assert equal_size(dx, x) and equal_size(dy, y), f"{x} {y} {dx} {dy}" + assert (dy.data == x.data).all(), f"{dy.data} {x.data}" + assert (dx.data == y.data).all(), f"{dx.data} {y.data}" + assert (dz.data == 1).all() + + def test_check_float(self): + x = jt.array(1) + y = x*x + expect_error(lambda: jt.grad(y, [x])) + + def test_grad2(self): + def test(n): + x = jt.array(2.0) + y = x + for _ in range(n-1): y = y*x + dx, = jt.grad(y, [x]) + assert dx.data == n*2**(n-1), f"{dx.data} {x.data}, {y.data}" + test(5) + test(6) + test(7) + test(8) + + def test_var_index(self): + x = jt.array(2.0) + y = x-x + dx, = jt.grad(y, [x]) + assert dx.data == 0, dx.data + x = jt.array(2.0) + y = x/x + dx, = jt.grad(x, [y]) + assert dx.data == 0 + + def test_random_graph(self): + @jt.flag_scope(auto_convert_64_to_32=0) + def test(num_vars, num_ops, seed): + np.random.seed(seed) + vars = [] + for _ in range(num_vars): + vars.append(np.random.rand(1)) + def random_func(vars): + np.random.seed(seed+1) + vars = list(vars) + for i in range(num_ops): + v1 = len(vars)-1-np.random.randint(num_vars) + v2 = len(vars)-1-np.random.randint(num_vars) + rop = "+-*/"[np.random.randint(4)] + if (rop == '/' or rop == '-') and v1 is v2: + rop = '+' + vout = eval(f"vars[v1]{rop}vars[v2]") + vars.append(vout) + if type(vars[i]) == jt.Var: + for i in range(len(vars)): + vars[i].name("v"+str(i)) + return vout + np_out, np_dout = ngrad(random_func, vars, 1e-7) + + jt_vars = [ jt.array(v) for v in vars ] + jt_out = random_func(jt_vars) + assert (np.abs(jt_out.data-np_out) < 1e-5).all(), (jt_out.data, np_out) + jt_dout = jt.grad(jt_out, jt_vars) + jt_dout = [ v.data for v in jt_dout ] + for jt_d, np_d in zip(jt_dout, np_dout): + assert abs(jt_d - np_d) < 1e-3, f"{jt_d} {np_d}" + test(1,1,0) + # test(3,3,1) + test(3,6,0) + test(10,100,2) + test(30,100,4) + test(50,100,6) + + def test_top_sort(self): + x = jt.array(2.0) + x.name('x') + y1 = x*x # 2 + y1.name('y1') + y2 = x*x # 2 + y2.name('y2') + y3 = y1*y2 # 4 + y3.name('y3') + y4 = y3*y1 # 6 + y4.name('y4') + y5 = y4*y1 # 8 + y5.name('y5') + y6 = y5*y1 # 10 + y6.name('y6') + vars = [x,y1,y2,y3,y4,y5,y6] + grads = [ g.data for g in jt.grad(y6, vars) ] + dx = grads[0] + assert dx == 10*2**9, f"{grads}" + + def test_int_grad(self): + x = jt.array(2.0) + z = x*x*x*x*x + dx, = jt.grad(z, [x]) + self.assertEqual(dx.data, 5*2**4) + + y1 = jt.int(x) + y2 = jt.float(x) + z = x*x*y1*y1*y2 + expect_error(lambda: jt.grad(z, [y1])) + dx, = jt.grad(z, [x]) + self.assertEqual(dx.data, 48) + + def test_int_enable_grad(self): + a = jt.int([1,2,3]) + a.requires_grad = True + a.start_grad() + + def test_nth_grad(self): + x = jt.array(2.0) + y = x*x*x*x + dx = jt.grad(y, x) + ddx = jt.grad(dx, x) + dddx = jt.grad(ddx, x) + self.assertEqual(y.data, 2**4) + self.assertEqual(dx.data, 4*2**3) + self.assertEqual(ddx.data, 4*3*2**2) + self.assertEqual(dddx.data, 4*3*2*2**1) + + def test_no_grad(self): + a = jt.array(1.0) + with jt.no_grad(): + b = a + for i in range(10): + b = b.clone() + 1 + assert b.data == 11 + jt.clean() + assert jt.liveness_info()["lived_vars"] == 2 + + def test_requires_grad(self): + a = jt.array(2.0) + assert a.requires_grad == True + a.requires_grad = False + assert a.requires_grad == False + assert jt.grad(a**2, a) == 0 + a.requires_grad = True + assert a.requires_grad == True + assert jt.grad(a**2, a) == 4 + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_group_conv_tuner.py b/python/jittor/test/test_group_conv_tuner.py new file mode 100644 index 00000000..0539c449 --- /dev/null +++ b/python/jittor/test/test_group_conv_tuner.py @@ -0,0 +1,148 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import os +import numpy as np +from jittor import compile_extern +# TODO: compare with pytorch + +from jittor.test.test_log import find_log_with_re +if jt.has_cuda: + from jittor.compile_extern import cublas_ops, cudnn_ops +else: + cublas_ops = cudnn_ops = None + + +def conv_nchw(x, in_planes, out_planes, kernel_size, padding, stride=1, dilation=1, groups=1, init_method=None, w_=None): + N,C,H,W = x.shape + Kh, Kw = kernel_size, kernel_size + G = groups + CpG = C // G # channels per group + padding = (padding, padding) + dilation = (dilation, dilation) + stride = (stride, stride) + assert C==in_planes + oc = out_planes + oh = (H+padding[0]*2-Kh*dilation[0]+dilation[0]-1)//stride[0]+1 + ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1 + + if w_ is None: + assert 0 + else: + w = w_ + + xx = x.reindex([N,G,oc//G,CpG,oh,ow,Kh,Kw], [ + 'i0', # Nid + f'i1*{CpG}+i3', # Gid + f'i4*{stride[0]}-{padding[0]}+i6*{dilation[0]}', # Hid+Khid + f'i5*{stride[1]}-{padding[1]}+i7*{dilation[1]}', # Wid+KWid + ]) + # w: [oc, CpG, Kh, Kw] + ww = w.reindex([N, G, oc//G, CpG, oh, ow, Kh, Kw], [ + f'i1*{oc//G}+i2', + 'i3', + 'i6', + 'i7' + ]) + + yy = xx*ww + y = yy.reindex_reduce('add', [N, oc, oh, ow], [ + 'i0', + f'i1*{oc//G}+i2', + 'i4', + 'i5' + ]) + return y + + +def test_nchw(x, w, stride, padding, dilation, groups): + _, in_planes, _, _ = x.shape + out_planes, _, kernel_size, _ = w.shape + return conv_nchw(x, in_planes, out_planes, kernel_size, padding, stride=stride, dilation=dilation, groups=groups, w_=w) + + +def check_forward(xshape, wshape, stride, padding, dilation, groups, use_cuda, nhwc): + assert nhwc == 0 + test_func = test_nchw + + # only check cudnn + with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1, + log_v=10, log_vprefix="op.cc=100,conv_tuner=1000" + ) as raw_log: + x = jt.random(xshape) + w = jt.random(wshape) + y = test_func(x, w, stride, padding, dilation, groups) + y.sync() + with jt.flag_scope(use_cuda=0, enable_tuner=0): + cy = test_func(x, w, stride, padding, dilation, groups) + cy.sync() + + logs = find_log_with_re(raw_log, "(Jit op key (not )?found: .*conv.*)") + assert len(logs)==1 + assert np.allclose(y.data, cy.data) + + +def check_backward(xshape, wshape, stride, padding, dilation, groups, use_cuda, nhwc): + assert nhwc == 0 + test_func = test_nchw + + # only check cudnn + with jt.log_capture_scope(use_cuda=use_cuda, enable_tuner=1, + log_v=10, log_vprefix="op.cc=100,conv_tuner=1000" + ) as raw_log: + x = jt.random(xshape) + w = jt.random(wshape) + y = test_func(x, w, stride, padding, dilation, groups) + y.sync() + dx, dw = jt.grad(y, [x, w]) + jt.sync([y, dx, dw]) + with jt.flag_scope(use_cuda=0, enable_tuner=0, compile_options={"test":233}): + cy = test_func(x, w, stride, padding, dilation, groups) + cdx, cdw = jt.grad(cy, [x, w]) + jt.sync([cy, cdx, cdw]) + + logs = find_log_with_re(raw_log, "(Jit op key (not )?found: .*conv.*)") + assert len(logs)==3 + assert np.allclose(y.data, cy.data) + assert np.allclose(dw.data, cdw.data, 1e-3), (dw.data, cdw.data, np.abs(dw.data - cdw.data).max()) + assert np.allclose(dx.data, cdx.data, 1e-3), (dx.data, cdx.data, np.abs(dx.data - cdx.data).max()) + + +class TestGroupConvTuner(unittest.TestCase): + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + def test_forward_cuda(self): + for groups in [2, 4, 8]: + check_forward([10,8,100,100], [8,8//groups,3,3], 1, 0, 1, groups, 1, False) + check_forward([10,8,40,50], [16,8//groups,5,5], 1, 1, 2, groups, 1, False) + check_forward([10,8,40,50], [16,8//groups,4,4], 3, 1, 3, groups, 1, False) + + def test_forward(self): + for groups in [2, 4, 8]: + check_forward([10,8,100,100], [8,8//groups,3,3], 1, 0, 1, groups, 0, False) + check_forward([10,8,40,50], [16,8//groups,5,5], 1, 1, 2, groups, 0, False) + check_forward([10,8,40,50], [16,8//groups,4,4], 3, 1, 3, groups, 0, False) + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + def test_backward_cuda(self): + for groups in [2, 4, 8]: + check_backward([10,8,100,100], [8,8//groups,3,3], 1, 0, 1, groups, 1, False) + check_backward([10,8,40,50], [16,8//groups,5,5], 1, 1, 2, groups, 1, False) + check_backward([10,8,40,50], [16,8//groups,4,4], 3, 1, 3, groups, 1, False) + + def test_backward(self): + for groups in [2, 4, 8]: + check_backward([10,8,100,100], [8,8//groups,3,3], 1, 0, 1, groups, 0, False) + check_backward([10,8,40,50], [16,8//groups,5,5], 1, 1, 2, groups, 0, False) + check_backward([10,8,40,50], [16,8//groups,4,4], 3, 1, 3, groups, 0, False) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_histc.py b/python/jittor/test/test_histc.py new file mode 100644 index 00000000..b4751a2c --- /dev/null +++ b/python/jittor/test/test_histc.py @@ -0,0 +1,43 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Haoyang Peng <2247838039@qq.com> +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +import numpy as np +import unittest + +try: + import torch + from torch.autograd import Variable + import autograd.numpy as anp + from autograd import jacobian + + has_autograd = True +except: + has_autograd = False + +@unittest.skipIf(not has_autograd, "No autograd found.") +class TestHistc(unittest.TestCase): + def test_histc(self): + for i in range(30): + inputs = np.random.uniform(0,10,(40,40)) + tn, tm = np.random.randn(3, 3).astype('float32'), np.random.randn(3, 3).astype('float32') + x = jt.array(inputs) + t_x = torch.from_numpy(inputs) + if i % 2: + min = max = 0 + else: + min = (inputs.min() + inputs.max()) / 3 + max = (inputs.min() + inputs.max()) / 3 * 2 + joup = jt.histc(x, bins=i+1, min=min, max=max) + toup = torch.histc(t_x, bins=i+1, min=min, max=max) + np.testing.assert_allclose(joup.data, toup.cpu().numpy(), rtol=1e-4, atol=1e-6) + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_hook.py b/python/jittor/test/test_hook.py new file mode 100644 index 00000000..2ec5c90f --- /dev/null +++ b/python/jittor/test/test_hook.py @@ -0,0 +1,75 @@ + +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Wenyang Zhou <576825820@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import jittor.nn as jnn + +skip_this_test = False + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + import torch.nn as tnn + import torchvision +except: + torch = None + tnn = None + torchvision = None + skip_this_test = True + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestHook(unittest.TestCase): + def test_bhook(self): + a = jnn.ReLU() + hooked = False + def hook(mod, grad_input, grad_output): + nonlocal hooked + hooked = True + assert len(grad_input) == 1 + assert len(grad_output) == 1 + np.testing.assert_allclose(grad_input[0].numpy(), [0, 1]) + np.testing.assert_allclose(grad_output[0].numpy(), [1, 1]) + return (jt.array([-1.0, -2.0]), ) + a.register_backward_hook(hook) + x = jt.array([-1.0,2]) + y = a(x) + dx = jt.grad(y, x) + assert hooked + np.testing.assert_allclose(dx.numpy(), [-1.0, -2.0]) + + def test_register_hook(self): + x = jt.array([0.0, 0.0]) + y = x * [1,2] + y.register_hook(lambda g: g*2) + dx = jt.grad(y, x) + np.testing.assert_allclose(dx.data, [2,4]) + + def test_requires_grads_(self): + class Mod(jt.nn.Module): + def execute(self, x): + return x*2 + x = jt.random((100,)) + mod = Mod() + mod.requires_grad_(True) + y = mod(x) + y = y*10 + dx = jt.grad(y, x) + np.testing.assert_allclose(dx.data, 20) + + mod.requires_grad_(False) + y = mod(x) + y = y*10 + dx = jt.grad(y, x) + np.testing.assert_allclose(dx.data, 0) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_image_folder.py b/python/jittor/test/test_image_folder.py new file mode 100644 index 00000000..b2711bd2 --- /dev/null +++ b/python/jittor/test/test_image_folder.py @@ -0,0 +1,79 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Dun Liang . +# All Rights Reserved. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** + +import jittor as jt +import unittest +import os +import numpy as np +import random + +pass_this_test = False +msg = "" +mid = 0 +if hasattr(os, "uname") and os.uname()[1] == "jittor-ce": + mid = 1 +try: + # check can we run this test + # test code + jt.dirty_fix_pytorch_runtime_error() + import torchvision.datasets as datasets + import torchvision.transforms as transforms + import torch + + traindir = ["/data1/cjld/imagenet/train/","/home/cjld/imagenet/train/"][mid] + check_num_batch = 5 + assert os.path.isdir(traindir) + + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + train_dataset = datasets.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ])) +except Exception as e: + pass_this_test = True + msg = str(e) + +@unittest.skipIf(pass_this_test, f"can not run imagenet dataset test: {msg}") +class TestImageFolder(unittest.TestCase): + def test_imagenet(self): + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=256, shuffle=False) + + random.seed(0) + tc_data = [] + for i, data in enumerate(train_loader): + tc_data.append(data) + print("get", data[0].shape) + if i==check_num_batch: break + + from jittor.dataset.dataset import ImageFolder + import jittor.transform as transform + + dataset = ImageFolder(traindir).set_attrs(batch_size=256, shuffle=False) + + dataset.set_attrs(transform = transform.Compose([ + transform.RandomCropAndResize(224), + transform.RandomHorizontalFlip(), + transform.ImageNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ])) + + random.seed(0) + + for i, (images, labels) in enumerate(dataset): + print("compare", i) + assert np.allclose(images.numpy(), tc_data[i][0].numpy()) + assert np.allclose(labels.numpy(), tc_data[i][1].numpy()) + if i==check_num_batch: break + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_inception.py b/python/jittor/test/test_inception.py new file mode 100644 index 00000000..714fe26b --- /dev/null +++ b/python/jittor/test/test_inception.py @@ -0,0 +1,127 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Meng-Hao Guo +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +from jittor import nn, Module +from jittor.models import inception +import numpy as np +import sys, os +import random +import math +import unittest +from jittor.test.test_reorder_tuner import simple_parser +from jittor.test.test_log import find_log_with_re +from jittor.dataset.mnist import MNIST +import jittor.transform as trans +import time + +skip_this_test = False + +class MnistNet(Module): + def __init__(self): + self.model = inception.inception_v3() + self.layer = nn.Linear(1000,10) + def execute(self, x): + x = self.model(x) + x = self.layer(x) + return x + +@unittest.skipIf(skip_this_test, "skip_this_test") +class TestInception(unittest.TestCase): + @classmethod + def setUpClass(self): + # hyper-parameters + self.batch_size = 32 + self.weight_decay = 0.0001 + self.momentum = 0.9 + self.learning_rate = 0.1 + # mnist dataset + self.train_loader = MNIST(train=True, transform=trans.Resize(300)) \ + .set_attrs(batch_size=self.batch_size, shuffle=True) + self.train_loader.num_workers = 4 + self.train_loader.total_len = self.batch_size * 300 + + # setup random seed + def setup_seed(self, seed): + np.random.seed(seed) + random.seed(seed) + jt.seed(seed) + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1, use_stat_allocator=1) + def test_inception(self): + self.setup_seed(1) + loss_list=[] + acc_list=[] + mnist_net = MnistNet() + global prev + prev = time.time() + SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay) + + for batch_idx, (data, target) in enumerate(self.train_loader): + + # train step + with jt.log_capture_scope( + log_silent=1, + log_v=1, log_vprefix="op.cc=100,exe=10", + ) as logs: + # breakpoint() + output = mnist_net(data) + loss = nn.cross_entropy_loss(output, target) + SGD.step(loss) + def callback(batch_idx, loss, output, target): + # print train info + global prev + pred = np.argmax(output, axis=1) + acc = np.mean(target==pred) + loss_list.append(loss[0]) + acc_list.append(acc) + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}' + .format(0, batch_idx, 300,1. * batch_idx / 6.0, loss[0], acc, time.time()-prev)) + # prev = time.time() + jt.fetch(batch_idx, loss, output, target, callback) + + log_conv = find_log_with_re(logs, + "Jit op key (not )?found: ((mkl)|(cudnn))_conv.*") + log_matmul = find_log_with_re(logs, + "Jit op key (not )?found: ((mkl)|(cublas))_matmul.*") + if batch_idx > 2: + assert len(log_conv)==283 and len(log_matmul)==6, (len(log_conv), len(log_matmul)) + + mem_used = jt.flags.stat_allocator_total_alloc_byte \ + -jt.flags.stat_allocator_total_free_byte + # assert mem_used < 4e9, mem_used + # TODO: why bigger? + assert mem_used < 15.6e9, mem_used + # example log: + # Train Epoch: 0 [0/100 (0%)] Loss: 2.352903 Acc: 0.110000 + # Train Epoch: 0 [1/100 (1%)] Loss: 2.840830 Acc: 0.080000 + # Train Epoch: 0 [2/100 (2%)] Loss: 3.473594 Acc: 0.100000 + # Train Epoch: 0 [3/100 (3%)] Loss: 3.131615 Acc: 0.200000 + # Train Epoch: 0 [4/100 (4%)] Loss: 2.524094 Acc: 0.230000 + # Train Epoch: 0 [5/100 (5%)] Loss: 7.780025 Acc: 0.080000 + # Train Epoch: 0 [6/100 (6%)] Loss: 3.890721 Acc: 0.160000 + # Train Epoch: 0 [7/100 (7%)] Loss: 6.370137 Acc: 0.140000 + # Train Epoch: 0 [8/100 (8%)] Loss: 11.390827 Acc: 0.150000 + # Train Epoch: 0 [9/100 (9%)] Loss: 21.598564 Acc: 0.080000 + # Train Epoch: 0 [10/100 (10%)] Loss: 23.369165 Acc: 0.130000 + # Train Epoch: 0 [20/100 (20%)] Loss: 4.804510 Acc: 0.100000 + # Train Epoch: 0 [30/100 (30%)] Loss: 3.393924 Acc: 0.110000 + # Train Epoch: 0 [40/100 (40%)] Loss: 2.286762 Acc: 0.130000 + # Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000 + + assert jt.core.number_of_lived_vars() < 50000, jt.core.number_of_lived_vars() + + jt.sync_all(True) + assert np.mean(loss_list[-20:])<1 + assert np.mean(acc_list[-20:])>0.5 + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_index_op.py b/python/jittor/test/test_index_op.py new file mode 100644 index 00000000..e14c3c6b --- /dev/null +++ b/python/jittor/test/test_index_op.py @@ -0,0 +1,65 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np + +class TestIndexOp(unittest.TestCase): + def test(self): + assert (jt.index([2,2], 0).data==[[0,0],[1,1]]).all() + assert (jt.index([2,2], 1).data==[[0,1],[0,1]]).all() + a = jt.index([2,2], 0) + b = jt.index([2,2], 1) + c = a+b + assert (c.data==[[0,1],[1,2]]).all(), c.data + + def test_multioutput(self): + a,b = jt.index([2,2]) + jt.sync([a,b]) + assert (a.data==[[0,0],[1,1]]).all() + assert (b.data==[[0,1],[0,1]]).all(), b.data + + def test_multioutput2(self): + a,b = jt.index([3,3]) + assert (a.data==[[0,0,0],[1,1,1],[2,2,2]]).all() + assert (b.data==[[0,1,2],[0,1,2],[0,1,2]]).all(), b.data + a,b = jt.index([3,3]) + c = a+b + assert (c.data==[[0,1,2],[1,2,3],[2,3,4]]).all(), c.data + + def test_multioutput3(self): + a,b = jt.index([3,3]) + del a + assert (b.data==[[0,1,2],[0,1,2],[0,1,2]]).all(), b.data + + def test_vary_shape_dep(self): + a, = jt.where([1,0,1]) + b, = a.index_var() + assert (b.data==[0,1]).all() + + def test_vary_shape_dep2(self): + a = jt.array([[1,2,3],[4,5,6],[7,8,9]]) + index0, = jt.where(a.sum(1)>7) # [1,2] + index0 = index0.broadcast([1,3], dims=[1]) # [[1,1,1],[2,2,2]] + index1 = index0.index_var(1) # [[0,1,2],[0,1,2]] + b = a.reindex_var([index0, index1]) + assert (b.data==[[4,5,6],[7,8,9]]).all() + assert (index0.data==[[1,1,1],[2,2,2]]).all() + assert (index1.data==[[0,1,2],[0,1,2]]).all() + + def test_doc(self): + assert "Index Operator" in jt.index.__doc__ + + def test_wrong_fuse(self): + a,b = jt.index([10,10]) + c = jt.zeros([10,10]) + c = c.reindex([b+1,a]) + x = b.clone() + jt.sync([c, x]) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_init.py b/python/jittor/test/test_init.py new file mode 100644 index 00000000..875565a3 --- /dev/null +++ b/python/jittor/test/test_init.py @@ -0,0 +1,96 @@ +# *************************************************************** +# Copyright (c) Jittor 2020, Author: +# All Rights Reserved. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +import unittest +import numpy as np +from jittor import models + +pass_this_test = False +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + import torchvision +except Exception as e: + pass_this_test = True + +def get_error(a, b): + return np.abs(a-b) / max(np.abs(a), np.abs(b), 1e-5) , np.abs(a-b) + +def check(jt_mod, torch_mod, rtol=1e-2, atol=1e-5, mean_atol=1e-5): + pa = [ p for p in jt_mod.parameters() if not p.is_stop_grad() ] + pb = list(torch_mod.parameters()) + assert len(pa) == len(pb) + error_count = 0 + for a,b in zip(pa, pb): + assert a.shape == list(b.shape), (a.shape, b.shape, a.name()) + stda, meana = np.std(a.numpy()), np.mean(a.numpy()) + stdb, meanb = np.std(b.detach().numpy()), np.mean(b.detach().numpy()) + + r_err, a_err = get_error(stda, stdb) + if r_err > rtol and a_err > atol: + error_count += 1 + print("compare std error", stda, stdb, r_err, a_err, a.name(), a.shape) + + r_err, a_err = get_error(meana, meanb) + if r_err > rtol and a_err > mean_atol: + error_count += 1 + print("compare mean error", meana, meanb, r_err, a_err, a.name(), a.shape) + assert error_count == 0 + +@unittest.skipIf(pass_this_test, f"pass init check, no torch found") +class TestInit(unittest.TestCase): + @classmethod + def setUpClass(self): + jt.seed(0) + np.random.seed(0) + torch.manual_seed(0) + + def test_conv(self): + check(jt.nn.Conv(64, 256, 3), torch.nn.Conv2d(64, 256, 3), rtol=1e-1, mean_atol=1e-2) + + def test_resnet(self): + check(models.resnet152(), torchvision.models.resnet152(), rtol=5e-2, mean_atol=1e-2) + +from jittor import init +from jittor import nn + +class TestInitFunc(unittest.TestCase): + def test_eye(self): + a = init.eye(2, "float32") + np.testing.assert_allclose(a.data, [[1,0],[0,1]]) + a = init.eye((2,3), "float32") + np.testing.assert_allclose(a.data, [[1,0,0],[0,1,0]]) + + linear = nn.Linear(2,2) + init.eye_(linear.weight) + np.testing.assert_allclose(linear.weight.data, [[1,0],[0,1]]) + + def test_constant(self): + a = init.constant(2, "float32") + np.testing.assert_allclose(a.data, [0,0]) + a = init.constant((2,3), value=1.) + np.testing.assert_allclose(a.data, [[1,1,1],[1,1,1]]) + + linear = nn.Linear(2,2) + init.constant_(linear.weight) + np.testing.assert_allclose(linear.weight.data, [[0,0],[0,0]]) + + def test_uniform(self): + a = init.uniform(5, "float32") + assert ((a>0) & (a<1)).all() + a = init.uniform((2,3), low=-1, high=1) + assert ((a>-1) & (a<1)).all() + + linear = nn.Linear(2,2) + init.uniform_(linear.weight) + assert (linear.weight > 0).all() + linear.weight.uniform_() + assert (linear.weight > 0).all() + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_interpolation.py b/python/jittor/test/test_interpolation.py new file mode 100644 index 00000000..87e52c01 --- /dev/null +++ b/python/jittor/test/test_interpolation.py @@ -0,0 +1,34 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Haoyang Peng <2247838039@qq.com> +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +from jittor import nn +import numpy as np +import unittest + +try: + import torch + has_torch = True +except: + has_torch = False + +@unittest.skipIf(not has_torch, "No pytorch installation found.") +class TestInterpolation(unittest.TestCase): + def test_interpolation_area(self): + img = np.random.uniform(0, 1, (1, 3, 24, 10)) + output_shape = (12, 5) + jimg = jt.array(img) + timg = torch.from_numpy(img) + joutput = nn.interpolate(jimg, output_shape, mode="area") + toutput = torch.nn.functional.interpolate(timg, output_shape, mode="area") + np.testing.assert_allclose(joutput.numpy(), toutput.numpy(), rtol=1e-7, atol=1e-7) + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_jit_tests.py b/python/jittor/test/test_jit_tests.py new file mode 100644 index 00000000..a366c6c3 --- /dev/null +++ b/python/jittor/test/test_jit_tests.py @@ -0,0 +1,31 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +from jittor import LOG + +def test(name): + doc = eval(f"jt.tests.{name}.__doc__") + doc = doc[doc.find("From"):].strip() + LOG.i(f"Run test {name} {doc}") + exec(f"jt.tests.{name}()") + +tests = [ name for name in dir(jt.tests) if not name.startswith("__") ] +src = "class TestJitTests(unittest.TestCase):\n" +for name in tests: + doc = eval(f"jt.tests.{name}.__doc__") + doc = doc[doc.find("From"):].strip() + src += f""" + def test_{name}(self): + test("{name}") + """ + +LOG.vvv("eval src\n"+src) +exec(src) + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_jtune.py b/python/jittor/test/test_jtune.py new file mode 100644 index 00000000..b9534525 --- /dev/null +++ b/python/jittor/test/test_jtune.py @@ -0,0 +1,54 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import os +import re +import sys + +class TestJtune(unittest.TestCase): + @classmethod + def setUpClass(self): + n, m, k = 2, 6, 16 + a = jt.random((n, m, 1)) + b = jt.random((1, m, k)) + jt.fetch_sync([a,b]) + with jt.profile_scope( + compile_options = {"jtune":1} + ) as rep: + c = (a*b).sum(1) + c.sync() + assert len(rep) == 2 + self.fname = rep[1][1] + self.jtune_path = os.path.join(jt.flags.jittor_path, "utils/jtune.py") + + def run_cmd(self, cmd): + cmd = f"warmup=0 rerun=0 {sys.executable} {self.jtune_path} {self.fname} {cmd}" + return jt.compiler.run_cmd(cmd) + + def test_run_so(self): + res = self.run_cmd("run_so").splitlines() + assert res[0]=="Enter fake_main entry.", res + assert res[1]==" Count TotalTime AvgTime MinTime MaxTime Input Output Compute", res + nums = res[2].split() + assert nums[0]=="1", nums + + def test_cc_to_so(self): + self.run_cmd("cc_to_so") + + def test_cc_to_s(self): + self.run_cmd("cc_to_s") + sname = self.fname[:-2] + 's' + with open(sname) as f: + src = f.read() + fma_ins = re.findall("fma.*", src) + assert len(fma_ins)>=4, f"fma instructions should be used for matmul. {fma_ins}" + self.run_cmd("s_to_so") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_knn.py b/python/jittor/test/test_knn.py new file mode 100644 index 00000000..a274eaf0 --- /dev/null +++ b/python/jittor/test/test_knn.py @@ -0,0 +1,56 @@ +# *************************************************************** +# Copyright (c) 2021 Jittor. All Rights Reserved. +# Maintainers: +# Zheng-Ning Liu +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** + + +import unittest +import jittor as jt +import numpy as np + +def topk(input, k, dim=None, largest=True, sorted=True): + if dim is None: + dim = -1 + if dim < 0: + dim += input.ndim + + transpose_dims = [i for i in range(input.ndim)] + transpose_dims[0] = dim + transpose_dims[dim] = 0 + input = input.transpose(transpose_dims) + index, values = jt.argsort(input, dim=0, descending=largest) + indices = index[:k] + values = values[:k] + indices = indices.transpose(transpose_dims) + values = values.transpose(transpose_dims) + return [values, indices] + +def knn(x, k): + inner = -2 * jt.nn.bmm(x.transpose(0, 2, 1), x) + xx = jt.sum(x ** 2, dim=1, keepdims=True) + distance = -xx - inner - xx.transpose(0, 2, 1) + return topk(distance, k=k, dim=-1) + +class TestKnnOp(unittest.TestCase): + def test_knn(self): + jt_a = jt.randn(32,512,3) + a1, b1 = jt.misc.knn(jt_a, jt_a, 16) + a2, b2 = knn(jt_a.transpose(0,2,1), 16) + a2 *= -1 + np.testing.assert_allclose(a1.data, a2.data, atol=1e-4) + + if jt.has_cuda: + with jt.flag_scope(use_cuda=1): + jt_a = jt.randn(32,512,3) + a1, b1 = jt.misc.knn(jt_a, jt_a, 16) + a2, b2 = knn(jt_a.transpose(0,2,1), 16) + a2 *= -1 + np.testing.assert_allclose(a1.data, a2.data, atol=1e-4) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_lazy_execution.py b/python/jittor/test/test_lazy_execution.py new file mode 100644 index 00000000..b92aada9 --- /dev/null +++ b/python/jittor/test/test_lazy_execution.py @@ -0,0 +1,49 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Meng-Hao Guo +# Dun Liang . +# +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +import unittest +import sys, os +from subprocess import getoutput + +class TestLazyExecution(unittest.TestCase): + @unittest.skipIf(not jt.has_cuda, "No cuda found") + def test_lazy_execution(self): + code = """ +import jittor as jt +jt.flags.use_cuda = 1 + +a = jt.zeros(1) +b = jt.code([1], a.dtype, [a], +cuda_header=''' +#include +''', +cuda_src=''' +__global__ void kernel(float32* a, float32* b) { + b[0] = a[0]; + assert(a[0] == 1); +} +kernel<<<1,1>>>(in0_p, out0_p); +''') +c = a+b +print(c) +""" + fpath = os.path.join(jt.flags.cache_path, "lazy_error.py") + with open(fpath, 'w') as f: + f.write(code) + res = getoutput(f"{sys.executable} {fpath}") + assert 'print(c)' in res + res = getoutput(f"lazy_execution=0 {sys.executable} {fpath}") + assert "''')" in res + + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_linalg.py b/python/jittor/test/test_linalg.py new file mode 100644 index 00000000..5ccc4986 --- /dev/null +++ b/python/jittor/test/test_linalg.py @@ -0,0 +1,317 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Haoyang Peng <2247838039@qq.com> +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +import numpy as np +import unittest + +try: + import torch + from torch.autograd import Variable + import autograd.numpy as anp + from autograd import jacobian + + has_autograd = True +except: + has_autograd = False + + +@unittest.skipIf(not has_autograd, "No autograd found.") +class TestLinalgOp(unittest.TestCase): + def test_svd(self): + def check_svd(a): + u, s, v = anp.linalg.svd(a, full_matrices=0) + return u, s, v + + def check_u(a): + u, s, v = anp.linalg.svd(a, full_matrices=0) + return u + + def check_s(a): + u, s, v = anp.linalg.svd(a, full_matrices=0) + return s + + def check_v(a): + u, s, v = anp.linalg.svd(a, full_matrices=0) + return v + + for i in range(50): + # not for full-matrices! + a = jt.random((2, 2, 5, 4)) + c_a = anp.array(a.data) + u, s, v = jt.linalg.svd(a) + tu, ts, tv = check_svd(c_a) + assert np.allclose(tu, u.data) + assert np.allclose(ts, s.data) + assert np.allclose(tv, v.data) + ju = jt.grad(u, a) + js = jt.grad(s, a) + jv = jt.grad(v, a) + grad_u = jacobian(check_u) + gu = grad_u(c_a) + gu = np.sum(gu, 4) + gu = np.sum(gu, 4) + gu = np.sum(gu, 2) + gu = np.sum(gu, 2) + grad_s = jacobian(check_s) + gs = grad_s(c_a) + gs = np.sum(gs, 4) + gs = np.sum(gs, 2) + gs = np.sum(gs, 2) + grad_v = jacobian(check_v) + gv = grad_v(c_a) + gv = np.sum(gv, 4) + gv = np.sum(gv, 4) + gv = np.sum(gv, 2) + gv = np.sum(gv, 2) + try: + assert np.allclose(ju.data, gu, atol=1e-5) + except AssertionError: + print(ju.data) + print(gu) + try: + assert np.allclose(js.data, gs, atol=1e-5) + except AssertionError: + print(js.data) + print(gs) + try: + assert np.allclose(jv.data, gv, atol=1e-5) + except AssertionError: + print(jv.data) + print(gv) + + def test_eigh(self): + def check_eigh(a, UPLO='L'): + w, v = anp.linalg.eigh(a, UPLO) + return w, v + + def check_w(a, UPLO='L'): + w, v = anp.linalg.eigh(a, UPLO) + return w + + def check_v(a, UPLO='L'): + w, v = anp.linalg.eigh(a, UPLO) + return v + + for i in range(50): + a = jt.random((2, 2, 3, 3)) + c_a = a.data + w, v = jt.linalg.eigh(a) + tw, tv = check_eigh(c_a) + assert np.allclose(w.data, tw) + assert np.allclose(v.data, tv) + jw = jt.grad(w, a) + jv = jt.grad(v, a) + check_gw = jacobian(check_w) + check_gv = jacobian(check_v) + gw = check_gw(c_a) + gw = np.sum(gw, 4) + gw = np.sum(gw, 2) + gw = np.sum(gw, 2) + assert np.allclose(gw, jw.data, rtol=1, atol=5e-8) + gv = check_gv(c_a) + gv = np.sum(gv, 4) + gv = np.sum(gv, 4) + gv = np.sum(gv, 2) + gv = np.sum(gv, 2) + assert np.allclose(gv, jv.data, rtol=1, atol=5e-8) + + def test_pinv(self): + def check_pinv(a): + w = anp.linalg.pinv(a) + return w + + for i in range(50): + x = jt.random((2, 2, 4, 3)) + c_a = x.data + mx = jt.linalg.pinv(x) + tx = check_pinv(c_a) + np.allclose(mx.data, tx) + jx = jt.grad(mx, x) + check_grad = jacobian(check_pinv) + gx = check_grad(c_a) + np.allclose(gx, jx.data) + + def test_inv(self): + def check_inv(a): + w = anp.linalg.inv(a) + return w + + for i in range(50): + tn = np.random.randn(4, 4).astype('float32') * 5 + while np.allclose(np.linalg.det(tn), 0): + tn = np.random.randn((4, 4)).astype('float32') * 5 + x = jt.array(tn) + x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"]) + c_a = x.data + mx = jt.linalg.inv(x) + tx = check_inv(c_a) + np.allclose(mx.data, tx) + jx = jt.grad(mx, x) + check_grad = jacobian(check_inv) + gx = check_grad(c_a) + np.allclose(gx, jx.data) + + def test_slogdet(self): + def check_ans(a): + s, w = anp.linalg.slogdet(a) + return s, w + + def check_slogdet(a): + s, w = anp.linalg.slogdet(a) + return w + + for i in range(50): + tn = np.random.randn(4, 4).astype('float32') * 10 + while np.allclose(np.linalg.det(tn), 0): + tn = np.random.randn((4, 4)).astype('float32') * 10 + x = jt.array(tn) + x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"]) + s = list(x.shape) + det_s = s[:-2] + if len(det_s) == 0: + det_s.append(1) + sign, mx = jt.linalg.slogdet(x) + ts, ta = check_ans(x.data) + assert np.allclose(sign.data, ts) + assert np.allclose(mx.data, ta) + jx = jt.grad(mx, x) + check_sgrad = jacobian(check_slogdet) + gx = check_sgrad(x.data) + gx = np.sum(gx, 2) + gx = np.sum(gx, 2) + assert np.allclose(gx, jx.data) + + def test_cholesky(self): + def check_cholesky(a): + L = anp.linalg.cholesky(a) + return L + + for i in range(50): + x = jt.array(np.diag((np.random.rand(3) + 1) * 2)) + x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"]) + tx = x.data + L = jt.linalg.cholesky(x) + tL = check_cholesky(tx) + assert np.allclose(tL, L.data) + jx = jt.grad(L, x) + check_grad = jacobian(check_cholesky) + gx = check_grad(tx) + gx = np.sum(gx, 0) + gx = np.sum(gx, 0) + gx = np.sum(gx, 0) + gx = np.sum(gx, 0) + assert np.allclose(jx.data, gx) + + def test_solve(self): + def check_solve(a, b): + ans = anp.linalg.solve(a, b) + return ans + + for i in range(50): + a = jt.random((2, 2, 3, 3)) + b = jt.random((2, 2, 3)) + ans = jt.linalg.solve(a, b) + ta = check_solve(a.data, b.data) + assert np.allclose(ans.data, ta) + jx = jt.grad(ans, a) + check_sgrad = jacobian(check_solve) + gx = check_sgrad(a.data, b.data) + gx = np.sum(gx, 0) + gx = np.sum(gx, 0) + gx = np.sum(gx, 0) + try: + assert np.allclose(gx, jx.data, rtol=1) + except AssertionError: + print(gx) + print(jx.data) + + def test_det(self): + def check_det(a): + de = anp.linalg.det(a) + return de + + for i in range(50): + tn = np.random.randn(3, 3).astype('float32') * 5 + while np.allclose(np.linalg.det(tn), 0): + tn = np.random.randn((3, 3)).astype('float32') * 5 + x = jt.array(tn) + x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"]) + s = list(x.shape) + x_s = s[:-2] + if len(s) == 2: + x_s.append(1) + det = jt.linalg.det(x) + ta = check_det(x.data) + assert np.allclose(det.data, ta) + jx = jt.grad(det, x) + check_sgrad = jacobian(check_det) + gx = check_sgrad(x.data) + gx = np.sum(gx, 2) + gx = np.sum(gx, 2) + assert np.allclose(gx, jx.data) + + def test_qr(self): + for i in range(50): + tn = np.random.randn(3, 3).astype('float32') + while np.allclose(np.linalg.det(tn), 0): + tn = np.random.randn((3, 3)).astype('float32') + x = jt.array(tn) + # x = x.reindex([2, 2, x.shape[0], x.shape[1]], ["i2", "i3"]) + t_x = torch.from_numpy(tn) + t_x = Variable(t_x, requires_grad=True) + jq, jr = jt.linalg.qr(x) + tq, tr = torch.qr(t_x) + try: + assert np.allclose(jq.data, tq.detach().numpy(), rtol=1e-4, atol=1e-6) + assert np.allclose(jr.data, tr.detach().numpy(), rtol=1e-4, atol=1e-6) + except AssertionError: + print("ours' qr results:") + print(jq) + print(jr) + print("pytorch's qr results:") + print(tq) + print(tr) + gq = jt.grad(jq, x).data + gr = jt.grad(jr, x).data + tgq = torch.autograd.grad(tq, t_x, torch.ones_like(tq), retain_graph=True) + tgr = torch.autograd.grad(tr, t_x, torch.ones_like(tr), retain_graph=True) + try: + assert np.allclose(gq, tgq[0].numpy(), rtol=1e-4, atol=1e-6) + assert np.allclose(gr, tgr[0].numpy(), rtol=1e-4, atol=1e-6) + except AssertionError: + print("ours' qr grad results:") + print(gq) + print(gr) + print("pytorch's qr grad result") + print(tgq[0]) + print(tgr[0]) + +@unittest.skipIf(not jt.has_cuda, "No cuda found.") +class TestBUG4_2Op(unittest.TestCase): + def test(self): + jt.flags.use_cuda = 1 + x = jt.randn(32, 50, 2) + y = jt.rand(32, 1, 2) + + # MLE + mean = x.mean(dim=1, keepdims=True)# [batch_size, 1, n_feature] + mup = jt.transpose((x - mean), [0, 2, 1])# [batch_size, n_feature, n_particles] + cov = jt.nn.bmm_transpose(mup, mup) / (50 - 1)# [batch_size, n_feature, n_feature] + prec = jt.linalg.inv(cov)# [batch_size, n_feature, n_feature] + # print(prec) + # log_prob + dst = y - mean + log_prob = -1/2 * jt.bmm(dst, jt.bmm_transpose(prec, dst)) + grad = jt.grad(log_prob, x) + grad.sync() + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_load_pth.py b/python/jittor/test/test_load_pth.py new file mode 100644 index 00000000..a70bd9a8 --- /dev/null +++ b/python/jittor/test/test_load_pth.py @@ -0,0 +1,63 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Wenyang Zhou <576825820@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +from jittor import nn +from jittor.models import resnet +import numpy as np +import sys, os +import random +import math +import unittest +from .test_reorder_tuner import simple_parser +from .test_log import find_log_with_re + +model_test = os.environ.get("model_test", "") == "1" +skip_model_test = not model_test + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + import torchvision as tv +except: + skip_model_test = True + +@unittest.skipIf(skip_model_test, "Skip model test") +class TestLoadPth(unittest.TestCase): + def test_load_pth(self): + # TODO: load torch model params + # define input img + img = np.random.random((1,3,224,224)).astype("float32") + jt_img = jt.array(img) + torch_img = torch.Tensor(img) + # define pytorch and jittor pretrained model + torch_model = tv.models.resnet18(True) + + jt_model = resnet.Resnet18() + jt_model.load_parameters(torch_model.state_dict()) + # todo: model.train() model.eval() + + # output + jt_out = jt_model(jt_img) + torch_out = torch_model(torch_img) + print(np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy()))) + assert np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())) < 1e-3 + + pth_name = os.path.join(jt.flags.cache_path, "x.pth") + torch.save(torch_model.state_dict(), pth_name) + jt_model.load(pth_name) + + # output + jt_out = jt_model(jt_img) + # torch_out = torch_model(torch_img) + print(np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy()))) + assert np.max(np.abs(jt_out.fetch_sync() - torch_out.detach().numpy())) < 1e-3 + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_lock.py b/python/jittor/test/test_lock.py new file mode 100644 index 00000000..016f043a --- /dev/null +++ b/python/jittor/test/test_lock.py @@ -0,0 +1,28 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Wenyang Zhou <576825820@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import os, sys +import jittor as jt +import jittor_utils as jit_utils + +class TestLock(unittest.TestCase): + def test(self): + if os.environ.get('lock_full_test', '0') == '1': + cache_path = os.path.join(jit_utils.home(), ".cache", "jittor", "lock") + assert os.system(f"rm -rf {cache_path}") == 0 + cmd = f"cache_name=lock {sys.executable} -m jittor.test.test_example" + else: + cmd = f"{sys.executable} -m jittor.test.test_example" + print("run cmd twice", cmd) + assert os.system(f"{cmd} & {cmd} & wait %1 && wait %2") == 0 + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_log.py b/python/jittor/test/test_log.py new file mode 100644 index 00000000..b207302c --- /dev/null +++ b/python/jittor/test/test_log.py @@ -0,0 +1,55 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import os +import re +import jittor as jt +from jittor import LOG + +def find_log_with_re(logs, pattern=None, **args): + if pattern: + pattern = re.compile(pattern) + flogs = [] + for log in logs: + for arg in args: + if log[arg] != args[arg]: + break + else: + if pattern: + res = re.findall(pattern, log["msg"]) + if len(res): + flogs.append(res[0]) + else: + flogs.append(log["msg"]) + return flogs + +class TestLog(unittest.TestCase): + def test_log_capture(self): + with jt.log_capture_scope(log_v=1000, log_vprefix="") as logs: + LOG.v("1") + LOG.vv("2") + LOG.i("3") + LOG.w("4") + LOG.e("5") + a = jt.zeros([10]) + a.sync() + # TODO: why need manually delete this variable? + del a + logs2 = LOG.log_capture_read() + assert len(logs2)==0 + + for i in range(5): + assert logs[i]['msg'] == str(i+1) + assert logs[i]['level'] == 'iiiwe'[i] + assert logs[i]['name'] == 'test_log.py' + finished_log = [ l["msg"] for l in logs + if l["name"]=="executor.cc" and "return vars:" in l["msg"]] + assert len(finished_log)==1 and "[10,]" in finished_log[0] + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_longest_dis_fuse.py b/python/jittor/test/test_longest_dis_fuse.py new file mode 100644 index 00000000..d18d2c5c --- /dev/null +++ b/python/jittor/test/test_longest_dis_fuse.py @@ -0,0 +1,69 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import sys +import os +import jittor as jt +import unittest +import time +import numpy as np + +def get_init_var(shape, dtype): + return jt.random(shape, dtype) + +def pool(x, size, op, padding, stride = 1): # TODO: stride, padding + N,C,H,W = x.shape + h = (H+padding*2-size)//stride+1 + w = (W+padding*2-size)//stride+1 + xx = x.reindex([N,C,h,w,size,size], [ + "i0", # Nid + "i1", # Cid + f"i2*{stride}-{padding}+i4", # Hid + f"i3*{stride}-{padding}+i5", # Wid + ]) + return xx.reindex_reduce(op, [N,C,h,w], [ + "i0", # Nid + "i1", # Cid + "i2", # Hid + "i3", # Wid + ]) + +def relu(x): return jt.maximum(x, jt.float32(0)) + +def resnet_fake(): + from jittor import nn + net = nn.Sequential( + nn.Conv(3, 64, 7, 2, 3), + nn.BatchNorm(64), + nn.ReLU(), + nn.Pool(3, 2, 1) + ) + return net + +class TestLongestDisFuse(unittest.TestCase): + + def test_longest_dis_fuse(self): + x = jt.array(np.random.rand(1,3,224,224).astype(np.float32)) + net = resnet_fake() + loss = jt.sum(net(x)) + ps = net.parameters() + gs = jt.grad(loss, ps) + jt.sync(gs) + # assert not alloc big tensor + g = jt.dump_all_graphs() + for s in g.nodes_info: + if not s.startswith("Var"): + continue + shape = s.split("[")[1].split("]")[0].split(",") + ptr = s.split("(")[1].split(")")[0].split(",")[-1] + if ptr != '0' and ptr != '0x0': + assert len(shape)<=5, s + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_loss.py b/python/jittor/test/test_loss.py new file mode 100644 index 00000000..168296f9 --- /dev/null +++ b/python/jittor/test/test_loss.py @@ -0,0 +1,174 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import os +import numpy as np +import jittor.nn as jnn + +skip_this_test = False + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + import torch.nn as tnn +except: + skip_this_test = True + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestLoss(unittest.TestCase): + def test_l1_loss(self): + jt_loss=jnn.L1Loss() + tc_loss=tnn.L1Loss() + output=np.random.randn(10,100).astype(np.float32) + target=np.random.randn(10,100).astype(np.float32) + jt_y=jt_loss(jt.array(output), jt.array(target)) + tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target)) + assert np.allclose(jt_y.numpy(), tc_y.numpy()) + + def test_mse_loss(self): + jt_loss=jnn.MSELoss() + tc_loss=tnn.MSELoss() + output=np.random.randn(10,100).astype(np.float32) + target=np.random.randn(10,100).astype(np.float32) + jt_y=jt_loss(jt.array(output), jt.array(target)) + tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target)) + assert np.allclose(jt_y.numpy(), tc_y.numpy()) + + def test_nll_loss(self): + tc_loss = tnn.functional.nll_loss + jt_loss = jnn.nll_loss + output=np.random.randn(10,10).astype(np.float32) + target=np.random.randint(10, size=(10)) + jt_y=jt_loss(jt.array(output), jt.array(target),reduction='mean') + tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target),reduction='mean') + assert np.allclose(jt_y.numpy(), tc_y.numpy()) + output=np.random.randn(10,10).astype(np.float32) + target=np.random.randint(10, size=(10)) + weight=np.random.randn(10,).astype(np.float32) + jt_y=jt_loss(jt.array(output), jt.array(target),jt.array(weight),reduction='mean') + tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target),torch.from_numpy(weight),reduction='mean') + assert np.allclose(jt_y.numpy(), tc_y.numpy()) + + def test_cross_entropy_loss(self): + jt_loss=jnn.CrossEntropyLoss() + tc_loss=tnn.CrossEntropyLoss() + output=np.random.randn(10,10).astype(np.float32) + target=np.random.randint(10, size=(10)) + jt_y=jt_loss(jt.array(output), jt.array(target)) + tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target)) + assert np.allclose(jt_y.numpy(), tc_y.numpy()) + + def test_cross_entropy_loss_v2(self): + B = 100 + C = 5 + for shape in [[100,1],[],[100,20]]: + s1 = [B,C]+shape + s2 = [B]+shape + a = np.random.randn(*s1).astype(np.float32) + b = np.random.randint(0,C,size=s2).astype(np.int32) + weight = np.random.randn(C).astype(np.float32) + + for r in ['mean','sum','none']: + r1 = torch.nn.functional.cross_entropy(torch.tensor(a),torch.tensor(b.astype(np.int64)),weight=torch.tensor(weight),reduction=r) + r2 = jnn.cross_entropy_loss(jt.array(a),jt.array(b),weight=jt.array(weight),reduction=r) + np.testing.assert_allclose(r1.numpy(),r2.numpy(),rtol=1e-3, atol=1e-3) + + for r in ['mean','sum','none']: + r1 = torch.nn.functional.cross_entropy(torch.tensor(a),torch.tensor(b.astype(np.int64)),reduction=r) + r2 = jnn.cross_entropy_loss(jt.array(a),jt.array(b),reduction=r) + np.testing.assert_allclose(r1.numpy(),r2.numpy(),rtol=1e-3, atol=1e-3) + + r1 = torch.nn.functional.cross_entropy(torch.tensor(a),torch.tensor(b.astype(np.int64))) + r2 = jnn.cross_entropy_loss(jt.array(a),jt.array(b)) + np.testing.assert_allclose(r1.numpy(),r2.numpy(),rtol=1e-3, atol=1e-3) + + r1 = torch.nn.functional.cross_entropy(torch.tensor(a),torch.tensor(b.astype(np.int64)),weight=torch.tensor(weight)) + r2 = jnn.cross_entropy_loss(jt.array(a),jt.array(b),weight=jt.array(weight)) + np.testing.assert_allclose(r1.numpy(),r2.numpy(),rtol=1e-3, atol=1e-3) + + for r in ['mean','sum','none']: + r1 = torch.nn.functional.cross_entropy(torch.tensor(a),torch.tensor(b.astype(np.int64)),weight=torch.tensor(weight),reduction=r,ignore_index=C//2) + r2 = jnn.cross_entropy_loss(jt.array(a),jt.array(b),weight=jt.array(weight),reduction=r,ignore_index=C//2) + np.testing.assert_allclose(r1.numpy(),r2.numpy(),rtol=1e-3, atol=1e-3) + + for r in ['mean','sum','none']: + r1 = torch.nn.functional.cross_entropy(torch.tensor(a),torch.tensor(b.astype(np.int64)),reduction=r,ignore_index=C//2) + r2 = jnn.cross_entropy_loss(jt.array(a),jt.array(b),reduction=r,ignore_index=C//2) + np.testing.assert_allclose(r1.numpy(),r2.numpy(),rtol=1e-3, atol=1e-3) + + r1 = torch.nn.functional.cross_entropy(torch.tensor(a),torch.tensor(b.astype(np.int64)),ignore_index=C//2) + r2 = jnn.cross_entropy_loss(jt.array(a),jt.array(b),ignore_index=C//2) + np.testing.assert_allclose(r1.numpy(),r2.numpy(),rtol=1e-3, atol=1e-3) + + r1 = torch.nn.functional.cross_entropy(torch.tensor(a),torch.tensor(b.astype(np.int64)),weight=torch.tensor(weight),ignore_index=C//2) + r2 = jnn.cross_entropy_loss(jt.array(a),jt.array(b),weight=jt.array(weight),ignore_index=C//2) + np.testing.assert_allclose(r1.numpy(),r2.numpy(),rtol=1e-3, atol=1e-3) + + + def test_cross_entropy_ignore_index(self): + ignore_index = np.random.randint(0, 10) + jt_loss = jnn.CrossEntropyLoss(ignore_index=ignore_index) + tc_loss = tnn.CrossEntropyLoss(ignore_index=ignore_index) + output = np.random.rand(100, 10).astype(np.float32) + target = np.random.randint(10, size=(100)) + jt_y=jt_loss(jt.array(output), jt.array(target)) + tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target)) + assert np.allclose(jt_y.numpy(), tc_y.numpy()) + + def test_cross_entropy_weight(self): + weight = np.random.rand(10).astype('float32') + jt_loss = jnn.CrossEntropyLoss(weight=jt.array(weight)) + tc_loss = tnn.CrossEntropyLoss(weight=torch.from_numpy(weight)) + output = np.random.rand(100, 10).astype(np.float32) + target = np.random.randint(10, size=(100)) + jt_y=jt_loss(jt.array(output), jt.array(target)) + tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target)) + assert np.allclose(jt_y.numpy(), tc_y.numpy()) + + def test_cross_entropy_weight_ignore(self): + weight = np.random.rand(4).astype('float32') + jt_loss = jnn.CrossEntropyLoss(weight=jt.array(weight), ignore_index=1) + tc_loss = tnn.CrossEntropyLoss(weight=torch.from_numpy(weight), ignore_index=1) + output = np.random.rand(3, 4, 2,2).astype(np.float32) + target = np.random.randint(4, size=(3, 2,2)) + jt_y=jt_loss(jt.array(output), jt.array(target)) + tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target)) + np.testing.assert_allclose(jt_y.numpy(), tc_y.numpy()) + + + def test_bce_loss(self): + jt_loss=jnn.BCELoss() + tc_loss=tnn.BCELoss() + jt_sig = jnn.Sigmoid() + tc_sig = tnn.Sigmoid() + output=np.random.randn(100).astype(np.float32) + target=np.random.randint(2, size=(100)).astype(np.float32) + jt_y=jt_loss(jt_sig(jt.array(output)), jt.array(target)) + tc_y=tc_loss(tc_sig(torch.from_numpy(output)), torch.from_numpy(target)) + assert np.allclose(jt_y.numpy(), tc_y.numpy()) + + weight=np.random.randn(100).astype(np.float32) + jt_loss=jnn.BCELoss(weight=jt.array(weight), size_average=False) + tc_loss=tnn.BCELoss(weight=torch.Tensor(weight), size_average=False) + jt_y=jt_loss(jt_sig(jt.array(output)), jt.array(target)) + tc_y=tc_loss(tc_sig(torch.from_numpy(output)), torch.from_numpy(target)) + assert np.allclose(jt_y.numpy(), tc_y.numpy()) + + def test_bce_with_logits_loss(self): + jt_loss=jnn.BCEWithLogitsLoss() + tc_loss=tnn.BCEWithLogitsLoss() + output=np.random.randn(100).astype(np.float32) + target=np.random.randint(2, size=(100)).astype(np.float32) + jt_y=jt_loss(jt.array(output), jt.array(target)) + tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target)) + assert np.allclose(jt_y.numpy(), tc_y.numpy()) + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_loss3d.py b/python/jittor/test/test_loss3d.py new file mode 100644 index 00000000..3374b069 --- /dev/null +++ b/python/jittor/test/test_loss3d.py @@ -0,0 +1,101 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Zheng-Ning Liu +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** + +import unittest +import numpy as np + +try: + import torch + from emd import earth_mover_distance as TEMD +except: + skip_this_test = True + +import jittor as jt +from jittor.loss3d import chamfer_loss +from jittor.loss3d import earth_mover_distance + + +class TestLoss3d(unittest.TestCase): + def test_chamfer(self): + def test(): + pc1 = np.random.randn(10, 100, 3).astype(np.float32) + pc2 = np.random.randn(10, 100, 3).astype(np.float32) + + Jpc1 = jt.array(pc1) + Jpc2 = jt.array(pc2) + Jcf = chamfer_loss(Jpc1, Jpc2, dims='BNC') + + ppc1 = np.repeat(pc1[:, :, None, :], 100, axis=2) + ppc2 = np.repeat(pc2[:, None, :, :], 100, axis=1) + ncf = np.sqrt(((ppc1 - ppc2) ** 2).sum(axis=-1)).min(axis=-1) + ncf = ncf.mean() + + self.assertTrue(np.allclose(ncf, Jcf.item())) + + test() + + if jt.has_cuda: + with jt.flag_scope(use_cuda=1): + test() + + def test_chamfer_dims(self): + def test(): + pc1 = np.random.randn(10, 100, 3).astype(np.float32) + pc2 = np.random.randn(10, 100, 3).astype(np.float32) + + Jpc1 = jt.array(pc1.transpose([0, 2, 1])) + Jpc2 = jt.array(pc2.transpose([0, 2, 1])) + Jcf = chamfer_loss(Jpc1, Jpc2, dims='BCN') + + ppc1 = np.repeat(pc1[:, :, None, :], 100, axis=2) + ppc2 = np.repeat(pc2[:, None, :, :], 100, axis=1) + ncf = np.sqrt(((ppc1 - ppc2) ** 2).sum(axis=-1)).min(axis=-1) + ncf = ncf.mean() + + self.assertTrue(np.allclose(ncf, Jcf.item())) + + test() + + if jt.has_cuda: + with jt.flag_scope(use_cuda=1): + test() + + @unittest.skipIf(skip_this_test, "No Pyorch_EMD found") + def test_emd_torch(self): + if jt.has_cuda: + jt.flags.use_cuda = True + + pc1 = np.random.randn(10, 100, 3).astype(np.float32) + pc2 = np.random.randn(10, 50, 3).astype(np.float32) + + Tpc1 = torch.from_numpy(pc1).cuda() + Tpc2 = torch.from_numpy(pc2).cuda() + Tpc1.requires_grad = True + Tpc2.requires_grad = True + Temdcost = TEMD(Tpc1, Tpc2, transpose=False) + Temd = Temdcost.mean() + + Jpc1 = jt.array(pc1) + Jpc2 = jt.array(pc2) + Jemd = earth_mover_distance(Jpc1, Jpc2, dims='BNC') + + Temd.backward() + Tgrad1 = Tpc1.grad.cpu().numpy() + Tgrad2 = Tpc2.grad.cpu().numpy() + + Jgrad1, Jgrad2 = jt.grad(Jemd, [Jpc1, Jpc2]) + + self.assertTrue(np.allclose(Temd.item(), Jemd.item()), Temd.item() - Jemd.item()) + self.assertTrue(np.allclose(Tgrad1, Jgrad1.data, atol=1e-4), np.abs(Tgrad1 - Jgrad1.data).max()) + self.assertTrue(np.allclose(Tgrad2, Jgrad2.data, atol=1e-4), np.abs(Tgrad2 - Jgrad2.data).max()) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_lr_scheduler.py b/python/jittor/test/test_lr_scheduler.py new file mode 100644 index 00000000..b51d979f --- /dev/null +++ b/python/jittor/test/test_lr_scheduler.py @@ -0,0 +1,53 @@ + +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import random + +skip_this_test = False + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch +except: + torch = None + skip_this_test = True + +def check_equal(q,k,v,tatt,jatt): + tq=torch.from_numpy(q) + jq=jt.array(q) + tk=torch.from_numpy(k) + jk=jt.array(k) + tv=torch.from_numpy(v) + jv=jt.array(v) + + jatt.load_parameters(tatt.state_dict()) + ty, tw = tatt(tq, tk, tv) + jy, jw = jatt(jq, jk, jv) + assert np.allclose(ty.detach().numpy(), jy.numpy(), rtol=1e-3) + assert np.allclose(tw.detach().numpy(), jw.numpy(), rtol=1e-3) + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestAttention(unittest.TestCase): + def test_attention(self): + j_opt = jt.optim.SGD([jt.array([1])], 1.0) + t_opt = torch.optim.SGD([torch.ones([1])], 1.0) + j_scheduler = jt.lr_scheduler.ReduceLROnPlateau(j_opt) + t_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(t_opt) + for i in range(100): + loss=random.random() + j_scheduler.step(loss) + t_scheduler.step(loss) + assert j_opt.lr == t_opt.state_dict()['param_groups'][0]['lr'] + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_lstm.py b/python/jittor/test/test_lstm.py new file mode 100644 index 00000000..d21947a6 --- /dev/null +++ b/python/jittor/test/test_lstm.py @@ -0,0 +1,107 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Zheng-Ning Liu +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import jittor.nn as nn +import numpy as np + + +skip_this_test = False + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + import torch.nn as tnn +except: + torch = None + tnn = None + skip_this_test = True + + +def check_equal(t_rnn, j_rnn, input, h0, c0): + j_rnn.load_state_dict(t_rnn.state_dict()) + + t_output, (th, tc) = t_rnn(torch.from_numpy(input), + (torch.from_numpy(h0), torch.from_numpy(c0))) + j_output, (jh, jc) = j_rnn(jt.float32(input), + (jt.float32(h0), jt.float32(c0))) + + assert np.allclose(t_output.detach().numpy(), j_output.data, rtol=1e-03, atol=1e-06) + assert np.allclose(th.detach().numpy(), jh.data, rtol=1e-03, atol=1e-06) + assert np.allclose(tc.detach().numpy(), jc.data, rtol=1e-03, atol=1e-06) + + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestLSTM(unittest.TestCase): + def test_lstm_cell(self): + np_h0 = torch.randn(3, 20).numpy() + np_c0 = torch.randn(3, 20).numpy() + + t_rnn = tnn.LSTMCell(10, 20) + input = torch.randn(2, 3, 10) + h0 = torch.from_numpy(np_h0) + c0 = torch.from_numpy(np_c0) + t_output = [] + for i in range(input.size()[0]): + h0, c0 = t_rnn(input[i], (h0, c0)) + t_output.append(h0) + t_output = torch.stack(t_output, dim=0) + + j_rnn = nn.LSTMCell(10, 20) + j_rnn.load_state_dict(t_rnn.state_dict()) + + input = jt.float32(input.numpy()) + h0 = jt.float32(np_h0) + c0 = jt.float32(np_c0) + j_output = [] + for i in range(input.size()[0]): + h0, c0 = j_rnn(input[i], (h0, c0)) + j_output.append(h0) + j_output = jt.stack(j_output, dim=0) + + t_output = t_output.detach().numpy() + j_output = j_output.data + assert np.allclose(t_output, j_output, rtol=1e-03, atol=1e-06) + + def test_lstm(self): + h0 = np.random.rand(1, 2, 20).astype(np.float32) + c0 = np.random.rand(1, 2, 20).astype(np.float32) + input = np.random.rand(5, 2, 10).astype(np.float32) + + t_rnn = tnn.LSTM(10, 20) + j_rnn = nn.LSTM(10, 20) + check_equal(t_rnn, j_rnn, input, h0, c0) + + proj_size = 13 + h0 = np.random.rand(1, 2, proj_size).astype(np.float32) + c0 = np.random.rand(1, 2, 20).astype(np.float32) + input = np.random.rand(5, 2, 10).astype(np.float32) + t_rnn = tnn.LSTM(10, 20, proj_size=proj_size) + j_rnn = nn.LSTM(10, 20, proj_size=proj_size) + check_equal(t_rnn, j_rnn, input, h0, c0) + + h0 = np.random.rand(2, 4, 20).astype(np.float32) + c0 = np.random.rand(2, 4, 20).astype(np.float32) + input = np.random.rand(5, 4, 10).astype(np.float32) + + t_rnn = tnn.LSTM(10, 20, num_layers=2) + j_rnn = nn.LSTM(10, 20, num_layers=2) + check_equal(t_rnn, j_rnn, input, h0, c0) + + h0 = np.random.rand(2, 4, proj_size).astype(np.float32) + c0 = np.random.rand(2, 4, 20).astype(np.float32) + input = np.random.rand(5, 4, 10).astype(np.float32) + + t_rnn = tnn.LSTM(10, 20, num_layers=2, proj_size=proj_size) + j_rnn = nn.LSTM(10, 20, num_layers=2, proj_size=proj_size) + check_equal(t_rnn, j_rnn, input, h0, c0) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_matmul.py b/python/jittor/test/test_matmul.py new file mode 100644 index 00000000..9639fae5 --- /dev/null +++ b/python/jittor/test/test_matmul.py @@ -0,0 +1,401 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from .test_log import find_log_with_re +f32 = jt.float32 +from jittor import nn, Module + +def relu(x): return jt.maximum(x, f32(0)) + +class Model(Module): + def __init__(self): + self.linear1 = nn.Linear(1, 10) + self.linear2 = nn.Linear(10, 1) + def execute(self, x): + x = self.linear1(x) + x = nn.relu(x) + x = self.linear2(x) + return x + +class Model2(Module): + def __init__(self): + self.linear1 = nn.Linear(1, 10) + def execute(self, x): + x = self.linear1(x) + return x + +def test_matmul(s1, s2): + a = jt.random(s1) + b = jt.random(s2) + c = jt.nn.matmul(a, b) + c_ = np.matmul(a.data, b.data) + with jt.log_capture_scope(log_v=0, log_vprefix="op.cc=100") as logs: + c__ = c.data + assert np.allclose(c_, c__) + logs = find_log_with_re(logs, + "Jit op key (not )?found: (mkl)|(cublas)_matmul.*") + assert(len(logs)==1) + +def matmul2(a, b, tp): + assert len(a.shape) >= 2 and len(b.shape) == 2 + if (tp == 0): + shape = [a.shape[0], a.shape[1], b.shape[1]] + sa = 2 + sb = 0 + d = 1 + elif (tp == 1): + shape = [a.shape[0], a.shape[1], b.shape[1]] + sa = 2 + sb = 1 + d = 0 + elif (tp == 2): + shape = [a.shape[0], b.shape[0], a.shape[1]] + sa = 1 + sb = 0 + d = 2 + else: + return + + a = a.broadcast(shape, [sa]) + b = b.broadcast(shape, [sb]) + return (a*b).sum(d) + +def test_matmul2(s1, s2, t1, t2, dtype = 'float32'): + if (not t1) and (not t2): + tp = 0 + if (t1) and (not t2): + tp = 1 + if (not t1) and (t2): + tp = 2 + + if (dtype.startswith('float')): + a = jt.random(s1, dtype = dtype) + b = jt.random(s2, dtype = dtype) + else: + a = jt.random(s1) + b = jt.random(s2) + a = (a * 2000 - 1000).cast(dtype) + b = (b * 2000 - 1000).cast(dtype) + c = matmul2(a, b, tp) + if t1: + a_ = a.data.transpose() + else: + a_ = a.data + if t2: + b_ = b.data.transpose() + else: + b_ = b.data + c_ = np.matmul(a_, b_) + with jt.log_capture_scope(log_v=0, log_vprefix="op.cc=100") as logs: + c__ = c.data + assert np.allclose(c_, c__) + logs = find_log_with_re(logs, + "Jit op key (not )?found: (mkl)|(cublas)_matmul.*") + if (dtype.startswith('float')): + if jt.flags.use_cuda or dtype == 'float32': + assert(len(logs)==1) + +class TestMatmul(unittest.TestCase): + def test_matmul_type(self): + test_matmul2([2,5],[5,8], False, False, 'float32') + test_matmul2([5,2],[5,8], True, False, 'float32') + test_matmul2([2,5],[8,5], False, True, 'float32') + + test_matmul2([2,5],[5,8], False, False, 'float64') + test_matmul2([5,2],[5,8], True, False, 'float64') + test_matmul2([2,5],[8,5], False, True, 'float64') + + test_matmul2([2,5],[5,8], False, False, 'int32') + test_matmul2([5,2],[5,8], True, False, 'int32') + test_matmul2([2,5],[8,5], False, True, 'int32') + + def test_matmul(self): + test_matmul([2,5],[5,8]) + test_matmul([200,500],[500,800]) + test_matmul([500,500],[500,50]) + test_matmul2([2,5],[5,8], False, False) + test_matmul2([5,2],[5,8], True, False) + test_matmul2([2,5],[8,5], False, True) + + def test_backward(self): + np.random.seed(0) + jt.set_seed(3) + model = Model() + for p in reversed(model.parameters()): p.sync(0,0) + SGD = jt.nn.SGD(model.parameters(), 0.05, 0.9, 0) + n = 1000 + batch_size = 50 + base_lr = 0.05 + # we need to stop grad of global value to prevent memory leak + lr = f32(base_lr).name("lr").stop_grad() + def get_data(n): + for i in range(n): + x = np.random.rand(batch_size, 1) + y = x*x + yield jt.float32(x), jt.float32(y) + + for i,(x,y) in enumerate(get_data(n)): + pred_y = model(x).name("pred_y") + loss = ((pred_y - y)**f32(2)).name("loss") + loss_mean = loss.mean() + + SGD.step(loss_mean) + if i>2: + assert prev == jt.liveness_info(), f"memory leak {prev} {jt.liveness_info()}" + prev = jt.liveness_info() + if (i % 10 == 9): + print(f"step {i}, loss = {loss_mean.data.sum()} {jt.liveness_info()}") + else: + loss_mean.data.sum() + jt.liveness_info() + + possible_results = [0.00022486248053610325, 0.00020916158973705024, 0.00561215] + loss_mean = loss_mean.data + assert any(abs(loss_mean - r) < 1e-6 for r in possible_results), loss_mean + jt.clean() + + def test_backward_once(self): + np.random.seed(0) + jt.set_seed(3) + model = Model2() + n = 1 + batch_size = 50 + + def get_data(n): + for i in range(n): + x = np.random.rand(batch_size, 1) + y = x*x + yield jt.float32(x), jt.float32(y) + + for i,(x,y) in enumerate(get_data(n)): + pred_y = model(x).name("pred_y") + with jt.log_capture_scope(log_v=0, log_vprefix="op.cc=100") as logs: + jt.sync_all() + logs = find_log_with_re(logs, + "Jit op key (not )?found: (mkl)_matmul.*") + assert(len(logs)==1) + with jt.log_capture_scope(log_silent=1, log_v=0, log_vprefix="op.cc=100,exe=1000") as logs_b: + gs = jt.grad(pred_y, x) + gs2 = jt.grad(pred_y, model.linear1.weight) + jt.sync_all() + logs_b = find_log_with_re(logs_b, + "Jit op key (not )?found: (mkl)_matmul.*") + assert len(logs_b)==2, len(logs_b) + jt.clean() + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + def test_matmul_type_cuda(self): + with jt.flag_scope(use_cuda=1): + test_matmul2([2,5],[5,8], False, False, 'float32') + test_matmul2([5,2],[5,8], True, False, 'float32') + test_matmul2([2,5],[8,5], False, True, 'float32') + + test_matmul2([2,5],[5,8], False, False, 'float64') + test_matmul2([5,2],[5,8], True, False, 'float64') + test_matmul2([2,5],[8,5], False, True, 'float64') + + test_matmul2([2,5],[5,8], False, False, 'int32') + test_matmul2([5,2],[5,8], True, False, 'int32') + test_matmul2([2,5],[8,5], False, True, 'int32') + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + def test_matmul_cuda(self): + with jt.flag_scope(use_cuda=1): + test_matmul([2,5],[5,8]) + test_matmul([200,500],[500,800]) + test_matmul([500,500],[500,50]) + test_matmul2([2,5],[5,8], False, False) + test_matmul2([5,2],[5,8], True, False) + test_matmul2([500,200],[500,800], True, False) + test_matmul2([500,500],[500,50], True, False) + test_matmul2([2,5],[8,5], False, True) + test_matmul2([200,500],[800,500], False, True) + test_matmul2([500,500],[50,500], False, True) + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + def test_backward_cuda(self): + with jt.flag_scope(use_cuda=1): + np.random.seed(0) + jt.set_seed(3) + model = Model() + SGD = jt.nn.SGD(model.parameters(), 0.05, 0.9, 0) + n = 1000 + batch_size = 50 + base_lr = 0.05 + # we need to stop grad of global value to prevent memory leak + lr = f32(base_lr).name("lr").stop_grad() + + def get_data(n): + for i in range(n): + x = np.random.rand(batch_size, 1) + y = x*x + yield jt.float32(x), jt.float32(y) + + for i,(x,y) in enumerate(get_data(n)): + pred_y = model(x).name("pred_y") + # cuda x**2.0 will return nan + loss = ((pred_y - y).sqr()).name("loss") + loss_mean = loss.mean() + + SGD.step(loss_mean) + + if i>2: + assert prev == jt.liveness_info(), f"memory leak {prev} {jt.liveness_info()}" + prev = jt.liveness_info() + if (i % 10 == 9): + print(f"step {i}, loss = {loss_mean.data.sum()} {jt.liveness_info()}") + else: + loss_mean.data.sum() + jt.liveness_info() + + # result is 0.00018236637697555125 + result = 0.00018236637697555125 + assert abs(loss_mean.data - result) < 1e-2 + jt.clean() + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + def test_backward_once_cuda(self): + with jt.flag_scope(use_cuda=1): + np.random.seed(0) + jt.set_seed(3) + model = Model2() + n = 1 + batch_size = 50 + + def get_data(n): + for i in range(n): + x = np.random.rand(batch_size, 1) + y = x*x + yield jt.float32(x), jt.float32(y) + + for i,(x,y) in enumerate(get_data(n)): + pred_y = model(x).name("pred_y") + with jt.log_capture_scope(log_v=0, log_vprefix="op.cc=100") as logs: + jt.sync_all() + logs = find_log_with_re(logs, + "Jit op key (not )?found: (cublas)_matmul.*") + assert(len(logs)==1) + with jt.log_capture_scope(log_silent=1, log_v=0, log_vprefix="op.cc=100,exe=1000") as logs_b: + gs = jt.grad(pred_y, x) + gs2 = jt.grad(pred_y, model.linear1.weight) + jt.sync_all() + logs_b = find_log_with_re(logs_b, + "Jit op key (not )?found: (cublas)_matmul.*") + assert len(logs_b)==2, len(logs_b) + jt.clean() + + def test_matmul_example(self): + a = jt.random([3]) + b = jt.random([3]) + c = jt.matmul(a, b) + assert c.shape == [1] + + a = jt.random([3, 4]) + b = jt.random([4]) + c = jt.matmul(a, b) + assert c.shape == [3] + + a = jt.random([10, 3, 4]) + b = jt.random([4]) + c = jt.matmul(a, b) + assert c.shape == [10, 3] + + a = jt.random([10, 3, 4]) + b = jt.random([4, 5]) + c = jt.matmul(a, b) + assert c.shape == [10, 3, 5] + + a = jt.random([10, 3, 4]) + b = jt.random([10, 4, 5]) + c = jt.matmul(a, b) + assert c.shape == [10, 3, 5] + + a = jt.random([8, 1, 3, 4]) + b = jt.random([10, 4, 5]) + c = jt.matmul(a, b) + assert c.shape == [8, 10, 3, 5] + + def test_matmul_example2(self): + def check(a_shape, b_shape): + a = jt.random(a_shape) + b = jt.random(b_shape) + c = jt.matmul(a, b) + cc = np.matmul(a.data, b.data) + assert c.shape == cc.shape or (cc.shape==() and c.shape==[1]), (c.shape, cc.shape) + np.testing.assert_allclose(c.data, cc, atol=1e-5) + da, db = jt.grad(c, [a, b]) + assert da.shape == a.shape + assert db.shape == b.shape + check([3], [3]) + check([3,4], [4]) + check([10,3,4], [4]) + check([10,3,4], [4,5]) + check([10,3,4], [10,4,5]) + check([8,1,3,4], [10,4,5]) + check([5,10,3,4], [5,10,4,5]) + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda=1) + def test_matmul_example2_cuda(self): + self.test_matmul_example2() + + def test_linear1d(self): + linear = jt.nn.Linear(10,20) + a = jt.random((10,)) + b = linear(a) + assert b.shape == (20,) + + # def test_tensorcore(self): + # import time + # jt.flags.use_cuda = 1 + # # jt.flags.use_tensorcore = 1 + # a = jt.rand(4096, 4096) + # b = jt.rand(4096, 4096) + # for i in range(100): + # c = jt.matmul(a, b) + # c.sync() + # jt.sync_all(True) + + # start = time.time() + # for i in range(1000): + # c = jt.matmul(a, b) + # c.sync() + # jt.sync_all(True) + # end = time.time() - start + # gflops = 4096**3*2 * 1000 / end / 10**9 + # print(end, gflops) + # # 14T vs 37T + + # def test_conv(self): + # import time + # jt.flags.use_cuda = 1 + # # jt.flags.use_tensorcore = 1 + # a = jt.rand(160, 1024, 16, 16) + # b = jt.rand(1024, 1024, 1, 1) + # for i in range(100): + # c = jt.nn.conv2d(a, b) + # c.sync() + # jt.sync_all(True) + + # start = time.time() + # for i in range(1000): + # c = jt.nn.conv2d(a, b) + # c.sync() + # jt.sync_all(True) + # end = time.time() - start + # gflops = a.numel() * b.numel() * 2 / 1024 * 1000 / end / 10**9 + # print(end, gflops) + # # 12T vs 30T + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_matmul_tuner.py b/python/jittor/test/test_matmul_tuner.py new file mode 100644 index 00000000..77c2a0dd --- /dev/null +++ b/python/jittor/test/test_matmul_tuner.py @@ -0,0 +1,44 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import sys +import os +import jittor as jt +import unittest +import time +import numpy as np +from .test_reorder_tuner import simple_parser +from .test_log import find_log_with_re + +class TestMatmulTuner(unittest.TestCase): + def test_matmul_tuner(self): + n,m,k = 10,10,10 + a = jt.random([n,m]) + b = jt.random([m,k]) + with jt.log_capture_scope( + log_v=0, log_vprefix="tuner_manager=100,var_relay=100", + compile_options={"test_matmul_tuner":1} + ) as rawlogs: + c = a.broadcast([n,m,k], [2]) * b.broadcast([n,m,k], [0]) + c = c.sum(1) + jc = c.numpy() + nc = np.matmul(a.numpy(), b.numpy()) + assert (np.abs(jc-nc)<1e-3).all() + logs = find_log_with_re(rawlogs, + "Run tuner matmul: confidence\\((.*)\\) candidates\\((.*)\\)$") + assert len(logs) == 1 + assert logs[0][0] == "20", "confidence of reorder should be 20" + candidates = simple_parser(logs[0][1]) + assert candidates == {"relay0":[1,0]}, candidates + logs = find_log_with_re(rawlogs, r"get_relay_src([\s\S]*)") + assert len(logs)==1 + assert "@relay_op" in logs[0] + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_mem.py b/python/jittor/test/test_mem.py new file mode 100644 index 00000000..0bcbb350 --- /dev/null +++ b/python/jittor/test/test_mem.py @@ -0,0 +1,44 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import os + +model_test = os.environ.get("model_test", "") == "1" +skip_model_test = not model_test + +class TestMem(unittest.TestCase): + def tearDown(self): + jt.clean() + jt.gc() + + @unittest.skipIf(not jt.has_cuda, "no cuda found") + @unittest.skipIf(skip_model_test, "skip_model_test") + @jt.flag_scope(use_cuda=1) + def test_oom(self): + backups = [] + jt.flags.use_cuda = 1 + + one_g = np.ones((1024*1024*1024//4,), "float32") + + meminfo = jt.get_mem_info() + n = int(meminfo.total_cuda_ram // (1024**3) * 0.6) + + for i in range(n): + a = jt.array(one_g) + b = a + 1 + b.sync() + backups.append((a,b)) + jt.sync_all(True) + backups = [] + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_memory_profiler.py b/python/jittor/test/test_memory_profiler.py new file mode 100644 index 00000000..616c1ef7 --- /dev/null +++ b/python/jittor/test/test_memory_profiler.py @@ -0,0 +1,106 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +from jittor import nn, Module +from jittor.models import resnet +import numpy as np +import sys, os +import random +import math +import unittest +from jittor.test.test_reorder_tuner import simple_parser +from jittor.test.test_log import find_log_with_re +from jittor.dataset.mnist import MNIST +import jittor.transform as trans +import time + +skip_this_test = False + +class MnistNet(Module): + def __init__(self): + self.model = resnet.Resnet18() + self.layer = nn.Linear(1000,10) + def execute(self, x): + x = self.model(x) + x = self.layer(x) + return x + +@unittest.skipIf(skip_this_test, "skip_this_test") +class TestMemoryProfiler(unittest.TestCase): + @classmethod + def setUpClass(self): + # hyper-parameters + self.batch_size = 100 + self.weight_decay = 0.0001 + self.momentum = 0.9 + self.learning_rate = 0.1 + # mnist dataset + self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \ + .set_attrs(batch_size=self.batch_size, shuffle=True) + self.train_loader.num_workers = 4 + + # setup random seed + def setup_seed(self, seed): + np.random.seed(seed) + random.seed(seed) + jt.seed(seed) + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1, use_stat_allocator=1, trace_py_var=3, profile_memory_enable=1) + def test_resnet(self): + self.setup_seed(1) + loss_list=[] + acc_list=[] + mnist_net = MnistNet() + global prev + prev = time.time() + SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay) + + iters = 10 + for batch_idx, (data, target) in enumerate(self.train_loader): + if (batch_idx > iters): + break + jt.display_memory_info() + output = mnist_net(data) + loss = nn.cross_entropy_loss(output, target) + SGD.step(loss) + def callback(batch_idx, loss, output, target): + global prev + pred = np.argmax(output, axis=1) + acc = np.mean(target==pred) + loss_list.append(loss[0]) + acc_list.append(acc) + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}' + .format(0, batch_idx, iters,1. * batch_idx / 6.0, loss[0], acc, time.time()-prev)) + jt.fetch(batch_idx, loss, output, target, callback) + jt.sync_all(True) + jt.display_max_memory_info() + _, out = jt.get_max_memory_treemap() + out_ = out.split('\n') + assert(out_[0] == 'root()') + assert(out_[3].endswith('(_run_module_as_main)')) + assert(out_[7].endswith('(_run_code)')) + _, out = jt.get_max_memory_treemap(build_by=1) + out_ = out.split('\n') + assert(out_[0] == 'root()') + assert(out_[4].endswith('(_run_module_as_main)')) + assert(out_[8].endswith('(_run_code)')) + + def test_sample(self): + net = jt.models.resnet18() + with jt.flag_scope(trace_py_var=3, profile_memory_enable=1): + imgs = jt.randn((1,3,224,224)) + net(imgs).sync() + jt.get_max_memory_treemap() + + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_merge_loop_var_pass.py b/python/jittor/test/test_merge_loop_var_pass.py new file mode 100644 index 00000000..3a89be80 --- /dev/null +++ b/python/jittor/test/test_merge_loop_var_pass.py @@ -0,0 +1,74 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as numpy + +class TestMergeLoopVarPass(unittest.TestCase): + def test(self): + a = jt.ones([10,10,10,10]) + a.sync() + with jt.profile_scope() as rep: + b = a.sum([2,3]) + b.sync() + with open(rep[1][1]) as f: + src = f.read() + assert "range01" in src + assert "range23" in src + + def test2(self): + a = jt.ones([10,10,10,10]) + a.sync() + with jt.profile_scope() as rep: + b = a + 1 + b.sync() + with open(rep[1][1]) as f: + src = f.read() + assert "range0123" in src + + def test3(self): + a = jt.ones([10,10,10,10]) + x = jt.ones([1,10,1,1]) + a.sync(), x.sync() + with jt.profile_scope() as rep: + b = a + x + b.sync() + with open(rep[1][1]) as f: + src = f.read() + assert "range23" in src + + def test4(self): + # don't optimize reindex like op yet + a = jt.ones([10,10,10,10]) + a.sync() + with jt.profile_scope() as rep: + b = a.reindex_reduce("add", [10,10], ["i0","i1"]) + b.sync() + with open(rep[1][1]) as f: + src = f.read() + assert "range23" not in src + + def test5(self): + a = jt.ones([10,10,10,10]) + a.sync() + with jt.profile_scope() as rep: + b = a.sum([1]) + b.sync() + with open(rep[1][1]) as f: + src = f.read() + assert "range01" not in src + assert "range23" in src + +@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") +class TestMergeLoopVarPassCuda(TestMergeLoopVarPass): + def setUp(self): + jt.flags.use_cuda = 1 + def tearDown(self): + jt.flags.use_cuda = 0 + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_merge_single_array_op.py b/python/jittor/test/test_merge_single_array_op.py new file mode 100644 index 00000000..f26cc5ba --- /dev/null +++ b/python/jittor/test/test_merge_single_array_op.py @@ -0,0 +1,148 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import random +from .test_log import find_log_with_re +from .test_core import expect_error + +def plus(a, b): + return a + b + +def subtraction(a, b): + return a - b + +def multiplication(a, b): + return a * b + +def division(a, b): + return a / b + +def get_random_op(): + v = random.randint(0, 3) + if (v == 0): + return plus + elif (v == 1): + return subtraction + elif (v == 2): + return multiplication + else: + return division + +def test(shape, op1, op2): + n = 753.1 + a = jt.random(shape) + b = jt.random(shape) + c = op1(a, n) + d = op2(c, b) + with jt.log_capture_scope(log_v=0, log_vprefix="fused_op.cc=100") as logs: + d__ = d.data + logs = find_log_with_re(logs, + "Jit (fused )?op key (not )?found: «opkey0:array«T:float32") + assert(len(logs)==1), logs + + a_ = a.data + b_ = b.data + d_ = op2(op1(a_, n), b_) + assert(np.allclose(d_, d__, atol=1e-4)) + +def gen_data(shape): + num = np.multiply.reduce(shape) + a = np.arange(0, num) + return a.reshape(shape) + +class TestSingleArray(unittest.TestCase): + def test7(self): + a = jt.random([100]) + x = a.reindex_var((a>0.1).where()) + x.data + + def test6(self): + jt.clean() + def check(hv, lv, lo): + self.assertEqual(jt.number_of_hold_vars(), hv) + self.assertEqual(jt.number_of_lived_vars(), lv) + self.assertEqual(jt.number_of_lived_ops(), lo) + check(0,0,0) + a = jt.array(1.0).name('a').stop_fuse() + b = (a+jt.array(1.0).name('t1').stop_fuse()).name('b') + c = (b+jt.array(1.0).name('t2').stop_fuse()).name('c') + check(3,5,5) + graph = jt.dump_all_graphs() + self.assertEqual(c.data, 3) + check(3,5,2) + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + def test5(self): + with jt.flag_scope(use_cuda=1): + f32 = jt.float32 + np.random.seed(0) + jt.set_seed(3) + + x = f32(np.random.rand(1, 1)) + w = (jt.random([x.shape[-1], 10])-f32(0.5)) / f32(x.shape[-1])**f32(0.5) + jt.nn.matmul(x, w).data + + def test4(self): + jt.array(1).data + + def test_concat(self): + def check(shape, dim, n): + num = np.prod(shape) + arr1 = [] + arr2 = [] + for i in range(n): + a = (np.array(range(num)) + i*num).reshape(shape) + arr1.append(a) + arr2.append(jt.array(a)) + x = np.concatenate(tuple(arr1), dim) + y = jt.concat(arr2, dim) + assert (x==y.data).all() + check([1], 0, 20) + + def test3(self): + def check(shape1, shape2): + a = gen_data(shape1) + b = gen_data(shape2) + aa,bb = np.broadcast_arrays(a, b) + ja = jt.ops.broadcast_var(a, b).data + assert ja.shape == aa.shape and (ja==aa).all(), f"{ja}, {aa}" + check([1], [3]) + + def test2(self): + a = jt.random([5]) + a = a * 2000 - 1000 + a.data + + def test_main(self): + test_n = 10 + test([50, 50, 50, 50], multiplication, subtraction) + for i in range(test_n): + n = random.randint(1,4) + shape = [] + for j in range(n): + shape.append(random.randint(1,50)) + test(shape, get_random_op(), get_random_op()) + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + def test_main_cuda(self): + with jt.flag_scope(use_cuda=1): + test_n = 10 + test([50, 50, 50, 50], multiplication, subtraction) + for i in range(test_n): + n = random.randint(1,4) + shape = [] + for j in range(n): + shape.append(random.randint(1,50)) + test(shape, get_random_op(), get_random_op()) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_misc_issue.py b/python/jittor/test/test_misc_issue.py new file mode 100644 index 00000000..848df544 --- /dev/null +++ b/python/jittor/test/test_misc_issue.py @@ -0,0 +1,217 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import os +import numpy as np +import sys + +class TestMiscIssue(unittest.TestCase): + def test_issue4(self): + try: + jt.dirty_fix_pytorch_runtime_error() + import torch + except: + return + # import with pytorch cause segfault + src = """N = 100 +import jittor as jt +a = jt.random([N, N]) +b = a.broadcast([N,N,N], dims=[0]) * a.broadcast([N,N,N], dims=[2]) +b = b.sum(1) +b.sync() + +import torch +A = torch.rand(N, N) +torch.matmul(A, A) +""" + assert os.system(f"{sys.executable} -c '{src}'")==0 + src = """N = 100 +import torch +A = torch.rand(N, N) +torch.matmul(A, A) + +import jittor as jt +a = jt.random([N, N]) +b = a.broadcast([N,N,N], dims=[0]) * a.broadcast([N,N,N], dims=[2]) +b = b.sum(1) +b.sync() +""" + assert os.system(f"{sys.executable} -c '{src}'")==0 + + def test_mkl_conflict1(self): + try: + jt.dirty_fix_pytorch_runtime_error() + import torch + except: + return + if jt.mkl_ops is None: + return + # import with pytorch cause segfault + src = """ +nchw = [2, 3, 100, 100] +oihw = [4, 3, 5, 5] +import jittor as jt +x = jt.random(nchw) +w = jt.random(oihw) +jt.mkl_ops.mkl_conv(x, w, 1, 1, 2, 2).sync() + +jt.dirty_fix_pytorch_runtime_error() + +import torch +m = torch.nn.Conv2d(3, 4, 5, 1, 2) +m(torch.rand(*nchw)) + +""" + assert os.system(f"{sys.executable} -c '{src}'")==0 + + def test_mkl_conflict2(self): + try: + jt.dirty_fix_pytorch_runtime_error() + import torch + except: + return + if jt.mkl_ops is None: + return + # import with pytorch cause segfault + src = """ +nchw = [2, 3, 100, 100] +oihw = [4, 3, 5, 5] + +import torch +m = torch.nn.Conv2d(3, 4, 5, 1, 2) +m(torch.rand(*nchw)) + +import jittor as jt +x = jt.random(nchw) +w = jt.random(oihw) +jt.mkl_ops.mkl_conv(x, w, 1, 1, 2, 2).sync() + + +""" + assert os.system(f"{sys.executable} -c '{src}'")==0 + + def test_cuda_lowsm(self): + if not jt.has_cuda: return + src = """ +import jittor +from jittor.nn import matmul_transpose + +a = jittor.ones((3,4,2), dtype="float32") +b = jittor.ones((5, 2), dtype="float32") +print(matmul_transpose(a, b)) + +jittor.flags.use_cuda = 1 +a = jittor.ones((3,4,2), dtype="float32") +b = jittor.ones((5, 2), dtype="float32") +print(matmul_transpose(a, b)) +""" + assert os.system(f"cuda_archs=52 {sys.executable} -c '{src}'")==0 + + def test_parallel(self): + a = jt.code([4], "int", cpu_src=""" + #pragma omp parallel num_threads(4) + @out(omp_get_thread_num()) = 456; + """, cpu_header='#include ').data + assert (a==[456]*4).all(), a + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda=1) + def test_reduce_opt(self): + a = jt.random((16,512,38,38)) + b = jt.random((16,512,38,38)) + jt.sync([a, b]) + with jt.profile_scope(rerun=10, warmup=10) as rep: + norm = a.sqr().sum(1, keepdims=True).sqrt() + c = a / norm + da = jt.grad(c*b, a) + jt.sync([c, da]) + gpu_c = c.numpy() + gpu_da = da.numpy() + with jt.flag_scope(use_cuda=0): + norm = a.sqr().sum(1, keepdims=True).sqrt() + c = a / norm + da = jt.grad(c*b, a) + assert np.allclose(gpu_c, c.data, 1e-3) + assert (np.abs(gpu_da-da.data).max() < 1e-6) + + assert float(rep[1][3]) < 15e6, float(rep[1][3]) # 15ms(about 8ms) + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda=1) + def test_cuda_min_max(self): + a = jt.random((10,)) - 2 + assert a.min().data == a.data.min(), (a.min(), a.data.min()) + assert a.max().data == a.data.max(), (a.max(), a.data.max()) + a = jt.random((10,)) + 2 + assert a.min().data == a.data.min(), (a.min(), a.data.min()) + assert a.max().data == a.data.max(), (a.max(), a.data.max()) + + a = jt.random((10,)).float64() - 2 + assert a.min().data == a.data.min(), (a.min(), a.data.min()) + assert a.max().data == a.data.max(), (a.max(), a.data.max()) + a = jt.random((10,)).float64() + 2 + assert a.min().data == a.data.min(), (a.min(), a.data.min()) + assert a.max().data == a.data.max(), (a.max(), a.data.max()) + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda=1) + def test_cuda_pow_grad_nan(self): + a = jt.float32([1,-1, -1000.1]) + da = jt.grad(a**2, a) + assert np.isnan(da.data).sum()==0, da.data + + def test_tanh_nan(self): + m=jt.nn.Tanh() + a = m(jt.array([1000])) + assert np.isnan(a.data).sum()==0, a + + def test_sigmoid_nan(self): + a = jt.float32([1,-1, -1000.1]) + da = jt.grad(a.sigmoid(), a) + assert np.isnan(da.data).sum()==0, da.data + + def test_sequential(self): + x = jt.nn.Sequential(lambda x:x, lambda x:x) + n = 0 + for a in x: + n += 1 + assert n == 2 + assert list(x.keys()) == [0,1] + p = x.parameters() + assert len(p)==0 + + def test_self_update(self): + from jittor.models import resnet18 + m = resnet18() + x = m.state_dict() + m.load_state_dict(x) + + def test_res2net(self): + import jittor.models + net = jittor.models.res2net50(True) + img = jt.random((2,3,224,224)) + out = net(img) + print(out.shape, out.sum()) + jt.display_memory_info() + jt.display_memory_info() + assert out.shape == [2,1000] + + def test_argmax_memleak(self): + a = jt.random([10]) + _, m = jt.argmax(a, 0) + del _ + m.sync() + g = jt.grad(m*10, a) + g.sync() + del a, g, m + jt.display_memory_info() + assert jt.liveness_info()["lived_ops"] == 0 + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_misc_op.py b/python/jittor/test/test_misc_op.py new file mode 100644 index 00000000..5b95f49c --- /dev/null +++ b/python/jittor/test/test_misc_op.py @@ -0,0 +1,406 @@ + +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Wenyang Zhou <576825820@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import jittor.nn as jnn + +skip_this_test = False + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + import torch.nn as tnn + import torchvision +except: + torch = None + tnn = None + torchvision = None + skip_this_test = True + +def check_equal(res1, res2, eps=1e-5): + assert np.allclose(res1.detach().numpy(), res2.numpy(), eps) + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestPad(unittest.TestCase): + def test_index_add_(self): + x = np.ones((5,3)) + a1 = torch.Tensor(x) + a1.index_add_(0, torch.tensor([0,4,2]), torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float)) + a2 = jt.array(x) + a2.index_add_(0, jt.array([0,4,2]), jt.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) + check_equal(a1, a2) + + x = np.ones((3,5)) + a1 = torch.Tensor(x) + a1.index_add_(1, torch.tensor([0,4,2]), torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float)) + a2 = jt.array(x) + a2.index_add_(1, jt.array([0,4,2]), jt.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])) + check_equal(a1, a2) + print('pass index_add_ test ...') + + def test_repeat(self): + arr = np.random.randn(16,3,224,224) + check_equal(torch.Tensor(arr).repeat(1,2,3,4), jt.array(arr).repeat(1,2,3,4)) + check_equal(torch.Tensor(arr).repeat(4,2,3,4), jt.array(arr).repeat(4,2,3,4)) + print('pass repeat test ...') + + def test_chunk(self): + arr = np.random.randn(16,3,224,224) + check_equal(torch.Tensor(arr).chunk(2,0)[0], jt.array(arr).chunk(2,0)[0]) + check_equal(torch.Tensor(arr).chunk(2,0)[1], jt.array(arr).chunk(2,0)[1]) + print('pass chunk test ...') + + def test_stack(self): + arr1 = np.random.randn(16,3,224,224) + arr2 = np.random.randn(16,3,224,224) + check_equal(torch.stack([torch.Tensor(arr1), torch.Tensor(arr2)], 0), jt.stack([jt.array(arr1), jt.array(arr2)], 0)) + print('pass stack test ...') + + def test_flip(self): + arr = np.random.randn(16,3,224,224) + check_equal(torch.Tensor(arr).flip(0), jt.array(arr).flip(0)) + check_equal(torch.Tensor(arr).flip(1), jt.array(arr).flip(1)) + check_equal(torch.Tensor(arr).flip(2), jt.array(arr).flip(2)) + check_equal(torch.Tensor(arr).flip(3), jt.array(arr).flip(3)) + check_equal(torch.Tensor(arr).flip([2,3]), jt.array(arr).flip([2,3])) + print('pass flip test ...') + + def test_cross(self): + def check_equal(a, b, tol): + np.testing.assert_allclose(a.detach().numpy(), b.numpy(), atol=1e-5) + arr1 = np.random.randn(16,3,224,224,3) + arr2 = np.random.randn(16,3,224,224,3) + check_equal(torch.Tensor(arr1).cross(torch.Tensor(arr2), dim=1), jt.array(arr1).cross(jt.array(arr2), dim=1), 1e-1) + check_equal(torch.Tensor(arr1).cross(torch.Tensor(arr2), dim=-4), jt.array(arr1).cross(jt.array(arr2), dim=-4), 1e-1) + check_equal(torch.Tensor(arr1).cross(torch.Tensor(arr2), dim=-1), jt.array(arr1).cross(jt.array(arr2), dim=-1), 1e-1) + check_equal(torch.Tensor(arr1).cross(torch.Tensor(arr2), dim=4), jt.array(arr1).cross(jt.array(arr2), dim=4), 1e-1) + print('pass cross test ...') + + def test_normalize(self): + arr = np.random.randn(16,3,224,224,3) + check_equal(tnn.functional.normalize(torch.Tensor(arr)), jt.normalize(jt.array(arr))) + check_equal(tnn.functional.normalize(torch.Tensor(arr), dim=0), jt.normalize(jt.array(arr), dim=0), 1e-1) + check_equal(tnn.functional.normalize(torch.Tensor(arr), dim=1), jt.normalize(jt.array(arr), dim=1), 1e-1) + check_equal(tnn.functional.normalize(torch.Tensor(arr), dim=-1), jt.normalize(jt.array(arr), dim=-1), 1e-1) + check_equal(tnn.functional.normalize(torch.Tensor(arr), dim=2), jt.normalize(jt.array(arr), dim=2), 1e-1) + check_equal(tnn.functional.normalize(torch.Tensor(arr), dim=3), jt.normalize(jt.array(arr), dim=3), 1e-1) + print('pass normalize test ...') + + def test_make_grid(self): + arr = np.random.randn(16,3,10,10) + check_equal(torchvision.utils.make_grid(torch.Tensor(arr)), jt.make_grid(jt.array(arr))) + check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=2), jt.make_grid(jt.array(arr), nrow=2)) + check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3), jt.make_grid(jt.array(arr), nrow=3)) + check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, padding=4), jt.make_grid(jt.array(arr), nrow=3, padding=4)) + check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, padding=4, pad_value=-1), jt.make_grid(jt.array(arr), nrow=3, padding=4, pad_value=-1)) + check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, normalize=True, padding=4, pad_value=-1), jt.make_grid(jt.array(arr), nrow=3, normalize=True, padding=4, pad_value=-1)) + check_equal(torchvision.utils.make_grid(torch.Tensor(arr), nrow=3, normalize=True, padding=4, pad_value=-1, range=(-100,100)), jt.make_grid(jt.array(arr), nrow=3, normalize=True, padding=4, pad_value=-1, range=(-100,100))) + print('pass make_grid test ...') + + def test_make_grid2(self): + def check(shape): + arr = np.random.randn(*shape) + check_equal(torchvision.utils.make_grid(torch.Tensor(arr)), jt.make_grid(jt.array(arr))) + check((3,100,200)) + check((1,100,200)) + check((100,200)) + check((1,3,100,200)) + check((4,3,100,200)) + check((10,3,100,200)) + + def test_make_grid3(self): + arr=np.random.randn(3,10,10) + check_equal(torchvision.utils.make_grid(torch.Tensor(arr)), jt.make_grid(jt.array(arr))) + check_equal(torchvision.utils.make_grid(torch.Tensor(arr), normalize=True), jt.make_grid(jt.array(arr), normalize=True)) + + def test_save_image(self): + arr = jt.array(np.random.randn(16,3,10,10)) + jt.save_image(arr, jt.flags.cache_path+"/tmp/a.jpg") + + def test_unbind(self): + arr = np.random.randn(2,3,4) + for dim in range(len(arr.shape)): + t_res = torch.unbind(torch.Tensor(arr), dim=dim) + j_res = jt.unbind(jt.array(arr), dim=dim) + for idx in range(len(t_res)): + assert np.allclose(t_res[idx].numpy(), j_res[idx].numpy()) + print('pass unbind test ...') + + def test_expand(self): + a = jt.zeros((3,1)) + b = a.expand(3, 4) + assert b.shape == (3,4) + b = a.expand(-1, 4) + assert b.shape == (3,4) + b = a.expand((3, 4)) + assert b.shape == (3,4) + b = a.expand((-1, 4)) + assert b.shape == (3,4) + + def test_bilinear(self): + from jittor import nn + m = nn.Bilinear(20, 30, 40) + input1 = jt.randn(128, 20) + input2 = jt.randn(128, 30) + output = m(input1, input2) + assert output.shape == [128,40] + + m2 = torch.nn.Bilinear(20, 30, 40) + m2.weight = torch.nn.Parameter(torch.Tensor(m.weight.data)) + m2.bias = torch.nn.Parameter(torch.Tensor(m.bias.data)) + in1 = torch.Tensor(input1.data) + in2 = torch.Tensor(input2.data) + out = m2(in1, in2) + np.testing.assert_allclose( + out.detach().numpy(), output.data, + atol=1e-4) + + def test_ctc_loss(self): + def check(T,C,N,S,S_min): + jt.set_global_seed(0) + + # Initialize random batch of input vectors, for *size = (T,N,C) + input = jt.randn(T, N, C).log_softmax(2) + # input = -jt.ones((T, N, C)) + # input[0,0,1] += 0.01 + + # Initialize random batch of targets (0 = blank, 1:C = classes) + target = jt.randint(low=1, high=C, shape=(N, S), dtype=jt.int) + _input_jt = input + + input_lengths = jt.full((N,), T, dtype=jt.int) + target_lengths = jt.randint(low=S_min, high=S+1, shape=(N,), dtype=jt.int) + # ctc_loss = nn.CTCLoss() + loss = jt.ctc_loss(input, target, input_lengths, target_lengths, reduction='none') + _loss_jt = loss + + loss_jt = loss.numpy() + + input = torch.Tensor(input.numpy()).detach().requires_grad_() + input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) + target_lengths = torch.LongTensor(target_lengths.numpy()) + input_lengths = torch.LongTensor(input_lengths.numpy()) + target = torch.LongTensor(target.numpy()) + loss = tnn.CTCLoss(reduction='none')(input, target, input_lengths, target_lengths) + np.testing.assert_allclose(loss.detach().numpy(), loss_jt, rtol=1e-5, atol=1e-5) + + dinput_jt = jt.grad(_loss_jt, _input_jt) + dinput_jt.sync() + + loss.sum().backward() + # print(input.grad) + # print(dinput_jt) + # print(loss) + + def check_gpu_with_cpu(T,C,N,S,S_min): + jt.set_global_seed(1) + + # Initialize random batch of input vectors, for *size = (T,N,C) + input = jt.randn(T, N, C).log_softmax(2) + # input = -jt.ones((T, N, C)) + # input[0,0,1] += 0.01 + + # Initialize random batch of targets (0 = blank, 1:C = classes) + target = jt.randint(low=1, high=C, shape=(N, S), dtype=jt.int) + _input_jt = input + + input_lengths = jt.full((N,), T, dtype=jt.int) + target_lengths = jt.randint(low=S_min, high=S+1, shape=(N,), dtype=jt.int) + # ctc_loss = nn.CTCLoss() + loss = jt.ctc_loss(input, target, input_lengths, target_lengths, reduction='none') + _loss_jt = loss + + loss_jt = loss.numpy() + + dinput_jt = jt.grad(_loss_jt, _input_jt) + dinput_jt.sync() + + with jt.flag_scope(use_cuda=1): + input = input.copy() + target = target.copy() + input_lengths = input_lengths.copy() + target_lengths = target_lengths.copy() + loss = jt.ctc_loss(input, target, input_lengths, target_lengths, reduction='none') + grad = jt.grad(loss, input) + np.testing.assert_allclose(_loss_jt.numpy(), loss.numpy(), atol=1e-5, rtol=1e-5) + np.testing.assert_allclose(dinput_jt.numpy(), grad.numpy(), atol=1e-5, rtol=1e-5) + + + check(2,2,1,1,1) + check(50,20,16,30,10) + + if jt.has_cuda: + with jt.flag_scope(use_cuda=1): + check(2,2,1,1,1) + check(50,20,16,30,10) + check_gpu_with_cpu(50,20,16,30,10) + +class TestOther(unittest.TestCase): + def test_save(self): + pp = [1,2,jt.array([1,2,3]), {"a":[1,2,3], "b":jt.array([1,2,3])}] + name = jt.flags.cache_path+"/xx.pkl" + jt.save(pp, name) + x = jt.load(name) + assert x[:2] == [1,2] + assert (x[2] == np.array([1,2,3])).all() + assert x[3]['a'] == [1,2,3] + assert (x[3]['b'] == np.array([1,2,3])).all() + + def test_arctan2(self): + x = jt.float32([1,1,-1,-1, 1,-1,0,0,0]) + y = jt.float32([-1,1,-1,1, 0,0,1,-1,0]) + z = jt.arctan2(y, x) + z2 = np.arctan2(y.data, x.data) + np.testing.assert_allclose(z.data, z2, atol=1e-6) + + y = jt.random((100,)) * 2 - 1 + x = jt.random((100,)) * 2 - 1 + z = jt.arctan2(y, x) + z2 = np.arctan2(y.data, x.data) + np.testing.assert_allclose(z.data, z2, atol=1e-6) + + np.testing.assert_allclose(jt.array([1]).arctan().item(), 0.7853982) + + def test_softmax_precision(self): + # jt.flags.use_cuda = 1 + a = -jt.array([1.0,2.0,1e5]) + b = a.log_softmax(0) + assert b.isfinite().all().item() + print("test_softmax_precision cpu ok") + if not jt.has_cuda: return + jt.flags.use_cuda = 1 + a = -jt.array([1.0,2.0,1e5]) + b = a.log_softmax(0) + assert b.isfinite().all().item() + print("test_softmax_precision gpu ok") + + def test_code_softmax(self): + if not jt.has_cuda: return + + def softmax(x, dim = None, log=False): + if dim is None: + x = (x - x.max()).exp() + ret = x / x.sum() + else: + x = (x-x.max(dim, keepdims=True)).exp() + ret = x / x.sum(dim, keepdims=True) + if log: return ret.log() + return ret + from jittor.other.code_softmax import softmax_v1 + + with jt.flag_scope(use_cuda = 1): + shape = (120, 2000, 2000) + shape = (3,3) + for log in [0,1]: + for shape in [(3,3), + (12, 200, 2000), + (12, 200, 2048), + (12, 200, 2049)]: + print(shape) + a = jt.rand(shape) + c = jt.rand(shape) + b = softmax(a, -1, log=log) + bb = softmax_v1(a, log=log) + + err = (bb - b).abs().max() + assert err.item() < 1e-5, (err, bb, b) + + d1 = jt.grad(b*c, a) + d2 = jt.grad(bb*c, a) + err = (d1 - d2).abs().max() + + if log: + assert err.item() < 1e-2, (err.item()) + else: + assert err.item() < 1e-5, (err.item()) + + def test_nan(self): + a = np.array([1.0,0.0,1.0,-1.0], "float32") / np.array([1.0,0.0,0.0,0.0], "float32") + np.testing.assert_allclose(jt.isnan(jt.array(a)).data, [0,1,0,0]) + np.testing.assert_allclose(jt.isfinite(jt.array(a)).data, [1,0,0,0]) + np.testing.assert_allclose(jt.isinf(jt.array(a)).data, [0,0,1,1]) + np.testing.assert_allclose(jt.isneginf(jt.array(a)).data, [0,0,0,1]) + np.testing.assert_allclose(jt.isposinf(jt.array(a)).data, [0,0,1,0]) + + def test_nan_cuda(self): + if not jt.has_cuda: return + with jt.flag_scope(use_cuda=1): + self.test_nan() + + def test_dropout2d(self): + m = jt.nn.Dropout2d(p=0.2) + m.train() + input = jt.randn(1, 10, 4, 3) + output = m(input) + output.sync() + + def test_tri(self): + a = jt.ones(3, 3) + b = jt.triu(a) + assert jt.all_equal(b, [[1,1,1],[0,1,1],[0,0,1]]) + + b = jt.triu(a, diagonal=1) + assert jt.all_equal(b, [[0,1,1],[0,0,1],[0,0,0]]) + + b = jt.triu(a, diagonal=-1) + assert jt.all_equal(b, [[1,1,1],[1,1,1],[0,1,1]]) + + a = jt.ones(3, 3) + b = jt.tril(a) + assert jt.all_equal(b, [[1,0,0],[1,1,0],[1,1,1]]) + + b = jt.tril(a, diagonal=1) + assert jt.all_equal(b, [[1,1,0],[1,1,1],[1,1,1]]) + + b = jt.tril(a, diagonal=-1) + assert jt.all_equal(b, [[0,0,0],[1,0,0],[1,1,0]]) + + def test_ones(self): + a = jt.ones(10, "int32") + a.sync() + assert a.shape == (10,) + assert a.dtype == "int32" + a = jt.ones((10,), "int32") + a.sync() + assert a.shape == (10,) + assert a.dtype == "int32" + + a = jt.ones(10,10) + assert a.shape == (10,10) + + a = jt.ones_like(jt.ones([10], "int16")) + assert a.dtype == "int16" + + a = jt.ones_like(jt.ones([10], "bool")) + assert a.dtype == "bool" + + def test_index_select(self): + x = jt.randn(3, 4) + indices = torch.tensor([2, 1]) + y = jt.index_select(x, 0, indices) + assert jt.all_equal(y, x[indices]) + y = jt.index_select(x, 1, indices) + assert jt.all_equal(y, x[:, indices]) + + def test_multinorm(self): + weights = jt.float32([0, 10, 3, 0]) + x = jt.multinomial(weights, 2) + assert jt.all_equal(x, [1, 2]) or jt.all_equal(x, [2, 1]) + x = jt.multinomial(weights, 4, replacement=True) + assert x.shape == (4, ) + + weights = jt.float32([[0,0,2],[0,1,0], [0.5,0,0]]) + x = jt.multinomial(weights, 1) + assert jt.all_equal(x, [[2],[1],[0]]) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_mkl_conv_op.py b/python/jittor/test/test_mkl_conv_op.py new file mode 100644 index 00000000..a6805248 --- /dev/null +++ b/python/jittor/test/test_mkl_conv_op.py @@ -0,0 +1,206 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import timeit +import os +from .test_reorder_tuner import simple_parser +from .test_log import find_log_with_re + +def conv(x, w, padding, stride = 1): + out_planes, in_planes, kernel_size, _ = w.shape + Kw = kernel_size + Kh = kernel_size + _C = in_planes + Kc = out_planes + N,C,H,W = x.shape + assert C==_C + xx = x.reindex([N,Kc,C,(H+padding*2-kernel_size)//stride+1,(W+padding*2-kernel_size)//stride+1,Kh,Kw], [ + 'i0', # Nid + 'i2', # Cid + f'i3*{stride}-{padding}+i5', # Hid+Khid + f'i4*{stride}-{padding}+i6', # Wid+KWid + ]) + ww = w.broadcast(xx.shape, [0,3,4]) + yy = xx*ww + y = yy.sum([2,5,6]) # Kc, Kh, Kw + return y + + +def conv_nhwc_hwio(x, w, stride=1, padding=0): + assert type(stride)==int and type(padding)==int + N,H,W,C = x.shape + Kh,Kw,C2,c = w.shape + oh, ow = (H-Kh+padding*2)//stride+1, (W-Kw+padding*2)//stride+1 + assert C2==C or C2==1 + x = x.reindex([N,oh,ow,Kh,Kw,C2,c], [ + 'i0', # Nid = Nid + f'i1*{stride}+i3-{padding}', # Hid = ohid*stride+Khid + f'i2*{stride}+i4-{padding}', # Wid = owid*stride+Kwid + 'i6' if C2==1 else 'i5', # depthwise or normal + ]) + y = (x*w).sum([3,4,5]) # Kh, Kw, C + return y + +@unittest.skipIf(not jt.compile_extern.use_mkl, "Not use mkl, Skip") +class TestMklConvOp(unittest.TestCase): + + def test_forward(self): + a = np.random.rand(1,3,224,224).astype(np.float32) + b = np.random.rand(64,3,7,7).astype(np.float32) + c = jt.mkl_ops.mkl_conv(a,b,2,2,3,3).data + + a_jt = jt.array(a) + b_jt = jt.array(b) + with jt.flag_scope(enable_tuner=0,compile_options={"test_mkl_conv":1}): + c_jt = conv(a_jt, b_jt, 3, 2).data + with jt.log_capture_scope( + enable_tuner=1, + compile_options={"test_mkl_conv":2}, + log_v=0, log_vprefix="tuner_manager=100,conv_tuner=1000", + ) as raw_logs: + c_jt_tune = conv(a_jt, b_jt, 3, 2).data + + assert np.max(c_jt-c)<1e-4 and np.max(c_jt_tune-c)<1e-4 + logs = find_log_with_re(raw_logs, + "Run tuner conv: confidence\\((.*)\\) candidates\\((.*)\\)$") + assert len(logs)==1 + assert logs[0][0] == '20' + assert simple_parser(logs[0][1]) == {'relay0':[1,0]} + + def test_forward_nhwc_hwio(self): + uid = [123] + def check(xshape, wshape, stride, pad): + a = np.random.rand(*xshape).astype(np.float32) + b = np.random.rand(*wshape).astype(np.float32) + c = jt.mkl_ops.mkl_conv(a,b,stride,stride,pad,pad,1,1,xformat="acdb",wformat="hwio").data + + a_jt = jt.array(a) + b_jt = jt.array(b) + with jt.flag_scope(enable_tuner=0, + compile_options={"test_mkl_conv":uid[0]}): + c_jt = conv_nhwc_hwio(a_jt, b_jt, stride, pad).data + with jt.log_capture_scope( + enable_tuner=1, + compile_options={"test_mkl_conv":uid[0]+1}, + log_v=0, log_vprefix="tuner_manager=100,conv_tuner=1000", + ) as raw_logs: + c_jt_tune = conv_nhwc_hwio(a_jt, b_jt, stride, pad).data + uid[0] += 2 + + assert np.max(c_jt-c)<1e-4 and np.max(c_jt_tune-c)<1e-4 + logs = find_log_with_re(raw_logs, + "Run tuner conv: confidence\\((.*)\\) candidates\\((.*)\\)$") + assert len(logs)==1, raw_logs + assert logs[0][0] == '20' + assert simple_parser(logs[0][1]) == {'relay0':[1,0]} + + check([1,100,100,3], [1,1,3,64], 1, 0) + check([1,100,100,3], [3,3,3,16], 1, 0) + check([1,100,100,3], [3,3,3,16], 2, 1) + # TODO: check([1,100,100,1], [3,3,1,1], 2, 1) + + def test_backward(self): + n,c,H,W = 2,3,5,5 + o,i,h,w = 4,c,3,3 + a = np.random.rand(n,c,H,W).astype(np.float32) + b = np.random.rand(o,i,h,w).astype(np.float32) + da = np.random.rand(n,o,H,W).astype(np.float32) + dx = jt.mkl_ops.mkl_conv_backward_x(b,da,H,W,1,1,1,1,1,1).data + dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,w,1,1,1,1,1,1).data + a_jt = jt.array(a) + b_jt = jt.array(b) + + with jt.flag_scope( + enable_tuner=0, + # compile_options={"test_mkl_conv":1} + ): + c_jt = conv(a_jt, b_jt, 1, 1) * da + gs=jt.grad(c_jt,[a_jt,b_jt]) + gs.append(c_jt) + jt.fetch_sync(gs) + dx_jt=gs[0].data + dw_jt=gs[1].data + with jt.log_capture_scope( + log_v=10, + log_vprefix="tuner_manager=100,var_relay=100", + enable_tuner=1, + compile_options={"test_mkl_conv":2} + ) as rawlogs: + gs_tune=jt.grad(c_jt,[a_jt,b_jt]) + jt.fetch_sync(gs_tune) + dx_jt_tune=gs_tune[0].data + dw_jt_tune=gs_tune[1].data + logs = find_log_with_re(rawlogs, + "Run tuner conv: confidence\\((20)\\) candidates\\((.*)\\)$") + assert len(logs) == 2, len(logs) + assert logs[0][0] == "20", "confidence of reorder should be 20" + candidates = simple_parser(logs[0][1]) + assert candidates == {"relay0":[1,0]}, candidates + + logs = find_log_with_re(rawlogs, r"get_relay_src([\s\S]*)") + assert len(logs)==2 + assert "@relay_op" in logs[0] + assert "@relay_op" in logs[1] + + assert np.max(dx_jt-dx)<1e-5 and np.max(dw_jt-dw)<1e-5 + assert np.max(dx_jt_tune-dx)<1e-5 and np.max(dw_jt_tune-dw)<1e-5 + + def test_backward_nhwc_hwio(self): + n,c,H,W = 2,3,5,5 + o,i,h,w = 4,c,3,3 + a = np.random.rand(n,H,W,c).astype(np.float32) + b = np.random.rand(h,w,i,o).astype(np.float32) + da = np.random.rand(n,H,W,o).astype(np.float32) + jt.mkl_ops.mkl_conv_backward_x(b,da,H,W,1,1,1,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb") + dx = jt.mkl_ops.mkl_conv_backward_x(b,da,H,W,1,1,1,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb").data + dw = jt.mkl_ops.mkl_conv_backward_w(a,da,h,w,1,1,1,1,1,1,xformat="acdb",wformat="hwio",yformat="acdb").data + a_jt = jt.array(a) + b_jt = jt.array(b) + + with jt.flag_scope( + enable_tuner=0, + # compile_options={"test_mkl_conv":1} + ): + c_jt = conv_nhwc_hwio(a_jt, b_jt, 1, 1) * da + gs=jt.grad(c_jt,[a_jt,b_jt]) + gs.append(c_jt) + jt.fetch_sync(gs) + dx_jt=gs[0].data + dw_jt=gs[1].data + with jt.log_capture_scope( + log_v=10, + log_vprefix="tuner_manager=100,var_relay=100", + enable_tuner=1, + compile_options={"test_mkl_conv":2} + ) as rawlogs: + gs_tune=jt.grad(c_jt,[a_jt,b_jt]) + jt.fetch_sync(gs_tune) + dx_jt_tune=gs_tune[0].data + dw_jt_tune=gs_tune[1].data + logs = find_log_with_re(rawlogs, + "Run tuner conv: confidence\\((20)\\) candidates\\((.*)\\)$") + assert len(logs) == 2 + assert logs[0][0] == "20", "confidence of reorder should be 20" + candidates = simple_parser(logs[0][1]) + assert candidates == {"relay0":[1,0]}, candidates + # assert candidates == {"relay0":[1,0],"relay1":[1,0]}, candidates + + logs = find_log_with_re(rawlogs, r"get_relay_src([\s\S]*)") + assert len(logs)==2 + assert "@relay_op" in logs[0] + assert "@relay_op" in logs[1] + + assert np.max(dx_jt_tune-dx)<1e-5 and np.max(dw_jt_tune-dw)<1e-5 + assert np.max(dx_jt-dx)<1e-5 and np.max(dw_jt-dw)<1e-5 + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_mkl_test_op.py b/python/jittor/test/test_mkl_test_op.py new file mode 100644 index 00000000..f75c4a6e --- /dev/null +++ b/python/jittor/test/test_mkl_test_op.py @@ -0,0 +1,17 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import os + +@unittest.skipIf(not jt.compile_extern.use_mkl, "Not use mkl, Skip") +class TestMklTestOp(unittest.TestCase): + def test(self): + assert jt.mkl_ops.mkl_test().data==123 + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_models.py b/python/jittor/test/test_models.py new file mode 100644 index 00000000..93340a64 --- /dev/null +++ b/python/jittor/test/test_models.py @@ -0,0 +1,113 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Wenyang Zhou <576825820@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import jittor.models as jtmodels + +skip_this_test = False +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + import torchvision.models as tcmodels + from torch import nn +except: + torch = None + skip_this_test = True + +@unittest.skipIf(skip_this_test, "skip_this_test") +class test_models(unittest.TestCase): + @classmethod + def setUpClass(self): + self.models = [ + 'squeezenet1_0', + 'squeezenet1_1', + 'alexnet', + 'resnet18', + 'resnet34', + 'resnet50', + 'resnet101', + 'resnet152', + 'resnext50_32x4d', + 'resnext101_32x8d', + 'vgg11', + 'vgg11_bn', + 'vgg13', + 'vgg13_bn', + 'vgg16', + 'vgg16_bn', + 'vgg19', + 'vgg19_bn', + 'wide_resnet50_2', + 'wide_resnet101_2', + 'googlenet', + 'mobilenet_v2', + 'mnasnet0_5', + 'mnasnet0_75', + 'mnasnet1_0', + 'mnasnet1_3', + 'shufflenet_v2_x0_5', + 'shufflenet_v2_x1_0', + 'shufflenet_v2_x1_5', + 'shufflenet_v2_x2_0', + "densenet121", + "densenet161", + "densenet169", + 'inception_v3', + ] + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1) + def test_models(self): + with torch.no_grad(): + self.run_models() + + def run_models(self): + def to_cuda(x): + if jt.has_cuda: + return x.cuda() + return x + threshold = 1e-2 + # Define numpy input image + bs = 1 + test_img = np.random.random((bs,3,224,224)).astype('float32') + # test_img = np.random.random((bs,3,280,280)).astype('float32') + # Define pytorch & jittor input image + pytorch_test_img = to_cuda(torch.Tensor(test_img)) + jittor_test_img = jt.array(test_img) + for test_model in self.models: + if test_model == "inception_v3": + test_img = np.random.random((bs,3,300,300)).astype('float32') + pytorch_test_img = to_cuda(torch.Tensor(test_img)) + jittor_test_img = jt.array(test_img) + # Define pytorch & jittor model + pytorch_model = to_cuda(tcmodels.__dict__[test_model]()) + jittor_model = jtmodels.__dict__[test_model]() + # Set eval to avoid dropout layer + pytorch_model.eval() + jittor_model.eval() + # Jittor loads pytorch parameters to ensure forward alignment + jittor_model.load_parameters(pytorch_model.state_dict()) + # Judge pytorch & jittor forward relative error. If the differece is lower than threshold, this test passes. + pytorch_result = pytorch_model(pytorch_test_img) + jittor_result = jittor_model(jittor_test_img) + x = pytorch_result.detach().cpu().numpy() + 1 + y = jittor_result.data + 1 + relative_error = abs(x - y) / abs(y) + diff = relative_error.mean() + assert diff < threshold, f"[*] {test_model} forward fails..., Relative Error: {diff}" + print(f"[*] {test_model} forword passes with Relative Error {diff}") + jt.clean() + jt.gc() + torch.cuda.empty_cache() + print('all models pass test.') + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_mpi.py b/python/jittor/test/test_mpi.py new file mode 100644 index 00000000..f0d4081e --- /dev/null +++ b/python/jittor/test/test_mpi.py @@ -0,0 +1,82 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import os, sys +import jittor as jt +import numpy as np +mpi = jt.compile_extern.mpi + +@unittest.skipIf(not jt.in_mpi, "no inside mpirun") +class TestMpi(unittest.TestCase): + def test_mpi_test_op(self): + assert jt.compile_extern.mpi_ops.mpi_test("").data == 123 + + @unittest.skipIf(jt.compile_extern.nccl_ops is None, "no nccl") + @jt.flag_scope(use_cuda=1) + def test_nccl_with_mpi(self): + assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123 + + def test_mpi_broadcast(self): + for i in range(mpi.world_size()): + a = np.zeros(100) + mpi.world_rank() + mpi.broadcast(a, i) + assert (a == i).all() + + def test_mpi_dataset(self): + from jittor.dataset.dataset import Dataset + class ToyDataset(Dataset): + def __init__(self): + super().__init__() + self.set_attrs(batch_size=21, total_len=211) + + def __getitem__(self, index): + return index, index*index + + toy = ToyDataset() + offset = ((toy.batch_size-1) // mpi.world_size() + 1) * mpi.world_rank() + + for _ in range(2): + for i,(a,b) in enumerate(toy): + assert (a.data*a.data == b.data).all() + if mpi.world_rank() == 0: + if i == len(toy)-1: + assert a.shape[0] == 1 + c = np.array([210]) + else: + assert toy.real_batch_size == 11 + c = np.array(range(offset+i*toy.batch_size, offset+i*toy.batch_size + toy.real_batch_size)) + else: + if i == len(toy)-1: + assert a.shape[0] == 1 + c = np.array([210]) + else: + assert toy.real_batch_size == 10 + c = np.array(range(offset+i*toy.batch_size, offset+i*toy.batch_size + toy.real_batch_size)) + + assert (c==a.data).all(), (c, a.data) + +def run_mpi_test(num_procs, name): + if not jt.compile_extern.inside_mpi(): + mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun") + cmd = f"{mpirun_path} -np {num_procs} {sys.executable} -m jittor.test.{name} -v" + print("run cmd:", cmd) + assert os.system(cmd)==0, "run cmd failed: "+cmd + +@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found") +class TestMpiEntry(unittest.TestCase): + def test_entry(self): + run_mpi_test(2, "test_mpi") + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + def test_mpi_resnet_entry(self): + run_mpi_test(2, "test_resnet") + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_mpi_batchnorm.py b/python/jittor/test/test_mpi_batchnorm.py new file mode 100644 index 00000000..b601f0e8 --- /dev/null +++ b/python/jittor/test/test_mpi_batchnorm.py @@ -0,0 +1,112 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import os, sys +import jittor as jt +from jittor import init +from jittor import nn +import numpy as np +from jittor.test.test_mpi import run_mpi_test + +mpi = jt.compile_extern.mpi +if mpi: + n = mpi.world_size() + +class FakeMpiBatchNorm(nn.Module): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True): + assert affine == None + + self.num_features = num_features + self.is_train = is_train + self.eps = eps + self.momentum = momentum + self.weight = init.constant((num_features,), "float32", 1.0) + self.bias = init.constant((num_features,), "float32", 0.0) + self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad() + self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad() + + def execute(self, x, global_x): + if self.is_train: + xmean = jt.mean(global_x, dims=[0,2,3], keepdims=1) + x2mean = jt.mean(global_x*global_x, dims=[0,2,3], keepdims=1) + + xvar = x2mean-xmean*xmean + norm_x = (x-xmean)/jt.sqrt(xvar+self.eps) + self.running_mean.update(self.running_mean + + (xmean.sum([0,2,3])-self.running_mean)*self.momentum) + self.running_var.update(self.running_var + + (xvar.sum([0,2,3])-self.running_var)*self.momentum) + else: + running_mean = self.running_mean.broadcast(x, [0,2,3]) + running_var = self.running_var.broadcast(x, [0,2,3]) + norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps) + w = self.weight.broadcast(x, [0,2,3]) + b = self.bias.broadcast(x, [0,2,3]) + return norm_x * w + b + +@unittest.skipIf(not jt.in_mpi, "no inside mpirun") +class TestMpiBatchnorm(unittest.TestCase): + @classmethod + def setUpClass(self): + np.random.seed(0) + jt.seed(3) + + def test_batchnorm(self): + mpi = jt.compile_extern.mpi + data = np.random.rand(30,3,10,10).astype("float32") + x1 = jt.array(data) + stride = 30//n + x2 = jt.array(data[mpi.world_rank()*stride:(mpi.world_rank()+1)*stride,...]) + + bn1 = nn.BatchNorm(3, sync=False) + bn2 = nn.BatchNorm(3, sync=True) + bn3 = FakeMpiBatchNorm(3) + y1 = bn1(x1).data + y2 = bn2(x2).data + y3 = bn3(x2,x1).data + + assert np.allclose(y2, y3, atol=1e-4), (y2, y3) + assert np.allclose(bn1.running_mean.data, bn2.running_mean.data), \ + (bn1.running_mean.data, bn2.running_mean.data) + assert np.allclose(bn1.running_var.data, bn2.running_var.data) + + def test_batchnorm_backward(self): + mpi = jt.compile_extern.mpi + data = np.random.rand(30,3,10,10).astype("float32") + global_x = jt.array(data) + stride = 30//n + x = jt.array(data[mpi.world_rank()*stride:(mpi.world_rank()+1)*stride,...]) + + bn1 = nn.BatchNorm(3, sync=True) + bn2 = FakeMpiBatchNorm(3) + y1 = bn1(x) + y2 = bn2(x,global_x) + gs1 = jt.grad(y1,bn1.parameters()) + gs2 = jt.grad(y2,bn2.parameters()) + + assert np.allclose(y1.data, y2.data, atol=1e-5),(mpi.world_rank(),y1.data, y2.data, y1.data-y2.data) + assert len(gs1) == len(gs2) + for i in range(len(gs1)): + assert np.allclose(gs1[i].data, gs2[i].data, rtol=1e-2),(mpi.world_rank(),gs1[i].data, gs2[i].data,gs1[i].data-gs2[i].data) + + @unittest.skipIf(not jt.has_cuda, "no cuda") + @jt.flag_scope(use_cuda=1) + def test_batchnorm_cuda(self): + self.test_batchnorm() + self.test_batchnorm_backward() + + +@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found") +class TestMpiBatchnormEntry(unittest.TestCase): + def test(self): + run_mpi_test(2, "test_mpi_batchnorm") + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_mpi_in_py.py b/python/jittor/test/test_mpi_in_py.py new file mode 100644 index 00000000..cae3acf4 --- /dev/null +++ b/python/jittor/test/test_mpi_in_py.py @@ -0,0 +1,66 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import os, sys +import jittor as jt +import numpy as np +from jittor import nn +from jittor import dataset +mpi = jt.compile_extern.mpi + + +class Model(nn.Module): + def __init__(self, input_size): + self.linear1 = nn.Linear(input_size, 10) + self.relu1 = nn.ReLU() + self.linear2 = nn.Linear(10, 10) + def execute(self, x): + x = self.linear1(x) + x = self.relu1(x) + return self.linear2(x) + +def fork_with_mpi(num_procs=4): + import sys + if jt.in_mpi: + # you can mult other process output + if jt.rank != 0: + sys.stdout = open("/dev/null", "w") + return + else: + print(sys.argv) + cmd = " ".join(["mpirun", "-np", str(num_procs), sys.executable] + sys.argv) + print("[RUN CMD]:", cmd) + os.system(cmd) + exit(0) + +def main(): + mnist = dataset.MNIST() + model = Model(mnist[0][0].size) + sgd = jt.optim.SGD(model.parameters(), 1e-3) + fork_with_mpi() + + for data, label in mnist: + pred = model(data.reshape(data.shape[0], -1)) + # print(data.shape, label.shape, pred.shape) + loss = nn.cross_entropy_loss(pred, label) + sgd.step(loss) + print(jt.rank, mnist.epoch_id, mnist.batch_id, loss) + # break + + + +# class TestMpiInPy(unittest.TestCase): +# def test(self): +# main() + + +if __name__ == "__main__": + # unittest.main() + main() \ No newline at end of file diff --git a/python/jittor/test/test_mpi_op.py b/python/jittor/test/test_mpi_op.py new file mode 100644 index 00000000..81ebff4a --- /dev/null +++ b/python/jittor/test/test_mpi_op.py @@ -0,0 +1,74 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import os, sys +import jittor as jt +import numpy as np +from jittor.test.test_mpi import run_mpi_test + +mpi = jt.compile_extern.mpi +if mpi: + n = mpi.world_size() + +@unittest.skipIf(not jt.in_mpi, "no inside mpirun") +class TestMpiOps(unittest.TestCase): + @classmethod + def setUpClass(self): + np.random.seed(0) + jt.seed(3) + + def test_all_reduce(self): + x = jt.random([5, 5]) + y = x.mpi_all_reduce() + np.testing.assert_allclose(y.data, (x*n).data) + g = jt.grad(y,x) + np.testing.assert_allclose(g.data, np.ones([5,5])*n) + + def test_all_reduce_mean(self): + x = jt.random([5, 5]) + y = x.mpi_all_reduce("mean") + np.testing.assert_allclose(y.data, x.data) + g = jt.grad(y,x) + np.testing.assert_allclose(g.data, np.ones([5,5])) + + def test_broadcast(self): + data = jt.random([5, 5]) + if mpi.world_rank() == 0: + x = data + else: + x = jt.zeros([5, 5]) + y = x.mpi_broadcast(0) + np.testing.assert_allclose(y.data, data.data) + g = jt.grad(y,x) + if mpi.world_rank() == 0: + np.testing.assert_allclose(g.data, np.ones([5,5])*n) + else: + np.testing.assert_allclose(g.data, np.zeros([5,5])) + + def test_reduce(self): + x = jt.random([5, 5]) + y = x.mpi_reduce(root=0) + y.sync() + if mpi.world_rank() == 0: + np.testing.assert_allclose(y.data, (x*n).data) + else: + np.testing.assert_allclose(y.data, np.zeros([5,5])) + g = jt.grad(y,x) + print(mpi.world_rank(), g) + np.testing.assert_allclose(g.data, np.ones([5,5])) + + +@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found") +class TestMpiOpsEntry(unittest.TestCase): + def test(self): + run_mpi_test(2, "test_mpi_op") + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_nano_string.py b/python/jittor/test/test_nano_string.py new file mode 100644 index 00000000..1c72bcc1 --- /dev/null +++ b/python/jittor/test/test_nano_string.py @@ -0,0 +1,75 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import time +from .test_core import expect_error +import os + +mid = 0 +if hasattr(os, "uname") and "jittor" in os.uname()[1]: + mid = 1 + +class TestNanoString(unittest.TestCase): + def test(self): + dtype = jt.NanoString + t = time.time() + n = 1000000 + for i in range(n): + dtype("float") + t = (time.time() - t)/n + # t is about 0.01 for 100w loop + # 92ns one loop + print("nanostring time", t) + assert t < [1.5e-7, 1.9e-7][mid], t + + assert (jt.hash("asdasd") == 4152566416) + assert str(jt.NanoString("float"))=="float32" + assert jt.NanoString("float")=="float32" + # py_bind11: 7 + # Tuple call: 1.3 + # fast call (with or with not): 0.9 + # init call 1.5 + # int init: 1.2 + # dtype init(cache): 0.75 + # final: 1.0 + + def test_type(self): + import numpy as np + assert str(jt.NanoString(float)) == "float32" + assert str(jt.NanoString(np.float)) == "float32" + assert str(jt.NanoString(np.float32)) == "float32" + assert str(jt.NanoString(np.float64)) == "float64" + assert str(jt.NanoString(np.int8)) == "int8" + assert str(jt.NanoString(np.array([1,2,3]).dtype)) == "int64" + + assert str(jt.NanoString(jt.float)) == "float32" + assert str(jt.NanoString(jt.float32)) == "float32" + assert str(jt.NanoString(jt.float64)) == "float64" + assert str(jt.NanoString(jt.int8)) == "int8" + assert str(jt.NanoString(jt.array([1,2,3]).dtype)) == "int32" + assert str(jt.NanoString(jt.sum)) == "add" + + def get_error_str(call): + es = "" + try: + call() + except Exception as e: + es = str(e) + return es + + e = get_error_str(lambda: jt.code([1,], {}, [1], cpu_header="")) + assert "help(jt.ops.code)" in e + assert "cpu_header=str" in e + e = get_error_str(lambda: jt.NanoString([1,2,3], fuck=1)) + assert "fuck=int" in str(e) + assert "(list, )" in str(e) + + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_nano_vector.py b/python/jittor/test/test_nano_vector.py new file mode 100644 index 00000000..783396f8 --- /dev/null +++ b/python/jittor/test/test_nano_vector.py @@ -0,0 +1,49 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import time +from .test_core import expect_error + +class TestNanoVector(unittest.TestCase): + def test(self): + nvector = jt.NanoVector + nv = nvector() + nv.append(1) + nv.append(2) + nv.append(3) + nv.append(1<<40) + assert nv[3] == (1<<40) + assert str(nv) == "[1,2,3,1099511627776,]" + assert nv == [1,2,3,1099511627776,] + expect_error(lambda : nv.append(1<<40)) + assert len(nv)==4, nv + s = 0 + for a in nv: + s += a + assert (s==1+2+3+(1<<40)) + s = max(nv) + assert s == (1<<40) + a, b, c, d = nv + assert [a,b,c,d] == nv + assert nv[-1] == (1<<40) + assert nv[:2] == [1,2] + assert nv[:-2] == [1,2] + assert nv[::-1] == list(nv)[::-1], (list(nv)[::-1], nv[::-1]) + assert (nvector([1,2]) + nvector([3,4])) == [1,2,3,4] + a = nvector([1,2]) + a += [3,4] + assert a == [1,2,3,4], a + + def test_slice_bug(self): + a = jt.NanoVector([2,3,4,5]) + assert a[:] == [2,3,4,5] + assert a[1:] == [3,4,5] + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_nccl.py b/python/jittor/test/test_nccl.py new file mode 100644 index 00000000..04f02a5e --- /dev/null +++ b/python/jittor/test/test_nccl.py @@ -0,0 +1,19 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +import unittest + +@unittest.skipIf(jt.compile_extern.nccl_ops is None, "no nccl found") +class TestNccl(unittest.TestCase): + @jt.flag_scope(use_cuda=1) + def test_nccl(self): + assert jt.compile_extern.nccl_ops.nccl_test("").data == 123 + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_nccl_ops.py b/python/jittor/test/test_nccl_ops.py new file mode 100644 index 00000000..dc7cc45d --- /dev/null +++ b/python/jittor/test/test_nccl_ops.py @@ -0,0 +1,144 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import os, sys +import jittor as jt +import numpy as np +from jittor import nn +from jittor import nn, Module +import copy +from jittor.test.test_log import find_log_with_re +from jittor.test.test_mpi import run_mpi_test +from jittor.compile_extern import mpi, nccl_ops +n = 2 + +@unittest.skipIf(nccl_ops is None, "nccl not found") +class TestNcclOps(unittest.TestCase): + @classmethod + def setUpClass(self): + np.random.seed(0) + jt.seed(3) + + @jt.flag_scope(use_cuda=1) + def test_all_reduce(self): + with jt.log_capture_scope(enable_tuner=1, log_silent=1, + log_v=1, log_vprefix="op.cc=100,exe=1000" + ) as raw_log: + x = jt.random([5, 5]) + y = x.mpi_all_reduce() + assert np.allclose(y.data, (x*n).data) + g = jt.grad(y,x) + assert np.allclose(g.data, np.ones([5,5])*n) + + logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_all_reduce.*)") + assert len(logs)==2, len(logs) + + @jt.flag_scope(use_cuda=1) + def test_broadcast(self): + with jt.log_capture_scope(enable_tuner=1, log_silent=1, + log_v=1, log_vprefix="op.cc=100,exe=1000" + ) as raw_log: + data = jt.random([5, 5]) + if mpi.world_rank() == 0: + x = data + else: + x = jt.zeros([5, 5]) + y = x.mpi_broadcast(0) + assert np.allclose(y.data, data.data) + g = jt.grad(y.sum(),x) + g_ = g.data + if mpi.world_rank() == 0: + assert np.allclose(g_, np.ones([5,5])*n) + logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_broadcast.*)") + assert len(logs)==1, len(logs) + + @jt.flag_scope(use_cuda=1) + def test_reduce(self): + with jt.log_capture_scope(enable_tuner=1, log_silent=1, + log_v=1, log_vprefix="op.cc=100,exe=1000" + ) as raw_log: + x = jt.random([5, 5]) + y = x.mpi_reduce(root=0) + y_ = y.data + x_ = (x*n).data + if mpi.world_rank() == 0: + assert np.allclose(y_, x_) + g = jt.grad(y,x) + assert np.allclose(g.data, np.ones([5,5])) + logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_reduce.*)") + assert len(logs)==1, len(logs) + + @jt.flag_scope(use_cuda=1) + def test_sync(self): + + class Model(Module): + def __init__(self): + self.linear1 = nn.Linear(3, 3) + self.linear2 = nn.Linear(3, 1024, False) + + def execute(self, x): + x = self.linear1(x) + x = nn.relu(x) + return self.linear2(x) + + net = Model() + if mpi.world_rank() == 0: + net.linear1.weight *= 0 + net.linear2.weight *= 0 + net.linear1.bias *= 0 + net.linear1.weight += 1 + net.linear2.weight += 1 + net.linear1.bias += 1 + net.mpi_param_broadcast() + assert np.allclose(net.linear1.weight.data, jt.ones(net.linear1.weight.shape).data) + assert np.allclose(net.linear2.weight.data, jt.ones(net.linear2.weight.shape).data) + assert np.allclose(net.linear1.bias.data, jt.ones(net.linear1.bias.shape).data) + + @jt.flag_scope(use_cuda=1) + def test_optimizer(self): + + class Model2(Module): + def __init__(self, input_size): + self.linear1 = nn.Linear(input_size, 10) + self.relu1 = nn.Relu() + self.linear2 = nn.Linear(10, 1) + def execute(self, x): + x = self.linear1(x) + x = self.relu1(x) + return self.linear2(x) + + def get_data(n): + for i in range(n): + x = np.random.rand(50, 1) + y = x*x + yield jt.float32(x), jt.float32(y) + + num = 2000 + model = Model2(1) + model.mpi_param_broadcast() + optimizer = nn.SGD(model.parameters(), 0.1) + dataset = list(enumerate(get_data(num))) + for i in range(mpi.world_rank(), num, n): + id, (x, y) = dataset[i] + pred_y = model(x) + loss = (pred_y - y)**2 + loss_mean = loss.mean() + optimizer.step(loss_mean) + assert loss_mean.data < 0.0025, loss_mean.data + jt.clean() + +@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found") +class TestNcclOpsEntry(unittest.TestCase): + def test(self): + run_mpi_test(2, "test_nccl_ops") + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_new_fused_op.py b/python/jittor/test/test_new_fused_op.py new file mode 100644 index 00000000..c625fd34 --- /dev/null +++ b/python/jittor/test/test_new_fused_op.py @@ -0,0 +1,44 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import sys +import os +import jittor as jt +import unittest +import time +import numpy as np +from .test_log import find_log_with_re + +class TestNewFuse(unittest.TestCase): + @classmethod + def setUpClass(self): + return + + def check(self, h, w, cs, rs, pa, rtp, dim): + a = jt.random([h,w]) + a.sync() + + with jt.log_capture_scope( + log_v=0, log_vprefix="tuner_manager=100", + # this value is used for force compile + compile_options={"test_new_fused_op":1} + ) as logs: + amean=jt.mean(a, dims=[dim], keepdims=1) + a2mean=jt.mean(a*a, dims=[dim], keepdims=1) + norm_aa=(a-amean.broadcast_var(a))/(jt.sqrt(a2mean-amean*amean).broadcast_var(a)) + norm_aa.sync() + logs = find_log_with_re(logs, + "Run tuner reduce: confidence\\((.*)\\) candidates\\((.*)\\)$") + assert len(logs) == 3, logs + + def test_new_fuse(self): + self.check(8192,8192, 0, 0, 0, 5, 0) + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_node.py b/python/jittor/test/test_node.py new file mode 100644 index 00000000..fd392203 --- /dev/null +++ b/python/jittor/test/test_node.py @@ -0,0 +1,159 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from .test_core import expect_error +from jittor_utils import LOG +import time, os + +def check(hv, lv, lo): + import gc + gc.collect() + jt.graph_check() + a, b, c = jt.number_of_hold_vars(), jt.number_of_lived_vars(), jt.number_of_lived_ops() + assert (a,b,c)==(hv,lv,lo), (a, b, c, jt.dump_all_graphs().nodes_info) + +def get_xorshf96(seed=0): + '''Marsaglia's xorshf generator''' + a = [ + np.uint64(123456789+seed), + np.uint64(362436069+seed), + np.uint64(521288629+seed), + ] + def xorshf96(): + a[0] ^= a[0] << np.uint64(16) + a[0] ^= a[0] >> np.uint64(5) + a[0] ^= a[0] << np.uint64(1) + t = a[0] + a[0] = a[1] + a[1] = a[2] + a[2] = t ^ a[0] ^ a[1] + return int(a[2]) + # for _ in range(10): xorshf96() + return xorshf96 + +class TestNode(unittest.TestCase): + def test_lived(self): + jt.clean() + check(0,0,0) + a = jt.array(1.0).stop_fuse() + a.name('a') + b = jt.array(1.0).stop_fuse() + b.name('b') + check(2,2,2) + c = a * b + c.name('c') + check(3,3,3) + vc = c.numpy() + check(3,3,1) + da, db = jt.grad(c, [a, b]) + da.name('da') + db.name('db') + check(5,6,4) # dc, 3, da, 1, db, 1 + del a, b, c + check(2,5,3) + da.sync(), db.sync() + check(2,2,0) + del da, db + check(0,0,0) + + def test_pending(self): + a = jt.float([1,2,3]) + b = jt.float([1,2,3]) + c = a.float().float().float() * b.float().float().float() + del a + c.data + assert (c.data==[1,4,9]).all(), c.data + d, = jt.grad(c, [b]) + d.data + assert (d.data==[1,2,3]).all(), d.data + + def test_node_performance(self): + mode = os.environ.get("test_node_performance") + if mode==None or mode not in "12": + return + if mode=="1": + bc = lambda x: jt.broadcast(x, [1,1,1,1],[0,1,2]) + rd = lambda x: jt.sum(x) + else: + bc = lambda x: jt.reindex(x, [1,1,1,1],["i0+i1+i2+i3"]) + rd = lambda x: jt.reindex_reduce(x, "add", [1], ["i0+i1+i2+i3"]) + if jt.compiler.is_debug: return + def run(): + start_time = time.time() + fop_num = 10000 + fop_input_num = (2, 3) # (i,j) -> [i,i+j] -> [2, 5] + # fop_output_num = (1, 0) # [1,1] + inner_op_num = (0, 3) + fop_type_num = 63 # how many different fuse op + input_queue_num = 15 + queue = [1.0]*(input_queue_num+1) + x = get_xorshf96() + rand = lambda x, l, r: l+((x())&r) + ops = ["add", "subtract", "multiply", "divide"] + get_op = lambda x: ops[(x())&3] + for i in range(fop_num): + prev = bc(queue[rand(x,0,input_queue_num)]) + y = get_xorshf96(x()&fop_type_num) + inum = rand(y, *fop_input_num) + q = [prev] + for i in range(inum-1): + n = bc(queue[rand(x,0,input_queue_num)]) + prev = jt.binary(prev, n, get_op(y)) + q.append(prev) + innum = rand(y,*inner_op_num) + for _ in range(innum): + j = rand(y,0,len(q)-1) + n = q[j] + prev = jt.binary(prev, n, get_op(y)) + q[j] = prev + prev = rd(prev) + queue[rand(x,0,input_queue_num)] = prev + a = jt.array(0.0) + for x in queue: + a += x + LOG.i("build graph", time.time()-start_time, jt.liveness_info().values()) + start_time = time.time() + a.sync() + LOG.i("execute", time.time()-start_time) + # debug mode: build(0.68), execute(0.44) + # normal mode: build(0.56), execute(0.25) + # cast opt: build(0.50), execute(0.25) + # dtype opt: build(0.49), execute(0.25) + # pyjt opt: build(0.48), execute(0.25) + # ns opt: build(0.46), execute(0.24) + # nv opt: build(0.42), execute(0.23) + # nv opt: build(0.415),execute(0.225) + # jit_key opt: build(0.415),execute(0.15) + # jit_key opt: build(0.415),execute(0.11) + # sv opt: build(0.42), execute(0.12) + # noded opt: build(0.42), execute(0.10) + + # tcm opt: build(0.40), execute(0.10) + + # mode2: reindex + # jit_key opt: build(0.46),execute(0.12) + # noded opt: build(0.44),execute(0.11) + # for i in range(20): + # run() + + # version 1.3.2.6 retest(laptop) + # mode1: + # origin 0.296 exec(0.11) + # int32flag 0.298 exec(0.11) + # add order 0.299 exec(0.11) + # rm p1 rule 0.299 exec(0.11) + for i in range(20): + run() + import gc + gc.collect() + run() + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_notebooks.py b/python/jittor/test/test_notebooks.py new file mode 100644 index 00000000..a90cca2d --- /dev/null +++ b/python/jittor/test/test_notebooks.py @@ -0,0 +1,54 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest, os +import jittor as jt +from jittor import LOG +import sys +import jittor_utils as jit_utils + +dirname = os.path.join(jt.flags.jittor_path, "notebook") +notebook_dir = os.path.join(jit_utils.home(), ".cache","jittor","notebook") +tests = [] +for mdname in os.listdir(dirname): + if not mdname.endswith(".src.md"): continue + # temporary disable model_test + if "GAN" in mdname: continue + tests.append(mdname[:-3]) + +try: + jt.compiler.run_cmd("ipython --help") + has_ipython = True +except: + has_ipython = False + +def test(name): + LOG.i(f"Run test {name} from {dirname}") + ipynb_name = os.path.join(notebook_dir, name+".ipynb") + jt.compiler.run_cmd("ipython "+ipynb_name) + +def init(): + cmd = sys.executable+" "+os.path.join(dirname, "md_to_ipynb.py") + LOG.i("init notebooks:", cmd) + jt.compiler.run_cmd(cmd) + +src = """class TestNodebooks(unittest.TestCase): + @classmethod + def setUpClass(self): + init() +""" +for name in tests: + src += f""" + @unittest.skipIf(not has_ipython, "No IPython found") + def test_{name.replace(".src","")}(self): + test("{name}") + """ + +LOG.vvv("eval src\n"+src) +exec(src) + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_numpy_code_op.py b/python/jittor/test/test_numpy_code_op.py new file mode 100644 index 00000000..41ea1b35 --- /dev/null +++ b/python/jittor/test/test_numpy_code_op.py @@ -0,0 +1,196 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +from jittor import Function +import jittor as jt +import numpy +import ctypes +import sys + +try: + import cupy +except: + pass + +class TestCodeOp(unittest.TestCase): + def test_func(self): + class Func(Function): + def forward_code(self, np, data): + a = data["inputs"][0] + b = data["outputs"][0] + if (jt.flags.use_cuda==0): + assert isinstance(a,numpy.ndarray) + else: + assert isinstance(a,cupy.ndarray) + np.add(a,a,out=b) + + def backward_code(self, np, data): + a, dout = data["inputs"] + out = data["outputs"][0] + np.copyto(out, dout*2.0) + + def execute(self, a): + self.save_vars = a + return jt.numpy_code( + a.shape, + a.dtype, + [a], + self.forward_code, + ) + + def grad(self, grad_a): + a = self.save_vars + return jt.numpy_code( + a.shape, + a.dtype, + [a, grad_a], + self.backward_code, + ) + + def check(): + a = jt.random((5,1)) + func = Func() + b = func(a) + assert numpy.allclose(b.data,(a+a).data) + da = jt.grad(b,a) + one=numpy.ones(a.shape) + assert numpy.allclose(da.data,one*2.0) + + if jt.has_cuda: + with jt.flag_scope(use_cuda=1): + check() + check() + + def test(self): + def forward_code(np, data): + a = data["inputs"][0] + b = data["outputs"][0] + if (jt.flags.use_cuda==0): + assert isinstance(a,numpy.ndarray) + else: + assert isinstance(a,cupy.ndarray) + np.add(a,a,out=b) + + def backward_code(np, data): + dout = data["dout"] + out = data["outputs"][0] + np.copyto(out, dout*2.0) + + def check(): + a = jt.random((5,1)) + b = jt.numpy_code( + a.shape, + a.dtype, + [a], + forward_code, + [backward_code], + ) + assert numpy.allclose(b.data,(a+a).data) + da = jt.grad(b,a) + one=numpy.ones(a.shape) + assert numpy.allclose(da.data,one*2.0) + + if jt.has_cuda: + with jt.flag_scope(use_cuda=1): + check() + check() + + def test_multi_input(self): + def forward_code(np, data): + a,b = data["inputs"] + c,d = data["outputs"] + np.add(a,b,out=c) + np.subtract(a,b,out=d) + + def backward_code1(np, data): + dout = data["dout"] + out = data["outputs"][0] + np.copyto(out, dout) + + def backward_code2(np, data): + dout = data["dout"] + out_index = data["out_index"] + out = data["outputs"][0] + if out_index==0: + np.copyto(out, dout) + else: + np.negative(dout, out) + + def check(): + a = jt.random((5,1)) + b = jt.random((5,1)) + c, d = jt.numpy_code( + [a.shape, a.shape], + [a.dtype, a.dtype], + [a, b], + forward_code, + [backward_code1,backward_code2], + ) + assert numpy.allclose(c.data,(a+b).data) + assert numpy.allclose(d.data,(a-b).data) + dca, dcb = jt.grad(c,[a,b]) + dda, ddb = jt.grad(d,[a,b]) + one=numpy.ones(a.shape) + mone=one*-1.0 + assert numpy.allclose(dca.data,one) + assert numpy.allclose(dcb.data,one) + assert numpy.allclose(dda.data,one) + assert numpy.allclose(ddb.data,mone) + + if jt.has_cuda: + with jt.flag_scope(use_cuda=1): + check() + check() + + @unittest.skipIf(True, "Memory leak testing is not in progress, Skip") + def test_memory_leak(self): + def forward_code(np, data): + a,b = data["inputs"] + c,d = data["outputs"] + np.add(a,b,out=c) + np.subtract(a,b,out=d) + + def backward_code1(np, data): + dout = data["dout"] + out = data["outputs"][0] + np.copyto(out, dout) + + def backward_code2(np, data): + dout = data["dout"] + out_index = data["out_index"] + out = data["outputs"][0] + if out_index==0: + np.copyto(out, dout) + else: + np.negative(dout, out) + + for i in range(1000000): + a = jt.random((10000,1)) + b = jt.random((10000,1)) + c, d = jt.numpy_code( + [a.shape, a.shape], + [a.dtype, a.dtype], + [a, b], + forward_code, + [backward_code1,backward_code2], + ) + assert numpy.allclose(c.data,(a+b).data) + assert numpy.allclose(d.data,(a-b).data) + dca, dcb = jt.grad(c,[a,b]) + dda, ddb = jt.grad(d,[a,b]) + one=numpy.ones(a.shape) + mone=one*-1.0 + assert numpy.allclose(dca.data,one) + assert numpy.allclose(dcb.data,one) + assert numpy.allclose(dda.data,one) + assert numpy.allclose(ddb.data,mone) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_op_compiler.py b/python/jittor/test/test_op_compiler.py new file mode 100644 index 00000000..50001497 --- /dev/null +++ b/python/jittor/test/test_op_compiler.py @@ -0,0 +1,161 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +from jittor import LOG +import numpy as np +from .test_core import expect_error + +jit_eval = jt.core.op_compiler.eval +jit_precompile = jt.core.op_compiler.precompile + +class TestOpCompiler(unittest.TestCase): + def test_eval(self): + def check(expr, vars={}): + for k,v in vars.items(): + locals()[k] = int(v) + _v1 = None + _v2 = None + try: + _v1 = jit_eval(expr, vars) + except: + pass + try: + _v2 = eval(expr) + except: + pass + LOG.vv(f"check {expr} = {_v1}, {_v2}, {_v1 == _v2}") + assert _v1 == _v2 + check("10+2*6") + check("100 * 2 + 12") + check("100*2+12") + check("100 * ( 2 + 12 )") + check("100*(2+12)") + check("100 * ( 2 + 12 ) / 14") + check("100*(2+12)/14") + check("-1") + check("- 1") + vars = {"a":"123", "b":"2"} + check("a", vars) + check("a+b", vars) + # python divide is different with c++ + # check("a/b", vars) + check("-1 +a *b", vars) + check("*****", vars) + + def test_precompile_ifdef(self): + vars = {"JIT_a":"1"} + check = lambda expr, result: \ + self.assertEqual(jit_precompile(vars, expr), result) + check("#ifdef JIT_a\nxxx\n#endif", "xxx\n") + check("#ifdef JIT_a\nxxx\n#else\nyyy\n #endif", "xxx\n") + check("#ifndef JIT_a\nxxx\n#else\nyyy\n #endif", "yyy\n ") + check("#ifdef JIT_b\nxxx\n#else\nyyy\n #endif", "yyy\n ") + check("#ifdef b\nxxx\n#else\nyyy\n #endif", + "#ifdef b\nxxx\n#else\nyyy\n #endif") + for va in [0,1]: + for vb in [0,1]: + vars["JIT_a"] = "1" + vars["JIT_b"] = "1" + if not va: del vars["JIT_a"] + if not vb: del vars["JIT_b"] + check(( + "#ifdef JIT_a\n" + "#ifdef JIT_b\n" + "0\n" + "#else\n" + "1\n" + "#endif\n" + "#else\n" + "#ifdef JIT_b\n" + "2\n" + "#else\n" + "3\n" + "#endif\n" + "#endif\n" + ), f"{3 - (va*2+vb)}\n") + + def test_precompile(self): + vars = {"a":"2", "b":"5", "a1":"1", "a2":"2", "OP":"mean"} + check = lambda expr, result: \ + self.assertEqual(jit_precompile(vars, expr), result) + check("@", "@") + check("@a", "2") + # check("//@a\n@a", "//@a\n2") + check("//@a\n@a", "\n2") + # check("@a//@a", "2//@a") + check("@a//@a", "2") + check("@{-a +b* 2}", "8") + # check("@{-a +b* 2}/*@{-a +b* 2}*/", "8/*@{-a +b* 2}*/") + check("@{-a +b* 2}/*@{-a +b* 2}*/", "8") + check("@for(i,a,b,+@i)", "+2+3+4") + check("@for(i, a+1, b*2-3, -@{i*2})", " -6 -8 -10 -12") + check("@for(i, b, a,-1,@i)", "543") + check("@for(i, b, a,-1,@for(j,0,i,@i@j))", "505152535440414243303132") + check("@{a@{a-1}+10}", "11") + check("@{a@a}", "2") + check("@if(0,1,0)", "0") + check("@if(1,1,0)", "1") + check("@if(0,1)", "") + check("@if(1,1)", "1") + check("@for(i,0,8,@if(i%2,+@i))", "+1+3+5+7") + check("@{1<1}", "0") + check("@{!1}", "0") + check("@{!!1}", "1") + check("@{!!1<<2}", "4") + check("@{a1\n@Tx\n#else\n@Tx@@1\n#endif", "#if aa>1\nfloat\n#else\nfloat1\n#endif") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_opt_state_dict.py b/python/jittor/test/test_opt_state_dict.py new file mode 100644 index 00000000..fe8ad6c1 --- /dev/null +++ b/python/jittor/test/test_opt_state_dict.py @@ -0,0 +1,47 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import torch +import torch.nn as tnn + +class Net(tnn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = tnn.Conv2d(3, 6, 5) + self.pool = tnn.MaxPool2d(2, 2) + self.conv2 = tnn.Conv2d(6, 16, 5) + self.fc1 = tnn.Linear(16 * 5 * 5, 120) + self.fc2 = tnn.Linear(120, 84) + self.fc3 = tnn.Linear(84, 10) + + def forward(self, x): + x = self.pool((self.conv1(x))) + x = self.pool((self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = (self.fc1(x)) + x = (self.fc2(x)) + x = self.fc3(x) + return x + + +class TestOptStateDict(unittest.TestCase): + def test_opt_state_dict(self): + return + net = Net() + optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + # print(optimizer.state_dict()) + img = torch.rand((2,3,40,40)) + pred = net(img) + optim.zero_grad() + pred.sum().backward() + optim.step() + # print(optimizer.state_dict()) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_optimizer.py b/python/jittor/test/test_optimizer.py new file mode 100644 index 00000000..0eae049e --- /dev/null +++ b/python/jittor/test/test_optimizer.py @@ -0,0 +1,55 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from jittor import nn + +class TestOptimizer(unittest.TestCase): + def test_param_groups(self): + pa = jt.ones((1,)) + pb = jt.ones((1,)) + data = jt.ones((1,)) + opt = nn.SGD([ + {"params":[pa], "lr":0.1}, + {"params":[pb]}, + ], 1) + opt.step(pa*data+pb*data) + assert pa.data == 0.9 and pb.data == 0, (pa, pb) + + def test_clip_grad_norm(self): + a = jt.ones(2) + opt = jt.optim.SGD([a], 0.1) + + loss = a*a + opt.zero_grad() + opt.backward(loss) + opt.clip_grad_norm(0.01, 2) + assert np.allclose(opt.param_groups[0]['grads'][0].norm(), 0.01) + opt.step() + + def test_state_dict(self): + a = jt.ones(2) + opt = jt.optim.SGD([a], 0.1) + s = opt.state_dict() + # print(s) + opt.load_state_dict(s) + + def test_opt_grad(self): + a = jt.ones(2) + opt = jt.optim.SGD([a], 0.1) + opt.backward(a**2) + g = a.opt_grad(opt) + np.testing.assert_allclose(g.data, 2) + + + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_optimizer_save_load.py b/python/jittor/test/test_optimizer_save_load.py new file mode 100644 index 00000000..05b7813d --- /dev/null +++ b/python/jittor/test/test_optimizer_save_load.py @@ -0,0 +1,66 @@ +import unittest +import jittor as jt +import numpy as np +from jittor import nn +import os + + +def compare(x, y, shape=4): + assert((x == y).sum() == shape) + + +def test_optim(optimzer_type, **kwargs): + # input + + x = jt.rand(20, 2, 2) + y1 = [] + y2 = [] + + # model & optimizer 1 for save + linear1 = nn.Linear(2, 2) + opt = optimzer_type(linear1.parameters(), **kwargs) + for i in range(10): + y = linear1(x[i]) + y1.append(y) + opt.step(y) + opt_dict = opt.state_dict() + linear_dict = linear1.state_dict() + jt.save({'opt': opt_dict, 'linear': linear_dict}, "./optim_test.tar") + for i in range(10, 20, 1): + y = linear1(x[i]) + y1.append(y) + opt.step(y) + + # model & optimizer 2 for load + linear2 = nn.Linear(2, 2) + opt2 = optimzer_type(linear2.parameters(), **kwargs) + opt2_dict = jt.load("./optim_test.tar") + opt2.load_state_dict(opt2_dict['opt']) + linear2.load_state_dict(opt2_dict['linear']) + for i in range(10, 20, 1): + y = linear2(x[i]) + y2.append(y) + opt2.step(y) + + for i in range(10): + compare(y1[10+i], y2[i]) + + +class TestOptimizerSaveLoad(unittest.TestCase): + def test(self): + optims = [ + {'opt': jt.nn.SGD, 'kwargs': {'lr': 0.1, 'momentum': 1e-2, + 'weight_decay': 1e-2, 'dampening': 1e-3, 'nesterov': True}}, + {'opt': jt.nn.RMSprop, 'kwargs': {'lr': 0.1}}, + {'opt': jt.nn.Adam, 'kwargs': {'lr': 0.1, 'weight_decay': 1e-2}}, + {'opt': jt.nn.AdamW, 'kwargs': {'lr': 0.1, 'weight_decay': 1e-2}}, + ] + for optim in optims: + test_optim(optim['opt'], **optim['kwargs']) + + def tearDown(self): + os.remove("./optim_test.tar") + + +if __name__ == '__main__': + unittest.main() diff --git a/python/jittor/test/test_pad.py b/python/jittor/test/test_pad.py new file mode 100644 index 00000000..eb82f51d --- /dev/null +++ b/python/jittor/test/test_pad.py @@ -0,0 +1,87 @@ + +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Wenyang Zhou <576825820@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import jittor.nn as jnn + +skip_this_test = False + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + import torch.nn as tnn +except: + torch = None + tnn = None + skip_this_test = True + +def check_equal(arr, j_layer, p_layer): + jittor_arr = jt.array(arr) + pytorch_arr = torch.Tensor(arr) + jittor_result = j_layer(jittor_arr) + pytorch_result = p_layer(pytorch_arr) + assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy()) + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestPad(unittest.TestCase): + def test_pad(self): + # *************************************************************** + # Test ReplicationPad2d Layer + # *************************************************************** + arr = np.random.randn(16,3,224,224) + check_equal(arr, jnn.ReplicationPad2d(10), tnn.ReplicationPad2d(10)) + check_equal(arr, jnn.ReplicationPad2d((1,23,4,5)), tnn.ReplicationPad2d((1,23,4,5))) + check_equal(arr, jnn.ReplicationPad2d((1,0,1,5)), tnn.ReplicationPad2d((1,0,1,5))) + check_equal(arr, jnn.ReplicationPad2d((100)), tnn.ReplicationPad2d((100))) + + # *************************************************************** + # Test ConstantPad2d Layer + # *************************************************************** + arr = np.random.randn(16,3,224,224) + check_equal(arr, jnn.ConstantPad2d(10,-2), tnn.ConstantPad2d(10,-2)) + check_equal(arr, jnn.ConstantPad2d((2,3,34,1),10.2), tnn.ConstantPad2d((2,3,34,1),10.2)) + + arr = np.random.randn(16,3,224,10,10) + check_equal(arr, jnn.ConstantPad2d(10,-2), tnn.ConstantPad2d(10,-2)) + check_equal(arr, jnn.ConstantPad2d((2,3,34,1),10.2), tnn.ConstantPad2d((2,3,34,1),10.2)) + + # *************************************************************** + # Test ZeroPad2d Layer + # *************************************************************** + arr = np.random.randn(16,3,224,224) + check_equal(arr, jnn.ZeroPad2d(1), tnn.ZeroPad2d(1)) + check_equal(arr, jnn.ZeroPad2d((2,3,34,1)), tnn.ZeroPad2d((2,3,34,1))) + + # *************************************************************** + # Test ReflectionPad2d Layer + # *************************************************************** + arr = np.random.randn(16,3,224,224) + check_equal(arr, jnn.ReflectionPad2d(20), tnn.ReflectionPad2d(20)) + check_equal(arr, jnn.ReflectionPad2d((2,3,34,1)), tnn.ReflectionPad2d((2,3,34,1))) + check_equal(arr, jnn.ReflectionPad2d((10,123,34,1)), tnn.ReflectionPad2d((10,123,34,1))) + check_equal(arr, jnn.ReflectionPad2d((100)), tnn.ReflectionPad2d((100))) + + # *************************************************************** + # Test function pad + # *************************************************************** + arr = np.random.randn(16,3,224,224) + padding = (10,11,2,3) + for mode in ['constant','replicate','reflect','circular']: + j_data = jt.array(arr) + t_data = torch.tensor(arr) + t_output = tnn.functional.pad(t_data,padding,mode=mode).detach().numpy() + j_output = jnn.pad(j_data,padding,mode).numpy() + assert np.allclose(t_output,j_output) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_parallel_pass.py b/python/jittor/test/test_parallel_pass.py new file mode 100644 index 00000000..8076cbba --- /dev/null +++ b/python/jittor/test/test_parallel_pass.py @@ -0,0 +1,220 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import os +import numpy as np + +class SimpleAsmParser: + def __init__(self, src): + funcs = [] + for s in src.split(".globl"): + funcs.append(s.splitlines()) + self.funcs = funcs + + def count_instructions(self, func_name, ins_name): + f = None + for func in self.funcs: + if func_name in func[0]: + assert f is None, f"Duplicate func name {func_name}" + f = func + assert not (f is None), f"function {func_name} not found" + count = 0 + for ins in f: + if ins_name in ins: + count += 1 + return count + + +class TestParallelPass(unittest.TestCase): + def check(self, use_int32): + n = 1024 + a = jt.random((n, n)) + b = jt.random((n, n)) + a.data, b.data + with jt.profile_scope(compile_options = { + "compile_shapes":1, "parallel":2, "try_use_32bit_index":use_int32 + }, try_use_32bit_index = use_int32) as rep: + c = a + b + nc = c.data + assert len(rep) == 2 + assert (a.data+b.data==nc).all() + fname = rep[1][1] + with open(fname) as f: + src = f.read() + assert "thread_id" in src + with open(fname.replace(".cc", ".s")) as f: + asm = SimpleAsmParser(f.read()) + func_name = "run" + ca = asm.count_instructions(func_name, "vmova") + cu = asm.count_instructions(func_name, "vmovu") + return ca, cu + + def test_int32_align(self): + ca, cu = self.check(1) + if jt.flags.cc_type=="clang": + assert ca>1 and cu<=1, (ca, cu) + + def test_int64_align(self): + ca, cu = self.check(0) + if jt.flags.cc_type=="clang": + assert ca>1 and cu<=1, (ca, cu) + +class TestParallelPass2(TestParallelPass): + def check(self, use_int32): + n = 1024 + a = jt.random((n, n*8)) + b = jt.random((n*8,)) + a.data, b.data + with jt.profile_scope(compile_options = { + "compile_shapes":1, "parallel":1, "split1":n, "order1":1 + }, try_use_32bit_index = use_int32) as rep: + c = a - b + # def func(a, b, c, tid, num): + # for i in range(tid*1024, 1024*8, num*1024): + # for j in range(n): + # for k in range(n): + # c[j*1024*8 + i+k] = a[j*1024*8 + i+k] - b[i+k] + nc = c.data + assert len(rep) == 2 + assert (a.data-b.data==nc).all() + fname = rep[1][1] + with open(fname) as f: + src = f.read() + assert "thread_id" in src + with open(fname.replace(".cc", ".s")) as f: + asm = SimpleAsmParser(f.read()) + func_name = "run" + ca = asm.count_instructions(func_name, "vmova") + cu = asm.count_instructions(func_name, "vmovu") + return ca, cu + +class TestParallelPass3(unittest.TestCase): + def test(self): + def check(ndim, depth, tdim): + a = jt.random([16]*ndim) + a.sync() + compile_options = {"parallel":1, "merge_loop_var": self.merge_loop_var} + if depth is not None: + compile_options["max_parallel_depth"] = depth + with jt.profile_scope(compile_options=compile_options) as rep: + b = (a+a).data + assert np.allclose(a.data*2, b) + assert len(rep) == 2 + fname = rep[1][1] + with open(fname) as f: + src = f.read() + for i in range(tdim): + assert f"tnum{i}" in src + assert f"tnum{tdim}" not in src + self.merge_loop_var = 0 + check(1, None, 0) + check(2, None, 1) + check(3, None, 2) + check(4, None, 2) + check(5, None, 2) + check(5, 3, 3) + check(5, 4, 4) + check(5, 5, 5) + if jt.compiler.has_cuda: + with jt.flag_scope(use_cuda=1): + check(1, 2, 1) + check(2, 2, 2) + check(3, 2, 2) + check(4, 2, 2) + check(5, 2, 2) + check(5, 3, 3) + check(5, 4, 4) + check(5, 5, 5) + + def reduce_check(self, ndim, depth, tdim, rdim, has_atomic, order=[], split=[], **args): + shape = [8]*ndim + a = jt.random(shape) + a.sync() + config = { + "parallel":1, "max_parallel_depth":depth, "merge_loop_var": self.merge_loop_var + } + for k in args: + config[k] = args[k] + if not isinstance(rdim, list): + rdim = [rdim] + rdim = tuple(rdim) + nshape = [1024, 256, 128][len(rdim)] + for d in rdim: shape[d] = nshape + for i,o in enumerate(order): + config[f"order{i}"] = o + for i,o in enumerate(split): + config[f"split{i}"] = o + with jt.profile_scope( + compile_options = config, + enable_tuner = 0 + ) as rep: + b = a.sum(rdim).data + assert len(rep) == 2 + fname = rep[1][1] + with open(fname) as f: + src = f.read() + for i in range(tdim): + assert f"tnum{i}" in src + assert f"tnum{tdim}" not in src, f"tnum{tdim}" + src_has_atomic = "atomic_add" in src or "atomicAdd" in src + assert has_atomic == src_has_atomic + assert np.allclose(a.data.sum(rdim), b), (b.sum(), a.data.sum()) + + def test_reduce(self): + self.merge_loop_var = 0 + check = lambda *a, **kw: self.reduce_check(*a, **kw) + check(1, 2, 1, 0, 1) + check(2, 1, 1, 1, 0) + check(2, 1, 1, 0, 1) + check(2, 1, 1, 0, 1, [0,0]) + check(2, 1, 1, 0, 0, [0,1]) + check(2, 1, 1, 0, 0, [0,1], [0,64]) + check(2, 1, 1, [0,1], 1, [0,1]) + check(3, 1, 1, [1,2], 0) + check(3, 1, 1, [0,1], 1) + check(3, 1, 1, [0,1], 0, [0,0,2]) + check(3, 2, 2, [2], 0) + if jt.flags.use_cuda: + # loop is not merged so parallel depth 2 + check(3, 2, 2, [1], 1) + else: + check(3, 2, 1, [1], 0) + check(3, 2, 2, [1], 1, merge=0) + check(4, 2, 2, [2,3], 0) + check(4, 2, 2, [0,3], 1) + + def test_reduce_with_merge_loop_var(self): + self.merge_loop_var = 1 + check = lambda *a, **kw: self.reduce_check(*a, **kw) + check(1, 2, 1, 0, 1) + check(2, 1, 1, 1, 0) + check(2, 1, 1, 0, 1) + check(2, 1, 1, 0, 1, [0,0]) + check(2, 1, 1, 0, 0, [0,1]) + check(2, 1, 1, 0, 0, [0,1], [0,64]) + check(2, 1, 1, [0,1], 1, [0,1]) + check(3, 1, 1, [1,2], 0) + check(3, 1, 1, [0,1], 1) + check(3, 1, 1, [0,1], 0, [0,0,2]) + check(3, 2, 1, [2], 0) + if jt.flags.use_cuda: + # loop is not merged so parallel depth 2 + check(3, 2, 2, [1], 1) + else: + check(3, 2, 1, [1], 0) + check(3, 2, 2, [1], 1, merge=0) + check(4, 2, 1, [2,3], 0) + check(4, 2, 2, [0,3], 1) + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + def test_reduce_cuda(self): + with jt.flag_scope(use_cuda=1): + self.test_reduce() + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_param_list.py b/python/jittor/test/test_param_list.py new file mode 100644 index 00000000..1076d15e --- /dev/null +++ b/python/jittor/test/test_param_list.py @@ -0,0 +1,31 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np + + + +class TestParamList(unittest.TestCase): + def test_param_list(self): + ps = jt.nn.ParameterList([jt.array([1,2,3]), jt.rand(10)]) + assert len(ps.parameters()) == 2 + assert list(ps.state_dict().keys()) == ['0', '1'], ps.state_dict().keys() + + def test_with_module(self): + class Net(jt.nn.Module): + def __init__(self): + self.ps1 = jt.nn.ParameterList([jt.array([1,2,3]), jt.rand(10)]) + self.ps2 = jt.nn.ParameterDict({ + "aaa":jt.array([1,2,3]), + "bbb": jt.rand(10) + }) + net = Net() + assert list(net.state_dict().keys()) == ['ps1.0', 'ps1.1', 'ps2.aaa', 'ps2.bbb'] + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_profiler.py b/python/jittor/test/test_profiler.py new file mode 100644 index 00000000..d2ae6bc5 --- /dev/null +++ b/python/jittor/test/test_profiler.py @@ -0,0 +1,42 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import os + +class TestProfiler(unittest.TestCase): + def test_profiler(self): + a = jt.rand(1000,1000) + b = jt.rand(1000,1000) + jt.sync_all() + with jt.profile_scope(10, 100, profiler_record_peek=1) as rep: + jt.matmul(a, b).sync() + x = float(rep[-1][4]) + y = float(rep[-2][4]) + assert abs(x-y)/x < 1e-3 + + def test_marks(self): + a = jt.rand(1000,1000) + b = jt.rand(1000,1000) + jt.sync_all() + results = [] + with jt.profile_scope() as rep: + results.append(jt.matmul(a, b)) + with jt.profile_mark("mark1"): + results.append(jt.matmul(a, b)) + with jt.profile_mark("mark2"): + results.append(jt.matmul(a, b)) + with jt.profile_mark("mark3"): + results.append(jt.matmul(a, b)) + results.append(jt.matmul(a, b)) + jt.sync_all() + assert len(rep) == 6 + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_pytorch_converter.py b/python/jittor/test/test_pytorch_converter.py new file mode 100644 index 00000000..f80539ac --- /dev/null +++ b/python/jittor/test/test_pytorch_converter.py @@ -0,0 +1,238 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from jittor.utils.pytorch_converter import convert +import os + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + from torch import nn +except: + torch = None + +code=""" +import torch.nn as nn +import torch.utils.model_zoo as model_zoo + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + '''3x3 convolution with padding''' + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU() + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + return out + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU() + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + out = self.conv3(out) + out = self.bn3(out) + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7, stride=1) + self.fc = nn.Linear(512 * block.expansion, num_classes) + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + +def resnet18(pretrained=False, **kwargs): + '''Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + ''' + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) + return model + + +def resnet34(pretrained=False, **kwargs): + '''Constructs a ResNet-34 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + ''' + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) + return model + + +def resnet50(pretrained=False, **kwargs): + '''Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + ''' + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) + return model + + +def resnet101(pretrained=False, **kwargs): + '''Constructs a ResNet-101 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + ''' + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) + return model + + +def resnet152(pretrained=False, **kwargs): + '''Constructs a ResNet-152 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + ''' + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + if pretrained: + model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) + return model +""" + +@unittest.skipIf(torch is None, "pytorch not found.") +class TestPytorchConverter(unittest.TestCase): + def test_pytorch_converter(self): + name1 = os.path.join(jt.flags.cache_path, 'test_pytorch_converter_1.py') + print(f"save source code into {name1}") + with open(name1, 'w') as f: + f.write(code) + + ret = convert(code) + + name2 = os.path.join(jt.flags.cache_path, 'test_pytorch_converter_2.py') + print(f"save destination code into {name2}") + with open(name2, 'w') as f: + f.write(ret) + + from test_pytorch_converter_1 import resnet18 as torch_resnet18 + from test_pytorch_converter_2 import resnet18 as jittor_resnet18 + model_torch = torch_resnet18(False) + model_jittor = jittor_resnet18(False) + model_jittor.load_parameters(model_torch.state_dict()) + + img = np.random.randn(1,3,224,224).astype("float32") + img_torch = torch.Tensor(img) + img_jittor = jt.array(img) + + out_torch = model_torch(img_torch) + out_jittor = model_jittor(img_jittor) + assert abs((out_torch.cpu().detach().numpy() - out_jittor.data)).mean() < 1e-4 + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_random_op.py b/python/jittor/test/test_random_op.py new file mode 100644 index 00000000..120a7516 --- /dev/null +++ b/python/jittor/test/test_random_op.py @@ -0,0 +1,107 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +from jittor import nn, Module +from jittor.models import vgg, resnet +import numpy as np +import sys, os +import random +import math +import unittest +from .test_reorder_tuner import simple_parser +from .test_log import find_log_with_re + +skip_this_test = False +try: + jt.dirty_fix_pytorch_runtime_error() + import torch +except: + skip_this_test = True + + +class TestRandomOp(unittest.TestCase): + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1) + def test(self): + jt.set_seed(3) + with jt.log_capture_scope( + log_silent=1, + log_v=0, log_vprefix="op.cc=100" + ) as raw_log: + t = jt.random([5,5]) + t.data + logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + "curand_random" + ".*)") + assert len(logs)==1 + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1) + def test_float64(self): + jt.set_seed(3) + with jt.log_capture_scope( + log_silent=1, + log_v=0, log_vprefix="op.cc=100" + ) as raw_log: + t = jt.random([5,5], dtype='float64') + t.data + logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + "curand_random" + ".*)") + assert len(logs)==1 + + @unittest.skipIf(skip_this_test, "No Torch Found") + def test_normal(self): + from jittor import init + n = 10000 + r = 0.155 + a = init.gauss([n], "float32", 1, 3) + data = a.data + + assert (np.abs((data<(1-3)).mean() - r) < 0.1) + assert (np.abs((data<(1)).mean() - 0.5) < 0.1) + assert (np.abs((data<(1+3)).mean() - (1-r)) < 0.1) + + np_res = np.random.normal(1, 0.1, (100, 100)) + jt_res = jt.normal(1., 0.1, (100, 100)) + assert (np.abs(np_res.mean() - jt_res.data.mean()) < 0.1) + assert (np.abs(np_res.std() - jt_res.data.std()) < 0.1) + + np_res = torch.normal(torch.arange(1., 10000.), 1) + jt_res = jt.normal(jt.arange(1, 10000), 1) + assert (np.abs(np_res.mean() - jt_res.data.mean()) < 0.1) + assert (np.abs(np_res.std() - jt_res.data.std()) < 1) + + np_res = np.random.randn(100, 100) + jt_res = jt.randn(100, 100) + assert (np.abs(np_res.mean() - jt_res.data.mean()) < 0.1) + assert (np.abs(np_res.std() - jt_res.data.std()) < 0.1) + + np_res = np.random.rand(100, 100) + jt_res = jt.rand(100, 100) + assert (np.abs(np_res.mean() - jt_res.data.mean()) < 0.1) + assert (np.abs(np_res.std() - jt_res.data.std()) < 0.1) + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1) + def test_normal_cuda(self): + self.test_normal() + + def test_other_rand(self): + a = jt.array([1.0,2.0,3.0]) + b = jt.rand_like(a) + c = jt.randn_like(a) + assert b.shape == c.shape + assert b.shape == a.shape + print(b, c) + assert jt.randint(10, 20, (2000,)).min() == 10 + assert jt.randint(10, 20, (2000,)).max() == 19 + assert jt.randint(10, shape=(2000,)).max() == 9 + assert jt.randint_like(a, 10).shape == a.shape + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_reduce_op.py b/python/jittor/test/test_reduce_op.py new file mode 100644 index 00000000..711256b4 --- /dev/null +++ b/python/jittor/test/test_reduce_op.py @@ -0,0 +1,123 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from .test_core import expect_error + +def gen_data(shape): + num = np.multiply.reduce(shape) + a = np.arange(0, num) + return a.reshape(shape).astype("int32") + + +class TestReduceOp(unittest.TestCase): + def setUp(self): + self.keepdims = False + + def test1(self): + def check(a, op, dims): + if ("logical" in op) and jt.flags.use_cuda: + # TODO: atomic bool operation for cuda not + # supported yet + return + np_dims = jt_dims = dims + if dims == (): + np_dims = tuple(range(len(a.shape))) + x = eval(f"np.{op}.reduce(a, {np_dims}, keepdims={self.keepdims})") + y = eval(f"jt.reduce_{op}(a, {jt_dims}, keepdims={self.keepdims}).data") + if len(x.shape) == 0: x = np.array([x]).astype(a.dtype) + x = x.astype(a.dtype) + y = y.astype(a.dtype) + assert x.dtype == y.dtype and x.shape == y.shape and (x==y).all(), \ + f"\n{a.shape}\n{op}\n{dims}\n{x}\n{y}\n{x.dtype}\n{y.dtype}\n{a.dtype}" + + ia = [gen_data([2,3,4,5]), gen_data([5,3])] + idims = [(), (0,), (1,), (2,), (3,), (0, 2), (1,3), (1,2,3), 2, 3] + + iop = [ op[7:] for op in dir(jt) if op.startswith("reduce_")] + assert len(iop) >= 10, iop + for a in ia: + check(a, iop[0], idims[0]) + for op in iop: + check(ia[0], op, idims[0]) + for dims in idims: + check(ia[0], iop[0], dims) + expect_error(lambda: jt.reduce_add([1,2,3], 2)) + + def test_bool_reduce(self): + x = (jt.bool([1,0,1]) | jt.bool([0,1,0])).all().item() + assert x + x = (jt.bool([1,0,1]) & jt.bool([0,1,0])).any().item() + assert not x + + def test_bool_reduce2(self): + def gen_data(shape): + num = np.multiply.reduce(shape) + a = np.random.randint(2, size=[num]).astype(bool) + return a.reshape(shape).astype("int32") + + def check(a, op, dims): + if ("logical" in op) and jt.flags.use_cuda: + # TODO: atomic bool operation for cuda not + # supported yet + return + np_dims = jt_dims = dims + if dims == (): + np_dims = tuple(range(len(a.shape))) + x = eval(f"np.{op}.reduce(a, {np_dims}, keepdims={self.keepdims})") + y = eval(f"jt.reduce_{op}(a, {jt_dims}, keepdims={self.keepdims}).data") + if len(x.shape) == 0: x = np.array([x]).astype(a.dtype) + x = x.astype(a.dtype) + y = y.astype(a.dtype) + assert x.dtype == y.dtype and x.shape == y.shape and (x==y).all(), \ + f"\n{a.shape}\n{op}\n{dims}\n{x}\n{y}\n{x.dtype}\n{y.dtype}\n{a.dtype}" + + ia = [gen_data([2,3,4,5]), gen_data([5,3])] + idims = [(), (0,), (1,), (2,), (3,), (0, 2), (1,3), (1,2,3), 2, 3] + + iop = [ op[7:] for op in dir(jt) if op.startswith("reduce_")] + assert len(iop) >= 10, iop + for a in ia: + check(a, iop[0], idims[0]) + for op in iop: + check(ia[0], op, idims[1]) + for dims in idims: + check(ia[0], iop[0], dims) + expect_error(lambda: jt.reduce_add([1,2,3], 2)) + + +class TestReduceOp2(TestReduceOp): + def setUp(self): + self.keepdims = True + + +@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") +class TestReduceOpCuda(TestReduceOp): + def setUp(self): + jt.flags.use_cuda = 2 + self.keepdims = False + def tearDown(self): + jt.flags.use_cuda = 0 + +@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") +class TestReduceOpCuda2(TestReduceOp): + def setUp(self): + jt.flags.use_cuda = 2 + self.keepdims = True + def tearDown(self): + jt.flags.use_cuda = 0 + + +class TestReduceOpMisc(unittest.TestCase): + def test_negtive_dim(self): + a = jt.array([[1,2],[3,4]]) + assert (a.sum(-1).data == [3,7]).all() + assert (a.sum(-2).data == [4,6]).all() + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_reduce_tuner.py b/python/jittor/test/test_reduce_tuner.py new file mode 100644 index 00000000..07cc2f11 --- /dev/null +++ b/python/jittor/test/test_reduce_tuner.py @@ -0,0 +1,48 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import sys +import os +import jittor as jt +import unittest +import time +import numpy as np +from .test_reorder_tuner import simple_parser +from .test_log import find_log_with_re + +class TestReduceTuner(unittest.TestCase): + @classmethod + def setUpClass(self): + return + + def check(self, h, w, cs, rs, pa, rtp, dim): + a = jt.random([h,w]) + a.data + + with jt.log_capture_scope( + log_v=0, log_vprefix="tuner_manager=100", + # this value is used for force compile + compile_options={"test_reduce_tuner":1} + ) as logs: + amean=jt.mean(a, dims=[dim], keepdims=1) + a2mean=jt.mean(a*a, dims=[dim], keepdims=1) + norm_aa=(a-amean.broadcast_var(a))/(jt.sqrt(a2mean-amean*amean).broadcast_var(a)) + norm_aa.data + logs = find_log_with_re(logs, + "Run tuner reduce: confidence\\((20)\\) candidates\\((.*)\\)$") + assert len(logs) == 1 , logs + assert logs[0][0] == "20", "confidence of reorder should be 20" + candidates = simple_parser(logs[0][1]) + assert candidates == {"order0": [0,], "order1": [1,], "order2": [0,], "split1": [2048,], } + + def test_reduce_tuner(self): + self.check(8192,8192, 0, 0, 0, 5, 0) + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_reindex_op.py b/python/jittor/test/test_reindex_op.py new file mode 100644 index 00000000..c2ff4124 --- /dev/null +++ b/python/jittor/test/test_reindex_op.py @@ -0,0 +1,330 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from .test_core import expect_error +from .test_grad import ngrad + +def conv(x, w): + N,H,W,C = x.shape + Kh, Kw, _C, Kc = w.shape + assert C==_C + xx = x.reindex([N,H+Kh-1,W+Kw-1,Kh,Kw,C,Kc], [ + 'i0', # Nid + 'i1-i3', # Hid+Khid + 'i2-i4', # Wid+KWid + 'i5', # Cid + ]) + ww = w.broadcast_var(xx) + yy = xx*ww + y = yy.sum([3,4,5]) # Kh, Kw, C + return y, yy + +def conv_naive(x, w): + N,H,W,C = x.shape + Kh, Kw, _C, Kc = w.shape + assert C==_C + y = np.zeros([N,H+Kh-1,W+Kw-1,Kc]) + for i0 in range(N): + for i1 in range(H+Kh-1): + for i2 in range(W+Kw-1): + for i3 in range(Kh): + for i4 in range(Kw): + for i5 in range(C): + for i6 in range(Kc): + if i1-i3<0 or i2-i4<0 or i1-i3>=H or i2-i4>=W: continue + y[i0, i1, i2, i6] += x[i0, i1-i3, i2-i4, i5] * w[i3,i4,i5,i6] + return y + +def conv_transpose(x, w): + N,H,W,C = x.shape + Kh, Kw, _C, Kc = w.shape + assert C==_C + xx = x.reindex([N,H*2+Kh-1,W*2+Kw-1,Kh,Kw,C,Kc], [ + 'i0', # Nid + '(i1-i3)/2', # Hid+Khid + '(i2-i4)/2', # Wid+KWid + 'i5', # Cid + ], 0, ['(i1-i3)%2', '(i2-i4)%2']) + ww = w.broadcast_var(xx) + yy = xx*ww + y = yy.sum([3,4,5]) # Kh, Kw, C + return y, yy + +def conv_transpose_naive(x, w): + N,H,W,C = x.shape + Kh, Kw, _C, Kc = w.shape + assert C==_C + y = np.zeros([N,H*2+Kh-1,W*2+Kw-1,Kc]) + for i0 in range(N): + for i1 in range(H*2+Kh-1): + for i2 in range(W*2+Kw-1): + for i3 in range(Kh): + for i4 in range(Kw): + for i5 in range(C): + for i6 in range(Kc): + if (i1-i3)//2<0 or (i2-i4)//2<0 or (i1-i3)//2>=H or (i2-i4)//2>=W: continue + if (i1-i3)%2 or (i2-i4)%2: continue + y[i0, i1, i2, i6] += x[i0, (i1-i3)//2, (i2-i4)//2, i5] * w[i3,i4,i5,i6] + return y + + +def is_fused(x): + return 's0' in x.debug_msg() + +def check_fused(dim): + jt.clean() + graph = jt.dump_all_graphs() + fused = True + has_v = False + for node in graph.nodes_info: + shape = node.split('[')[-1].split(',') + ndim = len(shape)-1 + if ndim>dim: + has_v = True + if 's0' not in node: + fused = False + assert fused and has_v, graph.nodes_info + +def resize_and_crop(x, bbox, interpolation="nearest"): + N, k = bbox.shape + H, W = x.shape + assert k==4 + shape = [N,H,W] + # fx x cx + # +------------> + # fy | a dx | b + # | dy + # y | - o - + # | + # cy | c | d + # v + img = x + bb = [ bbox.reindex(shape, ["i0", str(i)]) for i in range(4) ] + hid = jt.index(shape, 1) + wid = jt.index(shape, 2) + one = jt.float(1).broadcast(shape) + x = bb[0]*jt.float(H-1)+hid*(bb[2]-bb[0]) + y = bb[1]*jt.float(W-1)+wid*(bb[3]-bb[1]) + if interpolation=="nearest": + return img.reindex_var([x.round_int(), y.round_int()]) + if interpolation=="bilinear": + fx, fy = x.floor_int(), y.floor_int() + cx, cy = fx+one, fy+one + dx, dy = x-fx, y-fy + a = img.reindex_var([fx, fy]) + b = img.reindex_var([cx, fy]) + c = img.reindex_var([fx, cy]) + d = img.reindex_var([cx, cy]) + dnx, dny = one-dx, one-dy + ab = dx*b + dnx*a + cd = dx*d + dnx*c + o = ab*dny + cd*dy + return o + raise(f"Not support {interpolation}") + + +def resize_and_crop_naive(x, bbox, interpolation="nearest"): + N, k = bbox.shape + H, W = x.shape + assert k==4 + y = np.zeros([N,H,W]) + if interpolation=="nearest": + for i in range(N): + for j in range(H): + for k in range(W): + nj = int(round(bbox[i,0]*(H-1)+j*(bbox[i,2]-bbox[i,0]))) + nk = int(round(bbox[i,1]*(W-1)+k*(bbox[i,3]-bbox[i,1]))) + if nk<0 or nk>=W or nj<0 or nj>=H: + y[i,j,k] = 0 + else: + y[i,j,k] = x[nj,nk] + return y + else: # bilinear + # fx x cx + # +------------> + # fy | a dx | b + # | dy + # y | - o - + # | + # cy | c | d + # v + from math import floor, ceil + data = x + output = y + sample = lambda nj, nk: 0 if nk<0 or nk>=W or nj<0 or nj>=H else data[nj,nk] + for i in range(N): + for j in range(H): + for k in range(W): + x = bbox[i,0]*(H-1)+j*(bbox[i,2]-bbox[i,0]) + y = bbox[i,1]*(W-1)+k*(bbox[i,3]-bbox[i,1]) + fx, fy = floor(x), floor(y) + cx, cy = fx+1, fy+1 + a = sample(fx, fy) + b = sample(cx, fy) + c = sample(fx, cy) + d = sample(cx, cy) + dx, dy = x-fx, y-fy + dnx, dny = 1-dx, 1-dy + ab = dx*b + dnx*a + cd = dx*d + dnx*c + o = ab*dny + cd*dy + output[i,j,k] = o + return output + +class TestReindexOp(unittest.TestCase): + def test_pad(self): + size = 10 + lpad = 3 + rpad = 4 + a = jt.random([size]) + b = a.reindex([size+lpad+rpad], [f"i0-{lpad}"], -1) + na, nb = jt.fetch_sync([a, b]) + assert (nb[lpad:lpad+size]==na).all() + assert (nb[:lpad]==-1).all() + assert (nb[-rpad:]==-1).all() + + def test_matmul(self): + size = 10 + a = jt.random([size,size]) + b = jt.random([size,size]) + cc = a.reindex([size,size,size],["i0","i1"]) * \ + b.reindex([size,size,size],["i1","i2"]) + c = cc.sum(dim=1) + na, nb, nc = jt.fetch_sync([a, b, c]) + assert is_fused(cc) + assert not is_fused(c) + check_fused(len(a.shape)) + npc = np.matmul(na,nb) + assert np.allclose(npc, nc) + + def test_conv(self): + N,H,W,C = 3,10,10,3 + Kh, Kw, Kc = 3, 3, 4 + x = jt.random([N,H,W,C]) + w = jt.random([Kh,Kw,C,Kc]) + y, yy = conv(x, w) + ny = y.data + assert ny.shape == (N, H+Kh-1, W+Kw-1, Kc), (ny.shape, [N, H+Kh-1, W+Kw-1, Kc]) + assert is_fused(yy) + check_fused(len(x.shape)) + npy = conv_naive(x.data, w.data) + assert np.allclose(npy, ny) + + def test_conv_transpose(self): + N,H,W,C = 3,10,10,3 + Kh, Kw, Kc = 3, 3, 4 + x = jt.random([N,H,W,C]) + w = jt.random([Kh,Kw,C,Kc]) + y, yy = conv_transpose(x, w) + ny = y.data + assert is_fused(yy) + check_fused(len(x.shape)) + npy = conv_transpose_naive(x.data, w.data) + assert np.allclose(npy, ny), (np.where(np.abs(npy-ny)>1e-4), npy[0,:4,:4,0], ny[0,:4,:4,0]) + + + def test_conv_transpose_group(self): + N,C,H,W = 3,6,10,10 + i,o,h,w = 6,2,3,3 + g = 2 + x = jt.random([N,C,H,W]) + ww = jt.random([i,o,h,w]) + ct = jt.nn.ConvTranspose(i,o*g,(h,w), groups=2, bias=False) + assert ct.weight.shape == ww.shape, (ct.weight.shape, ww.shape) + ct.weight = ww + y = ct(x) + y2 = jt.nn.conv_transpose(x, ww, groups=2) + np.testing.assert_allclose(y.data, y2.data) + + def test_conv_transpose_grad(self): + N,H,W,C = 1,5,5,2 + Kh, Kw, Kc = 3, 3, 2 + x = jt.random([N,H,W,C]) + w = jt.random([Kh,Kw,C,Kc]) + y, yy = conv_transpose(x, w) + mask = jt.random(y.shape) + loss = (y*mask).sum() + dx, dw = jt.grad(loss, [x, w]) + jdx, jdw = jt.fetch_sync([dx, dw]) + check_fused(len(x.shape)) + nmask = mask.data + _, (ndx, ndw) = ngrad(lambda args: \ + (conv_transpose_naive(args[0], args[1])*nmask).sum(), + [np.float64(x.data), np.float64(w.data)], 1e-7) + assert np.allclose(ndx, jdx), (ndx, jdx, ndx-jdx) + assert np.allclose(ndw, jdw), (ndw, jdw) + + def test_resize_and_crop(self): + jt.set_seed(3) + N, H, W = 4, 5, 5 + for interpolation in ["bilinear", "nearest"]: + x = jt.random([H, W]) + # x = jt.ones([H, W]) + bbox = jt.random([N, 4]) + # bbox = jt.float([[0.51,0.71,0.61,0.81]]) + # bbox = jt.float([[0,0,1,1]]) + y = resize_and_crop(x, bbox, interpolation) + ny = resize_and_crop_naive(x.data, bbox.data, interpolation) + assert np.allclose(y.data, ny), (y.data, ny, x.data) + + # test grad + mask = jt.random(y.shape) + # mask = jt.ones(y.shape) + nmask = mask.data + import gc; gc.collect() + loss = y*mask + dx, dbbox = jt.grad(loss, [x, bbox]) + _, (ndx, ndbbox) = ngrad(lambda args: \ + (resize_and_crop_naive(args[0], args[1], interpolation)*nmask).sum(), + [np.float64(x.data), np.float64(bbox.data)], 1e-7) + assert np.allclose(y.data, ny), (y.data, ny, x.data) + assert np.allclose(ndx, dx.data, 1e-2), (ndx, dx.data) + assert np.allclose(ndbbox, dbbox.data, 1e-2), (ndbbox, dbbox.data) + + + + def test_doc(self): + assert "Reindex Operator" in jt.reindex.__doc__ + + + + def test_reindex_fuse_error(self): + a = jt.zeros([10,10]) + b = jt.array([1]) + c = a.reindex([8,8], ["@e0(0)", "@e1(0,i0 / @e0(0))"], extras=[b, jt.ones([10,10])]) + c.sync() + # print(c) + + def test_reindex_wrong_op(self): + a = jt.zeros([10,10]) + b = jt.array([1]) + c = a.reindex([8,8], ["@e0(0) // 1", "@e0(0)"], extras=[b, b]) + expect_error(lambda: c.sync()) + + def test_reindex_memopt(self): + a = jt.zeros([10,10]) + b = jt.array([1,2,3]).name("b") + c = a.reindex([8,8], ["@e0(0) / 1", "@e0(0)"], extras=[b, b]) + del b + c.sync() + da = jt.grad(c, a) + da.sync() + + + +@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") +class TestReindexOpCuda(TestReindexOp): + def setUp(self): + # TODO: replace to 2 + jt.flags.use_cuda = 1 + def tearDown(self): + jt.flags.use_cuda = 0 + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_reindex_reduce_op.py b/python/jittor/test/test_reindex_reduce_op.py new file mode 100644 index 00000000..54fc1b17 --- /dev/null +++ b/python/jittor/test/test_reindex_reduce_op.py @@ -0,0 +1,108 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from .test_core import expect_error +from .test_grad import ngrad + +def pool(x, size, op): + N,H,W,C = x.shape + h = (H+size-1)//size + w = (W+size-1)//size + return x.reindex_reduce(op, [N,h,w,C], [ + "i0", # Nid + f"i1/{size}", # Hid + f"i2/{size}", # Wid + "i3", # Cid + ]) + +def pool_naive(x, size, op): + N,H,W,C = x.shape + h = (H+size-1)//size + w = (W+size-1)//size + y = np.zeros([N,h,w,C], dtype="float64") + x = np.float64(x) + if op=="maximum": + y[:] = -1e100 + fop = lambda x,y: np.maximum(x,y) + elif op=="minimum": + y[:] = 1e100 + fop = lambda x,y: np.minimum(x,y) + elif op=="multiply": + y[:] = 1 + fop = lambda x,y: x*y + else: + assert op=="add" + fop = lambda x,y: x+y + for i0 in range(N): + for i1 in range(H): + for i2 in range(W): + for i3 in range(C): + y[i0,i1//size,i2//size,i3] = \ + fop(y[i0,i1//size,i2//size,i3], x[i0,i1,i2,i3]) + return y + +ops = ["maximum", "minimum", "multiply", "add"] + +class TestReindexReduceOp(unittest.TestCase): + def test_pool(self): + N,H,W,C = 3,10,10,4 + size=3 + for op in ops: + x = jt.random([N,H,W,C]) + y = pool(x, size, op) + ny = pool_naive(x.data, size, op) + assert np.allclose(y.data, ny), (op, y.data, ny) + + def test_pool_grad(self): + jt.set_seed(1) + N,H,W,C = 2,7,7,2 + size=3 + # ops = ["maximum"] + for op in ops: + x = jt.random([N,H,W,C]) + y = pool(x, size, op) + mask = jt.random(y.shape) + loss = (y*mask).sum() + dx = jt.grad(loss, x) + jdx = dx.data + nx = x.data + nmask = mask.data + _, (ndx,) = ngrad(lambda args: (pool_naive(args[0], size, op)*nmask).sum(), [nx], 1e-6) + assert np.allclose(jdx, ndx), (op, jdx[0,:,:,0], ndx[0,:,:,0]) + + def test_fuse_error(self): + a = jt.array([1,2,3,4]) + b = jt.zeros((3,3)) + jt.sync_all() + c = b.reindex_reduce("add", [4,4], ["@e0(i0)", "@e0(i1)"], extras=[-a]) + c.sync() + + a = jt.zeros((3,3)) + b = jt.zeros((3,3)) + jt.sync_all() + c = b.reindex_reduce("add", [4,4], ["@e0(i0,i1)", "@e0(i1,i0)"], extras=[-a]) + c.sync() + + def test_error(self): + jt.random([3]).reindex_reduce("add", [3], ["i0"]) + expect_error(lambda: jt.random([3]).reindex_reduce("add", [3], [])) + expect_error(lambda: jt.random([3]).reindex_reduce("add", [3], ["i0","i0"])) + expect_error(lambda: jt.random([3]).reindex_reduce("???", [3], ["i0"])) + expect_error(lambda: jt.random([3]).reindex_reduce("add", [-1], ["i0"])) + +@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") +class TestReindexReduceOpCuda(TestReindexReduceOp): + def setUp(self): + # TODO: replace to 2 + jt.flags.use_cuda = 1 + def tearDown(self): + jt.flags.use_cuda = 0 + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_relu.py b/python/jittor/test/test_relu.py new file mode 100644 index 00000000..5addd85b --- /dev/null +++ b/python/jittor/test/test_relu.py @@ -0,0 +1,91 @@ + +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Wenyang Zhou <576825820@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import jittor.nn as jnn + +skip_this_test = False + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + import torch.nn as tnn +except: + torch = None + tnn = None + skip_this_test = True + +def check_equal(arr, j_layer, p_layer): + jittor_arr = jt.array(arr) + pytorch_arr = torch.Tensor(arr) + jittor_result = j_layer(jittor_arr) + pytorch_result = p_layer(pytorch_arr) + assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy(),rtol=1e-5,atol=1e-5) + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestRelu(unittest.TestCase): + def test_relu(self): + # *************************************************************** + # Test ReLU Layer + # *************************************************************** + arr = np.random.randn(16,10,224,224) + check_equal(arr, jnn.ReLU(), tnn.ReLU()) + + # *************************************************************** + # Test PReLU Layer + # *************************************************************** + arr = np.random.randn(16,10,224,224) + check_equal(arr, jnn.PReLU(), tnn.PReLU()) + check_equal(arr, jnn.PReLU(10, 99.9), tnn.PReLU(10, 99.9)) + check_equal(arr, jnn.PReLU(10, 2), tnn.PReLU(10, 2)) + check_equal(arr, jnn.PReLU(10, -0.2), tnn.PReLU(10, -0.2)) + + # *************************************************************** + # Test ReLU6 Layer + # *************************************************************** + arr = np.random.randn(16,10,224,224) + check_equal(arr, jnn.ReLU6(), tnn.ReLU6()) + + # *************************************************************** + # Test LeakyReLU Layer + # *************************************************************** + arr = np.random.randn(16,10,224,224) + check_equal(arr, jnn.LeakyReLU(), tnn.LeakyReLU()) + check_equal(arr, jnn.LeakyReLU(2), tnn.LeakyReLU(2)) + check_equal(arr, jnn.LeakyReLU(99.9), tnn.LeakyReLU(99.9)) + + # *************************************************************** + # Test ELU Layer + # *************************************************************** + arr = np.random.randn(16,10,224,224) + check_equal(arr, jnn.ELU(), tnn.ELU()) + check_equal(arr, jnn.ELU(0.3), tnn.ELU(0.3)) + check_equal(arr, jnn.ELU(2), tnn.ELU(2)) + check_equal(arr, jnn.ELU(99.9), tnn.ELU(99.9)) + + # *************************************************************** + # Test GELU Layer + # *************************************************************** + if hasattr(tnn, "GELU"): + arr = np.random.randn(16,10,224,224) + check_equal(arr, jnn.GELU(), tnn.GELU()) + + # *************************************************************** + # Test Softplus Layer + # *************************************************************** + arr = np.random.randn(16,10,224,224) + check_equal(arr, jnn.Softplus (), tnn.Softplus ()) + check_equal(arr, jnn.Softplus (2), tnn.Softplus (2)) + check_equal(arr, jnn.Softplus (2, 99.9), tnn.Softplus (2, 99.9)) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_reorder_tuner.py b/python/jittor/test/test_reorder_tuner.py new file mode 100644 index 00000000..4b7d90c3 --- /dev/null +++ b/python/jittor/test/test_reorder_tuner.py @@ -0,0 +1,107 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import os +from .test_log import find_log_with_re +from .test_fused_op import retry + +# simple parser for parsing {a:1,b:2} +def simple_parser(s): + ss = s.split(":") + res = [] + for x in ss[:-1]: + j = len(x)-1 + if j<0: + res.append("") + continue + while j>=0 and x[j] in " \n": + j -= 1 + k = j + while k>=0 and x[k] not in " \n{},": + k -= 1 + res.append(f'{x[:k+1]}"{x[k+1:j+1]}"{x[j+1:]}') + res.append(ss[-1]) + res = ":".join(res) + return eval(res) + +gid = 0 + +class TestReorderTuner(unittest.TestCase): + def test(self): + a = jt.ones((8,8,8)) + a.data + with jt.log_capture_scope( + log_v=0, log_vprefix="tuner_manager=100" + ) as logs: + b = a + a + b.data + + logs = find_log_with_re(logs, + "Run tuner reorder: confidence\\((.*)\\) candidates\\((.*)\\)$") + assert len(logs) == 1 + assert logs[0][0] == "1", "confidence of reorder should be 1" + candidates = simple_parser(logs[0][1]) + assert candidates == { + "order0":[0,], "order1":[0,1,], "order2":[0,1,2,] + } + + def test_with_split(self): + a = jt.ones((8,8,8)) + a.data + global gid + gid+=1 + with jt.log_capture_scope( + log_v=0, log_vprefix="tuner_manager=100", + compile_options={ + "split0": 4, "split1": 4, "split2": 4, + "test_reorder_tuner":gid + } + ) as logs: + b = a + a + b.data + + logs = find_log_with_re(logs, + "Run tuner reorder: confidence\\((.*)\\) candidates\\((.*)\\)$") + assert len(logs) == 1 + assert logs[0][0] == "1", "confidence of reorder should be 1" + candidates = simple_parser(logs[0][1]) + assert candidates == { + "order0":[0,], "order1":[0,1,], "order2":[0,1,2,], + "order3":[0,1,2,], "order4":[0,1,2,], "order5":[0,1,2,], + }, candidates + + @retry(10) + def test_searcher(self): + a = jt.ones((80,80,80)) + a.data + global gid + gid+=1 + with jt.log_capture_scope( + log_v=0, log_vprefix="jit_searcher=1000", + jit_search_kernel=1, + compile_options={ + "compile_shape":1, + "test_reorder_tuner":gid + } + ) as logs: + b = a + a + b.data + ls = find_log_with_re(logs, "Choices") + assert len(ls) == 6, (ls, logs) + ls = find_log_with_re(logs, "Best choices\\(.*\\): (.*)$") + assert len(ls) == 1 + best = simple_parser(ls[0]) + assert best == { + "compile_shape": 1, "order0": 0, "order1": 0, "order2": 0, + "test_reorder_tuner":gid + } + + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_repeat.py b/python/jittor/test/test_repeat.py new file mode 100644 index 00000000..7cef9566 --- /dev/null +++ b/python/jittor/test/test_repeat.py @@ -0,0 +1,61 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Zheng-Ning Liu +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** + + +import unittest +import jittor as jt +import numpy as np + + +class TestRepeatOp(unittest.TestCase): + def test_repeat(self): + np_a = np.arange(5) + jt_a = jt.array(np_a) + + np_b = np.tile(np_a, (2, 3)) + jt_b = jt.repeat(jt_a, (2, 3)) + + assert np.allclose(np_b, jt_b.data) + + np_b = np.tile(np_a, (2, 3, 1)) + jt_b = jt.repeat(jt_a, (2, 3, 1)) + + assert np.allclose(np_b, jt_b.data) + + np_a = np.arange(24).reshape(2, 3, 4) + jt_a = jt.array(np_a) + + np_b = np.tile(np_a, (2, 3)) + jt_b = jt.repeat(jt_a, (2, 3)) + + assert np.allclose(np_b, jt_b.data) + + + def test_highdim(self): + np_a = np.arange(64).reshape(2, 2, 2, 2, 2, 2) + jt_a = jt.array(np_a) + + np_b = np.tile(np_a, (2, 3)) + jt_b = jt.repeat(jt_a, (2, 3)) + + assert np.allclose(np_b, jt_b.data) + + np_b = np.tile(np_a, (2, 1, 1, 3)) + jt_b = jt.repeat(jt_a, (2, 1, 1, 3)) + + assert np.allclose(np_b, jt_b.data) + + np_b = np.tile(np_a, (2, 1, 1, 1, 3, 1)) + jt_b = jt.repeat(jt_a, (2, 1, 1, 1, 3, 1)) + + assert np.allclose(np_b, jt_b.data) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_reshape.py b/python/jittor/test/test_reshape.py new file mode 100644 index 00000000..25bf7f75 --- /dev/null +++ b/python/jittor/test/test_reshape.py @@ -0,0 +1,87 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from .test_grad import ngrad +from .test_cuda import test_cuda + +def get_node_info(s): + mem_ptr = s.split(')')[0].split(',')[-1] + name = s.split(')')[0].split(',')[-2] + return name, mem_ptr + +def get_info(graph): + bop = [ node for node in graph.nodes_info if node.startswith("Var")] + node_dict = {} + for bop_ in bop: + name, mem_ptr = get_node_info(bop_) + node_dict[name] = mem_ptr + return node_dict + +def check_equal(a, b): + eps = 1e-1 # icc error almost reaches 1e-1 + return abs(a - b) < eps + +class TestReshapeOp(unittest.TestCase): + def test_reshape(self): + a = jt.random([123, 456, 789]).name("a") + b = jt.reshape(a, [123 * 2, int(789 * 456 / 2)]).name("b") + c = jt.reshape(b, [123 * 456 * 789]).name("c") + d = jt.reshape(c, [2, int(123 / 3), 789, int(456 / 2), 3]).name("d") + e = jt.reshape(d, [2, int(123 / 3), 789, -1, 3]).name("e") + assert b.shape == [123 * 2, int(789 * 456 / 2)] + assert c.shape == [123 * 456 * 789] + assert d.shape == [2, int(123 / 3), 789, int(456 / 2), 3] + assert e.shape == [2, int(123 / 3), 789, int(456 / 2), 3] + a_mean = a.mean().data + b_mean = b.mean().data + c_mean = c.mean().data + d_mean = d.mean().data + e_mean = e.mean().data + a = (a + 1).name("new_a") + new_a_mean = a.mean().data + new_b_mean = b.mean().data + node_dict = get_info(jt.dump_all_graphs()) + assert check_equal(a_mean, b_mean), f"{a_mean} != {b_mean}" + assert check_equal(a_mean, c_mean), f"{a_mean} != {c_mean}" + assert check_equal(a_mean, d_mean), f"{a_mean} != {d_mean}" + assert check_equal(a_mean, e_mean), f"{a_mean} != {e_mean}" + assert check_equal(b_mean, new_b_mean), f"{b_mean} != {new_b_mean}" + assert not check_equal(a_mean, new_a_mean), f"{a_mean} == {new_a_mean}" + assert node_dict['a'] == node_dict['b'] + assert node_dict['a'] == node_dict['c'] + assert node_dict['a'] == node_dict['d'] + assert node_dict['a'] == node_dict['e'] + + def test_view(self): + a = jt.ones([2,3,4]) + assert a.view(2,-1).shape == [2,12] + + def test_flatten(self): + a = jt.ones([2,3,4]) + assert a.flatten().shape == [24] + assert a.flatten(1).shape == [2,12] + assert a.flatten(0,-2).shape == [6,4] + + def test_reshape_var(self): + a = jt.zeros(10) + b = a.reshape(a.shape) + + def test_reshape_empty(self): + a = jt.array([]) + b = a.reshape(0, 1, 2) + assert b.shape == [0, 1, 2] + b = a.reshape(0, -1) + assert b.shape == [0, 0] + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_resize_and_crop.py b/python/jittor/test/test_resize_and_crop.py new file mode 100644 index 00000000..bbe96c61 --- /dev/null +++ b/python/jittor/test/test_resize_and_crop.py @@ -0,0 +1,142 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import random +import os + +import numpy as np +import jittor.nn as jnn +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + import torch.nn as tnn +except: + torch = None + tnn = None + skip_this_test = True + +mid = 0 +if hasattr(os, "uname") and "jittor" in os.uname()[1]: + mid = 1 + +def resize_and_crop(x, bbox, interpolation="nearest", out_size=[224,224]): + N, k = bbox.shape + H, W, C = x.shape + assert k==4 + shape = [N, out_size[0], out_size[1], C] + # shape = [N,H,W] + # fx x cx + # +------------> + # fy | a dx | b + # | dy + # y | - o - + # | + # cy | c | d + # v + img = x + bb = [ bbox.reindex(shape, ["i0", str(i)]) for i in range(4) ] + hid = jt.index(shape, dim=1) + wid = jt.index(shape, dim=2) + cid = jt.index(shape, dim=3) + one = jt.array(1.0).broadcast(shape) + x = bb[0]*(H-1.0)+hid*((H-1)*1.0/(shape[1]-1))*(bb[2]-bb[0]) + y = bb[1]*(W-1.0)+wid*((W-1)*1.0/(shape[2]-1))*(bb[3]-bb[1]) + if interpolation=="nearest": + return img.reindex([x.round_int(), y.round_int(), cid]) + if interpolation=="bilinear": + fx, fy = x.floor_int(), y.floor_int() + cx, cy = fx+one, fy+one + dx, dy = x-fx, y-fy + a = img.reindex_var([fx, fy, cid]) + b = img.reindex_var([cx, fy, cid]) + c = img.reindex_var([fx, cy, cid]) + d = img.reindex_var([cx, cy, cid]) + dnx, dny = one-dx, one-dy + ab = dx*b + dnx*a + cd = dx*d + dnx*c + o = ab*dny + cd*dy + return o + raise(f"Not support {interpolation}") + +def test_case(box_num, out_size, time_limit): + boxes = [] + for i in range(box_num): + t = [random.random() * 0.9, random.random() * 0.9, random.random() * 0.9, random.random() * 0.9] + t2 = [min(t[0], t[2]), min(t[1], t[3]), max(t[0], t[2]) + 0.1, max(t[1], t[3]) + 0.1] + boxes.append(t2) + img = jt.random([121, 121, 3]) + out = resize_and_crop(img, jt.array(boxes), interpolation='bilinear', out_size=out_size) + with jt.profile_scope() as rep: + our_out = out.data + t = 0 + fused_op_num = 0 + for i in range(1, len(rep)): + t += float(rep[i][3]) / 1e9 + name = rep[i][0] + if name.startswith('«') and (not '«graph:«' in name): + fused_op_num += 1 + assert fused_op_num == 1, fused_op_num + assert t <= time_limit, t + +def check_equal(arr, j_layer, p_layer): + jittor_arr = jt.array(arr) + pytorch_arr = torch.Tensor(arr) + jittor_result = j_layer(jittor_arr) + pytorch_result = p_layer(pytorch_arr) + np.testing.assert_allclose(pytorch_result.detach().numpy(), jittor_result.numpy(), rtol=1e-6) + +class TestResizeAndCrop(unittest.TestCase): + def test(self): + test_case(100, [224, 224], 0.45) + test_case(100, [180, 224], 0.3) + test_case(20, [1024, 1024], [1.2, 1.8][mid]) + test_case(20, [1024, 666], [0.8,1.0][mid]) + + @unittest.skipIf(torch is None, "no torch found") + def test_resize(self): + import torch.nn.functional as F + x = np.array(range(2*3*25)).reshape(2,3,5,5).astype("float32") + for r_size in [3,4,5,6]: + for align_corners in [True,False]: + check_equal(x, + jnn.Resize((r_size, r_size), 'bilinear', align_corners), + lambda x: F.interpolate(x, size=(r_size, r_size), mode='bilinear',align_corners=align_corners)) + + @unittest.skipIf(torch is None, "no torch found") + def test_upsample(self): + arr = np.random.randn(2,3,224,224) + check_equal(arr, jnn.Upsample(scale_factor=2), tnn.Upsample(scale_factor=2)) + check_equal(arr, jnn.Upsample(scale_factor=0.5), tnn.Upsample(scale_factor=0.5)) + # pytorch change behav when scale_factor changed + # this test cannot pass + # check_equal(arr, jnn.Upsample(scale_factor=0.2), tnn.Upsample(scale_factor=0.2)) + + @unittest.skipIf(torch is None, "no torch found") + def test_pixelshuffle(self): + arr = np.random.randn(2,4,224,224) + check_equal(arr, jnn.PixelShuffle(upscale_factor=2), tnn.PixelShuffle(upscale_factor=2)) + arr = np.random.randn(1,3*3,224,224) + check_equal(arr, jnn.PixelShuffle(upscale_factor=3), tnn.PixelShuffle(upscale_factor=3)) + + def test_resize(self): + arr = np.random.randn(1,1,2,2) + check_equal(arr, jnn.Resize((4,4)), tnn.Upsample(scale_factor=2)) + # check_equal(arr, jnn.Upsample(scale_factor=0.5), tnn.Upsample(scale_factor=0.5)) + + def test_interpolate(self): + a = jt.rand(1,3,64,64) + b = jt.nn.interpolate(a, scale_factor=0.5) + b.sync() + assert b.shape == (1,3,32,32) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_resnet.py b/python/jittor/test/test_resnet.py new file mode 100644 index 00000000..45c75ae2 --- /dev/null +++ b/python/jittor/test/test_resnet.py @@ -0,0 +1,151 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Meng-Hao Guo +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +from jittor import nn, Module +from jittor.models import resnet +import numpy as np +import sys, os +import random +import math +import unittest +from jittor.test.test_reorder_tuner import simple_parser +from jittor.test.test_log import find_log_with_re +from jittor.dataset.mnist import MNIST +import jittor.transform as trans +import time + +skip_this_test = False +if os.name == 'nt': + skip_this_test = True + +class MnistNet(Module): + def __init__(self): + self.model = resnet.Resnet18() + self.layer = nn.Linear(1000,10) + def execute(self, x): + x = self.model(x) + x = self.layer(x) + return x + +@unittest.skipIf(skip_this_test, "skip_this_test") +class TestResnetFp32(unittest.TestCase): + # setup random seed + def setup_seed(self, seed): + np.random.seed(seed) + random.seed(seed) + jt.seed(seed) + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1, use_stat_allocator=1) + def test_resnet(self): + self.setup_seed(1) + + # hyper-parameters + self.batch_size = int(os.environ.get("TEST_BATCH_SIZE", "100")) + self.weight_decay = 0.0001 + self.momentum = 0.9 + self.learning_rate = 0.1 + if jt.flags.amp_reg: + self.learning_rate = 0.01 + # mnist dataset + self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \ + .set_attrs(batch_size=self.batch_size, shuffle=True) + self.train_loader.num_workers = 4 + + loss_list=[] + acc_list=[] + mnist_net = MnistNet() + global prev + prev = time.time() + SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay) + self.train_loader.endless = True + + for data, target in self.train_loader: + batch_id = self.train_loader.batch_id + epoch_id = self.train_loader.epoch_id + data = data.float_auto() + + # train step + # with jt.log_capture_scope( + # log_silent=1, + # log_v=1, log_vprefix="op.cc=100,exe=10", + # ) as logs: + output = mnist_net(data) + loss = nn.cross_entropy_loss(output, target) + SGD.step(loss) + def callback(epoch_id, batch_id, loss, output, target): + # print train info + global prev + pred = np.argmax(output, axis=1) + acc = np.mean(target==pred) + loss_list.append(loss[0]) + acc_list.append(acc) + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}' + .format(epoch_id, batch_id, 600,1. * batch_id / 6.0, loss[0], acc, time.time()-prev)) + # prev = time.time() + # async version + jt.fetch(epoch_id, batch_id, loss, output, target, callback) + # sync version + # callback(epoch_id, batch_id, loss.numpy(), output.numpy(), target.numpy()) + + # log_conv = find_log_with_re(logs, + # "Jit op key (not )?found: ((mkl)|(cudnn))_conv.*") + # log_matmul = find_log_with_re(logs, + # "Jit op key (not )?found: ((mkl)|(cublas))_matmul.*") + # if batch_id > 2: + # assert len(log_conv)==59 and len(log_matmul)==6, (len(log_conv), len(log_matmul)) + + mem_used = jt.flags.stat_allocator_total_alloc_byte \ + -jt.flags.stat_allocator_total_free_byte + # assert mem_used < 4e9, mem_used + # TODO: why bigger? + assert mem_used < 5.6e9, mem_used + # example log: + # Train Epoch: 0 [0/100 (0%)] Loss: 2.352903 Acc: 0.110000 + # Train Epoch: 0 [1/100 (1%)] Loss: 2.840830 Acc: 0.080000 + # Train Epoch: 0 [2/100 (2%)] Loss: 3.473594 Acc: 0.100000 + # Train Epoch: 0 [3/100 (3%)] Loss: 3.131615 Acc: 0.200000 + # Train Epoch: 0 [4/100 (4%)] Loss: 2.524094 Acc: 0.230000 + # Train Epoch: 0 [5/100 (5%)] Loss: 7.780025 Acc: 0.080000 + # Train Epoch: 0 [6/100 (6%)] Loss: 3.890721 Acc: 0.160000 + # Train Epoch: 0 [7/100 (7%)] Loss: 6.370137 Acc: 0.140000 + # Train Epoch: 0 [8/100 (8%)] Loss: 11.390827 Acc: 0.150000 + # Train Epoch: 0 [9/100 (9%)] Loss: 21.598564 Acc: 0.080000 + # Train Epoch: 0 [10/100 (10%)] Loss: 23.369165 Acc: 0.130000 + # Train Epoch: 0 [20/100 (20%)] Loss: 4.804510 Acc: 0.100000 + # Train Epoch: 0 [30/100 (30%)] Loss: 3.393924 Acc: 0.110000 + # Train Epoch: 0 [40/100 (40%)] Loss: 2.286762 Acc: 0.130000 + # Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000 + + if jt.flags.amp_reg: + continue + if jt.in_mpi: + assert jt.core.number_of_lived_vars() < 8100, jt.core.number_of_lived_vars() + else: + assert jt.core.number_of_lived_vars() < 7000, jt.core.number_of_lived_vars() + if self.train_loader.epoch_id >= 2: + break + + jt.sync_all(True) + assert np.mean(loss_list[-50:])<0.5 + assert np.mean(acc_list[-50:])>0.8 + + +@unittest.skipIf(skip_this_test, "skip_this_test") +class TestResnetFp16(TestResnetFp32): + def setup(self): + jt.flags.auto_mixed_precision_level = 5 + + def tearDown(self): + jt.flags.auto_mixed_precision_level = 0 + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_ring_buffer.py b/python/jittor/test/test_ring_buffer.py new file mode 100644 index 00000000..8d243bf4 --- /dev/null +++ b/python/jittor/test/test_ring_buffer.py @@ -0,0 +1,76 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +from jittor_utils.ring_buffer import * +import unittest + +def test_ring_buffer(): + buffer = mp.Array('c', 8000, lock=False) + buffer = RingBuffer(buffer) + def test_send_recv(data): + print("test send recv", type(data)) + buffer.send(data) + recv = buffer.recv() + if isinstance(recv, np.ndarray): + assert (recv == data).all() + else: + assert data == recv + test_send_recv("float32") + test_send_recv("") + test_send_recv("xxxxxxxxxx") + + test_send_recv(1) + test_send_recv(100000000000) + + test_send_recv(1e-5) + test_send_recv(100000000000.0) + + test_send_recv([1,0.2]) + test_send_recv({'asd':1}) + + test_send_recv(np.random.rand(10,10)) + +def test_ring_buffer_allocator(p=0.7): + print("test_ring_buffer_allocator", p) + n = 1000 + buffer = RingBufferAllocator(n) + m = 10000 + sizes = [0]*m + a = [-1]*n + l = 0 + r = 0 + for i in range(m): + if l==r or random.random()<0.7: + size = random.randint(10, 20) + location = buffer.alloc(size) + if location is not None: + sizes[r] = size + for j in range(location, location+size): + a[j] = r + r += 1 + continue + assert l. +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +import unittest +import numpy as np +import random +from .test_core import expect_error +from jittor.dataset.mnist import MNIST +import jittor.transform as trans +from tqdm import tqdm + +class BBox: + def __init__(self, x): + self.x = x + + def __eq__(self, other): + return bool((self.x == other.x).all()) + +def test_ring_buffer(): + buffer = jt.RingBuffer(2000) + def test_send_recv(data): + print("test send recv", type(data)) + buffer.push(data) + recv = buffer.pop() + if isinstance(data, (np.ndarray, jt.Var)): + assert (recv == data).all() + else: + assert data == recv + + n_byte = 0 + test_send_recv(1) + n_byte += 1 + 8 + assert n_byte == buffer.total_pop() and n_byte == buffer.total_push() + test_send_recv(100000000000) + n_byte += 1 + 8 + assert n_byte == buffer.total_pop() and n_byte == buffer.total_push() + + test_send_recv(1e-5) + n_byte += 1 + 8 + assert n_byte == buffer.total_pop() and n_byte == buffer.total_push() + test_send_recv(100000000000.0) + n_byte += 1 + 8 + assert n_byte == buffer.total_pop() and n_byte == buffer.total_push() + + test_send_recv("float32") + n_byte += 1 + 8 + 7 + assert n_byte == buffer.total_pop() and n_byte == buffer.total_push() + test_send_recv("") + n_byte += 1 + 8 + 0 + assert n_byte == buffer.total_pop() and n_byte == buffer.total_push() + test_send_recv("xxxxxxxxxx") + n_byte += 1 + 8 + 10 + assert n_byte == buffer.total_pop() and n_byte == buffer.total_push() + + test_send_recv([1,0.2]) + n_byte += 1 + 8 + 1 + 8 + 1 + 8 + assert n_byte == buffer.total_pop() and n_byte == buffer.total_push() + test_send_recv({'asd':1}) + n_byte += 1 + 8 + 1 + 8 + 3 + 1 + 8 + assert n_byte == buffer.total_pop() and n_byte == buffer.total_push() + + test_send_recv(np.random.rand(10,10)) + n_byte += 1 + 16 + 4 + 10*10*8 + assert n_byte == buffer.total_pop() and n_byte == buffer.total_push(), \ + (n_byte, buffer.total_pop(), n_byte, buffer.total_push()) + test_send_recv(test_ring_buffer) + + test_send_recv(jt.array(np.random.rand(10,10))) + + bbox = BBox(jt.array(np.random.rand(10,10))) + test_send_recv(bbox) + + expect_error(lambda: test_send_recv(np.random.rand(10,1000))) + + +class TestRingBuffer(unittest.TestCase): + + def test_ring_buffer(self): + test_ring_buffer() + + def test_dataset(self): + return + self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \ + .set_attrs(batch_size=300, shuffle=True) + self.train_loader.num_workers = 1 + import time + for batch_idx, (data, target) in tqdm(enumerate(self.train_loader)): + # time.sleep(5) + # print("break") + # break + # self.train_loader.display_worker_status() + if batch_idx > 30: + break + pass + for batch_idx, (data, target) in tqdm(enumerate(self.train_loader)): + # time.sleep(5) + # print("break") + # break + # self.train_loader.display_worker_status() + if batch_idx > 300: + break + pass + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_rnn.py b/python/jittor/test/test_rnn.py new file mode 100644 index 00000000..a604a762 --- /dev/null +++ b/python/jittor/test/test_rnn.py @@ -0,0 +1,645 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Zheng-Ning Liu +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +from unittest.case import skipIf +try: + import torch + import torch.nn as tnn +except: + torch = None + tnn = None + skip_this_test = True + +import jittor as jt +import jittor.nn as nn +import numpy as np + +skip_this_test = False + +def check_equal_1(t_rnn, j_rnn, input, h0, dev=None): + j_rnn.load_state_dict(t_rnn.state_dict()) + + if dev: + t_output, th = t_rnn(torch.from_numpy(input).to(dev), torch.from_numpy(h0).to(dev)) + + else: + t_output, th = t_rnn(torch.from_numpy(input), torch.from_numpy(h0)) + t_output = t_output.detach().cpu().numpy() + th = th.detach().cpu().numpy() + + j_output, jh = j_rnn(jt.float32(input), jt.float32(h0)) + j_output, jh = j_output.data, jh.data + + np.testing.assert_allclose(t_output, j_output.data, rtol=1e-03, atol=1e-06) + np.testing.assert_allclose(th, jh.data, rtol=1e-03, atol=1e-06) + +def check_equal_2(t_rnn, j_rnn, input, h0, c0, dev=None): + j_rnn.load_state_dict(t_rnn.state_dict()) + + if dev: + t_output, (th, tc) = t_rnn(torch.from_numpy(input).to(dev), + (torch.from_numpy(h0).to(dev), torch.from_numpy(c0).to(dev))) + else: + t_output, (th, tc) = t_rnn(torch.from_numpy(input).to(dev), + (torch.from_numpy(h0), torch.from_numpy(c0))) + + j_output, (jh, jc) = j_rnn(jt.float32(input), + (jt.float32(h0), jt.float32(c0))) + + np.testing.assert_allclose(t_output.detach().cpu().numpy(), j_output.data, rtol=1e-03, atol=1e-06) + np.testing.assert_allclose(th.detach().cpu().numpy(), jh.data, rtol=1e-03, atol=1e-06) + np.testing.assert_allclose(tc.detach().cpu().numpy(), jc.data, rtol=1e-03, atol=1e-06) + + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestRNN(unittest.TestCase): + def test_lstm_cell(self): + np_h0 = torch.randn(3, 20).numpy() + np_c0 = torch.randn(3, 20).numpy() + + t_rnn = tnn.LSTMCell(10, 20) + input = torch.randn(2, 3, 10) + h0 = torch.from_numpy(np_h0) + c0 = torch.from_numpy(np_c0) + t_output = [] + for i in range(input.size()[0]): + h0, c0 = t_rnn(input[i], (h0, c0)) + t_output.append(h0) + t_output = torch.stack(t_output, dim=0) + + j_rnn = nn.LSTMCell(10, 20) + j_rnn.load_state_dict(t_rnn.state_dict()) + + input = jt.float32(input.numpy()) + h0 = jt.float32(np_h0) + c0 = jt.float32(np_c0) + j_output = [] + for i in range(input.size()[0]): + h0, c0 = j_rnn(input[i], (h0, c0)) + j_output.append(h0) + j_output = jt.stack(j_output, dim=0) + + t_output = t_output.detach().numpy() + j_output = j_output.data + assert np.allclose(t_output, j_output, rtol=1e-03, atol=1e-06) + + def test_rnn_cell(self): + np_h0 = torch.randn(3, 20).numpy() + + t_rnn = tnn.RNNCell(10, 20) + input = torch.randn(2, 3, 10) + h0 = torch.from_numpy(np_h0) + t_output = [] + for i in range(input.size()[0]): + h0 = t_rnn(input[i], h0) + t_output.append(h0) + t_output = torch.stack(t_output, dim=0) + + j_rnn = nn.RNNCell(10, 20) + j_rnn.load_state_dict(t_rnn.state_dict()) + + input = jt.float32(input.numpy()) + h0 = jt.float32(np_h0) + j_output = [] + for i in range(input.size()[0]): + h0 = j_rnn(input[i], h0) + j_output.append(h0) + j_output = jt.stack(j_output, dim=0) + + t_output = t_output.detach().numpy() + j_output = j_output.data + assert np.allclose(t_output, j_output, rtol=1e-03, atol=1e-06) + + def test_gru_cell(self): + np_h0 = torch.randn(3, 20).numpy() + + t_rnn = tnn.GRUCell(10, 20) + input = torch.randn(2, 3, 10) + h0 = torch.from_numpy(np_h0) + t_output = [] + for i in range(input.size()[0]): + h0 = t_rnn(input[i], h0) + t_output.append(h0) + t_output = torch.stack(t_output, dim=0) + + j_rnn = nn.GRUCell(10, 20) + j_rnn.load_state_dict(t_rnn.state_dict()) + + input = jt.float32(input.numpy()) + h0 = jt.float32(np_h0) + j_output = [] + for i in range(input.size()[0]): + h0 = j_rnn(input[i], h0) + j_output.append(h0) + j_output = jt.stack(j_output, dim=0) + + t_output = t_output.detach().numpy() + j_output = j_output.data + assert np.allclose(t_output, j_output, rtol=1e-03, atol=1e-06) + + def test_basic_rnn(self): + h0 = np.random.rand(1, 24, 200).astype(np.float32) + input = np.random.rand(32, 24, 100).astype(np.float32) + + t_rnn = tnn.RNN(100, 200) + j_rnn = nn.RNN(100, 200) + check_equal_1(t_rnn, j_rnn, input, h0) + + def test_multilayer_rnn(self): + h0 = np.random.rand(4, 4, 200).astype(np.float32) + input = np.random.rand(5, 4, 100).astype(np.float32) + + t_rnn = tnn.RNN(100, 200, num_layers=4) + j_rnn = nn.RNN(100, 200, num_layers=4) + check_equal_1(t_rnn, j_rnn, input, h0) + + def test_bidirectional_rnn(self): + h0 = np.random.rand(2, 1, 200).astype(np.float32) + input = np.random.rand(5, 1, 100).astype(np.float32) + + t_rnn = tnn.RNN(100, 200, bidirectional=True) + j_rnn = nn.RNN(100, 200, bidirectional=True) + check_equal_1(t_rnn, j_rnn, input, h0) + + def test_no_bias_rnn(self): + h0 = np.random.rand(4, 4, 200).astype(np.float32) + input = np.random.rand(5, 4, 100).astype(np.float32) + + t_rnn = tnn.RNN(100, 200, num_layers=2, bidirectional=True, bias=False) + j_rnn = nn.RNN(100, 200, num_layers=2, bidirectional=True, bias=False) + check_equal_1(t_rnn, j_rnn, input, h0) + + def test_dropout_rnn(self): + h0 = np.random.rand(2, 4, 200).astype(np.float32) + input = np.random.rand(5, 4, 100).astype(np.float32) + + t_rnn = tnn.RNN(100, 200, num_layers=2, dropout=0.5, bias=False) + j_rnn = nn.RNN(100, 200, num_layers=2, dropout=0.5, bias=False) + t_rnn.eval() + j_rnn.eval() + check_equal_1(t_rnn, j_rnn, input, h0) + + def test_basic_lstm(self): + h0 = np.random.rand(1, 24, 200).astype(np.float32) + c0 = np.random.rand(1, 24, 200).astype(np.float32) + input = np.random.rand(32, 24, 100).astype(np.float32) + + t_rnn = tnn.LSTM(100, 200) + j_rnn = nn.LSTM(100, 200) + check_equal_2(t_rnn, j_rnn, input, h0, c0) + + def test_projection_lstm(self): + proj_size = 13 + h0 = np.random.rand(1, 24, proj_size).astype(np.float32) + c0 = np.random.rand(1, 24, 200).astype(np.float32) + input = np.random.rand(32, 24, 100).astype(np.float32) + t_rnn = tnn.LSTM(100, 200, proj_size=proj_size) + j_rnn = nn.LSTM(100, 200, proj_size=proj_size) + check_equal_2(t_rnn, j_rnn, input, h0, c0) + + def test_multilayer_lstm(self): + h0 = np.random.rand(4, 4, 200).astype(np.float32) + c0 = np.random.rand(4, 4, 200).astype(np.float32) + input = np.random.rand(5, 4, 100).astype(np.float32) + + t_rnn = tnn.LSTM(100, 200, num_layers=4) + j_rnn = nn.LSTM(100, 200, num_layers=4) + check_equal_2(t_rnn, j_rnn, input, h0, c0) + + def test_multilayer_projection_lstm(self): + proj_size = 8 + h0 = np.random.rand(2, 4, proj_size).astype(np.float32) + c0 = np.random.rand(2, 4, 20).astype(np.float32) + input = np.random.rand(5, 4, 10).astype(np.float32) + + t_rnn = tnn.LSTM(10, 20, num_layers=2, proj_size=proj_size) + j_rnn = nn.LSTM(10, 20, num_layers=2, proj_size=proj_size) + check_equal_2(t_rnn, j_rnn, input, h0, c0) + + def test_bidirectional_lstm(self): + h0 = np.random.rand(2, 1, 200).astype(np.float32) + c0 = np.random.rand(2, 1, 200).astype(np.float32) + input = np.random.rand(5, 1, 100).astype(np.float32) + + t_rnn = tnn.LSTM(100, 200, bidirectional=True) + j_rnn = nn.LSTM(100, 200, bidirectional=True) + check_equal_2(t_rnn, j_rnn, input, h0, c0) + + def test_bidirectional_projection_lstm(self): + proj_size = 10 + h0 = np.random.rand(2, 4, proj_size).astype(np.float32) + c0 = np.random.rand(2, 4, 200).astype(np.float32) + input = np.random.rand(5, 4, 100).astype(np.float32) + + t_rnn = tnn.LSTM(100, 200, bidirectional=True, proj_size=proj_size) + j_rnn = nn.LSTM(100, 200, bidirectional=True, proj_size=proj_size) + check_equal_2(t_rnn, j_rnn, input, h0, c0) + + def test_multilayer_bidirectional_projection_lstm(self): + h0 = np.random.rand(4, 4, 200).astype(np.float32) + c0 = np.random.rand(4, 4, 200).astype(np.float32) + input = np.random.rand(5, 4, 100).astype(np.float32) + + t_rnn = tnn.LSTM(100, 200, num_layers=2, bidirectional=True, bias=False) + j_rnn = nn.LSTM(100, 200, num_layers=2, bidirectional=True, bias=False) + check_equal_2(t_rnn, j_rnn, input, h0, c0) + + def test_dropout_lstm(self): + h0 = np.random.rand(2, 4, 200).astype(np.float32) + c0 = np.random.rand(2, 4, 200).astype(np.float32) + input = np.random.rand(5, 4, 100).astype(np.float32) + + t_rnn = tnn.LSTM(100, 200, num_layers=2, dropout=0.5, bias=False) + j_rnn = nn.LSTM(100, 200, num_layers=2, dropout=0.5, bias=False) + t_rnn.eval() + j_rnn.eval() + check_equal_2(t_rnn, j_rnn, input, h0, c0) + + def test_basic_gru(self): + h0 = np.random.rand(1, 24, 200).astype(np.float32) + input = np.random.rand(32, 24, 100).astype(np.float32) + + t_rnn = tnn.GRU(100, 200) + j_rnn = nn.GRU(100, 200) + check_equal_1(t_rnn, j_rnn, input, h0) + + def test_multilayer_gru(self): + h0 = np.random.rand(4, 4, 200).astype(np.float32) + input = np.random.rand(5, 4, 100).astype(np.float32) + + t_rnn = tnn.GRU(100, 200, num_layers=4) + j_rnn = nn.GRU(100, 200, num_layers=4) + check_equal_1(t_rnn, j_rnn, input, h0) + + def test_bidirectional_gru(self): + h0 = np.random.rand(2, 1, 200).astype(np.float32) + input = np.random.rand(5, 1, 100).astype(np.float32) + + t_rnn = tnn.GRU(100, 200, bidirectional=True) + j_rnn = nn.GRU(100, 200, bidirectional=True) + check_equal_1(t_rnn, j_rnn, input, h0) + + def test_multilayer_bidirectional_gru(self): + h0 = np.random.rand(4, 4, 200).astype(np.float32) + input = np.random.rand(5, 4, 100).astype(np.float32) + + t_rnn = tnn.GRU(100, 200, num_layers=2, bidirectional=True, bias=False) + j_rnn = nn.GRU(100, 200, num_layers=2, bidirectional=True, bias=False) + check_equal_1(t_rnn, j_rnn, input, h0) + + def test_multilayer_dropout_gru(self): + h0 = np.random.rand(2, 4, 200).astype(np.float32) + input = np.random.rand(5, 4, 100).astype(np.float32) + + t_rnn = tnn.GRU(100, 200, num_layers=2, dropout=0.5, bias=False) + j_rnn = nn.GRU(100, 200, num_layers=2, dropout=0.5, bias=False) + t_rnn.eval() + j_rnn.eval() + check_equal_1(t_rnn, j_rnn, input, h0) + + def test_rnn_default_hx(self): + input = np.random.rand(32, 24, 12).astype(np.float32) + h0 = np.zeros((1, 24, 24)).astype(np.float32) + + t_rnn = tnn.RNN(12, 24) + j_rnn = nn.RNN(12, 24) + j_rnn.load_state_dict(t_rnn.state_dict()) + t_output, th = t_rnn(torch.from_numpy(input)) + j_output, jh = j_rnn(jt.array(input)) + + np.testing.assert_allclose(t_output.detach().cpu().numpy(), j_output.data, rtol=1e-03, atol=1e-06) + np.testing.assert_allclose(th.detach().cpu().numpy(), jh.data, rtol=1e-03, atol=1e-06) + + def test_lstm_default_hx(self): + input = np.random.rand(32, 24, 10).astype(np.float32) + t_rnn = tnn.LSTM(10, 20, num_layers=2, bidirectional=True) + j_rnn = nn.LSTM(10, 20, num_layers=2, bidirectional=True) + j_rnn.load_state_dict(t_rnn.state_dict()) + t_output, (th, tc) = t_rnn(torch.from_numpy(input)) + j_output, (jh, jc) = j_rnn(jt.array(input)) + np.testing.assert_allclose(t_output.detach().cpu().numpy(), j_output.data, rtol=1e-03, atol=1e-06) + np.testing.assert_allclose(th.detach().cpu().numpy(), jh.data, rtol=1e-03, atol=1e-06) + np.testing.assert_allclose(tc.detach().cpu().numpy(), jc.data, rtol=1e-03, atol=1e-06) + + def test_twobilinear_lstm(self): + x = jt.rand(5, 4, 10) + rnn1 = nn.LSTM(10, 20, bidirectional=True) + out1, _ = rnn1(x) + rnn2 = nn.LSTM(40, 20, bidirectional=True) + out2, _ = rnn2(out1) + target = jt.zeros_like(out2) + loss = nn.mse_loss(out2, target) + + from jittor import optim + optimizer = optim.RMSprop(rnn1.parameters()) + optimizer.step(loss) + + @skipIf(not jt.has_cuda, "No Cuda found") + @jt.flag_scope(use_cuda=1) + def test_cudnn_rnn(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(100, 200, nonlinearity='relu').to(dev) + + j_rnn = nn.RNN(100, 200, nonlinearity='relu') + j_rnn.train() + j_rnn.load_state_dict(t_rnn.state_dict()) + + h0 = np.random.rand(1, 24, 200).astype(np.float32) + input = np.random.rand(32, 24, 100).astype(np.float32) + + t_output, th = t_rnn(torch.from_numpy(input).to(dev), + torch.from_numpy(h0).to(dev)) + + j_output, jh = j_rnn(jt.array(input), jt.array(h0)) + + np.testing.assert_allclose(j_output.data, t_output.detach().cpu().numpy()) + np.testing.assert_allclose(jh.data, th.detach().cpu().numpy()) + + @skipIf(not jt.has_cuda, "No Cuda found") + @jt.flag_scope(use_cuda=1) + def test_cudnn_rnn_train(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(32, 64, nonlinearity='relu').to(dev) + t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9) + + j_rnn = nn.RNN(32, 64, nonlinearity='relu') + j_rnn.load_state_dict(t_rnn.state_dict()) + j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9) + + h0 = np.random.rand(1, 4, 64).astype(np.float32) + input = np.random.rand(12, 4, 32).astype(np.float32) + + for _ in range(10): + t_optim.zero_grad() + t_output, th = t_rnn(torch.from_numpy(input).to(dev), torch.from_numpy(h0).to(dev)) + t_loss = (t_output ** 2).sum() + (th ** 2).sum() + t_loss.backward() + t_optim.step() + + j_input, jh = jt.array(input), jt.array(h0) + j_output, jh = j_rnn(j_input, jh) + j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + j_optim.step(j_loss) + + np.testing.assert_allclose(t_loss.item(), j_loss.item(), rtol=1e-2) + np.testing.assert_allclose(t_rnn.bias_hh_l0.detach().cpu().numpy(), j_rnn.bias_hh_l0.data, atol=1e-3, rtol=1e-2) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_basic_cudnn_rnn(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(100, 200, nonlinearity='relu').to(dev) + j_rnn = nn.RNN(100, 200, nonlinearity='relu') + + h0 = np.random.rand(1, 24, 200).astype(np.float32) + input = np.random.rand(32, 24, 100).astype(np.float32) + check_equal_1(t_rnn, j_rnn, input, h0, dev) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_multilayer_cudnn_rnn(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(100, 200, num_layers=4, nonlinearity='tanh').to(dev) + j_rnn = nn.RNN(100, 200, num_layers=4, nonlinearity='tanh') + + h0 = np.random.rand(4, 8, 200).astype(np.float32) + input = np.random.rand(5, 8, 100).astype(np.float32) + check_equal_1(t_rnn, j_rnn, input, h0, dev) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_bidirectional_cudnn_rnn(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(100, 200, bidirectional=True, nonlinearity='tanh').to(dev) + j_rnn = nn.RNN(100, 200, bidirectional=True, nonlinearity='tanh') + + h0 = np.random.rand(2, 8, 200).astype(np.float32) + input = np.random.rand(5, 8, 100).astype(np.float32) + check_equal_1(t_rnn, j_rnn, input, h0, dev) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_no_bias_cudnn_rnn(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(100, 200, bidirectional=True, bias=False, nonlinearity='tanh').to(dev) + j_rnn = nn.RNN(100, 200, bidirectional=True, bias=False, nonlinearity='tanh') + + h0 = np.random.rand(2, 8, 200).astype(np.float32) + input = np.random.rand(5, 8, 100).astype(np.float32) + check_equal_1(t_rnn, j_rnn, input, h0, dev) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_dropout_cudnn_rnn(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(100, 200, num_layers=2, dropout=0.5, nonlinearity='tanh').to(dev) + j_rnn = nn.RNN(100, 200, num_layers=2, dropout=0.5, nonlinearity='tanh') + t_rnn.eval() + j_rnn.eval() + + h0 = np.random.rand(2, 8, 200).astype(np.float32) + input = np.random.rand(5, 8, 100).astype(np.float32) + check_equal_1(t_rnn, j_rnn, input, h0, dev) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_basic_lstm_rnn(self): + dev = torch.device('cuda:0') + t_rnn = tnn.LSTM(100, 200).to(dev) + j_rnn = nn.LSTM(100, 200) + + h0 = np.random.rand(1, 24, 200).astype(np.float32) + c0 = np.random.rand(1, 24, 200).astype(np.float32) + input = np.random.rand(32, 24, 100).astype(np.float32) + check_equal_2(t_rnn, j_rnn, input, h0, c0, dev) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_cudnn_rnn_train(self): + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(32, 64, nonlinearity='relu').to(dev) + t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9) + + j_rnn = nn.RNN(32, 64, nonlinearity='relu') + j_rnn.load_state_dict(t_rnn.state_dict()) + j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9) + + h0 = np.random.rand(1, 4, 64).astype(np.float32) + input = np.random.rand(12, 4, 32).astype(np.float32) + + for _ in range(10): + t_optim.zero_grad() + t_output, th = t_rnn(torch.from_numpy(input).to(dev), torch.from_numpy(h0).to(dev)) + t_loss = (t_output ** 2).sum() + (th ** 2).sum() + t_loss.backward() + t_optim.step() + + j_input, jh = jt.array(input), jt.array(h0) + j_output, jh = j_rnn(j_input, jh) + j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + j_optim.step(j_loss) + + np.testing.assert_allclose(t_loss.item(), j_loss.item(), rtol=1e-4) + np.testing.assert_allclose(t_rnn.bias_hh_l0.detach().cpu().numpy(), j_rnn.bias_hh_l0.data, atol=1e-4, rtol=1e-4) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_cudnn_gru_train(self): + dev = torch.device('cuda:0') + t_rnn = tnn.GRU(32, 64).to(dev) + t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9) + + j_rnn = nn.GRU(32, 64) + j_rnn.load_state_dict(t_rnn.state_dict()) + j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9) + + h0 = np.random.rand(1, 4, 64).astype(np.float32) + input = np.random.rand(12, 4, 32).astype(np.float32) + + for _ in range(10): + t_optim.zero_grad() + t_output, th = t_rnn(torch.from_numpy(input).to(dev), torch.from_numpy(h0).to(dev)) + t_loss = (t_output ** 2).sum() + (th ** 2).sum() + t_loss.backward() + t_optim.step() + + j_input, jh = jt.array(input), jt.array(h0) + j_output, jh = j_rnn(j_input, jh) + j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + j_optim.step(j_loss) + + np.testing.assert_allclose(t_loss.item(), j_loss.item(), rtol=1e-4) + np.testing.assert_allclose(t_rnn.bias_hh_l0.detach().cpu().numpy(), j_rnn.bias_hh_l0.data, atol=1e-4, rtol=1e-4) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_cudnn_lstm_train(self): + dev = torch.device('cuda:0') + t_rnn = tnn.LSTM(32, 64).to(dev) + t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9) + + j_rnn = nn.LSTM(32, 64) + j_rnn.load_state_dict(t_rnn.state_dict()) + j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9) + + h0 = np.random.rand(1, 4, 64).astype(np.float32) + c0 = np.random.rand(1, 4, 64).astype(np.float32) + input = np.random.rand(12, 4, 32).astype(np.float32) + + for _ in range(10): + t_optim.zero_grad() + t_output, (th, tc) = t_rnn(torch.from_numpy(input).to(dev), + (torch.from_numpy(h0).to(dev), torch.from_numpy(c0).to(dev))) + t_loss = (t_output ** 2).sum() + (th ** 2).sum() + (tc ** 2).sum() + t_loss.backward() + t_optim.step() + + j_input, jh0, jc0 = jt.array(input), jt.array(h0), jt.array(c0) + j_output, (jh, jc) = j_rnn(j_input, (jh0, jc0)) + j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + (jc ** 2).sum() + j_optim.step(j_loss) + + np.testing.assert_allclose(t_loss.item(), j_loss.item(), rtol=1e-4) + np.testing.assert_allclose(t_rnn.bias_hh_l0.detach().cpu().numpy(), j_rnn.bias_hh_l0.data, atol=1e-4, rtol=1e-4) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @unittest.skipIf(not jt.cudnn, "No Cudnn found") + @jt.flag_scope(use_cuda=1) + def test_multilayer_bidirectional_cudnn_lstm_train(self): + dev = torch.device('cuda:0') + t_rnn = tnn.LSTM(32, 64, num_layers=4, bidirectional=True).to(dev) + t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9) + + j_rnn = nn.LSTM(32, 64, num_layers=4, bidirectional=True) + j_rnn.load_state_dict(t_rnn.state_dict()) + j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9) + + h0 = np.random.rand(8, 4, 64).astype(np.float32) + c0 = np.random.rand(8, 4, 64).astype(np.float32) + input = np.random.rand(12, 4, 32).astype(np.float32) + + for _ in range(10): + t_optim.zero_grad() + t_output, (th, tc) = t_rnn(torch.from_numpy(input).to(dev), + (torch.from_numpy(h0).to(dev), torch.from_numpy(c0).to(dev))) + t_loss = (t_output ** 2).sum() + (th ** 2).sum() + (tc ** 2).sum() + t_loss.backward() + t_optim.step() + + j_input, jh0, jc0 = jt.array(input), jt.array(h0), jt.array(c0) + j_output, (jh, jc) = j_rnn(j_input, (jh0, jc0)) + j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + (jc ** 2).sum() + j_optim.step(j_loss) + + np.testing.assert_allclose(t_loss.item(), j_loss.item(), rtol=1e-4) + np.testing.assert_allclose(t_rnn.bias_hh_l0.detach().cpu().numpy(), j_rnn.bias_hh_l0.data, atol=1e-4, rtol=1e-4) + + @unittest.skipIf(not jt.has_cuda, "No Cuda found") + @jt.flag_scope(use_cuda=1) + def test_cudnn_rnn_speed(self): + from time import time + iters = 100 + + h0 = np.random.rand(1, 128, 256).astype(np.float32) + input = np.random.rand(128, 128, 128).astype(np.float32) + + dev = torch.device('cuda:0') + t_rnn = tnn.RNN(128, 256, nonlinearity='relu').to(dev) + t_optim = torch.optim.SGD(t_rnn.parameters(), lr=1e-3, momentum=0.9) + + t_input = torch.from_numpy(input).to(dev) + t_h0 = torch.from_numpy(h0).to(dev) + + start_time = time() + for i in range(iters): + t_optim.zero_grad() + t_output, th = t_rnn(t_input, t_h0) + t_loss = (t_output ** 2).sum() + (th ** 2).sum() + t_loss.backward() + t_optim.step() + print('torch time = ', time() - start_time) + + j_rnn = nn.RNN(128, 256, nonlinearity='relu') + j_rnn.load_state_dict(t_rnn.state_dict()) + j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9) + j_input, j_h0 = jt.array(input), jt.array(h0) + + start_time = time() + for i in range(iters): + j_output, jh = j_rnn(j_input, j_h0) + j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + j_optim.step(j_loss) + jt.sync_all(True) + print('jittor Cudnn time = ', time() - start_time) + + jt_cudnn, jt.cudnn = jt.cudnn, None + j_rnn = nn.RNN(128, 256, nonlinearity='relu') + j_rnn.load_state_dict(t_rnn.state_dict()) + j_optim = nn.SGD(j_rnn.parameters(), lr=1e-3, momentum=0.9) + start_time = time() + for i in range(iters): + j_output, jh = j_rnn(j_input, j_h0) + j_loss = (j_output ** 2).sum() + (jh ** 2).sum() + j_optim.step(j_loss) + jt.sync_all(True) + print('jittor native time = ', time() - start_time) + jt.cudnn = jt_cudnn + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_rocm.py b/python/jittor/test/test_rocm.py new file mode 100644 index 00000000..bf7fc0a8 --- /dev/null +++ b/python/jittor/test/test_rocm.py @@ -0,0 +1,459 @@ +# *************************************************************** +# Copyright (c) 2021 Jittor. All Rights Reserved. +# Maintainers: Zheng-Ning Liu . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest + +import os +import random +import math +import time + +import numpy as np +import tqdm + +import jittor as jt +from jittor import init, Module, nn, Function +from jittor.models import vgg +from jittor.dataset.mnist import MNIST +import jittor.transform as trans + +from .test_core import expect_error +from .test_reorder_tuner import simple_parser +from .test_log import find_log_with_re + + +def test_rocm(use_rocm=1): + @unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found") + class TestCudaBase(unittest.TestCase): + def setUp(self): + jt.flags.use_rocm = use_rocm + def tearDown(self): + jt.flags.use_rocm = 0 + return TestCudaBase + + +@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found") +class TestROCm(unittest.TestCase): + + @jt.flag_scope(use_rocm=1) + def test_array(self): + a = jt.array([1,2,3]) + np.testing.assert_allclose(a.numpy(), [1,2,3]) + + @jt.flag_scope(use_rocm=1) + def test_add(self): + a = jt.array([1,2,3]) + b = a+a + np.testing.assert_allclose(b.numpy(), [2,4,6]) + + @jt.flag_scope(use_rocm=1) + def test_add_float(self): + a = jt.array([1.0,2.0,3.0]) + b = a+a + np.testing.assert_allclose(b.numpy(), [2,4,6]) + + @jt.flag_scope(use_rocm=1) + def test_array_cast(self): + # this test cannot pass because cast error + x = np.random.rand(10) + y = jt.float32(x) + np.testing.assert_allclose(x, y.numpy()) + + def test_meminfo(self): + jt.display_memory_info() + + @jt.flag_scope(use_rocm=1) + def test_cuda_flags(self): + a = jt.random((10, 10)) + a.sync() + + @jt.flag_scope(use_rocm=1) + def test_rocm_custom_op_from_cuda(self): + my_op = jt.compile_custom_op(""" + struct MyCudaOp : Op { + Var* output; + MyCudaOp(NanoVector shape, string dtype="float"); + + const char* name() const override { return "my_cuda"; } + DECLARE_jit_run; + }; + """, """ + #ifndef JIT + MyCudaOp::MyCudaOp(NanoVector shape, string dtype) { + flags.set(NodeFlags::_cuda); + output = create_output(shape, dtype); + } + + void MyCudaOp::jit_prepare(JK& jk) { + add_jit_define(jk, "T", output->dtype()); + } + + #else // JIT + #ifdef JIT_cuda + + __global__ void kernel(index_t n, T *x) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (int i = index; i < n; i += stride) + x[i] = (T)-i; + } + + void MyCudaOp::jit_run() { + index_t num = output->num; + auto* __restrict__ x = output->ptr(); + int blockSize = 256; + int numBlocks = (num + blockSize - 1) / blockSize; + kernel<<>>(num, x); + } + #endif // JIT_cuda + #endif // JIT + """, + "my_cuda") + a = my_op([3,4,5], 'float') + na = a.data + assert a.shape == [3,4,5] and a.dtype == 'float' + assert (-na.flatten() == range(3*4*5)).all(), na + + def test_rocm_fused_op(self): + a = jt.array([1,2,3]) + a.sync() + with jt.flag_scope(use_rocm=1): + ((a+a)*2).data + + +class Model(Module): + def __init__(self, input_size): + self.linear1 = nn.Linear(input_size, 10) + self.relu1 = nn.Relu() + self.linear2 = nn.Linear(10, 1) + def execute(self, x): + x = self.linear1(x) + x = self.relu1(x) + return self.linear2(x) + + +@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found") +class TestExample(unittest.TestCase): + @jt.flag_scope(use_rocm=1) + def test1(self): + np.random.seed(0) + jt.set_seed(3) + n = 1000 + batch_size = 50 + lr = 0.05 + + def get_data(n): + for i in range(n): + x = np.random.rand(batch_size, 1).astype("float32") + y = x*x + yield jt.float32(x), jt.float32(y) + + model = Model(input_size=1) + ps = model.parameters() + + for i,(x,y) in enumerate(get_data(n)): + jt.sync_all(True) + pred_y = model(x).name("pred_y") + loss = ((pred_y - y).sqr()).name("loss") + loss_mean = loss.mean() + + gs = jt.grad(loss_mean, ps) + for p, g in zip(ps, gs): + p -= g * lr + + if i>2: + assert prev == jt.liveness_info(), f"memory leak {prev} {jt.liveness_info()}" + prev = jt.liveness_info() + + possible_results = [ + 0.0009948202641680837, + 0.001381353591568768, + 0.00110957445576787, + 0.001124994712881744 + ] + loss_mean = loss_mean.data + assert any(abs(loss_mean - r) < 1e-6 for r in possible_results) + + jt.clean() + + +from .test_unary_op import TestUnaryOp +@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found") +class TestROCmUnaryOp(TestUnaryOp, test_rocm(1)): + pass + + +from .test_binary_op import TestBinaryOp +@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found") +class TestROCmBinaryOp(TestBinaryOp, test_rocm(1)): + pass + + +from .test_reduce_op import TestReduceOp +@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found") +class TestROCmReduceOp(TestReduceOp, test_rocm(1)): + pass + + +from .test_reindex_op import TestReindexOp +@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found") +class TestROCmReindexOp(TestReindexOp, test_rocm(1)): + pass + + +from .test_where_op import TestWhereOp +@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found") +class TestROCmWhereOp(TestWhereOp, test_rocm(1)): + pass + + +# from .test_reindex_reduce_op import TestReindexReduceOp +# @unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found") +# class TestROCmReindexReduceOp(TestReindexReduceOp, test_rocm(1)): +# pass + + +@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found") +class TestROCmCodeOp(unittest.TestCase): + @jt.flag_scope(use_rocm=1) + def test_cuda(self): + a = jt.random([100000]) + b = jt.random([100000]) + c = jt.code(a.shape, a.dtype, [a,b], + cuda_src=''' + __global__ static void kernel1(@ARGS_DEF) { + @PRECALC + int i = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (; i>>(@ARGS); + ''', + cuda_grad_src = [''' + __global__ static void kernel2(@ARGS_DEF) { + @PRECALC + int i = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (; i>>(@ARGS); + ''', ''' + __global__ static void kernel3(@ARGS_DEF) { + @PRECALC + int i = threadIdx.x + blockIdx.x * blockDim.x; + int stride = blockDim.x * gridDim.x; + for (; i>>(@ARGS); + ''']) + da, db = jt.grad(c, [a, b]) + assert np.allclose(c.data, a.data*b.data), (c.data, a.data*b.data) + assert np.allclose(da.data, b.data) + assert np.allclose(db.data, a.data) + + @jt.flag_scope(use_rocm=1) + def test_cuda2(self): + a = jt.random((100,100)) + b = jt.random((100,100)) + c = jt.code(a.shape, a.dtype, [a,b], + cuda_src=''' + __global__ static void kernel1(@ARGS_DEF) { + @PRECALC + for (int i=blockIdx.x; i>>(@ARGS); + ''', + cuda_grad_src = [''' + __global__ static void kernel(@ARGS_DEF) { + @PRECALC + for (int i=blockIdx.x; i>>(@ARGS); + ''', ''' + __global__ static void kernel(@ARGS_DEF) { + @PRECALC + @pout(0,0); + for (int i=blockIdx.x; i>>(@ARGS); + ''']) + da, db = jt.grad(c, [a, b]) + assert np.allclose(c.data, a.data*b.data), (c.data, a.data*b.data) + assert np.allclose(da.data, b.data) + assert np.allclose(db.data, a.data) + + @jt.flag_scope(use_rocm=1) + def test_cuda2_use_func(self): + class Func(Function): + def execute(self, a, b): + self.save_vars = a, b + return jt.code(a.shape, a.dtype, [a,b], + cuda_src=''' + __global__ static void kernel1(@ARGS_DEF) { + @PRECALC + for (int i=blockIdx.x; i>>(@ARGS); + ''') + + def grad(self, grad): + a, b = self.save_vars + return jt.code([a.shape, b.shape], [a.dtype, b.dtype], [a, b, grad], + cuda_src=''' + __global__ static void kernel2(@ARGS_DEF) { + @PRECALC + for (int i=blockIdx.x; i>>(@ARGS); + ''') + + a = jt.random((100,100)) + b = jt.random((100,100)) + + func = Func() + c = func(a,b) + da, db = jt.grad(c, [a, b]) + assert np.allclose(c.data, a.data*b.data), (c.data, a.data*b.data) + assert np.allclose(da.data, b.data) + assert np.allclose(db.data, a.data) + + +@unittest.skipIf(not jt.compiler.has_rocm, "No ROCm found") +class TestBMM(unittest.TestCase): + def test_bmm_rocm(self): + def check(batch, n, m, k): + def calc(use_rocm, a, b, mask): + jt.flags.use_rocm = use_rocm + a = jt.array(a) + b = jt.array(b) + mask = jt.array(mask) + c = nn.bmm(a, b) + da, db = jt.grad(c*mask, [a, b]) + return c.data, da.data, db.data + mask = np.random.rand(batch, n, k).astype("float32") + a = np.random.rand(batch, n, m).astype("float32") + b = np.random.rand(batch, m, k).astype("float32") + a1,a2,a3 = calc(0, a, b, mask) + b1,b2,b3 = calc(1, a, b, mask) + assert np.allclose(a1, b1) + assert np.allclose(a2, b2) + assert np.allclose(a3, b3) + check(10,3,4,5) + check(10,8,8,8) + check(10,8,1,8) + check(10,8,8,1) + check(10,1,8,8) + check(1,7,8,8) + +class Model(Module): + def __init__(self, input_size): + self.linear1 = nn.Linear(input_size, 10) + self.relu1 = nn.Relu() + self.linear2 = nn.Linear(10, 1) + def execute(self, x): + x = self.linear1(x) + x = self.relu1(x) + return self.linear2(x) + +from jittor.models import resnet + +class MnistNet(Module): + def __init__(self): + self.model = resnet.Resnet18() + self.layer = nn.Linear(1000,10) + def execute(self, x): + x = self.model(x) + x = self.layer(x) + return x + +@unittest.skipIf(not jt.compiler.has_rocm, "skip_this_test") +class TestResnetFp32(unittest.TestCase): + # setup random seed + def setup_seed(self, seed): + np.random.seed(seed) + random.seed(seed) + jt.seed(seed) + + @jt.flag_scope(use_cuda=1) + def test_resnet(self): + self.setup_seed(1) + + # hyper-parameters + self.batch_size = int(os.environ.get("TEST_BATCH_SIZE", "100")) + self.weight_decay = 0.0001 + self.momentum = 0.9 + self.learning_rate = 0.1 + if jt.flags.amp_reg: + self.learning_rate = 0.01 + # mnist dataset + self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \ + .set_attrs(batch_size=self.batch_size, shuffle=True) + self.train_loader.num_workers = 4 + + loss_list=[] + acc_list=[] + mnist_net = MnistNet() + global prev + SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay) + self.train_loader.endless = True + + for data, target in self.train_loader: + batch_id = self.train_loader.batch_id + epoch_id = self.train_loader.epoch_id + data = data.float_auto() + output = mnist_net(data) + loss = nn.cross_entropy_loss(output, target) + + break + jt.sync_all(True) + + for _ in range(10): + output = mnist_net(data) + loss = nn.cross_entropy_loss(output, target) + SGD.step(loss) + def callback(epoch_id, batch_id, loss, output, target): + pred = np.argmax(output, axis=1) + acc = np.mean(target==pred) + jt.fetch(epoch_id, _, loss, output, target, callback) + jt.sync_all(True) + + all_time = time.time() + prev = time.time() + print('starting') + for _ in range(100): + output = mnist_net(data) + loss = nn.cross_entropy_loss(output, target) + SGD.step(loss) + def callback(epoch_id, batch_id, loss, output, target): + global prev + pred = np.argmax(output, axis=1) + acc = np.mean(target==pred) + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}' + .format(epoch_id, batch_id, 600,1. * batch_id / 6.0, loss[0], acc, time.time()-prev)) + prev = time.time() + jt.fetch(epoch_id, _, loss, output, target, callback) + jt.sync_all(True) + print(f'all = {time.time() - all_time}') + + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_sampler.py b/python/jittor/test/test_sampler.py new file mode 100644 index 00000000..2cb2d864 --- /dev/null +++ b/python/jittor/test/test_sampler.py @@ -0,0 +1,59 @@ +import jittor as jt +from jittor.dataset import * +from PIL import Image +import numpy as np +import unittest + + + +class TestSamplerDataset(Dataset): + def __init__(self): + super().__init__() + self.set_attrs(total_len=40, batch_size=1) + + def __getitem__(self, idx): + return idx**2 + + +class TestSampler(unittest.TestCase): + def test_sequential_sampler(self): + testdataset = TestSamplerDataset() + seqsampler = SequentialSampler(testdataset) + assert len(seqsampler) == 40 + for idx, batch in enumerate(seqsampler): + assert idx == batch + for i, data in enumerate(testdataset): + assert data.item() == i**2 + + def test_random_sampler(self): + testdataset = TestSamplerDataset() + randomsampler = RandomSampler(testdataset) + assert len(randomsampler) == 40 + diff = 0 + for i, data in enumerate(testdataset): + diff += data.item() == i**2 + assert diff < 10 + + def test_subset_random_sampler(self): + testdataset = TestSamplerDataset() + subsetsampler = SubsetRandomSampler(testdataset, (20, 30)) + assert len(subsetsampler) == 10 + s = 0 + for i, data in enumerate(testdataset): + s += data.item() + s2 = 0 + for i in range(20,30): + s2 += i**2 + assert s == s2, (s, s2) + + def test_batch_sampler(self): + testdataset = TestSamplerDataset() + seqforbatch = SequentialSampler(testdataset) + batchsampler = BatchSampler(seqforbatch, 4, drop_last=False) + assert len(batchsampler) == 10 + for batch in batchsampler: + assert len(batch) == 4 + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_search_sorted.py b/python/jittor/test/test_search_sorted.py new file mode 100644 index 00000000..b9d72b03 --- /dev/null +++ b/python/jittor/test/test_search_sorted.py @@ -0,0 +1,72 @@ + +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import jittor.nn as jnn + +skip_this_test = False + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + import torch.nn as tnn +except: + torch = None + tnn = None + skip_this_test = True + +# TODO: more test +@unittest.skipIf(skip_this_test, "No Torch found") +class TestSearchSorted(unittest.TestCase): + def test_origin(self): + sorted = jt.array([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]]) + values = jt.array([[3, 6, 9], [3, 6, 9]]) + ret = jt.searchsorted(sorted, values) + assert (ret == [[1, 3, 4], [1, 2, 4]]).all(), ret + + ret = jt.searchsorted(sorted, values, right=True) + assert (ret == [[2, 3, 5], [1, 3, 4]]).all(), ret + + sorted_1d = jt.array([1, 3, 5, 7, 9]) + ret = jt.searchsorted(sorted_1d, values) + assert (ret == [[1, 3, 4], [1, 3, 4]]).all(), ret + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda=1) + def test_cuda(self): + self.test_origin() + + + def test_searchsorted_cpu(self): + for i in range(1,3): + s = np.sort(np.random.rand(*((10,)*i)),-1) + v = np.random.rand(*((10,)*i)) + s_jt = jt.array(s) + v_jt = jt.array(v) + s_tc = torch.from_numpy(s) + v_tc = torch.from_numpy(v) + + y_tc = torch.searchsorted(s_tc, v_tc, right=True) + y_jt = jt.searchsorted(s_jt, v_jt, right=True) + assert np.allclose(y_jt.numpy(), y_tc.data) + y_jt = jt.searchsorted(s_jt, v_jt, right=False) + y_tc = torch.searchsorted(s_tc, v_tc, right=False) + assert np.allclose(y_jt.numpy(), y_tc.data) + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda=1) + def test_searchsorted_gpu(self): + self.test_searchsorted_cpu() + + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_searchsorted_op.py b/python/jittor/test/test_searchsorted_op.py new file mode 100644 index 00000000..b96a195b --- /dev/null +++ b/python/jittor/test/test_searchsorted_op.py @@ -0,0 +1,45 @@ +# *************************************************************** +# Copyright (c) 2020 Jittor. Authors: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# All Rights Reserved. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np + +skip_this_test = False +try: + jt.dirty_fix_pytorch_runtime_error() + import torch +except: + skip_this_test = True + + +@unittest.skipIf(skip_this_test, "No Torch Found") +class TestSearchsorted(unittest.TestCase): + def test_searchsorted_cpu(self): + for i in range(1,3): + s = np.sort(np.random.rand(*((10,)*i)),-1) + v = np.random.rand(*((10,)*i)) + s_jt = jt.array(s) + v_jt = jt.array(v) + s_tc = torch.from_numpy(s) + v_tc = torch.from_numpy(v) + + y_tc = torch.searchsorted(s_tc, v_tc, right=True) + y_jt = jt.searchsorted(s_jt, v_jt, right=True) + assert np.allclose(y_jt.numpy(), y_tc.data) + y_jt = jt.searchsorted(s_jt, v_jt, right=False) + y_tc = torch.searchsorted(s_tc, v_tc, right=False) + assert np.allclose(y_jt.numpy(), y_tc.data) + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda=1) + def test_searchsorted_gpu(self): + self.test_searchsorted_cpu() + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_setitem.py b/python/jittor/test/test_setitem.py new file mode 100644 index 00000000..c2a5453b --- /dev/null +++ b/python/jittor/test/test_setitem.py @@ -0,0 +1,449 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Wenyang Zhou <576825820@qq.com>. +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +skip_this_test = False + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestSetitem(unittest.TestCase): + def test_setitem_(self): + arr0 = jt.random((4,2,2)) + data0 = jt.ones((2,2)) + arr0[1] = data0 + arr0.sync() + data0.data[0,0] = 0 + assert arr0[1,0,0] == 0 + + arr00 = jt.random((4,2,2)) + data00 = jt.ones((2,2)) + # share memory will fail if d has an edge to other nodes. + tmp = data00 + 1 + arr00[1] = data00 + arr00.sync() + data00.data[0,0] = 0 + assert arr00[1,0,0] == 0 + + arr1 = jt.random((4,2,2)) + data1 = jt.zeros((2,2)) + arr1[3,:,:] = data1 + arr1.sync() + data1.data[0,0] = 1 + assert arr1[3,0,0] == 1 + + arr21 = jt.ones((2,2)) + arr22 = jt.ones((2,2)) * 2 + arr2 = jt.concat([arr21, arr22], dim=0) + arr2.sync() + arr21.data[0,0] = 3 + arr22.data[0,0] = 4 + assert arr2[0,0] == 3 + assert arr2[2,0] == 4 + + def test_getitem(self): + # test for different slice type + arr0 = jt.random((4,3)) + arr0_res = arr0[2,:] + arr0_res.data[1] = 1 + assert arr0[2,1] == 1 + + arr1 = jt.array([1,2,3,4]) + arr1_res = arr1[None] + arr1_res.data[0,2] = -1 + assert arr1[2] == -1 + + arr2 = jt.array([1,2,3,4]) + arr2_res = arr2[...] + arr2_res.data[2] = -1 + assert arr2[2] == -1 + + arr3 = jt.array([1,2,3,4]) + arr3_res = arr3[3] + arr3_res.data[0] = -1 + assert arr3[3] == -1 + + arr4 = jt.random((4,2,3,3)) + arr4_res = arr4[...,:,:] + arr4_res.data[0,0,1,1] = 1 + assert arr4[0,0,1,1] == 1 + + arr4 = jt.random((4,2,3,3)) + arr4_res = arr4[...,:,:2] + arr4_res.data[0,0,1,1] = 1 + assert arr4[0,0,1,1] != 1 + + arr4 = jt.random((3,3)) + arr4_res = arr4[...,:,:2] + arr4_res.data[1,1] = 1 + assert arr4[1,1] != 1 + + arr5 = jt.random((4,2,3,3)) + arr5_res = arr5[1:3,:,:,:] + arr5_res.data[1,0,1,1] = 1 + assert arr5[2,0,1,1] == 1 + + arr6 = jt.random((4,2,3,3)) + arr6_res = arr6[1] + arr6_res.data[0,1,1] = 1 + assert arr6[1,0,1,1] == 1 + + # test for different data type (float32/float64/bool/int8/int32) + arr_float32 = jt.random((4,2,3)) + arr_float32_res = arr_float32[1:3,:,:] + arr_float32_res.data[0,0,0] = 1 + assert arr_float32[1,0,0] == 1 + arr_float32_res.data[1,1,2] = 1 + assert arr_float32[2,1,2] == 1 + arr_float32[1,0,0] = 0 + # getitem and setitem do not conflict + assert arr_float32_res[0,0,0] == 1 + + arr_bool = jt.bool(np.ones((4,2,3))) + arr_bool_res = arr_bool[1:3,:,:] + arr_bool_res.data[0,0,0] = False + assert arr_bool[1,0,0] == False + arr_bool_res.data[0,0,1] = False + assert arr_bool[1,0,1] == False + + arr_float64 = jt.random((4,2,3), dtype='float64') + arr_float64_res = arr_float64[1:3,:,:] + arr_float64_res.data[0,0,0] = 1 + assert arr_float64[1,0,0] == 1 + arr_float64_res.data[1,1,2] = 1 + assert arr_float64[2,1,2] == 1 + + arr_int32 = jt.ones((4,2,3), dtype='int32') + arr_int32_res = arr_int32[1:3,:,:] + arr_int32_res.data[0,0,0] = 0 + assert arr_int32[1,0,0] == 0 + arr_int32_res.data[1,1,2] = 0 + assert arr_int32[2,1,2] == 0 + + def test_setitem_inplace_case1(self): + # test type case + a = jt.zeros((3,)) + a[1] = 123 + assert a.data[1] == 123 + + def test_setitem_inplace_case2(self): + # test un-continuous first dim + a = jt.zeros((3,)) + a[0::2] = jt.ones((2,)) + assert a.data[2] == 1 + + def test_setitem_inplace_case3(self): + # test broadcast + a = jt.zeros((3,)) + a[0:] = 1.0 + assert a.data[2] == 1 + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda=1) + def test_getitem_inplace_array(self): + a = jt.array([[1,2],[3,4]]) + assert (a[0].numpy() == [1,2]).all(), a[0].numpy() + assert (a[1].numpy() == [3,4]).all(), a[1].numpy() + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda=1) + def test_setitem_inplace_array(self): + a = jt.array([[1,2],[3,4]]) + a[0,0] = -1 + a[1,1] = -2 + assert (a[0].numpy() == [-1,2]).all(), a[0].numpy() + assert (a[1].numpy() == [3,-2]).all(), a[1].numpy() + + def test_scatter(self): + src = jt.arange(1, 11).reshape((2, 5)) + index = jt.array([[0, 1, 2, 0]]) + x = jt.zeros((3, 5), dtype=src.dtype).scatter_(0, index, src) + assert (x.data == + [[1, 0, 0, 4, 0], + [0, 2, 0, 0, 0], + [0, 0, 3, 0, 0]]).all() + index = jt.array([[0, 1, 2], [0, 1, 4]]) + x = jt.zeros((3, 5), dtype=src.dtype).scatter_(1, index, src) + assert (x.data == + [[1, 2, 3, 0, 0], + [6, 7, 0, 0, 8], + [0, 0, 0, 0, 0]]).all() + x = jt.full((2, 4), 2.).scatter_(1, jt.array([[2], [3]]), + jt.array(1.23), reduce='multiply') + assert np.allclose(x.data, + [[2.0000, 2.0000, 2.4600, 2.0000], + [2.0000, 2.0000, 2.0000, 2.4600]]), x + x = jt.full((2, 4), 2.).scatter_(1, jt.array([[2], [3]]), + jt.array(1.23), reduce='add') + assert np.allclose(x.data, + [[2.0000, 2.0000, 3.2300, 2.0000], + [2.0000, 2.0000, 2.0000, 3.2300]]) + + def test_gather(self): + t = jt.array([[1, 2], [3, 4]]) + data = t.gather(1, jt.array([[0, 0], [1, 0]])).data + assert (data == [[ 1, 1], [ 4, 3]]).all() + data = t.gather(0, jt.array([[0, 0], [1, 0]])).data + assert (data == [[ 1, 2], [ 3, 2]]).all() + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda=1) + def test_scatter_cuda(self): + self.test_scatter() + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda=1) + def test_gather_cuda(self): + self.test_gather() + + def test_setitem_bool(self): + a = jt.array([1,2,3,4]) + b = jt.array([True,False,True,False]) + a[b] = jt.array([-1,-2]) + assert (a.data == [-1,2,-2,4]).all() + + def test_setitem_bool2(self): + a = jt.array([1,2,3,4]) + b = jt.array([True,False,True,False]) + a[b] = jt.array([-1]) + assert (a.data == [-1,2,-1,4]).all(), a + a = jt.array([1,2,3,4]) + b = jt.array([True,False,True,False]) + a[b] = -1 + assert (a.data == [-1,2,-1,4]).all(), a + + def test_slice_none(self): + a = jt.array([1,2]) + assert a[None,:,None,None,...,None].shape == (1,2,1,1,1) + + def test_roll(self): + x = jt.array([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2) + y = x.roll(1, 0) + assert (y.numpy() == [[7,8],[1,2],[3,4],[5,6]]).all(), y + y = x.roll(-1, 0) + assert (y.numpy() == [[3,4],[5,6],[7,8],[1,2]]).all() + y = x.roll(shifts=(2, 1), dims=(0, 1)) + assert (y.numpy() == [[6,5],[8,7],[2,1],[4,3]]).all() + + def test_ellipsis_with_none(self): + a = jt.arange(2*4*4).reshape(2,4,4) + b = a[...,:,None,:2] + assert b.shape == [2,4,1,2] + np.testing.assert_allclose(b.data, a.data[...,:,None,:2]) + + def test_flip_grad(self): + a = jt.rand(10) + b = a[::-1] + c = b[::-1] + d = c.sum() + jt.grad(d, [a]) + + def test_concat2(self): + a = jt.rand(10) + b = jt.rand(11) + c = jt.rand(12) + def cc(): + x = jt.concat([b.copy(), c.copy()]) + d = jt.concat([a.copy(), x]) + return d.copy().copy().copy().copy().copy().copy()\ + .copy().copy() + x.sum()*0.0 + d = cc() + np.testing.assert_allclose(d.data, + np.concatenate([a.data,b.data,c.data])) + + def test_concat3(self): + # a = jt.rand(10) + b = jt.rand(11) + c = jt.rand(12) + def cc(): + x = jt.concat([b.copy(), c.copy()]) + d = jt.concat([x]) + return d.copy().copy().copy().copy().copy().copy()\ + .copy().copy() + x.sum()*0.0 + d = cc() + np.testing.assert_allclose(d.data, + np.concatenate([b.data,c.data])) + + + def test_concat4(self): + # a = jt.rand(10) + b = jt.rand(11) + c = jt.rand(12) + def cc(): + x = jt.concat([b.copy(), c.copy()]) + d = jt.concat([x]) + return d + d = cc() + np.testing.assert_allclose(d.data, + np.concatenate([b.data,c.data])) + + def test_concat_random(self): + def check(backward=False): + n1, n2, n3 = 1000, 20, 10 + # n1, n2, n3 = 3, 2, 3 + import random + data = [] + back = [] + for i in range(n1): + if len(data) > n2: + v = random.randint(0,len(data)-1) + # print("del", v) + del data[v] + x1 = random.randint(0,9) + # print(i, x1) + if len(data) == 0: + # a = jt.random((random.randint(10,20),)) + a = jt.array(np.random.rand(random.randint(n3,n3*2))) + data.append(a) + if x1 == 0: + a = data[random.randint(0,len(data)-1)] + a = a.copy() + data.append(a) + elif x1 == 1: + a = data[random.randint(0,len(data)-1)] + a = a.clone() + data.append(a) + elif x1 == 2: + a = data[random.randint(0,len(data)-1)] + b = np.random.permutation(np.arange(a.numel())) + # print("permutation", b) + a = a[b] + data.append(a) + elif x1 == 3: + a = data[random.randint(0,len(data)-1)] + a = a[:100] + # print(a.shape) + data.append(a) + elif x1 == 4: + # a = jt.random((random.randint(10,20),)) + a = jt.array(np.random.rand(random.randint(n3,n3*2))) + if backward and random.randint(0,1): + back.append(a) + data.append(a) + elif x1 == 5: + v = random.randint(0,len(data)-1) + a = data[v] + # print("split", v, a.shape) + arr = a.split(n3-1) + data += arr + else: + if not len(data): continue + n = random.randint(1,3) + a = [ data[random.randint(0,len(data)-1)] for i in range(n) ] + a = jt.concat(a) + if a.numel() > 1000: + b = np.random.permutation(np.arange(a.numel())) + a = a[b][:100] + data.append(a) + ret = jt.concat(data) + if backward and len(back): + grads = jt.grad(jt.rand_like(ret)*ret, back) + return jt.concat(grads).numpy() + return ret.numpy() + + for s in range(100): + print("check", s) + for check_grad in [True, False]: + jt.set_global_seed(s) + data = check(check_grad) + jt.gc() + jt.set_global_seed(s) + with jt.flag_scope(gopt_disable=1): + data2 = check(check_grad) + jt.gc() + np.testing.assert_allclose(data, data2, atol=1e-5, rtol=1e-5) + + def test_concat_grad(self): + n = 30000 + m = 100 + arr = [] + for i in range(n): + arr.append(jt.random((m,))) + x = jt.concat(arr) + y = jt.rand_like(x) + grads = jt.grad(x*y, arr) + for i in range(n): + np.testing.assert_allclose(grads[i].numpy(), y[i*m:(i+1)*m].numpy()) + + def test_split_grad(self): + n = 30000 + m = 100 + x = jt.random((n*m,)) + arr = x.split(m) + yy = [ jt.rand(m) for i in range(n) ] + arr2 = [ y*yy[i] for i,y in enumerate(arr) ] + g = jt.grad(jt.concat(arr2), x) + for i in range(n): + np.testing.assert_allclose(g.data[i*m:(i+1)*m], yy[i].data) + + def test_dfs_memopt(self): + with jt.flag_scope(profile_memory_enable=1): + n = 1024 + b = [] + for i in range(n): + a = jt.rand(n).copy().copy() + a = a.sum() + # a.sync() + b.append(a) + jt.sync_all() + jt.get_max_memory_treemap() + + + def test_setitem_bc(self): + a = jt.random([10,11,12]) + b = a[jt.arange(3)[:,None], + jt.arange(4)[None,:]] + b.sync() + assert (a[:3, :4] == b).all() + + a = jt.random([10,11,12]) + b = a[jt.arange(3)[:,None], + jt.arange(4)[None,:], + jt.arange(4)[None,:]] + nb = a.data[np.arange(3)[:,None], + np.arange(4)[None,:], + np.arange(4)[None,:]] + np.testing.assert_allclose(nb, b.data) + + a = jt.random([10,11,12]) + b = a[jt.arange(3)[::-1,None], + jt.arange(4)[None,:], + jt.arange(4)[None,:]] + nb = a.data[np.arange(3)[::-1,None], + np.arange(4)[None,:], + np.arange(4)[None,:]] + np.testing.assert_allclose(nb, b.data) + + a = jt.random([10,11,12]) + b = a[jt.arange(3)[::-1,None], + jt.arange(4)[None,:], + jt.arange(4)[None,::-1]] + nb = a.data[np.arange(3)[::-1,None], + np.arange(4)[None,:], + np.arange(4)[None,::-1]] + np.testing.assert_allclose(nb, b.data) + + def test_cuda_slice_migrate_bug(self): + a = jt.array([1,2,3,4,5]) + jt.sync_all() + if not jt.has_cuda: return + with jt.flag_scope(use_cuda=1): + b = a[0] + b.sync(True) + assert b.item() == 1 + + def test_cascade_setitem(self): + a = jt.zeros(3,3,3,3) + a[1][2][0][0] = 1 + assert a[1,2,0,0] == 1 + # TODO: convert a[x] = a[x] + b -> a[x] += b + a[1][2][0][0] += 1 + assert a[1,2,0,0] == 2 + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_single_process_scope.py b/python/jittor/test/test_single_process_scope.py new file mode 100644 index 00000000..7338b631 --- /dev/null +++ b/python/jittor/test/test_single_process_scope.py @@ -0,0 +1,46 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Wenyang Zhou <576825820@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import os, sys +import jittor as jt +import numpy as np +from jittor.test.test_mpi import run_mpi_test +mpi = jt.compile_extern.mpi + +from jittor.dataset.mnist import MNIST + +def val1(): + dataloader = MNIST(train=False).set_attrs(batch_size=16) + for i, (imgs, labels) in enumerate(dataloader): + assert(imgs.shape[0]==8) + if i == 5: + break + +@jt.single_process_scope(rank=0) +def val2(): + dataloader = MNIST(train=False).set_attrs(batch_size=16) + for i, (imgs, labels) in enumerate(dataloader): + assert(imgs.shape[0]==16) + if i == 5: + break + +@unittest.skipIf(not jt.in_mpi, "no inside mpirun") +class TestSingleProcessScope(unittest.TestCase): + def test_single_process_scope(self): + val1() + val2() + +@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found") +class TestSingleProcessScopeEntry(unittest.TestCase): + def test_entry(self): + run_mpi_test(2, "test_single_process_scope") + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_slice.py b/python/jittor/test/test_slice.py new file mode 100644 index 00000000..1d73c529 --- /dev/null +++ b/python/jittor/test/test_slice.py @@ -0,0 +1,153 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. +# All Rights Reserved. +# Maintainers: +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from jittor.test.test_grad import ngrad + +class TestSlice(unittest.TestCase): + def test_slice_bool(self): + a = jt.zeros(10, "bool") + a[1] = True + a[2] = 1 + assert a.dtype == "bool" + a.sync() + assert np.equal(a.data, np.array([0,1,1,0,0,0,0,0,0,0])).all() + + def test_var_slices(self): + def check(slices, msg): + with jt.log_capture_scope() as logs: + jt.core._print_var_slice(slices) + s = logs[0]['msg'] + assert s == msg, s + check((1), "[1,]") + check(([[0],[1]],slice(None),[1,2],1), "[int32[2,1,],::,int32[2,],1,]") + check((slice(None),slice(None),slice(None),slice(None)), "[::,::,::,::,]") + check(([0,1],[0,1],[0,1],[0,1]), "[int32[2,],int32[2,],int32[2,],int32[2,],]") + check(([0,1],-2,slice(None),[0,1]), "[int32[2,],-2,::,int32[2,],]") + check(([0,1],slice(1,2,2),[1,2],1), "[int32[2,],1:2:2,int32[2,],1,]") + check(([0,1],slice(None),[1,2],1), "[int32[2,],::,int32[2,],1,]") + check((slice(1,None,2),slice(-1,None,2),[1,2],-4), "[1::2,-1::2,int32[2,],-4,]") + check(0, "[0,]") + check(10, "[10,]") + check(-10, "[-10,]") + check(1, "[1,]") + check((1,slice(None),2), "[1,::,2,]") + check((-2,slice(None),2,slice(1,9,2)), "[-2,::,2,1:9:2,]") + check((None,1,None,2,None), "[-,1,-,2,-,]") + check((...,1,...,2,...), "[...,1,...,2,...,]") + + @unittest.skipIf(not jt.has_cuda, "No cuda") + @jt.flag_scope(use_cuda=1) + def test_getitem(self): + def check(shape, slices, i_to_vs="", i_to_o="", o_shape=""): + # print(slices) + x = jt.random(shape) + + with jt.log_capture_scope(log_vprefix="getitem=999") as logs: + a = x.getitem(slices) + a.sync() + b = x.data[slices] + bshape = b.shape if len(b.shape) else (1,) + assert a.shape == bshape, (a.shape, bshape) + s = logs[-1]['msg'] + assert "i_to_vs: "+i_to_vs in s + assert "i_to_o: "+i_to_o in s + assert "o_shape: "+o_shape in s + aa = a.numpy() + assert (aa==b).all(), (aa, b) + + y = x.numpy() + v = jt.random(a.shape) + z = x.setitem(slices, v) + y[slices] = v.data + assert (z.data==y).all(), (z.data, y, v.data, x.data) + + # test_setitem broadcast + adim = len(a.shape) + for mask in range(1<>i)&1: + new_shape[i] = 1 + y = x.numpy() + v = jt.random(new_shape) + z = x.setitem(slices, v) + y[slices] = v.data + assert (z.data==y).all(), (z.data, y, v.data, x.data) + + + # TODO: when slice same row/col many times and assign value, numpy will retain the last value but we assign their sum. eg: check([3,3,3,3], ([[0,1,1]],slice(None),[[1],[2],[0]],1)) + check([3], (1), "[0,]", "[-1,]", "[]") + check([3,3,3,3], ([[0],[1]],slice(None),[1,2],1), "[0,-1,2,3,]", "[-1,2,-1,-1,]", "[2,2,3,]") + check([3,3,3,3], (slice(None),slice(None),slice(None),slice(None)), "[-1,-2,-2,-2,]", "[0,0,0,0,]", "[81,]") + check([3,3,3,3], ([0,1],[0,1],[0,1],[0,1]), "[0,1,2,3,]", "[-1,-1,-1,-1,]", "[2,]") + check([3,3,3,3], ([0,1],-2,slice(None),[0,1]), "[0,1,-1,3,]", "[-1,-1,1,-1,]", "[2,3,]") + check([3,3,3,3], ([0,1],slice(1,2,2),[1,2],1), "[0,1,2,3,]", "[-1,1,-1,-1,]", "[2,1,]") + check([3,3,3,3], ([0,1],slice(None),[1,2],1), "[0,-1,2,3,]", "[-1,1,-1,-1,]", "[2,3,]") + check([3,3,3,3], (slice(1,10,1),...,slice(2,None,-1)), "[0,-1,-2,2,]", "[0,1,1,2,]", "[2,9,3,]") + check([10,10,10,10], (slice(1,None,2),slice(-1,None,2),[1,2],-4), "[0,1,2,3,]", "[0,1,-1,-1,]", "") + check([20], 0, "[0,]", "[-1,]", "[]") + check([20], 10, "[0,]", "[-1,]", "[]") + check([20], -10, "[0,]", "[-1,]", "[]") + check([10,10,10,10], 1, "[0,-1,-2,-2,]", "[-1,0,0,0,]", "[1000,]") + check([10,10,10,10], (1,slice(None),2), "[0,-1,2,-1,]", "[-1,0,-1,1,]", "") + check([10,10,10,10], (-2,slice(None),2,slice(1,9,2)), "[0,-1,2,3,]", "[-1,0,-1,1,]") + + def test_getitem_grad(self): + shape = (10,) + slices = slice(2,4) + + a = jt.random(shape) + b = a.getitem(slices) + mask = jt.random(b.shape) + loss = b*mask + da = jt.grad(loss, a) + + _, np_grad = ngrad(lambda vars: (vars[0][slices]*mask.data).sum(), [a.numpy()], 1e-3) + assert np.allclose(da.numpy(), np_grad, atol = 1e-3), (da.numpy(), np_grad) + + shape = (10,) + slices = slice(2,4) + + a = jt.random(shape) + b = a.getitem(slices) + b = jt.random(b.shape) + c = a.setitem(slices, b) + mask = jt.random(c.shape) + loss = c*mask + da,db = jt.grad(loss, [a,b]) + + def numpy_grad(vars): + a, b = vars + a = a.copy() + a[slices] = b + return (a*mask.data).sum() + + _, (nda, ndb) = ngrad(numpy_grad, [a.numpy(), b.numpy()], 1e-3) + assert np.allclose(da.numpy(), nda, atol = 1e-3) + assert np.allclose(db.numpy(), ndb, atol = 1e-3) + + def test_vary_shape_setitem(self): + a = jt.array([1,2,3,4,5]) + b = jt.array([1,2,3,4,5]) + c = tuple(jt.where(b>3)) + a[c] = 0 + assert (a.data == [1,2,3,0,0]).all() + + def test_numpy_scalar_slice(self): + a = jt.random((2,2)) + b = np.array([1])[0] + assert a[b].shape == [2] + + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_sparse.py b/python/jittor/test/test_sparse.py new file mode 100644 index 00000000..9f982989 --- /dev/null +++ b/python/jittor/test/test_sparse.py @@ -0,0 +1,40 @@ + +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Xiangli Li <1905692338@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import jittor.nn as jnn + +skip_this_test = False + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + import torch.nn as tnn +except: + torch = None + tnn = None + skip_this_test = True + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestSparse(unittest.TestCase): + def test_sparse_var(self): + indices = np.array([[0,1,1],[2,0,2]]) + values = np.array([3,4,5]).astype(np.float32) + shape = [2,3] + jt_array = jt.sparse.sparse_array(jt.array(indices),jt.array(values),jt.NanoVector(shape)) + torch_tensor = torch.sparse.FloatTensor(torch.from_numpy(indices),torch.from_numpy(values),torch.Size(shape)) + jt_numpy = jt_array.to_dense().numpy() + torch_numpy = torch_tensor.to_dense().numpy() + assert np.allclose(jt_numpy,torch_numpy) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_stop_fuse.py b/python/jittor/test/test_stop_fuse.py new file mode 100644 index 00000000..c0cc6920 --- /dev/null +++ b/python/jittor/test/test_stop_fuse.py @@ -0,0 +1,54 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from .test_core import expect_error + +class TestStopFuse(unittest.TestCase): + def test_stop_fuse(self): + with jt.profile_scope() as report: + a = jt.float32(0).stop_fuse() + c = jt.float32(0) + bs = [c] + for i in range(2000): + b = jt.float32(i)*2*c + bs.append(b) + a += b + + a = a*2 + + dbs = jt.grad(a, bs) + jt.sync(dbs+[a]) + + for a in report[1:]: + # origin is 50 + # after update queue, increase to 102 + assert len(a[0].split("opkey")) < 110, len(a[0].split("opkey")) + + def test_stop_fuse2(self): + with jt.profile_scope() as report: + a = jt.float32(0).stop_fuse() + c = jt.float32(0).stop_fuse() + bs = [c] + for i in range(2000): + b = jt.float32(i)*2*c + bs.append(b) + a += b + + a = a*2 + + dbs = jt.grad(a, bs) + jt.sync(dbs+[a]) + + for a in report[1:]: + # origin is 8 + # after update queue, increase to 12 + assert len(a[0].split("opkey")) < 16, len(a[0].split("opkey")) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_superglue.py b/python/jittor/test/test_superglue.py new file mode 100644 index 00000000..abb631d5 --- /dev/null +++ b/python/jittor/test/test_superglue.py @@ -0,0 +1,121 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import os + +from jittor.test.misc import superglue +from jittor.test.misc.superglue import SuperGlue +import time + +@jt.flag_scope(use_cuda=1) +def main(): + global superglue + superglue.split_size = int(os.environ.get("split_size", "12")) + # superglue.split_size = 1000000 + + batch = 30 + num = 2000 + dim = 128 + + # jt.display_memory_info() + # os.system("nvidia-smi") + # breakpoint() + + with jt.no_grad(): + + config = { + 'superglue': { + 'sinkhorn_iterations': 25, + 'match_threshold': 0.01, + 'keypoint_position_dim': 2, + 'descriptor_dim': dim, + 'use_dual_softmax': True, + 'GNN_layers': ['self', 'cross'] * 9, + } + } + + superglue = SuperGlue(config.get('superglue', {})) + + superglue.eval() + + data = { + 'keypoints0': jt.rand((batch, num, 2), dtype=jt.float), + 'keypoints1': jt.rand((batch, num, 2), dtype=jt.float), + 'shape0': jt.rand((batch, 2), dtype=jt.float), + 'shape1': jt.rand((batch, 2), dtype=jt.float), + 'descriptors0': jt.rand((batch, dim, num), dtype=jt.float), + 'descriptors1': jt.rand((batch, dim, num), dtype=jt.float), + 'scores0': jt.rand((batch, num), dtype=jt.float), + 'scores1': jt.rand((batch, num), dtype=jt.float), + 'all_matches': jt.randint(0, num, (batch, num, 2), dtype=jt.int), + 'return_match': False, + # 'match_num': match_num + } + + use_fp16 = int(os.environ.get("use_fp16", "0")) + if use_fp16: + jt.flags.amp_reg = 2 + for k,v in data.items(): + if isinstance(v, jt.Var) and v.dtype == "float32": + v.assign(v.float16()) + for v in superglue.parameters(): + if v.dtype == "float32": + v.assign(v.float16()) + jt.sync_all(True) + + import pickle + jt.sync_all(True) + for x in range(5): + print(x) + jt.gc() + x = superglue(data)['loss'] + x.sync() + jt.display_memory_info() + # os.system("nvidia-smi") + # breakpoint() + # print(data) + # print(x) + + # with open("/tmp/record.pkl", "wb") as f: + # pickle.dump([data, x], f, pickle.HIGHEST_PROTOCOL) + + # with jt.flag_scope(trace_py_var=3, profile_memory_enable=1): + # x = superglue(data)['loss'] + # x.sync() + # jt.get_max_memory_treemap() + # exit(0) + + jt.sync_all(True) + time0 = time.time() + jt.flags.profiler_enable = int(os.environ.get("profiler", "0")) + + for x in range(20): + print(x) + # jt.display_memory_info() + x = superglue(data)['loss'] + x.sync() + # print(x) + + jt.sync_all(True) + time1 = time.time() + print("avg time:", (time1 - time0) / 20) + return (time1 - time0) / 20 + + +class TestSuperglue(unittest.TestCase): + def test(self): + if not jt.has_cuda: return + t1 = main() + os.environ["use_fp16"] = "1" + t2 = main() + os.environ["use_fp16"] = "0" + assert t1*0.55 > t2 + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_ternary_op.py b/python/jittor/test/test_ternary_op.py new file mode 100644 index 00000000..61a08c38 --- /dev/null +++ b/python/jittor/test/test_ternary_op.py @@ -0,0 +1,55 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from .test_core import expect_error +from .test_grad import ngrad +from .test_cuda import test_cuda + +class TestTernaryOp(unittest.TestCase): + def test_with_np(self): + np.random.seed(0) + a = np.random.rand(5,10).astype("float32") + b = np.random.rand(5,10).astype("float32") + ja = jt.array(a) + jb = jt.array(b) + jc = jt.ternary(ja>jb, ja, jb) + assert (jc.data==np.maximum(a,b)).all(), f"\n{jc.data}\n{np.maximum(a,b)}\n{a}\n{b}" + jda, jdb = jt.grad(jc, [ja, jb]) + assert (jda.data==(a>b)*1).all() + assert (jdb.data==1-(a>b)).all() + + def test_where(self): + np.random.seed(0) + a = np.random.rand(5,10).astype("float32") + b = np.random.rand(5,10).astype("float32") + ja = jt.array(a) + jb = jt.array(b) + jc = jt.where(ja>jb, ja, jb) + assert (jc.data==np.maximum(a,b)).all(), f"\n{jc.data}\n{np.maximum(a,b)}\n{a}\n{b}" + jda, jdb = jt.grad(jc, [ja, jb]) + assert (jda.data==(a>b)*1).all() + assert (jdb.data==1-(a>b)).all() + + def test_min(self): + np.random.seed(1) + a = np.random.rand(5,10).astype("float32") + b = np.random.rand(5,10).astype("float32") + ja = jt.array(a) + jb = jt.array(b) + jc = jt.minimum(ja,jb) + assert (jc.data==np.minimum(a,b)).all(), f"\n{jc.data}\n{np.minimum(a,b)}\n{a}\n{b}" + jda, jdb = jt.grad(jc, [ja, jb]) + assert (jda.data==(a. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from jittor import Module +from jittor.models import resnet +import pickle +from PIL import Image +import platform + +f32 = jt.float32 + +def matmul(a, b): + (n, m), k = a.shape, b.shape[-1] + a = a.broadcast([n,m,k], dims=[2]) + b = b.broadcast([n,m,k], dims=[0]) + return (a*b).sum(dim=1) + + +def relu(x): + return jt.maximum(x, 0.0) +Relu = jt.make_module(relu) + +class Model(Module): + def __init__(self, input_size): + self.linear1 = Linear(input_size, 10) + self.relu1 = Relu() + self.linear2 = Linear(10, 1) + def execute(self, x): + x = self.linear1(x) + x = self.relu1(x) + return self.linear2(x) + +def print_stack_tree(data): + tree = {} + for n in data["node_data"].values(): + p = tree + for s in n["stacks"]: + name = s['name'] + if name not in p: + p[name] = {} + p = p[name] + from pprint import pprint + pprint(tree) + +class Linear(Module): + def __init__(self, in_features, out_features, bias=True): + self.w = (jt.random((in_features, out_features))-0.5) / in_features**0.5 + self.b = jt.random((out_features,))-0.5 if bias else None + def execute(self, x): + x = matmul(x, self.w) + if self.b is not None: + return x+self.b + return x + + +class TestTraceVar(unittest.TestCase): + def test_simple_model(self): + with jt.flag_scope(trace_py_var=2): + + model = Model(input_size=1) + batch_size = 10 + x = jt.float32(np.random.rand(batch_size, 1)) + y = model(x) + y.sync() + + + data = jt.dump_trace_data() + jt.clear_trace_data() + with open(f"{jt.flags.cache_path}/simple_model.pkl", "wb") as f: + pickle.dump(data, f) + + def test_simple_model_train(self): + with jt.flag_scope(trace_py_var=2): + + model = Model(input_size=1) + opt = jt.optim.SGD(model.parameters(), 0.1) + + batch_size = 10 + x = jt.float32(np.random.rand(batch_size, 1)) + y = model(x) + opt.step(y**2) + jt.sync_all() + + data = jt.dump_trace_data() + jt.clear_trace_data() + # print_stack_tree(data) + for k,v in data["execute_op_info"].items(): + for i in v['fused_ops']: + if i not in data["node_data"]: + assert 0, (i, "not found") + + for k,v in list(data["node_data"].items()): + if v["attrs"]["name"] == "unname": + assert 0 + print(len(data["node_data"])) + with open(f"{jt.flags.cache_path}/simple_model_train.pkl", "wb") as f: + pickle.dump(data, f) + + def test_resnet_infer(self): + with jt.flag_scope(trace_py_var=2): + + resnet18 = resnet.Resnet18() + x = jt.float32(np.random.rand(2, 3, 224, 224)) + y = resnet18(x) + y.sync() + + data = jt.dump_trace_data() + jt.clear_trace_data() + with open(f"{jt.flags.cache_path}/resnet.pkl", "wb") as f: + pickle.dump(data, f) + for k,v in data["execute_op_info"].items(): + for i in v['fused_ops']: + if i not in data["node_data"]: + assert 0, (i, "not found") + + def test_resnet_infer_with_feature(self): + cat_url = "https://ss1.bdstatic.com/70cFuXSh_Q1YnxGkpoWK1HF6hhy/it/u=3782485413,1118109468&fm=26&gp=0.jpg" + import jittor_utils + cat_path = f"{jt.flags.cache_path}/cat.jpg" + print("download") + jittor_utils.download(cat_url, cat_path) + with open(cat_path, 'rb') as f: + img = Image.open(f).convert('RGB') + img = jt.array(np.array(img)) + print(img.shape, img.dtype) + img = ((img.float() - 128) / 255).transpose(2,0,1) + + + with jt.flag_scope(trace_py_var=2, trace_var_data=1): + img = img[None,...] + + resnet18 = resnet.Resnet18(pretrained=True) + x = jt.float32(img) + y = resnet18(x) + y.sync() + + data = jt.dump_trace_data() + jt.clear_trace_data() + with open(f"{jt.flags.cache_path}/resnet_with_feature.pkl", "wb") as f: + pickle.dump(data, f) + for k,v in data["execute_op_info"].items(): + for i in v['fused_ops']: + if i not in data["node_data"]: + assert 0, (i, "not found") + + def test_resnet_trainx(self): + with jt.flag_scope(trace_py_var=2): + + resnet18 = resnet.Resnet18() + opt = jt.optim.SGD(resnet18.parameters(), 0.1) + x = jt.float32(np.random.rand(2, 3, 224, 224)) + y = resnet18(x) + + opt.step(y**2) + jt.sync_all() + + data = jt.dump_trace_data() + jt.clear_trace_data() + with open(f"{jt.flags.cache_path}/resnet_train.pkl", "wb") as f: + pickle.dump(data, f) + for k,v in data["execute_op_info"].items(): + for i in v['fused_ops']: + if i not in data["node_data"]: + assert 0, (i, "not found") + for k,v in data["node_data"].items(): + if 'name' not in v["attrs"]: + print(v) + # assert 'name' in v["attrs"], v + # for s in v["stacks"]: + # if "_opt" in s["name"] or "_model" in s["name"]: + # assert 0, v + + def test_resnet_train_profile(self): + with jt.profile_scope(trace_py_var=1): + + resnet18 = resnet.Resnet18() + opt = jt.optim.SGD(resnet18.parameters(), 0.1) + x = jt.float32(np.random.rand(2, 3, 224, 224)) + y = resnet18(x) + + opt.step(y**2) + jt.sync_all() + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_tracer.py b/python/jittor/test/test_tracer.py new file mode 100644 index 00000000..447ecbfe --- /dev/null +++ b/python/jittor/test/test_tracer.py @@ -0,0 +1,47 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import os +import subprocess as sp +import sys + +class TestTracer(unittest.TestCase): + def test_print_trace(self): + jt.print_trace() + + if os.name != 'nt': + # force use addr2line + with jt.flag_scope(gdb_path=""): + jt.print_trace() + + def test_breakpoint(self): + fname = os.path.join(jt.flags.cache_path, "test_breakpoint.py") + with open(fname, 'w') as f: + f.write(""" +import jittor as jt +with jt.flag_scope(extra_gdb_cmd="c;q"): + jt.flags.gdb_attach = 1 +""") + out = sp.getoutput(sys.executable+' '+fname) + print(out) + assert "Attaching to" in out + + def test_segfault(self): + if os.name == 'nt': + a = jt.array([1,2,3]) + b = jt.array([1,2,300000000]) + c = a[b] + try: + c.sync() + except Exception as e: + assert "access violation reading" in str(e) + + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_transform.py b/python/jittor/test/test_transform.py new file mode 100644 index 00000000..53a81bce --- /dev/null +++ b/python/jittor/test/test_transform.py @@ -0,0 +1,976 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. +# All Rights Reserved. +# Maintainers: +# Dun Liang . +# +# Contributors: +# Xin Yao +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** + +import unittest +import random +from PIL import Image +import numpy as np +from numpy.testing import assert_array_almost_equal +import jittor as jt +import jittor.transform as transform + +try: + from scipy import stats +except ImportError: + stats = None + + +class Tester(unittest.TestCase): + + def test_crop(self): + height = random.randint(10, 32) * 2 + width = random.randint(10, 32) * 2 + oheight = random.randint(5, (height - 2) / 2) * 2 + owidth = random.randint(5, (width - 2) / 2) * 2 + + img = np.ones([height, width, 3]) + oh1 = (height - oheight) // 2 + ow1 = (width - owidth) // 2 + # imgnarrow = img[oh1:oh1 + oheight, ow1:ow1 + owidth, :] + # imgnarrow.fill(0) + img[oh1:oh1 + oheight, ow1:ow1 + owidth, :] = 0 + # img = jt.array(img) + result = transform.Compose([ + transform.ToPILImage(), + transform.CenterCrop((oheight, owidth)), + transform.ToTensor(), + ])(img) + self.assertEqual(result.sum(), 0, + f"height: {height} width: {width} oheight: {oheight} owdith: {owidth}") + oheight += 1 + owidth += 1 + result = transform.Compose([ + transform.ToPILImage(), + transform.CenterCrop((oheight, owidth)), + transform.ToTensor(), + ])(img) + sum1 = result.sum() + # TODO: not pass + # self.assertGreater(sum1, 1, + # f"height: {height} width: {width} oheight: {oheight} owdith: {owidth}") + oheight += 1 + owidth += 1 + result = transform.Compose([ + transform.ToPILImage(), + transform.CenterCrop((oheight, owidth)), + transform.ToTensor(), + ])(img) + sum2 = result.sum() + self.assertGreater(sum2, 0, + f"height: {height} width: {width} oheight: {oheight} owdith: {owidth}") + self.assertGreaterEqual(sum2, sum1, + f"height: {height} width: {width} oheight: {oheight} owdith: {owidth}") + + def test_resize(self): + height = random.randint(24, 32) * 2 + width = random.randint(24, 32) * 2 + osize = random.randint(5, 12) * 2 + + img = jt.ones([height, width, 3]) + result = transform.Compose([ + transform.ToPILImage(), + transform.Resize(osize), + transform.ToTensor(), + ])(img) + self.assertIn(osize, result.shape) + if height < width: + self.assertLessEqual(result.shape[1], result.shape[2]) + elif width < height: + self.assertGreaterEqual(result.shape[1], result.shape[2]) + + result = transform.Compose([ + transform.ToPILImage(), + transform.Resize([osize, osize]), + transform.ToTensor(), + ])(img) + self.assertIn(osize, result.shape) + self.assertEqual(result.shape[1], osize) + self.assertEqual(result.shape[2], osize) + + oheight = random.randint(5, 12) * 2 + owidth = random.randint(5, 12) * 2 + result = transform.Compose([ + transform.ToPILImage(), + transform.Resize((oheight, owidth)), + transform.ToTensor(), + ])(img) + self.assertEqual(result.shape[1], oheight) + self.assertEqual(result.shape[2], owidth) + + result = transform.Compose([ + transform.ToPILImage(), + transform.Resize([oheight, owidth]), + transform.ToTensor(), + ])(img) + self.assertEqual(result.shape[1], oheight) + self.assertEqual(result.shape[2], owidth) + + def test_random_crop(self): + height = random.randint(10, 32) * 2 + width = random.randint(10, 32) * 2 + oheight = random.randint(5, (height - 2) / 2) * 2 + owidth = random.randint(5, (width - 2) / 2) * 2 + img = np.ones((height, width, 3)) + result = transform.Compose([ + transform.ToPILImage(), + transform.RandomCrop((oheight, owidth)), + transform.ToTensor(), + ])(img) + self.assertEqual(result.shape[1], oheight) + self.assertEqual(result.shape[2], owidth) + + result = transform.Compose([ + transform.ToPILImage(), + transform.RandomCrop((oheight, owidth)), + transform.ToTensor(), + ])(img) + self.assertEqual(result.shape[1], oheight) + self.assertEqual(result.shape[2], owidth) + + result = transform.Compose([ + transform.ToPILImage(), + transform.RandomCrop((height, width)), + transform.ToTensor() + ])(img) + self.assertEqual(result.shape[1], height) + self.assertEqual(result.shape[2], width) + self.assertTrue(np.allclose(img, result.transpose(1,2,0))) + + with self.assertRaises(AssertionError): + result = transform.Compose([ + transform.ToPILImage(), + transform.RandomCrop((height + 1, width + 1)), + transform.ToTensor(), + ])(img) + + def test_lambda(self): + trans = transform.Lambda(lambda x: x.add(10)) + x = jt.random([10]) + y = trans(x) + self.assertTrue(np.allclose(y.data, jt.add(x, 10).data)) + + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_apply(self): + random_state = random.getstate() + random.seed(42) + random_apply_transform = transform.RandomApply( + [ + transform.RandomHorizontalFlip(), + transform.RandomVerticalFlip(), + ], p=0.4 + ) + img = transform.ToPILImage()(jt.random((3, 10, 10))) + num_samples = 250 + num_applies = 0 + for _ in range(num_samples): + out = random_apply_transform(img) + if out != img: + num_applies += 1 + + p_value = stats.binom_test(num_applies, num_samples, p=0.3) + random.setstate(random_state) + self.assertGreater(p_value, 0.0001) + + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_choice(self): + random_state = random.getstate() + random.seed(42) + random_choice_transform = transform.RandomChoice( + [ + transform.Resize(15), + transform.Resize(20), + transform.CenterCrop(10) + ] + ) + img = transform.ToPILImage()(jt.random((25, 25, 3))) + num_samples = 250 + num_resize_15 = 0 + num_resize_20 = 0 + num_crop_10 = 0 + for _ in range(num_samples): + out = random_choice_transform(img) + if out.size == (15, 15): + num_resize_15 += 1 + elif out.size == (20, 20): + num_resize_20 += 1 + elif out.size == (10, 10): + num_crop_10 += 1 + + p_value = stats.binom_test(num_resize_15, num_samples, p=0.33333) + self.assertGreater(p_value, 0.0001) + p_value = stats.binom_test(num_resize_20, num_samples, p=0.33333) + self.assertGreater(p_value, 0.0001) + p_value = stats.binom_test(num_crop_10, num_samples, p=0.33333) + self.assertGreater(p_value, 0.0001) + + random.setstate(random_state) + + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_order(self): + random_state = random.getstate() + random.seed(42) + random_order_transform = transform.RandomOrder( + [ + transform.Resize(20), + transform.CenterCrop(10) + ] + ) + img = transform.ToPILImage()(jt.random((3, 25, 25))) + num_samples = 250 + num_normal_order = 0 + resize_crop_out = transform.CenterCrop(10)(transform.Resize(20)(img)) + for _ in range(num_samples): + out = random_order_transform(img) + if out == resize_crop_out: + num_normal_order += 1 + + p_value = stats.binom_test(num_normal_order, num_samples, p=0.5) + random.setstate(random_state) + self.assertGreater(p_value, 0.0001) + + def test_to_tensor(self): + test_channels = [1, 3, 4] + height, width = 4, 4 + trans = transform.ToTensor() + + with self.assertRaises(TypeError): + trans(np.random.rand(1, height, width).tolist()) + + with self.assertRaises(ValueError): + trans(np.random.rand(height)) + trans(np.random.rand(1, 1, height, width)) + + for channels in test_channels: + input_data = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.float32) / np.float32(255.0) + img = transform.ToPILImage()(input_data) + output = trans(img) + expect = input_data.transpose(2,0,1) + self.assertTrue(np.allclose(expect, output), f"{expect.shape}\n{output.shape}") + + ndarray = np.random.randint(low=0, high=255, size=(channels, height, width)).astype(np.uint8) + output = trans(ndarray) + expected_output = ndarray / 255.0 + np.testing.assert_allclose(output, expected_output) + + ndarray = np.random.rand(channels, height, width).astype(np.float32) + output = trans(ndarray) + expected_output = ndarray + self.assertTrue(np.allclose(output, expected_output)) + + # separate test for mode '1' PIL images + input_data = np.random.binomial(1, 0.5, size=(height, width, 1)).astype(np.uint8) + img = transform.ToPILImage()(input_data * 255).convert('1') + output = trans(img) + self.assertTrue(np.allclose(input_data[:,:,0], output[0]), f"{input_data.shape}\n{output.shape}") + + def test_1_channel_tensor_to_pil_image(self): + to_tensor = transform.ToTensor() + shape = (4, 4, 1) + + img_data_float = jt.array(np.random.rand(*shape), dtype='float32') + img_data_byte = jt.array(np.random.randint(0, 255, shape), dtype='uint8') + img_data_short = jt.array(np.random.randint(0, 32767, shape), dtype='int16') + img_data_int = jt.array(np.random.randint(0, 2147483647, shape), dtype='int32') + + inputs = [img_data_float, img_data_byte, img_data_short, img_data_int] + expected_outputs = [img_data_float.multiply(255).int().float().divide(255).numpy(), + img_data_byte.float().divide(255.0).numpy(), + img_data_short.numpy(), + img_data_int.numpy()] + expected_modes = ['F', 'L', 'I;16', 'I'] + + for img_data, expected_output, mode in zip(inputs, expected_outputs, expected_modes): + for t in [transform.ToPILImage(), transform.ToPILImage(mode=mode)]: + img = t(img_data) + self.assertEqual(img.mode, mode) + np.testing.assert_allclose(expected_output[:,:,0], to_tensor(img)[0], atol=0.01) + # 'F' mode for torch.FloatTensor + img_F_mode = transform.ToPILImage(mode='F')(img_data_float) + self.assertEqual(img_F_mode.mode, 'F') + + def test_1_channel_ndarray_to_pil_image(self): + img_data_float = np.random.rand(4, 4, 1).astype(np.float32) + img_data_byte = np.random.randint(0, 255, (4, 4, 1)).astype(np.uint8) + img_data_short = np.random.randint(0, 32767, (4, 4, 1)).astype(np.int16) + img_data_int = np.random.randint(0, 2147483647, (4, 4, 1)).astype(np.int32) + + inputs = [img_data_float, img_data_byte, img_data_short, img_data_int] + expected_modes = ['F', 'L', 'I;16', 'I'] + for img_data, mode in zip(inputs, expected_modes): + for t in [transform.ToPILImage(), transform.ToPILImage(mode=mode)]: + img = t(img_data) + self.assertEqual(img.mode, mode) + self.assertTrue(np.allclose(img_data[:, :, 0], img)) + + def test_2_channel_ndarray_to_pil_image(self): + def verify_img_data(img_data, mode): + if mode is None: + img = transform.ToPILImage()(img_data) + self.assertEqual(img.mode, 'LA') # default should assume LA + else: + img = transform.ToPILImage(mode=mode)(img_data) + self.assertEqual(img.mode, mode) + split = img.split() + for i in range(2): + self.assertTrue(np.allclose(img_data[:, :, i], split[i])) + + img_data = np.random.randint(0, 255, (4, 4, 2)).astype(np.uint8) + for mode in [None, 'LA']: + verify_img_data(img_data, mode) + + with self.assertRaises(ValueError): + # should raise if we try a mode for 4 or 1 or 3 channel images + transform.ToPILImage(mode='RGBA')(img_data) + transform.ToPILImage(mode='P')(img_data) + transform.ToPILImage(mode='RGB')(img_data) + + def test_2_channel_tensor_to_pil_image(self): + def verify_img_data(img_data, expected_output, mode): + if mode is None: + img = transform.ToPILImage()(img_data) + self.assertEqual(img.mode, 'LA') # default should assume LA + else: + img = transform.ToPILImage(mode=mode)(img_data) + self.assertEqual(img.mode, mode) + split = img.split() + for i in range(2): + self.assertTrue(np.allclose(expected_output[:,:,i], transform.to_tensor(split[i]))) + + img_data = jt.random((4, 4, 2)) + expected_output = img_data.multiply(255).int().float().divide(255) + for mode in [None, 'LA']: + verify_img_data(img_data, expected_output, mode=mode) + + with self.assertRaises(ValueError): + # should raise if we try a mode for 4 or 1 or 3 channel images + transform.ToPILImage(mode='RGBA')(img_data) + transform.ToPILImage(mode='P')(img_data) + transform.ToPILImage(mode='RGB')(img_data) + + def test_3_channel_tensor_to_pil_image(self): + def verify_img_data(img_data, expected_output, mode): + if mode is None: + img = transform.ToPILImage()(img_data) + self.assertEqual(img.mode, 'RGB') # default should assume RGB + else: + img = transform.ToPILImage(mode=mode)(img_data) + self.assertEqual(img.mode, mode) + split = img.split() + for i in range(3): + self.assertTrue(np.allclose(expected_output[:,:,i], transform.to_tensor(split[i]))) + + img_data = jt.random((4, 4, 3)) + expected_output = img_data.multiply(255).int().float().divide(255) + for mode in [None, 'RGB', 'HSV', 'YCbCr']: + verify_img_data(img_data, expected_output, mode=mode) + + with self.assertRaises(ValueError): + # should raise if we try a mode for 4 or 1 or 2 channel images + transform.ToPILImage(mode='RGBA')(img_data) + transform.ToPILImage(mode='P')(img_data) + transform.ToPILImage(mode='LA')(img_data) + + with self.assertRaises(ValueError): + transform.ToPILImage()(jt.random((1, 3, 4, 4))) + + def test_3_channel_ndarray_to_pil_image(self): + def verify_img_data(img_data, mode): + if mode is None: + img = transform.ToPILImage()(img_data) + self.assertEqual(img.mode, 'RGB') # default should assume RGB + else: + img = transform.ToPILImage(mode=mode)(img_data) + self.assertEqual(img.mode, mode) + split = img.split() + for i in range(3): + self.assertTrue(np.allclose(img_data[:, :, i], split[i])) + + img_data = np.random.randint(0, 255, (4, 4, 3)).astype(np.uint8) + for mode in [None, 'RGB', 'HSV', 'YCbCr']: + verify_img_data(img_data, mode) + + with self.assertRaises(ValueError): + # should raise if we try a mode for 4 or 1 or 2 channel images + transform.ToPILImage(mode='RGBA')(img_data) + transform.ToPILImage(mode='P')(img_data) + transform.ToPILImage(mode='LA')(img_data) + + def test_4_channel_tensor_to_pil_image(self): + def verify_img_data(img_data, expected_output, mode): + if mode is None: + img = transform.ToPILImage()(img_data) + self.assertEqual(img.mode, 'RGBA') # default should assume RGBA + else: + img = transform.ToPILImage(mode=mode)(img_data) + self.assertEqual(img.mode, mode) + + split = img.split() + for i in range(4): + self.assertTrue(np.allclose(expected_output[:,:,i], transform.to_tensor(split[i])[0])) + + img_data = jt.random((4, 4, 4)) + expected_output = img_data.multiply(255).int().float().divide(255) + for mode in [None, 'RGBA', 'CMYK', 'RGBX']: + verify_img_data(img_data, expected_output, mode) + + with self.assertRaises(ValueError): + # should raise if we try a mode for 3 or 1 or 2 channel images + transform.ToPILImage(mode='RGB')(img_data) + transform.ToPILImage(mode='P')(img_data) + transform.ToPILImage(mode='LA')(img_data) + + def test_4_channel_ndarray_to_pil_image(self): + def verify_img_data(img_data, mode): + if mode is None: + img = transform.ToPILImage()(img_data) + self.assertEqual(img.mode, 'RGBA') # default should assume RGBA + else: + img = transform.ToPILImage(mode=mode)(img_data) + self.assertEqual(img.mode, mode) + split = img.split() + for i in range(4): + self.assertTrue(np.allclose(img_data[:, :, i], split[i])) + + img_data = np.random.randint(0, 255, (4, 4, 4)).astype(np.uint8) + for mode in [None, 'RGBA', 'CMYK', 'RGBX']: + verify_img_data(img_data, mode) + + with self.assertRaises(ValueError): + # should raise if we try a mode for 3 or 1 or 2 channel images + transform.ToPILImage(mode='RGB')(img_data) + transform.ToPILImage(mode='P')(img_data) + transform.ToPILImage(mode='LA')(img_data) + + def test_2d_tensor_to_pil_image(self): + to_tensor = transform.ToTensor() + + img_data_float = jt.array(np.random.rand(4, 4), dtype='float32') + img_data_byte = jt.array(np.random.randint(0, 255, (4, 4)), dtype='uint8') + img_data_short = jt.array(np.random.randint(0, 32767, (4, 4)), dtype='int16') + img_data_int = jt.array(np.random.randint(0, 2147483647, (4, 4)), dtype='int32') + + inputs = [img_data_float, img_data_byte, img_data_short, img_data_int] + expected_outputs = [img_data_float.multiply(255).int().float().divide(255).numpy(), + img_data_byte.float().divide(255.0).numpy(), + img_data_short.numpy(), + img_data_int.numpy()] + expected_modes = ['F', 'L', 'I;16', 'I'] + + for img_data, expected_output, mode in zip(inputs, expected_outputs, expected_modes): + for t in [transform.ToPILImage(), transform.ToPILImage(mode=mode)]: + img = t(img_data) + self.assertEqual(img.mode, mode) + self.assertTrue(np.allclose(expected_output, to_tensor(img), atol=0.01, rtol=0.01)) + + def test_2d_ndarray_to_pil_image(self): + img_data_float = np.random.rand(4, 4).astype(np.float32) + img_data_byte = np.random.randint(0, 255, (4, 4)).astype(np.uint8) + img_data_short = np.random.randint(0, 32767, (4, 4)).astype(np.int16) + img_data_int = np.random.randint(0, 2147483647, (4, 4)).astype(np.int32) + + inputs = [img_data_float, img_data_byte, img_data_short, img_data_int] + expected_modes = ['F', 'L', 'I;16', 'I'] + for img_data, mode in zip(inputs, expected_modes): + for t in [transform.ToPILImage(), transform.ToPILImage(mode=mode)]: + img = t(img_data) + self.assertEqual(img.mode, mode) + self.assertTrue(np.allclose(img_data, img)) + + def test_tensor_bad_types_to_pil_image(self): + with self.assertRaises(ValueError): + transform.ToPILImage()(jt.ones((1, 3, 4, 4))) + + def test_ndarray_bad_types_to_pil_image(self): + trans = transform.ToPILImage() + with self.assertRaises(TypeError): + trans(np.ones([4, 4, 1], np.int64)) + trans(np.ones([4, 4, 1], np.uint16)) + trans(np.ones([4, 4, 1], np.uint32)) + trans(np.ones([4, 4, 1], np.float64)) + + with self.assertRaises(ValueError): + transform.ToPILImage()(np.ones([1, 4, 4, 3])) + + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_vertical_flip(self): + random_state = random.getstate() + random.seed(42) + img = transform.ToPILImage()(jt.random((3, 10, 10))) + vimg = img.transpose(Image.FLIP_TOP_BOTTOM) + + num_samples = 250 + num_vertical = 0 + for _ in range(num_samples): + out = transform.RandomVerticalFlip()(img) + if out == vimg: + num_vertical += 1 + + p_value = stats.binom_test(num_vertical, num_samples, p=0.5) + random.setstate(random_state) + self.assertGreater(p_value, 0.0001) + + num_samples = 250 + num_vertical = 0 + for _ in range(num_samples): + out = transform.RandomVerticalFlip(p=0.7)(img) + if out == vimg: + num_vertical += 1 + + p_value = stats.binom_test(num_vertical, num_samples, p=0.7) + random.setstate(random_state) + self.assertGreater(p_value, 0.0001) + + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_horizontal_flip(self): + random_state = random.getstate() + random.seed(42) + img = transform.ToPILImage()(jt.random((3, 10, 10))) + himg = img.transpose(Image.FLIP_LEFT_RIGHT) + + num_samples = 250 + num_horizontal = 0 + for _ in range(num_samples): + out = transform.RandomHorizontalFlip()(img) + if out == himg: + num_horizontal += 1 + + p_value = stats.binom_test(num_horizontal, num_samples, p=0.5) + random.setstate(random_state) + self.assertGreater(p_value, 0.0001) + + num_samples = 250 + num_horizontal = 0 + for _ in range(num_samples): + out = transform.RandomHorizontalFlip(p=0.7)(img) + if out == himg: + num_horizontal += 1 + + p_value = stats.binom_test(num_horizontal, num_samples, p=0.7) + random.setstate(random_state) + self.assertGreater(p_value, 0.0001) + + @unittest.skipIf(stats is None, 'scipy.stats is not available') + def test_normalize(self): + def samples_from_standard_normal(tensor): + p_value = stats.kstest(list(tensor.reshape(-1).data), 'norm', args=(0, 1)).pvalue + return p_value > 0.0001 + + random_state = random.getstate() + random.seed(42) + for channels in [1, 3]: + img = jt.random((channels, 10, 10)) + mean = [img[c].mean().item() for c in range(channels)] + std = [img[c].std().item() for c in range(channels)] + normalized = transform.ImageNormalize(mean, std)(img) + self.assertTrue(samples_from_standard_normal(normalized)) + random.setstate(random_state) + + def test_normalize_different_dtype(self): + for dtype1 in ['float32', 'float64']: + img = jt.random((3, 10, 10), dtype=dtype1) + for dtype2 in ['int64', 'float32', 'float64']: + mean = jt.array([1, 2, 3], dtype=dtype2) + std = jt.array([1, 2, 1], dtype=dtype2) + # checks that it doesn't crash + transform.image_normalize(img, mean, std) + + def test_normalize_3d_tensor(self): + jt.seed(28) + n_channels = 3 + img_size = 10 + mean = jt.random((n_channels,)).data + std = jt.random((n_channels,)).data + img = jt.random((n_channels, img_size, img_size)).data + target = transform.image_normalize(img, mean, std) + + mean_unsqueezed = mean.reshape(-1, 1, 1) + std_unsqueezed = std.reshape(-1, 1, 1) + result1 = transform.image_normalize(img, mean_unsqueezed, std_unsqueezed) + result2 = transform.image_normalize(img, + mean_unsqueezed, + std_unsqueezed) + assert_array_almost_equal(target, result1) + assert_array_almost_equal(target, result2) + + def test_adjust_brightness(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + + # test 0 + y_pil = transform.adjust_brightness(x_pil, 1) + y_np = np.array(y_pil) + self.assertTrue(np.allclose(y_np, x_np)) + + # test 1 + y_pil = transform.adjust_brightness(x_pil, 0.5) + y_np = np.array(y_pil) + y_ans = [0, 2, 6, 27, 67, 113, 18, 4, 117, 45, 127, 0] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + self.assertTrue(np.allclose(y_np, y_ans)) + + # test 2 + y_pil = transform.adjust_brightness(x_pil, 2) + y_np = np.array(y_pil) + y_ans = [0, 10, 26, 108, 255, 255, 74, 16, 255, 180, 255, 2] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + self.assertTrue(np.allclose(y_np, y_ans)) + + def test_adjust_contrast(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + + # test 0 + y_pil = transform.adjust_contrast(x_pil, 1) + y_np = np.array(y_pil) + self.assertTrue(np.allclose(y_np, x_np)) + + # test 1 + y_pil = transform.adjust_contrast(x_pil, 0.5) + y_np = np.array(y_pil) + y_ans = [43, 45, 49, 70, 110, 156, 61, 47, 160, 88, 170, 43] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + self.assertTrue(np.allclose(y_np, y_ans)) + + # test 2 + y_pil = transform.adjust_contrast(x_pil, 2) + y_np = np.array(y_pil) + y_ans = [0, 0, 0, 22, 184, 255, 0, 0, 255, 94, 255, 0] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + self.assertTrue(np.allclose(y_np, y_ans)) + + # @unittest.skipIf(Image.__version__ >= '7', "Temporarily disabled") + def test_adjust_saturation(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + + # test 0 + y_pil = transform.adjust_saturation(x_pil, 1) + y_np = np.array(y_pil) + self.assertTrue(np.allclose(y_np, x_np)) + + # test 1 + y_pil = transform.adjust_saturation(x_pil, 0.5) + y_np = np.array(y_pil) + y_ans = [2, 4, 8, 87, 128, 173, 39, 25, 138, 133, 216, 89] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + self.assertTrue(np.allclose(y_np, y_ans)) + + # test 2 + y_pil = transform.adjust_saturation(x_pil, 2) + y_np = np.array(y_pil) + y_ans = [0, 6, 22, 0, 149, 255, 32, 0, 255, 3, 255, 0] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + self.assertTrue(np.allclose(y_np, y_ans)) + + def test_adjust_hue(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + + with self.assertRaises(ValueError): + transform.adjust_hue(x_pil, -0.7) + transform.adjust_hue(x_pil, 1) + + # test 0: almost same as x_data but not exact. + # probably because hsv <-> rgb floating point ops + y_pil = transform.adjust_hue(x_pil, 0) + y_np = np.array(y_pil) + y_ans = [0, 5, 13, 54, 139, 226, 35, 8, 234, 91, 255, 1] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + self.assertTrue(np.allclose(y_np, y_ans)) + + # test 1 + y_pil = transform.adjust_hue(x_pil, 0.25) + y_np = np.array(y_pil) + y_ans = [13, 0, 12, 224, 54, 226, 234, 8, 99, 1, 222, 255] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + self.assertTrue(np.allclose(y_np, y_ans)) + + # test 2 + y_pil = transform.adjust_hue(x_pil, -0.25) + y_np = np.array(y_pil) + y_ans = [0, 13, 2, 54, 226, 58, 8, 234, 152, 255, 43, 1] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + self.assertTrue(np.allclose(y_np, y_ans)) + + def test_adjust_gamma(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + + # test 0 + y_pil = transform.adjust_gamma(x_pil, 1) + y_np = np.array(y_pil) + self.assertTrue(np.allclose(y_np, x_np)) + + # test 1 + y_pil = transform.adjust_gamma(x_pil, 0.5) + y_np = np.array(y_pil) + y_ans = [0, 35, 57, 117, 186, 241, 97, 45, 245, 152, 255, 16] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + self.assertTrue(np.allclose(y_np, y_ans)) + + # test 2 + y_pil = transform.adjust_gamma(x_pil, 2) + y_np = np.array(y_pil) + y_ans = [0, 0, 0, 11, 71, 201, 5, 0, 215, 31, 255, 0] + y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) + self.assertTrue(np.allclose(y_np, y_ans)) + + def test_adjusts_L_mode(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_rgb = Image.fromarray(x_np, mode='RGB') + + x_l = x_rgb.convert('L') + self.assertEqual(transform.adjust_brightness(x_l, 2).mode, 'L') + self.assertEqual(transform.adjust_saturation(x_l, 2).mode, 'L') + self.assertEqual(transform.adjust_contrast(x_l, 2).mode, 'L') + self.assertEqual(transform.adjust_hue(x_l, 0.4).mode, 'L') + self.assertEqual(transform.adjust_gamma(x_l, 0.5).mode, 'L') + + def test_color_jitter(self): + color_jitter = transform.ColorJitter(2, 2, 2, 0.1) + + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + x_pil_2 = x_pil.convert('L') + + for i in range(10): + y_pil = color_jitter(x_pil) + self.assertEqual(y_pil.mode, x_pil.mode) + + y_pil_2 = color_jitter(x_pil_2) + self.assertEqual(y_pil_2.mode, x_pil_2.mode) + + def test_gray(self): + """Unit tests for grayscale transform""" + + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + x_pil_2 = x_pil.convert('L') + gray_np = np.array(x_pil_2) + + # Test Set: Gray an image with desired number of output channels + # Case 1: RGB -> 1 channel grayscale + trans1 = transform.Gray(num_output_channels=1) + gray_pil_1 = trans1(x_pil) + gray_np_1 = np.array(gray_pil_1) + # self.assertEqual(gray_pil_1.mode, 'L', 'mode should be L') + self.assertEqual(gray_np_1.shape[1:], tuple(x_shape[0:2]), 'should be 1 channel') + assert np.allclose(gray_np/255, gray_np_1[0], atol=0.01) + + # Case 2: RGB -> 3 channel grayscale + trans2 = transform.Gray(num_output_channels=3) + gray_pil_2 = trans2(x_pil) + gray_np_2 = np.array(gray_pil_2) + # self.assertEqual(gray_pil_2.mode, 'RGB', 'mode should be RGB') + self.assertEqual(gray_np_2.shape, tuple(x_shape), 'should be 3 channel') + np.testing.assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) + np.testing.assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) + assert np.allclose(gray_np/255, gray_np_2[:, :, 0], atol=0.01) + + # Case 3: 1 channel grayscale -> 1 channel grayscale + trans3 = transform.Gray(num_output_channels=1) + gray_pil_3 = trans3(x_pil_2) + gray_np_3 = np.array(gray_pil_3) + # self.assertEqual(gray_pil_3.mode, 'L', 'mode should be L') + self.assertEqual(gray_np_3.shape[1:], tuple(x_shape[0:2]), 'should be 1 channel') + np.testing.assert_allclose(gray_np/255, gray_np_3[0], atol=0.01) + + # Case 4: 1 channel grayscale -> 3 channel grayscale + trans4 = transform.Gray(num_output_channels=3) + gray_pil_4 = trans4(x_pil_2) + gray_np_4 = np.array(gray_pil_4) + # self.assertEqual(gray_pil_4.mode, 'RGB', 'mode should be RGB') + self.assertEqual(gray_np_4.shape, tuple(x_shape), 'should be 3 channel') + np.testing.assert_equal(gray_np_4[:, :, 0], gray_np_4[:, :, 1]) + np.testing.assert_equal(gray_np_4[:, :, 1], gray_np_4[:, :, 2]) + np.testing.assert_allclose(gray_np/255, gray_np_4[:, :, 0], atol=0.01) + + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_gray(self): + """Unit tests for random grayscale transform""" + + # Test Set 1: RGB -> 3 channel grayscale + random_state = random.getstate() + random.seed(42) + x_shape = [2, 2, 3] + x_np = np.random.randint(0, 256, x_shape, np.uint8) + x_pil = Image.fromarray(x_np, mode='RGB') + x_pil_2 = x_pil.convert('L') + gray_np = np.array(x_pil_2) + + num_samples = 250 + num_gray = 0 + for _ in range(num_samples): + gray_pil_2 = transform.RandomGray(p=0.5)(x_pil) + gray_np_2 = np.array(gray_pil_2) + if np.array_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) and \ + np.array_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) and \ + np.array_equal(gray_np, gray_np_2[:, :, 0]): + num_gray = num_gray + 1 + + p_value = stats.binom_test(num_gray, num_samples, p=0.5) + random.setstate(random_state) + self.assertGreater(p_value, 0.0001) + + # Test Set 2: grayscale -> 1 channel grayscale + random_state = random.getstate() + random.seed(42) + x_shape = [2, 2, 3] + x_np = np.random.randint(0, 256, x_shape, np.uint8) + x_pil = Image.fromarray(x_np, mode='RGB') + x_pil_2 = x_pil.convert('L') + gray_np = np.array(x_pil_2) + + num_samples = 250 + num_gray = 0 + for _ in range(num_samples): + gray_pil_3 = transform.RandomGray(p=0.5)(x_pil_2) + gray_np_3 = np.array(gray_pil_3) + if np.array_equal(gray_np, gray_np_3): + num_gray = num_gray + 1 + + p_value = stats.binom_test(num_gray, num_samples, p=1.0) # Note: grayscale is always unchanged + random.setstate(random_state) + self.assertGreater(p_value, 0.0001) + + # Test set 3: Explicit tests + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + x_pil = Image.fromarray(x_np, mode='RGB') + x_pil_2 = x_pil.convert('L') + gray_np = np.array(x_pil_2) + + # Case 3a: RGB -> 3 channel grayscale (grayscaled) + trans2 = transform.RandomGray(p=1.0) + gray_pil_2 = trans2(x_pil) + gray_np_2 = np.array(gray_pil_2) + self.assertEqual(gray_pil_2.mode, 'RGB', 'mode should be RGB') + self.assertEqual(gray_np_2.shape, tuple(x_shape), 'should be 3 channel') + np.testing.assert_equal(gray_np_2[:, :, 0], gray_np_2[:, :, 1]) + np.testing.assert_equal(gray_np_2[:, :, 1], gray_np_2[:, :, 2]) + np.testing.assert_equal(gray_np, gray_np_2[:, :, 0]) + + # Case 3b: RGB -> 3 channel grayscale (unchanged) + trans2 = transform.RandomGray(p=0.0) + gray_pil_2 = trans2(x_pil) + gray_np_2 = np.array(gray_pil_2) + self.assertEqual(gray_pil_2.mode, 'RGB', 'mode should be RGB') + self.assertEqual(gray_np_2.shape, tuple(x_shape), 'should be 3 channel') + np.testing.assert_equal(x_np, gray_np_2) + + # Case 3c: 1 channel grayscale -> 1 channel grayscale (grayscaled) + trans3 = transform.RandomGray(p=1.0) + gray_pil_3 = trans3(x_pil_2) + gray_np_3 = np.array(gray_pil_3) + self.assertEqual(gray_pil_3.mode, 'L', 'mode should be L') + self.assertEqual(gray_np_3.shape, tuple(x_shape[0:2]), 'should be 1 channel') + np.testing.assert_equal(gray_np, gray_np_3) + + # Case 3d: 1 channel grayscale -> 1 channel grayscale (unchanged) + trans3 = transform.RandomGray(p=0.0) + gray_pil_3 = trans3(x_pil_2) + gray_np_3 = np.array(gray_pil_3) + self.assertEqual(gray_pil_3.mode, 'L', 'mode should be L') + self.assertEqual(gray_np_3.shape, tuple(x_shape[0:2]), 'should be 1 channel') + np.testing.assert_equal(gray_np, gray_np_3) + + def test_RandomPerspective(self): + img = jt.random((30,40,3)) + result = transform.Compose([ + transform.ToPILImage(), + transform.RandomPerspective(p=1), + transform.ToTensor(), + ])(img) + + + def test_RandomResizedCrop(self): + img = jt.random((30,40,3)) + result = transform.Compose([ + transform.ToPILImage(), + transform.RandomResizedCrop(20), + transform.ToTensor(), + ])(img) + + + def test_FiveCrop(self): + img = jt.random((30,40,3)) + result = transform.Compose([ + transform.ToPILImage(), + transform.FiveCrop(20), + transform.ToTensor(), + ])(img) + + + def test_TenCrop(self): + img = jt.random((30,40,3)) + result = transform.Compose([ + transform.ToPILImage(), + transform.TenCrop(20), + transform.ToTensor(), + ])(img) + + + def test_RandomRotation(self): + img = jt.random((30,40,3)) + result = transform.Compose([ + transform.ToPILImage(), + transform.RandomRotation(20), + transform.ToTensor(), + ])(img) + + + def test_RandomAffine(self): + img = jt.random((30,40,3)) + result = transform.Compose([ + transform.ToPILImage(), + transform.RandomAffine(20), + transform.ToTensor(), + ])(img) + + def test_not_pil_image(self): + img = jt.random((30,40,3)) + result = transform.Compose([ + transform.RandomAffine(20), + transform.ToTensor(), + ])(img) + + img = jt.random((30,40,3)) + result = transform.Compose([ + transform.ToPILImage(), + transform.Gray(), + transform.Resize(20), + transform.ToTensor(), + ])(img) + + + + +if __name__ == '__main__': + unittest.main() diff --git a/python/jittor/test/test_transpose_op.py b/python/jittor/test/test_transpose_op.py new file mode 100644 index 00000000..706fe943 --- /dev/null +++ b/python/jittor/test/test_transpose_op.py @@ -0,0 +1,159 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from .test_core import expect_error +from .test_grad import ngrad +from itertools import permutations +from jittor.test.test_cuda import test_cuda + +def gen_data(shape): + num = np.multiply.reduce(shape) + a = np.arange(0, num) + return a.reshape(shape) + +class TestTransposeOp(unittest.TestCase): + def test_with_np(self): + def check(a): + perms = list(permutations(range(a.ndim))) + [None] + for perm in perms: + if perm: + x = np.transpose(a, perm) + y = jt.transpose(a, perm).data + else: + x = np.transpose(a) + y = jt.transpose(a).data + self.assertEqual(x.shape, y.shape) + assert (x==y).all(), f"\n{x}\n{y}" + + # ia = [gen_data([2,3,4,5]), gen_data([5,3])] + ia = [gen_data([2,2,2]), gen_data([2,3,4,5]), gen_data([5,3])] + for a in ia: check(a) + + def test_grad(self): + def check(a): + perms = list(permutations(range(a.ndim))) + [None] + for perm in perms: + x = jt.array(a).float() + if perm: + y = x.transpose(perm) + else: + y = x.transpose() + dx = jt.grad(y*y, x).data + self.assertEqual(dx.shape, a.shape) + assert (dx==a*2).all(), f"\n{dx}\n{a}\n{perm}" + ia = [gen_data([2,2,2]), gen_data([2,3,4,5]), gen_data([5,3])] + for a in ia: check(a) + + def test_matmul_grad(self): + np.random.seed(0) + for i in range(10): + a = np.random.rand(2,3).astype("float32") + b = np.random.rand(3,4).astype("float32") + out, (da, db) = ngrad(lambda vars: np.matmul(vars[0],vars[1]).sum(), [a,b], 1e-1) + ja = jt.array(a) + jb = jt.array(b) + jc = ja.matmul(jb) + jda, jdb = jt.grad(jc, [ja,jb]) + assert ((da-jda.data)<1e-5).all(), (da, jda.data, da-jda.data) + assert ((db-jdb.data)<1e-5).all(), (db-jdb.data) + + def test_permute(self): + a = jt.ones([2,3,4]) + assert a.permute().shape == [4,3,2] + assert a.permute(0,2,1).shape == [2,4,3] + + def test_transpose_3d2i(self): + a = jt.ones([2,3,4]) + assert a.transpose(0,1).shape == (3,2,4) + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda=1) + def test_cutt(self): + a = jt.rand((10,2)) > 0.5 + b = a.transpose() + assert (a.data.transpose() == b.data).all() + + a = jt.zeros((1,1)) + b = a.transpose((1,0)) + b.sync() + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda=1) + def test_cutt_bug(self): + a = jt.rand(640000,4,3) + b = a.transpose(0,2,1) + b.sync(True) + print(a.shape, b.shape) + + +class TestFuseTransposeOp(unittest.TestCase): + + def test_fuse_transpose1(self): + with jt.profile_scope() as rep: + a = jt.rand((10,11,12)) + b = a.fuse_transpose((1,2,0))+1 + np.testing.assert_allclose( + a.data.transpose((1,2,0))+1, + b.data + ) + assert len(rep) == 3 + + def test_fuse_transpose2(self): + with jt.profile_scope() as rep: + a = jt.rand((10,11,12)) + b = (a+1).fuse_transpose((1,2,0)) + np.testing.assert_allclose( + a.data.transpose((1,2,0))+1, + b.data + ) + assert len(rep) == 3 + + def test_fuse_transpose3(self): + with jt.profile_scope() as rep: + a = jt.rand((10,11,12)) + c = jt.rand((11,12,10)) + b = a.fuse_transpose((1,2,0))+c + np.testing.assert_allclose( + a.data.transpose((1,2,0))+c.data, + b.data + ) + assert len(rep) == 3 + + def test_fuse_transpose4(self): + with jt.profile_scope() as rep: + a = jt.rand((10,11,12)) + c = jt.rand((10,11,12)) + b = (a+c).fuse_transpose((1,2,0)) + np.testing.assert_allclose( + (a.data+c.data).transpose((1,2,0)), + b.data + ) + assert len(rep) == 3 + + def test_fuse_transpose5(self): + with jt.profile_scope() as rep: + a = jt.rand((10,11,6,7)) + c = jt.rand((10,11,6,7)) + b = (a+c).fuse_transpose((1,0,2,3)) + np.testing.assert_allclose( + (a.data+c.data).transpose((1,0,2,3)), + b.data + ) + assert len(rep) == 3 + + +@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") +class TestFuseTransposeCudaOp(TestFuseTransposeOp): + def setUp(self): + jt.flags.use_cuda = 1 + def tearDown(self): + jt.flags.use_cuda = 0 + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_unary_op.py b/python/jittor/test/test_unary_op.py new file mode 100644 index 00000000..a8d25515 --- /dev/null +++ b/python/jittor/test/test_unary_op.py @@ -0,0 +1,120 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +from .test_grad import ngrad +from .test_cuda import test_cuda + +def check(op, *args): + x = eval(f"np.{op}(*args)") + y = eval(f"jt.{op}(*args).data") + convert = lambda x: x.astype("uint8") if x.dtype=="bool" else x + x = convert(x) + y = convert(y) + # str match nan and inf + assert x.dtype == y.dtype and x.shape == y.shape, \ + (x.dtype, y.dtype, x.shape, y.shape) + for a,b in zip(x.flatten(), y.flatten()): + assert str(a)[:5] == str(b)[:5], (a,b) + +class TestUnaryOp(unittest.TestCase): + def test_unary_op(self): + assert jt.float64(1).data.dtype == "float64" + assert (jt.abs(-1) == 1).data.all() + assert (abs(-jt.float64(1)) == 1).data.all() + a = np.array([-1,2,3,0], dtype="int32") + check("abs", a) + check("negative", a) + check("logical_not", a) + check("bitwise_not", a) + b = np.array([1.1, 2.2, 3.3, 4.4, -1, 0]) + type = "float16" if (jt.flags.amp_reg & 2) else "float32" + check("log", a.astype(type)) + check("exp", a.astype(type)) + check("sqrt", a.astype(type)) + + def test_grad(self): + ops = ["abs", "negative", "log", "exp", "sqrt", + "sin", "arcsin", "sinh", "arcsinh", + "tan", "arctan", "tanh", "arctanh", + "cos", "arccos", "cosh", "arccosh", + "sigmoid", + ] + a = np.array([1.1, 2.2, 3.3, 4.4]) + for op in ops: + if op == "abs": + b = np.array(a+[-1,]) + elif op == "arccosh": + b = np.array(a) + elif "sin" in op or "cos" in op or "tan" in op: + b = np.array(a) / 5 + else: + b = np.array(a) + func = lambda x: eval(f"np.{op}(x[0]).sum()") + if op == "sigmoid": + func = lambda x: (1/(1+np.exp(-x[0]))).sum() + x, (da,) = ngrad(func, [b], 1e-8) + ja = jt.array(b) + jb = eval(f"jt.{op}(ja)") + jda = jt.grad(jb, ja) + tol = 1e-2 if jt.flags.amp_reg & 2 else 1e-6 + assert (np.allclose(jda.data, da, atol=tol, rtol=tol)), (jda.data,da,op) + + def test_sigmoid(self): + a = np.arange(-150,150, 10).astype("float32") + # a = np.array([-150.0, -140.0, -130.0]).astype("float32") + b = jt.array(a, dtype='float32') + b1 = b.sigmoid().numpy() + assert np.isnan(b1).any() == False + + def test_safe_clip(self): + a = jt.array([-1.0,0,0.4,1,2,3]) + b = a.safe_clip(0.1, 0.5) + assert np.allclose(b.data, [0.1,0.1,0.4,0.5,0.5,0.5]) + da = jt.grad(b, a) + assert (da.data == 1).all() + + def test_erfinv(self): + from scipy import special + y = np.linspace(-1.0, 1.0, num=10) + x = special.erfinv(y) + y2 = jt.array(y) + x2 = jt.erfinv(y2) + np.testing.assert_allclose(y.data, y2.data) + + + y = np.linspace(-0.9, 0.9, num=10) + x = special.erfinv(y) + y2 = jt.array(y) + x2 = jt.erfinv(y2) + np.testing.assert_allclose(y.data, y2.data) + d = jt.grad(x2, y2) + _, (dn,) = ngrad(lambda y: special.erfinv(y).sum(), [y], 1e-8) + tol = 1e-3 if jt.flags.amp_reg & 2 else 1e-6 + np.testing.assert_allclose(d.data, dn, atol=tol, rtol=tol) + + +class TestUnaryOpCuda(TestUnaryOp, test_cuda(2)): + pass + +class TestUnaryOpCpuFp16(TestUnaryOp, test_cuda(2)): + def setUp(self): + jt.flags.amp_reg = 2 | 4 | 8 | 16 + def tearDown(self): + jt.flags.amp_reg = 0 + +class TestUnaryOpCudaFp16(TestUnaryOp, test_cuda(2)): + def setUp(self): + jt.flags.amp_reg = 2 | 4 | 8 | 16 + jt.flags.use_cuda = 1 + def tearDown(self): + jt.flags.amp_reg = 0 + jt.flags.use_cuda = 0 + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_unique.py b/python/jittor/test/test_unique.py new file mode 100644 index 00000000..31953bb7 --- /dev/null +++ b/python/jittor/test/test_unique.py @@ -0,0 +1,57 @@ + +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Dun Liang . +# Xiangli Li <1905692338@qq.com> +# Jiapeng Zhang +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** + +from cgi import test +import unittest +import jittor as jt +import numpy as np + +skip_this_test = False + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch +except: + torch = None + skip_this_test = True + +def test_unique_with_torch(input, dim=None): + jt0, jt1, jt2 = jt.unique(jt.array(input), True, True, dim) + torch0, torch1, torch2 = torch.unique(torch.tensor(input), True, True, True, dim) + assert np.allclose(jt0, torch0) and np.allclose(jt1, torch1) and np.allclose(jt2, torch2) + + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestSparse(unittest.TestCase): + + def test_unique(self): + test_unique_with_torch(np.array([1, 3, 2, 3, 3, 3], dtype=np.int32)) + test_unique_with_torch(np.array([[1, 3], [2, 3], [1, 2]], dtype=np.int64)) + + def test_unique_dim(self): + test_unique_with_torch(np.array([[1, 3], [2, 3], [1, 3], [2, 3]]), 0) + test_unique_with_torch(np.array([[1, 3], [2, 3], [1, 3], [2, 3]]), 1) + + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda=1) + def test_unique_cuda(self): + self.test_unique() + + @unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found") + @jt.flag_scope(use_cuda=1) + def test_unique_dim_cuda(self): + self.test_unique_dim() + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_utils.py b/python/jittor/test/test_utils.py new file mode 100644 index 00000000..2f634e37 --- /dev/null +++ b/python/jittor/test/test_utils.py @@ -0,0 +1,53 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import os +from jittor import LOG + +def find_jittor_path(): + path = os.path.realpath(__file__) + suffix = "test_utils.py" + assert path.endswith(suffix), path + return path[:-len(suffix)] + ".." + +def find_cache_path(): + import jittor_utils as jit_utils + path = jit_utils.home() + dirs = [".cache", "jittor"] + for d in dirs: + path = os.path.join(path, d) + if not os.path.isdir(path): + os.mkdir(path) + assert os.path.isdir(path) + return path + +cache_path = find_cache_path() +jittor_path = find_jittor_path() + +cc_flags = f" -g -O0 -DTEST --std=c++14 -I{jittor_path}/test -I{jittor_path}/src " + +class TestUtils(unittest.TestCase): + def test_cache_compile(self): + cmd = f"cd {cache_path} && g++ {jittor_path}/src/utils/log.cc {jittor_path}/src/utils/tracer.cc {jittor_path}/src/utils/str_utils.cc {jittor_path}/src/utils/cache_compile.cc -lpthread {cc_flags} -o cache_compile && cache_path={cache_path} jittor_path={jittor_path} ./cache_compile" + self.assertEqual(os.system(cmd), 0) + + def test_log(self): + return + cc_flags = f" -g -O3 -DTEST_LOG -DLOG_ASYNC --std=c++14 -I{jittor_path}/test -I{jittor_path}/src -lpthread " + cmd = f"cd {cache_path} && g++ {jittor_path}/src/utils/log.cc {jittor_path}/src/utils/tracer.cc {cc_flags} -o log && log_v=1000 log_sync=0 ./log" + LOG.v(cmd) + assert os.system(cmd) == 0 + + def test_mwsr_list(self): + cc_flags = f" -g -O3 -DTEST -DLOG_ASYNC --std=c++14 -I{jittor_path}/test -I{jittor_path}/src -lpthread " + cmd = f"cd {cache_path} && g++ {jittor_path}/src/utils/mwsr_list.cc {cc_flags} -o mwsr_list && ./mwsr_list" + LOG.v(cmd) + assert os.system(cmd) == 0 + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/test/test_var.py b/python/jittor/test/test_var.py new file mode 100644 index 00000000..116df6c7 --- /dev/null +++ b/python/jittor/test/test_var.py @@ -0,0 +1,51 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Dun Liang . +# Zheng-Ning Liu +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import jittor.nn as jnn + +skip_this_test = False + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + import torch.nn as tnn +except: + skip_this_test = True + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestVarFunctions(unittest.TestCase): + def test_var(self): + x = np.random.randn(100, 1000).astype(np.float32) + + jt_x = jt.array(x) + tc_x = torch.from_numpy(x) + np.testing.assert_allclose(jt_x.var().numpy(), tc_x.var().numpy(), rtol=1e-3, atol=1e-4) + np.testing.assert_allclose(jt_x.var(dim=1).numpy(), tc_x.var(dim=1).numpy(), rtol=1e-3, atol=1e-4) + np.testing.assert_allclose(jt_x.var(dim=0, unbiased=True).numpy(), tc_x.var(dim=0, unbiased=True).numpy(), rtol=1e-3, atol=1e-4) + + def test_std(self): + x=np.random.randn(100, 1000).astype(np.float32) + jt_x = jt.array(x) + tc_x = torch.from_numpy(x) + np.testing.assert_allclose(jt_x.std().numpy(), tc_x.std().numpy(), 1e-4) + + def test_norm(self): + x = np.random.randn(100, 1000).astype(np.float32) + jt_x = jt.array(x) + tc_x = torch.from_numpy(x) + np.testing.assert_allclose(jt_x.norm(1,1).numpy(), tc_x.norm(1,1).numpy(), atol=1e-6) + np.testing.assert_allclose(jt_x.norm(1,0).numpy(), tc_x.norm(1,0).numpy(), atol=1e-6) + np.testing.assert_allclose(jt_x.norm(2,1).numpy(), tc_x.norm(2,1).numpy(), atol=1e-6) + np.testing.assert_allclose(jt_x.norm(2,0).numpy(), tc_x.norm(2,0).numpy(), atol=1e-6) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_vgg.py b/python/jittor/test/test_vgg.py new file mode 100644 index 00000000..2125201a --- /dev/null +++ b/python/jittor/test/test_vgg.py @@ -0,0 +1,100 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +from jittor import nn, Module +from jittor.models import vgg +import numpy as np +import sys, os +import random +import math +import unittest +from .test_reorder_tuner import simple_parser +from .test_log import find_log_with_re +from jittor.dataset.mnist import MNIST +import jittor.transform as trans + +model_test = os.environ.get("model_test", "") == "1" +skip_model_test = not model_test + +class MnistNet(Module): + def __init__(self): + self.model = vgg.vgg16_bn() + self.layer = nn.Linear(1000,10) + def execute(self, x): + x = self.model(x) + x = self.layer(x) + return x + +@unittest.skipIf(skip_model_test, "skip_this_test, model_test != 1") +class TestVGGClass(unittest.TestCase): + @classmethod + def setUpClass(self): + # hyper-parameters + self.batch_size = 32 + self.weight_decay = 0.0001 + self.momentum = 0.9 + self.learning_rate = 0.01 + # mnist dataset + self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \ + .set_attrs(batch_size=self.batch_size, shuffle=True) + + # setup random seed + def setup_seed(self, seed): + np.random.seed(seed) + random.seed(seed) + jt.seed(seed) + + @unittest.skipIf(not jt.has_cuda, "Cuda not found") + @jt.flag_scope(use_cuda=1, use_stat_allocator=1) + def test_vgg(self): + self.setup_seed(1) + loss_list=[] + acc_list=[] + mnist_net = MnistNet() + SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay) + + for batch_idx, (data, target) in enumerate(self.train_loader): + output = mnist_net(data) + loss = nn.cross_entropy_loss(output, target) + + # train step + with jt.log_capture_scope( + log_silent=1, + log_v=1, log_vprefix="op.cc=100,exe=10", + ) as logs: + SGD.step(loss) + def callback(loss, output, target, batch_idx): + # print train info + pred = np.argmax(output, axis=1) + acc = np.sum(target==pred)/self.batch_size + loss_list.append(loss[0]) + acc_list.append(acc) + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f}' + .format(0, batch_idx, 100,1. * batch_idx, loss[0], acc)) + jt.fetch(batch_idx, loss, output, target, callback) + + log_conv = find_log_with_re(logs, + "Jit op key (not )?found: ((mkl)|(cudnn))_conv.*") + log_matmul = find_log_with_re(logs, + "Jit op key (not )?found: ((mkl)|(cublas))_matmul.*") + # if batch_idx: + # assert len(log_conv)==38 and len(log_matmul)==12, (len(log_conv), len(log_matmul)) + + mem_used = jt.flags.stat_allocator_total_alloc_byte \ + -jt.flags.stat_allocator_total_free_byte + assert mem_used < 11e9, mem_used + # assert jt.core.number_of_lived_vars() < 3500 + if (np.mean(loss_list[-50:])<0.2): + break + + assert np.mean(loss_list[-50:])<0.2 + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_weightnorm.py b/python/jittor/test/test_weightnorm.py new file mode 100644 index 00000000..439a7c97 --- /dev/null +++ b/python/jittor/test/test_weightnorm.py @@ -0,0 +1,63 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Haoyang Peng <2247838039@qq.com> +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +from jittor import nn +import numpy as np +import unittest +from jittor.weightnorm import weight_norm + +try: + import torch + from torch.autograd import Variable + import autograd.numpy as anp + from autograd import jacobian + + has_autograd = True +except: + has_autograd = False + +class jt_module(jt.nn.Module): + def __init__(self, weight): + super().__init__() + self.linear = jt.array(weight) + + def execute(self, x): + return jt.matmul(self.linear, x) + +class torch_module(torch.nn.Module): + def __init__(self, weight): + super().__init__() + self.linear = torch.nn.Parameter(torch.from_numpy(weight)) + + def forward(self, x): + return torch.matmul(self.linear, x) + +@unittest.skipIf(not has_autograd, "No autograd found.") +class TestWeightNorm(unittest.TestCase): + def test_weightnorm(self): + for i in range(30): + weight = np.random.uniform(0,1,(i+10,40)) + jm = jt_module(weight) + tm = torch_module(weight) + inp = np.random.uniform(0,1,(40,i+30)) + torch.nn.utils.weight_norm(tm, 'linear', -1) + weight_norm(jm, 'linear', -1) + jinp = jt.array(inp) + tinp = Variable(torch.from_numpy(inp), requires_grad=True) + joup = jm(jinp) + toup = tm(tinp) + np.testing.assert_allclose(joup.data, toup.detach().numpy(), rtol=1e-4, atol=1e-6) + gq = jt.grad(joup, jinp).data + tgq = torch.autograd.grad(toup, tinp, torch.ones_like(toup), retain_graph=True) + np.testing.assert_allclose(gq, tgq[0].numpy(), rtol=1e-4, atol=1e-6) + +if __name__ == "__main__": + unittest.main() diff --git a/python/jittor/test/test_where_op.py b/python/jittor/test/test_where_op.py new file mode 100644 index 00000000..20e54691 --- /dev/null +++ b/python/jittor/test/test_where_op.py @@ -0,0 +1,92 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np + +class TestWhereOp(unittest.TestCase): + def setUp(self): + self.where = jt.where + + def test(self): + assert (self.where([0,1,0,1])[0].data == [1,3]).all() + a, = self.where([0,1,0,1]) + assert a.uncertain_shape==[2] + a.data + assert a.uncertain_shape==[2] + a,b = self.where([[0,0,1],[1,0,0]]) + assert (a.data==[0,1]).all() and (b.data==[2,0]).all() + + def test_reindex_dep(self): + a = jt.random([10]) + b, = self.where(a>1) + assert len(b.data)==0 + b, = self.where(a>0.5) + assert (b.data==np.where(a.data>0.5)).all() + b = a.reindex_var(self.where(a>0.5)) + assert (b.data==a.data[a.data>0.5]).all() + + def test_binary_dep(self): + a = jt.random([10]) + b, = self.where(a>0.5) + b = b+1 + assert (b.data==np.where(a.data>0.5)[0]+1).all() + b, = self.where(a>1) + b = b+1 + assert (b.data==np.where(a.data>1)[0]+1).all() + + def test_self_dep(self): + a = jt.random([100]) + x = a.reindex_var(self.where(a>0.1)) + x = x.reindex_var(self.where(x<0.9)) + na = a.data + assert np.allclose(na[np.logical_and(na>0.1, na<0.9)], x.data) + + def test_reduce_dep(self): + a = jt.random([100,100]) + index = self.where(a>0.5) + assert isinstance(index, tuple) + x = a.reindex_var(index) + xsum =x.sum() + na = a.data + assert np.allclose(np.sum(na[na>0.5]),xsum.data), (x.data, xsum.data, np.sum(na[na>0.5])) + + def test_doc(self): + assert "Where Operator" in jt.where.__doc__ + + +@unittest.skipIf(not jt.has_cuda, "No Torch found") +class TestWhereOpCuda(TestWhereOp): + def setUp(self): + self.where = jt.where + + @classmethod + def setUpClass(self): + jt.flags.use_cuda = 1 + + @classmethod + def tearDownClass(self): + jt.flags.use_cuda = 0 + + + +@unittest.skipIf(not jt.has_cuda, "No Torch found") +class TestWhereOpCub(TestWhereOpCuda): + def setUp(self): + self.where = jt.compile_extern.cub_ops.cub_where + + @classmethod + def setUpClass(self): + jt.flags.use_cuda = 1 + + @classmethod + def tearDownClass(self): + jt.flags.use_cuda = 0 + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/python/jittor/transform/__init__.py b/python/jittor/transform/__init__.py new file mode 100644 index 00000000..1affa96d --- /dev/null +++ b/python/jittor/transform/__init__.py @@ -0,0 +1,1470 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. +# All Rights Reserved. +# Maintainers: +# Dun Liang . +# +# Contributors: +# Xin Yao +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +from PIL import Image +import random +import math +import numpy as np +import warnings +from collections.abc import Sequence, Mapping +import numbers +import jittor as jt + +from . import function_pil as F_pil + +def _get_image_size(img): + """ + Return image size as (w, h) + """ + return F_pil._get_image_size(img) + +def _get_image_num_channels(img): + return F_pil._get_image_num_channels(img) + +def _is_numpy(img): + return isinstance(img, np.ndarray) + +def _is_numpy_image(img): + return img.ndim in {2, 3} + +def crop(img, top, left, height, width): + ''' + Function for cropping image. + + Args:: + + [in] img(Image.Image): Input image. + [in] top(int): the top boundary of the cropping box. + [in] left(int): the left boundary of the cropping box. + [in] height(int): height of the cropping box. + [in] width(int): width of the cropping box. + + Example:: + + img = Image.open(...) + img_ = transform.crop(img, 10, 10, 100, 100) + ''' + return img.crop((left, top, left + width, top + height)) + +def resize(img, size, interpolation=Image.BILINEAR): + ''' + Function for resizing image. + + Args:: + + [in] img(Image.Image): Input image. + [in] size: resize size. [h, w] + [in] interpolation(int): type of resize. default: PIL.Image.BILINEAR + + Example:: + + img = Image.open(...) + img_ = transform.resize(img, (100, 100)) + ''' + if isinstance(size, Sequence): + return img.resize(size[::-1], interpolation) + else: + w, h = img.size + if (h > w): + return img.resize((size, int(round(size * h / w))), interpolation) + else: + return img.resize((int(round(size * w / h)), size), interpolation) + + +def gray(img, num_output_channels): + """ + Function for converting PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image. + Args:: + [in] img(PIL Image.Image): Input image. + [in] num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1. + Returns:: + [out] PIL Image: Grayscale version of the image. + if num_output_channels = 1 : returned image is single channel + if num_output_channels = 3 : returned image is 3 channel with r = g = b + """ + return F_pil.gray(img, num_output_channels) + + +def center_crop(img, output_size): + """ + Function for cropping the given image at the center. + Args:: + [in] img(PIL Image.Image): Input image. + [in] output_size (sequence or int): (height, width) of the crop box. + If int or sequence with single int, it is used for both directions. + Returns:: + PIL Image.Image: Cropped image. + """ + + output_size = _setup_size(output_size, error_msg="If size is a sequence, it should have 2 values") + + image_width, image_height = _get_image_size(img) + crop_height, crop_width = output_size + + crop_top = int(round((image_height - crop_height) / 2.)) + crop_left = int(round((image_width - crop_width) / 2.)) + return crop(img, crop_top, crop_left, crop_height, crop_width) + +def crop_and_resize(img, top, left, height, width, size, interpolation=Image.BILINEAR): + ''' + Function for cropping and resizing image. + + Args:: + + [in] img(Image.Image): Input image. + [in] top(int): the top boundary of the cropping box. + [in] left(int): the left boundary of the cropping box. + [in] height(int): height of the cropping box. + [in] width(int): width of the cropping box. + [in] size: resize size. [h, w] + [in] interpolation(int): type of resize. default: PIL.Image.BILINEAR + + Example:: + + img = Image.open(...) + img_ = transform.resize(img, 10,10,200,200,100) + ''' + img = crop(img, top, left, height, width) + img = resize(img, size, interpolation) + return img + +class Crop: + """Crop and the PIL Image to given size. + + Args: + + * top(int): top pixel indexes + * left(int): left pixel indexes + * height(int): image height + * width(int): image width + """ + def __init__(self, top, left, height, width): + self.top = top + self.left = left + self.height = height + self.width = width + def __call__(self, img:Image.Image): + if not isinstance(img, Image.Image): + img = to_pil_image(img) + return crop(img, self.top, self.left, self.height, self.width) + + +class RandomCropAndResize: + """Random crop and resize the given PIL Image to given size. + + Args:: + + [in] size(int or tuple): [height, width] of the output image. + [in] scale(tuple): range of scale ratio of the area. + [in] ratio(tuple): range of aspect ratio. + [in] interpolation: type of resize. default: PIL.Image.BILINEAR. + + Example:: + + transform = transform.RandomCropAndResize(224) + img_ = transform(img) + """ + def __init__(self, size, scale:tuple=(0.08, 1.0), ratio:tuple=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): + self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values") + assert scale[0] <= scale[1] and ratio[0] <= ratio[1] + + self.size = size + self.scale = scale + self.ratio = ratio + self.interpolation = interpolation + + def __call__(self, img:Image.Image): + if not isinstance(img, Image.Image): + img = to_pil_image(img) + width, height = img.size + scale = self.scale + ratio = self.ratio + area = height * width + + for _ in range(10): + target_area = random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + break + else: + # Fallback to central crop + in_ratio = float(width) / float(height) + if in_ratio < min(ratio): + w = width + h = int(round(w / min(ratio))) + elif in_ratio > max(ratio): + h = height + w = int(round(h * max(ratio))) + else: + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return crop_and_resize(img, i, j, h, w, self.size, self.interpolation) + + + +def hflip(img): + """ + Function for horizontally flipping the given image. + Args:: + [in] img(PIL Image.Image): Input image. + Example:: + + img = Image.open(...) + img_ = transform.hflip(img) + """ + return F_pil.hflip(img) + + +def vflip(img): + """ + Function for vertically flipping the given image. + Args:: + [in] img(PIL Image.Image): Input image. + Example:: + + img = Image.open(...) + img_ = transform.vflip(img) + """ + return F_pil.vflip(img) + + + +def adjust_brightness(img, brightness_factor): + """ + Function for adjusting brightness of an RGB image. + Args:: + [in] img (PIL Image.Image): Image to be adjusted. + [in] brightness_factor (float): How much to adjust the brightness. + Can be any non negative number. 0 gives a black image, 1 gives the + original image while 2 increases the brightness by a factor of 2. + Returns:: + [out] PIL Image.Image: Brightness adjusted image. + Example:: + + img = Image.open(...) + img_ = transform.adjust_brightness(img, 0.5) + """ + return F_pil.adjust_brightness(img, brightness_factor) + + +def adjust_contrast(img, contrast_factor): + """ + Function for adjusting contrast of an image. + Args:: + [in] img (PIL Image.Image): Image to be adjusted. + [in] contrast_factor (float): How much to adjust the contrast. + Can be any non negative number. 0 gives a solid gray image, + 1 gives the original image while 2 increases the contrast by a factor of 2. + Returns:: + [out] PIL Image.Image: Contrast adjusted image. + Example:: + + img = Image.open(...) + img_ = transform.adjust_contrast(img, 0.5) + """ + return F_pil.adjust_contrast(img, contrast_factor) + + +def adjust_saturation(img, saturation_factor): + """ + Function for adjusting saturation of an image. + Args:: + [in] img (PIL Image.Image): Image to be adjusted. + [in] saturation_factor (float): How much to adjust the saturation. + 0 will give a black and white image, 1 will give the original image + while 2 will enhance the saturation by a factor of 2. + Returns:: + [out] PIL Image.Image: Saturation adjusted image. + Example:: + + img = Image.open(...) + img_ = transform.adjust_saturation(img, 0.5) + """ + return F_pil.adjust_saturation(img, saturation_factor) + + +def adjust_hue(img, hue_factor): + """ + Function for adjusting hue of an image. + The image hue is adjusted by converting the image to HSV and + cyclically shifting the intensities in the hue channel (H). + The image is then converted back to original image mode. + `hue_factor` is the amount of shift in H channel and must be in the + interval `[-0.5, 0.5]`. + See `Hue`_ for more details. + .. _Hue: https://en.wikipedia.org/wiki/Hue + Args:: + [in] img (PIL Image.Image): Image to be adjusted. + [in] hue_factor (float): How much to shift the hue channel. + Should be in [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of + hue channel in HSV space in positive and negative direction respectively. + 0 means no shift. Therefore, both -0.5 and 0.5 will give an image + with complementary colors while 0 gives the original image. + Returns:: + [out] PIL Image.Image: Saturation adjusted image. + Example:: + + img = Image.open(...) + img_ = transform.adjust_hue(img, 0.1) + """ + return F_pil.adjust_hue(img, hue_factor) + + +def adjust_gamma(img, gamma, gain=1): + """ + Function for performing gamma correction on an image. + Also known as Power Law Transform. Intensities in RGB mode are adjusted + based on the following equation: + .. math:: + I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma} + See `Gamma Correction`_ for more details. + .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction + Args:: + [in] img (PIL Image.Image): Image to be adjusted. + [in] gamma (float): Non negative real number, same as :math:`\gamma` in the equation. + gamma larger than 1 make the shadows darker, + while gamma smaller than 1 make dark regions lighter. + [in] gain (float): The constant multiplier. + Returns:: + [out] PIL Image.Image: Gamma adjusted image. + """ + return F_pil.adjust_gamma(img, gamma, gain) + + + +class RandomHorizontalFlip: + """ + Random flip the image horizontally. + + Args:: + + [in] p(float): The probability of image flip, default: 0.5. + + Example:: + + transform = transform.RandomHorizontalFlip(0.6) + img_ = transform(img) + """ + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img:Image.Image): + if not isinstance(img, Image.Image): + img = to_pil_image(img) + if random.random() < self.p: + return img.transpose(Image.FLIP_LEFT_RIGHT) + return img + +class CenterCrop: + ''' + Class for cropping image centrally. + + Args:: + + [in] size(int or tuple): Size want to crop. + + Example:: + + transform = transform.CenterCrop(224) + img_ = transform(img) + ''' + def __init__(self, size): + self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values") + + def __call__(self, img:Image.Image): + if not isinstance(img, Image.Image): + img = to_pil_image(img) + width, height = img.size + return crop(img, (height - self.size[0]) / 2, (width - self.size[1]) / 2, self.size[0], self.size[1]) + +def to_tensor(pic): + """ + Function for turning Image.Image to np.array with CHW format. + + Args:: + + [in] img(Image.Image): Input image. + + Example:: + + img = Image.open(...) + img_ = transform.to_tensor(img) + """ + if isinstance(pic, jt.Var): + return pic + + if isinstance(pic, tuple): + # try convert ten crop tuple + pic = ( to_tensor(pic) for p in pic ) + pic = np.array(pic) + return pic + if not(F_pil._is_pil_image(pic) or _is_numpy(pic)): + raise TypeError(f'img should be PIL Image or ndarray. Got {type(pic)}.') + + if _is_numpy(pic) and not _is_numpy_image(pic): + raise ValueError(f'img should be 2/3 dimensional. Got {pic.ndim} dimensions.') + + if _is_numpy(pic): + # handle numpy array + if pic.ndim == 2: + pic = pic[None, :, :] + + # backward compatibility + if pic.dtype == 'uint8': + return np.float32(pic) * np.float32(1/255.0) + else: + return pic + + # handle PIL Image + if pic.mode == 'I': + img = np.array(pic, np.int32, copy=False) + elif pic.mode == 'I;16': + img = np.array(pic, np.int16, copy=False) + elif pic.mode == 'F': + img = np.array(pic, np.float32, copy=False) + elif pic.mode == '1': + img = np.array(pic, np.uint8, copy=False) * 255 + else: + img = np.array(pic, np.uint8, copy=False) + + # put it from HWC to CHW format + img = img.reshape(pic.size[1], pic.size[0], len(pic.getbands())) + img = img.transpose(2, 0, 1) + if img.dtype == 'uint8': + return np.float32(img) * np.float32(1/255.0) + else: + return img +pil_to_tensor = to_tensor + + +def _to_jittor_array(pic): + """ + Function for turning Image.Image or np.ndarray (HWC) to jt.Var (CHW). + Args:: + [in] img(PIL Image.Image or np.ndarray): Input image. + If input type is np.ndarray, the shape should be in HWC. + Return: + [out] jt.Var in shape CHW. + Example:: + + img = Image.open(...) + img_ = transform.to_tensor(img) + """ + if not(F_pil._is_pil_image(pic) or _is_numpy(pic)): + raise TypeError(f'img should be PIL Image or ndarray. Got {type(pic)}.') + + if _is_numpy(pic) and not _is_numpy_image(pic): + raise ValueError(f'img should be 2/3 dimensional. Got {pic.ndim} dimensions.') + + if _is_numpy(pic): + # handle numpy array + if pic.ndim == 2: + pic = pic[:, :, None] + + img = jt.array(pic.transpose((2, 0, 1))) + # backward compatibility + if img.dtype == 'uint8': + return img.float().divide(255) + else: + return img + + # handle PIL Image + if pic.mode == 'I': + img = jt.array(np.array(pic, np.int32, copy=False)) + elif pic.mode == 'I;16': + img = jt.array(np.array(pic, np.int16, copy=False)) + elif pic.mode == 'F': + img = jt.array(np.array(pic, np.float32, copy=False)) + elif pic.mode == '1': + img = jt.array(np.array(pic, np.uint8, copy=False) * 255, dtype='uint8') + else: + img = jt.array(np.array(pic, np.uint8, copy=False)) + + # put it from HWC to CHW format + img = img.reshape(pic.size[1], pic.size[0], len(pic.getbands())) + img = img.permute((2, 0, 1)) + if img.dtype == 'uint8': + return img.float().divide(255) + else: + return img + +def to_pil_image(pic, mode=None): + """Convert a tensor or an ndarray to PIL Image. + Args: + pic (Tensor or numpy.ndarray): Image(HWC format) to be converted to PIL Image. + mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). + .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes + Returns: + PIL Image: Image converted to PIL Image. + """ + if isinstance(pic, jt.Var): + pic = pic.data + if not isinstance(pic, np.ndarray): + raise TypeError('pic should be Tensor or ndarray. Got {}.'.format(type(pic))) + + else: + if pic.ndim not in {2, 3}: + raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim)) + + elif pic.ndim == 2: + # if 2D image, add channel dimension (HWC) + pic = np.expand_dims(pic, 2) + + npimg = pic + if 'float' in str(pic.dtype) and mode != 'F' and npimg.shape[2] != 1: + npimg = np.uint8(pic * 255) + # npimg = np.transpose(pic, (1, 2, 0)) + + if not isinstance(npimg, np.ndarray): + raise TypeError('Input pic must be a jt.Var or NumPy ndarray, ' + + 'not {}'.format(type(npimg))) + + if npimg.shape[2] == 1: + expected_mode = None + npimg = npimg[:, :, 0] + if npimg.dtype == np.uint8: + expected_mode = 'L' + elif npimg.dtype == np.int16: + expected_mode = 'I;16' + elif npimg.dtype == np.int32: + expected_mode = 'I' + elif npimg.dtype == np.float32: + expected_mode = 'F' + if mode is not None and mode != expected_mode: + raise ValueError("Incorrect mode ({}) supplied for input type {}. Should be {}" + .format(mode, np.dtype, expected_mode)) + mode = expected_mode + + elif npimg.shape[2] == 2: + permitted_2_channel_modes = ['LA'] + if mode is not None and mode not in permitted_2_channel_modes: + raise ValueError("Only modes {} are supported for 2D inputs".format(permitted_2_channel_modes)) + + if mode is None and npimg.dtype == np.uint8: + mode = 'LA' + + elif npimg.shape[2] == 4: + permitted_4_channel_modes = ['RGBA', 'CMYK', 'RGBX'] + if mode is not None and mode not in permitted_4_channel_modes: + raise ValueError("Only modes {} are supported for 4D inputs".format(permitted_4_channel_modes)) + + if mode is None and npimg.dtype == np.uint8: + mode = 'RGBA' + else: + permitted_3_channel_modes = ['RGB', 'YCbCr', 'HSV'] + if mode is not None and mode not in permitted_3_channel_modes: + raise ValueError("Only modes {} are supported for 3D inputs".format(permitted_3_channel_modes)) + if mode is None and npimg.dtype == np.uint8: + mode = 'RGB' + + if mode is None: + raise TypeError('Input type {} is not supported'.format(npimg.dtype)) + + return Image.fromarray(npimg, mode=mode) + + + +def image_normalize(img, mean, std): + """ + Function for normalizing image. + Args:: + [in] image(PIL Image.Image or np.ndarray): input image. + If type of input image is np.ndarray, it should be in shape (C, H, W). + [in] mean(list): the mean value of Normalization. + [in] std(list): the std value of Normalization. + Example:: + img = Image.open(...) + img_ = transform.image_normalize(img, mean=[0.5], std=[0.5]) + """ + if not isinstance(img, (Image.Image, jt.Var, np.ndarray)): + raise TypeError(f'Input type should be in (PIL Image, jt.Var, np.ndarray). Got {type(img)}.') + elif isinstance(img, Image.Image): + assert img.mode == 'RGB', f"input image mode should be 'RGB'. Got {img.mode}." + img = (np.array(img).transpose((2, 0, 1)) \ + - mean * np.float32(255.)) \ + / (std * np.float32(255.)) + else: + if img.ndim < 3: + raise ValueError(f'Expected input to be a array image of size (..., C, H, W). Got {img.shape}.') + if isinstance(img, jt.Var): + mean = jt.array(mean) + std = jt.array(std) + if (std.data == 0).any(): + raise ValueError('std cannot be zero.') + else: + mean = np.asarray(mean) + std = np.asarray(std) + if (std == 0).any(): + raise ValueError('std cannot be zero.') + if mean.ndim == 1: + mean = mean.reshape(-1, 1, 1) + if std.ndim == 1: + std = std.reshape(-1, 1, 1) + img = (img - mean) / std + return img + + + +class ImageNormalize: + ''' + Class for normalizing the input image. + + Args:: + + [in] mean(list): the mean value of Normalization. + [in] std(list): the std value of Normalization. + + Example:: + + transform = transform.ImageNormalize(mean=[0.5], std=[0.5]) + img_ = transform(img) + ''' + + def __init__(self, mean, std): + self.mean = np.float32(mean).reshape(-1,1,1) + self.std = np.float32(std).reshape(-1,1,1) + + def __call__(self, img): + if isinstance(img, Image.Image): + img = (np.array(img).transpose((2,0,1)) \ + - self.mean*np.float32(255.)) \ + * (np.float32(1./255.)/self.std) + else: + img = (img - self.mean) / self.std + return img + +class Compose: + ''' + Base class for combining various transformations. + + Args:: + + [in] transforms(list): a list of transform. + + Example:: + + transform = transform.Compose([ + transform.Resize(opt.img_size), + transform.Gray(), + transform.ImageNormalize(mean=[0.5], std=[0.5]), + ]) + img_ = transform(img) + ''' + def __init__(self, transforms): + self.transforms = transforms + def __call__(self, *data): + if len(data) == 1: + data = data[0] + for t in self.transforms: + data = t(data) + else: + for t in self.transforms: + data = t(*data) + return data + +class Resize: + ''' + Class for resizing image. + + Args:: + + [in] size(int or tuple): Size want to resize. [h, w] + [in] mode(int): type of resize. + + Example:: + + transform = transform.Resize(224) + img_ = transform(img) + ''' + def __init__(self, size, mode=Image.BILINEAR): + self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values") + self.mode = mode + def __call__(self, img:Image.Image): + if not isinstance(img, Image.Image): + img = to_pil_image(img) + return resize(img, self.size, self.mode) + +class Gray: + ''' + Convert image to grayscale. + + Example:: + + transform = transform.Gray() + img_ = transform(img) + ''' + def __init__(self, num_output_channels=1): + self.num_output_channels = num_output_channels + + def __call__(self, img:Image.Image): + if not isinstance(img, Image.Image): + img = to_pil_image(img) + img = np.float32(img.convert('L')) / np.float32(255.0) + if self.num_output_channels == 1: + return img[np.newaxis, :] + else: + return np.dstack([img, img, img]) + +class RandomGray: + ''' + Randomly convert image to grayscale. + Args:: + [in] p (float): probability that image should be converted to grayscale, default: 0.1 + Returns:: + [out] PIL Image: Grayscale version of the image with probability p and unchanged + with probability (1-p). + - If input image is 1 channel: grayscale version is 1 channel + - If input image is 3 channel: grayscale version is 3 channel with r == g == b + Example:: + transform = transform.Gray() + img_ = transform(img) + ''' + def __init__(self, p=0.1): + self.p = p + + def __call__(self, img:Image.Image): + if not isinstance(img, Image.Image): + img = to_pil_image(img) + num_output_channels = _get_image_num_channels(img) + if random.random() < self.p: + return gray(img, num_output_channels=num_output_channels) + return img + +class RandomCrop: + ''' + Class for randomly cropping the input image. + + Args:: + + [in] size(tuple or int): the size want to crop. + + Example:: + + transform = transform.RandomCrop(128) + img_ = transform(img) + ''' + def __init__(self, size): + self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values") + def __call__(self, img:Image.Image): + if not isinstance(img, Image.Image): + img = to_pil_image(img) + width, height = img.size + assert self.size[0] <= height and self.size[1] <= width, f"crop size exceeds the input image in RandomCrop, {(self.size, height, width)}" + top = np.random.randint(0,height-self.size[0]+1) + left = np.random.randint(0,width-self.size[1]+1) + return crop(img, top, left, self.size[0], self.size[1]) + +class Lambda: + """Apply a user-defined lambda as a transform. + Args: + lambd (function): Lambda/function to be used for transform. + """ + + def __init__(self, lambd): + assert callable(lambd), repr(type(lambd).__name__) + " object is not callable" + self.lambd = lambd + + def __call__(self, img): + return self.lambd(img) + + def __repr__(self): + return self.__class__.__name__ + '()' + + +class RandomApply: + """ + Apply randomly a list of transformations with a given probability + Args:: + [in] transforms (list or tuple): list of transformations + [in] p (float): probability + """ + + def __init__(self, transforms, p=0.5): + assert isinstance(transforms, (list, tuple)) + self.transforms = transforms + self.p = p + + def __call__(self, img): + if self.p < random.random(): + return img + for t in self.transforms: + img = t(img) + return img + + +class RandomOrder: + """ + Apply a list of transformations in a random order. + Args:: + [in] transforms (list or tuple): list of transformations + [in] p (float): probability + """ + + def __init__(self, transforms): + assert isinstance(transforms, (list, tuple)) + self.transforms = transforms + + def __call__(self, img): + order = list(range(len(self.transforms))) + random.shuffle(order) + for i in order: + img = self.transforms[i](img) + return img + + +class RandomChoice: + """ + Apply single transformation randomly picked from a list. + Args:: + [in] transforms (list or tuple): list of transformations + [in] p (float): probability + """ + + def __init__(self, transforms): + assert isinstance(transforms, (list, tuple)) + self.transforms = transforms + + def __call__(self, img): + t = random.choice(self.transforms) + return t(img) + + +class RandomVerticalFlip: + """ + Random flip the image vertically. + Args:: + [in] p(float): The probability of image flip, default: 0.5. + Example:: + transform = transform.RandomVerticalFlip(0.6) + img_ = transform(img) + """ + def __init__(self, p=0.5): + self.p = p + + def __call__(self, img:Image.Image): + if not isinstance(img, Image.Image): + img = to_pil_image(img) + if random.random() < self.p: + return vflip(img) + return img + + +class ColorJitter: + """ + Randomly change the brightness, contrast, saturation and hue of an image. + Args:: + [in] brightness (float or tuple of float (min, max)): How much to jitter brightness. + brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] + or the given [min, max]. Should be non negative numbers. + [in] contrast (float or tuple of float (min, max)): How much to jitter contrast. + contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] + or the given [min, max]. Should be non negative numbers. + [in] saturation (float or tuple of float (min, max)): How much to jitter saturation. + saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] + or the given [min, max]. Should be non negative numbers. + [in] hue (float or tuple of float (min, max)): How much to jitter hue. + hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. + Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. + """ + + def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): + self.brightness = self._check_input(brightness, 'brightness') + self.contrast = self._check_input(contrast, 'contrast') + self.saturation = self._check_input(saturation, 'saturation') + self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), + clip_first_on_zero=False) + + @staticmethod + def _check_input(value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError(f"If {name} is a single number, it must be non negative.") + value = [center - float(value), center + float(value)] + if clip_first_on_zero: + value[0] = max(value[0], 0.0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + if not bound[0] <= value[0] <= value[1] <= bound[1]: + raise ValueError(f"{name} values should be between {bound}") + else: + raise TypeError(f"{name} should be a single number or a list/tuple with length 2.") + + # if value is 0 or (1., 1.) for brightness/contrast/saturation + # or (0., 0.) for hue, do nothing + if value[0] == value[1] == center: + value = None + return value + + @staticmethod + def _get_transform(brightness, contrast, saturation, hue): + """ + Get a randomized transform to be applied on image. + Arguments are same as that of __init__. + Returns:: + Transform which randomly adjusts brightness, contrast, saturation + and hue in a random order. + """ + transforms = [] + + if brightness is not None: + brightness_factor = random.uniform(brightness[0], brightness[1]) + transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor))) + + if contrast is not None: + contrast_factor = random.uniform(contrast[0], contrast[1]) + transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor))) + + if saturation is not None: + saturation_factor = random.uniform(saturation[0], saturation[1]) + transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor))) + + if hue is not None: + hue_factor = random.uniform(hue[0], hue[1]) + transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor))) + + random.shuffle(transforms) + transform = Compose(transforms) + + return transform + + def __call__(self, img:Image.Image): + """ + Args:: + [in] img (PIL Image): Input image. + Returns:: + [out] PIL Image: Color jittered image. + """ + if not isinstance(img, Image.Image): + img = to_pil_image(img) + transform = self._get_transform(self.brightness, self.contrast, self.saturation, self.hue) + + return transform(img) + + +def _setup_size(size, error_msg): + if isinstance(size, numbers.Number): + return int(size), int(size) + + if isinstance(size, Sequence) and len(size) == 1: + return size[0], size[0] + + if len(size) != 2: + raise ValueError(error_msg) + + return size + +class ToTensor: + def __call__(self, pic): + """ + Args: + pic (PIL Image or numpy.ndarray): Image to be converted to tensor. + Returns: + Tensor: Converted image. + """ + return to_tensor(pic) + + def __repr__(self): + return self.__class__.__name__ + '()' + +class ToPILImage(object): + """Convert a tensor or an ndarray to PIL Image. + Args: + pic (Tensor or numpy.ndarray): Image(HWC format) to be converted to PIL Image. + mode (`PIL.Image mode`_): color space and pixel depth of input data (optional). + .. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes + Returns: + PIL Image: Image converted to PIL Image. + """ + def __init__(self, mode=None): + self.mode = mode + + def __call__(self, pic): + """ + Args: + pic (Tensor or numpy.ndarray): Image to be converted to PIL Image. + Returns: + PIL Image: Image converted to PIL Image. + """ + return to_pil_image(pic, self.mode) + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + if self.mode is not None: + format_string += 'mode={0}'.format(self.mode) + format_string += ')' + return format_string + + + +class RandomPerspective(object): + """Performs Perspective transformation of the given PIL Image randomly with a given probability. + + Args: + interpolation : Default- Image.BICUBIC + + p (float): probability of the image being perspectively transformed. Default value is 0.5 + + distortion_scale(float): it controls the degree of distortion and ranges from 0 to 1. Default value is 0.5. + + """ + + def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC): + self.p = p + self.interpolation = interpolation + self.distortion_scale = distortion_scale + + def __call__(self, img:Image.Image): + """ + Args: + img (PIL Image): Image to be Perspectively transformed. + + Returns: + PIL Image: Random perspectivley transformed image. + """ + if not isinstance(img, Image.Image): + img = to_pil_image(img) + + if random.random() < self.p: + width, height = img.size + startpoints, endpoints = self.get_params(width, height, self.distortion_scale) + return F_pil.perspective(img, startpoints, endpoints, self.interpolation) + return img + + @staticmethod + def get_params(width, height, distortion_scale): + """Get parameters for ``perspective`` for a random perspective transform. + + Args: + width : width of the image. + height : height of the image. + + Returns: + List containing [top-left, top-right, bottom-right, bottom-left] of the original image, + List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image. + """ + half_height = int(height / 2) + half_width = int(width / 2) + topleft = (random.randint(0, int(distortion_scale * half_width)), + random.randint(0, int(distortion_scale * half_height))) + topright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1), + random.randint(0, int(distortion_scale * half_height))) + botright = (random.randint(width - int(distortion_scale * half_width) - 1, width - 1), + random.randint(height - int(distortion_scale * half_height) - 1, height - 1)) + botleft = (random.randint(0, int(distortion_scale * half_width)), + random.randint(height - int(distortion_scale * half_height) - 1, height - 1)) + startpoints = [(0, 0), (width - 1, 0), (width - 1, height - 1), (0, height - 1)] + endpoints = [topleft, topright, botright, botleft] + return startpoints, endpoints + + def __repr__(self): + return self.__class__.__name__ + '(p={})'.format(self.p) + +class RandomResizedCrop(object): + """Crop the given PIL Image to random size and aspect ratio. + + A crop of random size (default: of 0.08 to 1.0) of the original size and a random + aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop + is finally resized to given size. + This is popularly used to train the Inception networks. + + Args: + size: expected output size of each edge + scale: range of size of the origin size cropped + ratio: range of aspect ratio of the origin aspect ratio cropped + interpolation: Default: PIL.Image.BILINEAR + """ + + def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): + if isinstance(size, tuple): + self.size = size + else: + self.size = (size, size) + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("range should be of kind (min, max)") + + self.interpolation = interpolation + self.scale = scale + self.ratio = ratio + + @staticmethod + def get_params(img, scale, ratio): + """Get parameters for ``crop`` for a random sized crop. + + Args: + img (PIL Image): Image to be cropped. + scale (tuple): range of size of the origin size cropped + ratio (tuple): range of aspect ratio of the origin aspect ratio cropped + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for a random + sized crop. + """ + width, height = _get_image_size(img) + area = height * width + + for attempt in range(10): + target_area = random.uniform(*scale) * area + log_ratio = (math.log(ratio[0]), math.log(ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) + return i, j, h, w + + # Fallback to central crop + in_ratio = float(width) / float(height) + if (in_ratio < min(ratio)): + w = width + h = int(round(w / min(ratio))) + elif (in_ratio > max(ratio)): + h = height + w = int(round(h * max(ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + return i, j, h, w + + def __call__(self, img:Image.Image): + """ + Args: + img (PIL Image): Image to be cropped and resized. + + Returns: + PIL Image: Randomly cropped and resized image. + """ + if not isinstance(img, Image.Image): + img = to_pil_image(img) + i, j, h, w = self.get_params(img, self.scale, self.ratio) + return F_pil.resized_crop(img, i, j, h, w, self.size, self.interpolation) + + def __repr__(self): + interpolate_str = str(self.interpolation) + format_string = self.__class__.__name__ + '(size={0}'.format(self.size) + format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) + format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) + format_string += ', interpolation={0})'.format(interpolate_str) + return format_string + + +RandomSizedCrop = RandomResizedCrop + + +class FiveCrop(object): + """Crop the given PIL Image into four corners and the central crop + + .. Note:: + This transform returns a tuple of images and there may be a mismatch in the number of + inputs and targets your Dataset returns. See below for an example of how to deal with + this. + + Args: + size (sequence or int): Desired output size of the crop. If size is an ``int`` + instead of sequence like (h, w), a square crop of size (size, size) is made. + + Example: + >>> transform = Compose([ + >>> FiveCrop(size), # this is a list of PIL Images + >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor + >>> ]) + >>> #In your test loop you can do the following: + >>> input, target = batch # input is a 5d tensor, target is 2d + >>> bs, ncrops, c, h, w = input.size() + >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops + >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops + """ + + def __init__(self, size): + self.size = size + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + self.size = size + + def __call__(self, img:Image.Image): + if not isinstance(img, Image.Image): + img = to_pil_image(img) + return F_pil.five_crop(img, self.size) + + def __repr__(self): + return self.__class__.__name__ + '(size={0})'.format(self.size) + + +class TenCrop(object): + """Crop the given PIL Image into four corners and the central crop plus the flipped version of + these (horizontal flipping is used by default) + + .. Note:: + This transform returns a tuple of images and there may be a mismatch in the number of + inputs and targets your Dataset returns. See below for an example of how to deal with + this. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + vertical_flip (bool): Use vertical flipping instead of horizontal + + Example: + >>> transform = Compose([ + >>> TenCrop(size), # this is a list of PIL Images + >>> Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor + >>> ]) + >>> #In your test loop you can do the following: + >>> input, target = batch # input is a 5d tensor, target is 2d + >>> bs, ncrops, c, h, w = input.size() + >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops + >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops + """ + + def __init__(self, size, vertical_flip=False): + self.size = size + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + self.size = size + self.vertical_flip = vertical_flip + + def __call__(self, img:Image.Image): + if not isinstance(img, Image.Image): + img = to_pil_image(img) + return F_pil.ten_crop(img, self.size, self.vertical_flip) + + def __repr__(self): + return self.__class__.__name__ + '(size={0}, vertical_flip={1})'.format(self.size, self.vertical_flip) + + +class RandomRotation(object): + """Rotate the image by angle. + + Args: + degrees (sequence or float or int): Range of degrees to select from. + If degrees is a number instead of sequence like (min, max), the range of degrees + will be (-degrees, +degrees). + resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): + An optional resampling filter. See `filters`_ for more information. + If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. + expand (bool, optional): Optional expansion flag. + If true, expands the output to make it large enough to hold the entire rotated image. + If false or omitted, make the output image the same size as the input image. + Note that the expand flag assumes rotation around the center and no translation. + center (2-tuple, optional): Optional center of rotation. + Origin is the upper left corner. + Default is the center of the image. + fill (n-tuple or int or float): Pixel fill value for area outside the rotated + image. If int or float, the value is used for all bands respectively. + Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + + def __init__(self, degrees, resample=False, expand=False, center=None, fill=None): + if isinstance(degrees, numbers.Number): + if degrees < 0: + raise ValueError("If degrees is a single number, it must be positive.") + self.degrees = (-degrees, degrees) + else: + if len(degrees) != 2: + raise ValueError("If degrees is a sequence, it must be of len 2.") + self.degrees = degrees + + self.resample = resample + self.expand = expand + self.center = center + self.fill = fill + + @staticmethod + def get_params(degrees): + """Get parameters for ``rotate`` for a random rotation. + + Returns: + sequence: params to be passed to ``rotate`` for random rotation. + """ + angle = random.uniform(degrees[0], degrees[1]) + + return angle + + def __call__(self, img:Image.Image): + """ + Args: + img (PIL Image): Image to be rotated. + + Returns: + PIL Image: Rotated image. + """ + if not isinstance(img, Image.Image): + img = to_pil_image(img) + angle = self.get_params(self.degrees) + + return F_pil.rotate(img, angle, self.resample, self.expand, self.center, self.fill) + + def __repr__(self): + format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees) + format_string += ', resample={0}'.format(self.resample) + format_string += ', expand={0}'.format(self.expand) + if self.center is not None: + format_string += ', center={0}'.format(self.center) + format_string += ')' + return format_string + + +class RandomAffine(object): + """Random affine transformation of the image keeping center invariant + + Args: + degrees (sequence or float or int): Range of degrees to select from. + If degrees is a number instead of sequence like (min, max), the range of degrees + will be (-degrees, +degrees). Set to 0 to deactivate rotations. + translate (tuple, optional): tuple of maximum absolute fraction for horizontal + and vertical translations. For example translate=(a, b), then horizontal shift + is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is + randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default. + scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is + randomly sampled from the range a <= scale <= b. Will keep original scale by default. + shear (sequence or float or int, optional): Range of degrees to select from. + If shear is a number, a shear parallel to the x axis in the range (-shear, +shear) + will be apllied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the + range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values, + a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. + Will not apply shear by default + resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional): + An optional resampling filter. See `filters`_ for more information. + If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. + fillcolor (tuple or int): Optional fill color (Tuple for RGB Image And int for grayscale) for the area + outside the transform in the output image.(Pillow>=5.0.0) + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + + def __init__(self, degrees, translate=None, scale=None, shear=None, resample=False, fillcolor=0): + if isinstance(degrees, numbers.Number): + if degrees < 0: + raise ValueError("If degrees is a single number, it must be positive.") + self.degrees = (-degrees, degrees) + else: + assert isinstance(degrees, (tuple, list)) and len(degrees) == 2, \ + "degrees should be a list or tuple and it must be of length 2." + self.degrees = degrees + + if translate is not None: + assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ + "translate should be a list or tuple and it must be of length 2." + for t in translate: + if not (0.0 <= t <= 1.0): + raise ValueError("translation values should be between 0 and 1") + self.translate = translate + + if scale is not None: + assert isinstance(scale, (tuple, list)) and len(scale) == 2, \ + "scale should be a list or tuple and it must be of length 2." + for s in scale: + if s <= 0: + raise ValueError("scale values should be positive") + self.scale = scale + + if shear is not None: + if isinstance(shear, numbers.Number): + if shear < 0: + raise ValueError("If shear is a single number, it must be positive.") + self.shear = (-shear, shear) + else: + assert isinstance(shear, (tuple, list)) and \ + (len(shear) == 2 or len(shear) == 4), \ + "shear should be a list or tuple and it must be of length 2 or 4." + # X-Axis shear with [min, max] + if len(shear) == 2: + self.shear = [shear[0], shear[1], 0., 0.] + elif len(shear) == 4: + self.shear = [s for s in shear] + else: + self.shear = shear + + self.resample = resample + self.fillcolor = fillcolor + + @staticmethod + def get_params(degrees, translate, scale_ranges, shears, img_size): + """Get parameters for affine transformation + + Returns: + sequence: params to be passed to the affine transformation + """ + angle = random.uniform(degrees[0], degrees[1]) + if translate is not None: + max_dx = translate[0] * img_size[0] + max_dy = translate[1] * img_size[1] + translations = (np.round(random.uniform(-max_dx, max_dx)), + np.round(random.uniform(-max_dy, max_dy))) + else: + translations = (0, 0) + + if scale_ranges is not None: + scale = random.uniform(scale_ranges[0], scale_ranges[1]) + else: + scale = 1.0 + + if shears is not None: + if len(shears) == 2: + shear = [random.uniform(shears[0], shears[1]), 0.] + elif len(shears) == 4: + shear = [random.uniform(shears[0], shears[1]), + random.uniform(shears[2], shears[3])] + else: + shear = 0.0 + + return angle, translations, scale, shear + + def __call__(self, img:Image.Image): + """ + img (PIL Image): Image to be transformed. + + Returns: + PIL Image: Affine transformed image. + """ + if not isinstance(img, Image.Image): + img = to_pil_image(img) + ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size) + return F_pil.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor) + + def __repr__(self): + s = '{name}(degrees={degrees}' + if self.translate is not None: + s += ', translate={translate}' + if self.scale is not None: + s += ', scale={scale}' + if self.shear is not None: + s += ', shear={shear}' + if self.resample > 0: + s += ', resample={resample}' + if self.fillcolor != 0: + s += ', fillcolor={fillcolor}' + s += ')' + d = dict(self.__dict__) + d['resample'] = str(d['resample']) + return s.format(name=self.__class__.__name__, **d) diff --git a/python/jittor/transform/function_pil.py b/python/jittor/transform/function_pil.py new file mode 100644 index 00000000..b26f9551 --- /dev/null +++ b/python/jittor/transform/function_pil.py @@ -0,0 +1,649 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. +# All Rights Reserved. +# Maintainers: +# Dun Liang . +# +# Contributors: +# Xin Yao +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +from typing import Sequence +from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION +import numpy as np +import numbers +import math +from math import cos, sin, tan + + +def _is_pil_image(img): + return isinstance(img, Image.Image) + + +def _get_image_size(img): + if _is_pil_image(img): + return img.size + raise TypeError(f"Unexpected type {type(img)}") + + +def _get_image_num_channels(img): + if _is_pil_image(img): + return 1 if img.mode == 'L' else 3 + raise TypeError(f"Unexpected type {type(img)}") + + +def hflip(img): + """ + Function for horizontally flipping the given image. + + Args:: + + [in] img(PIL Image.Image): Input image. + + Example:: + + img = Image.open(...) + img_ = transform.hflip(img) + """ + if not _is_pil_image(img): + raise TypeError(f'img should be PIL Image. Got {type(img)}') + + return img.transpose(Image.FLIP_LEFT_RIGHT) + + +def vflip(img): + """ + Function for vertically flipping the given image. + + Args:: + + [in] img(PIL Image.Image): Input image. + + Example:: + + img = Image.open(...) + img_ = transform.vflip(img) + """ + if not _is_pil_image(img): + raise TypeError(f'img should be PIL Image. Got {type(img)}') + + return img.transpose(Image.FLIP_TOP_BOTTOM) + + +def adjust_brightness(img, brightness_factor): + """ + Function for adjusting brightness of an RGB image. + + Args:: + + [in] img (PIL Image.Image): Image to be adjusted. + [in] brightness_factor (float): How much to adjust the brightness. + Can be any non negative number. 0 gives a black image, 1 gives the + original image while 2 increases the brightness by a factor of 2. + + Returns:: + + [out] PIL Image.Image: Brightness adjusted image. + + Example:: + + img = Image.open(...) + img_ = transform.adjust_brightness(img, 0.5) + """ + if not _is_pil_image(img): + raise TypeError(f'img should be PIL Image. Got {type(img)}') + + enhancer = ImageEnhance.Brightness(img) + img = enhancer.enhance(brightness_factor) + return img + + +def adjust_contrast(img, contrast_factor): + """ + Function for adjusting contrast of an image. + + Args:: + + [in] img (PIL Image.Image): Image to be adjusted. + [in] contrast_factor (float): How much to adjust the contrast. + Can be any non negative number. 0 gives a solid gray image, + 1 gives the original image while 2 increases the contrast by a factor of 2. + + Returns:: + + [out] PIL Image.Image: Contrast adjusted image. + + Example:: + + img = Image.open(...) + img_ = transform.adjust_contrast(img, 0.5) + """ + if not _is_pil_image(img): + raise TypeError(f'img should be PIL Image. Got {type(img)}') + + enhancer = ImageEnhance.Contrast(img) + img = enhancer.enhance(contrast_factor) + return img + + +def adjust_saturation(img, saturation_factor): + """ + Function for adjusting saturation of an image. + + Args:: + + [in] img (PIL Image.Image): Image to be adjusted. + [in] saturation_factor (float): How much to adjust the saturation. + 0 will give a black and white image, 1 will give the original image + while 2 will enhance the saturation by a factor of 2. + + Returns:: + + [out] PIL Image.Image: Saturation adjusted image. + + Example:: + + img = Image.open(...) + img_ = transform.adjust_saturation(img, 0.5) + """ + if not _is_pil_image(img): + raise TypeError(f'img should be PIL Image. Got {type(img)}') + + enhancer = ImageEnhance.Color(img) + img = enhancer.enhance(saturation_factor) + return img + + +def adjust_hue(img, hue_factor): + """ + Function for adjusting hue of an image. + + The image hue is adjusted by converting the image to HSV and + cyclically shifting the intensities in the hue channel (H). + The image is then converted back to original image mode. + + `hue_factor` is the amount of shift in H channel and must be in the + interval `[-0.5, 0.5]`. + + See `Hue`_ for more details. + + .. _Hue: https://en.wikipedia.org/wiki/Hue + + Args:: + + [in] img (PIL Image.Image): Image to be adjusted. + [in] hue_factor (float): How much to shift the hue channel. + Should be in [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of + hue channel in HSV space in positive and negative direction respectively. + 0 means no shift. Therefore, both -0.5 and 0.5 will give an image + with complementary colors while 0 gives the original image. + + Returns:: + + [out] PIL Image.Image: Saturation adjusted image. + + Example:: + + img = Image.open(...) + img_ = transform.adjust_hue(img, 0.1) + """ + if not(-0.5 <= hue_factor <= 0.5): + raise ValueError(f'hue_factor ({hue_factor}) is not in [-0.5, 0.5].') + + if not _is_pil_image(img): + raise TypeError(f'img should be PIL Image. Got {type(img)}') + + input_mode = img.mode + if input_mode in {'L', '1', 'I', 'F'}: + return img + + h, s, v = img.convert('HSV').split() + + np_h = np.array(h, dtype=np.uint8) + # uint8 addition take cares of rotation across boundaries + with np.errstate(over='ignore'): + np_h += np.uint8(hue_factor * 255) + h = Image.fromarray(np_h, 'L') + + img = Image.merge('HSV', (h, s, v)).convert(input_mode) + return img + + +def adjust_gamma(img, gamma, gain=1): + """ + Function for performing gamma correction on an image. + + Also known as Power Law Transform. Intensities in RGB mode are adjusted + based on the following equation: + + .. math:: + I_{\text{out}} = 255 \times \text{gain} \times \left(\frac{I_{\text{in}}}{255}\right)^{\gamma} + + See `Gamma Correction`_ for more details. + + .. _Gamma Correction: https://en.wikipedia.org/wiki/Gamma_correction + + Args:: + + [in] img (PIL Image.Image): Image to be adjusted. + [in] gamma (float): Non negative real number, same as :math:`\gamma` in the equation. + gamma larger than 1 make the shadows darker, + while gamma smaller than 1 make dark regions lighter. + [in] gain (float): The constant multiplier. + + Returns:: + + [out] PIL Image.Image: Gamma adjusted image. + """ + if not _is_pil_image(img): + raise TypeError(f'img should be PIL Image. Got {type(img)}') + + if gamma < 0: + raise ValueError('Gamma should be a non-negative real number') + + input_mode = img.mode + img = img.convert('RGB') + gamma_map = [int((255 + 1 - 1e-3) * gain * pow(ele / 255., gamma)) for ele in range(256)] * 3 + img = img.point(gamma_map) # use PIL's point-function to accelerate this part + + img = img.convert(input_mode) + return img + + +def crop(img, top, left, height, width): + """ + Function for cropping image. + + Args:: + + [in] img(PIL Image.Image): Input image. + [in] top(int): the top boundary of the cropping box. + [in] left(int): the left boundary of the cropping box. + [in] height(int): height of the cropping box. + [in] width(int): width of the cropping box. + + Returns:: + + [out] PIL Image.Image: Cropped image. + + Example:: + + img = Image.open(...) + img_ = transform.crop(img, 10, 10, 100, 100) + """ + if not _is_pil_image(img): + raise TypeError(f'img should be PIL Image. Got {type(img)}') + + return img.crop((left, top, left + width, top + height)) + + +def resize(img, size, interpolation=Image.BILINEAR): + """ + Function for resizing the input image to the given size. + + Args:: + + [in] img(PIL Image.Image): Input image. + [in] size(sequence or int): Desired output size. If size is a sequence like + (h, w), the output size will be matched to this. If size is an int, + the smaller edge of the image will be matched to this number maintaining + the aspect ratio. If a tuple or list of length 1 is provided, it is + interpreted as a single int. + [in] interpolation(int, optional): type of interpolation. default: PIL.Image.BILINEAR + + Returns:: + + [out] PIL Image.Image: Resized image. + + Example:: + + img = Image.open(...) + img_ = transform.resize(img, (100, 100)) + """ + if not _is_pil_image(img): + raise TypeError(f'img should be PIL Image. Got {type(img)}') + if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))): + raise TypeError(f'Got inappropriate size arg: {size}') + + if isinstance(size, int) or len(size) == 1: + if isinstance(size, Sequence): + size = size[0] + w, h = img.size + if (w <= h and w == size) or (h <= w and h == size): + return img + if w < h: + ow = size + oh = int(size * h / w) + return img.resize((ow, oh), interpolation) + else: + oh = size + ow = int(size * w / h) + return img.resize((ow, oh), interpolation) + else: + return img.resize(size[::-1], interpolation) + + +def gray(img, num_output_channels): + """ + Function for converting PIL image of any mode (RGB, HSV, LAB, etc) to grayscale version of image. + + Args:: + + [in] img(PIL Image.Image): Input image. + [in] num_output_channels (int): number of channels of the output image. Value can be 1 or 3. Default, 1. + + Returns:: + + [out] PIL Image: Grayscale version of the image. + if num_output_channels = 1 : returned image is single channel + if num_output_channels = 3 : returned image is 3 channel with r = g = b + """ + if not _is_pil_image(img): + raise TypeError(f'img should be PIL Image. Got {type(img)}') + + if num_output_channels == 1: + img = img.convert('L') + elif num_output_channels == 3: + img = img.convert('L') + np_img = np.array(img, dtype=np.uint8) + np_img = np.dstack([np_img, np_img, np_img]) + img = Image.fromarray(np_img, 'RGB') + else: + raise ValueError('num_output_channels should be either 1 or 3') + + return img + +def _get_perspective_coeffs(startpoints, endpoints): + """Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms. + + In Perspective Transform each pixel (x, y) in the orignal image gets transformed as, + (x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) ) + + Args: + List containing [top-left, top-right, bottom-right, bottom-left] of the orignal image, + List containing [top-left, top-right, bottom-right, bottom-left] of the transformed + image + Returns: + octuple (a, b, c, d, e, f, g, h) for transforming each pixel. + """ + matrix = [] + + for p1, p2 in zip(endpoints, startpoints): + matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]]) + matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]]) + + A = np.array(matrix, dtype="float") + B = np.array(startpoints, dtype="float").reshape(8) + res = np.linalg.lstsq(A, B, rcond=-1)[0] + return res.tolist() + + +def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC): + """Perform perspective transform of the given PIL Image. + + Args: + img (PIL Image): Image to be transformed. + startpoints: List containing [top-left, top-right, bottom-right, bottom-left] of the orignal image + endpoints: List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image + interpolation: Default- Image.BICUBIC + Returns: + PIL Image: Perspectively transformed Image. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + coeffs = _get_perspective_coeffs(startpoints, endpoints) + return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation) + + +def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINEAR): + """Crop the given PIL Image and resize it to desired size. + + Notably used in :class:`~torchvision.transforms.RandomResizedCrop`. + + Args: + img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image. + top (int): Vertical component of the top left corner of the crop box. + left (int): Horizontal component of the top left corner of the crop box. + height (int): Height of the crop box. + width (int): Width of the crop box. + size (sequence or int): Desired output size. Same semantics as ``resize``. + interpolation (int, optional): Desired interpolation. Default is + ``PIL.Image.BILINEAR``. + Returns: + PIL Image: Cropped image. + """ + assert _is_pil_image(img), 'img should be PIL Image' + img = crop(img, top, left, height, width) + img = resize(img, size, interpolation) + return img + +def center_crop(img, output_size): + """Crop the given PIL Image and resize it to desired size. + + Args: + img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image. + output_size (sequence or int): (height, width) of the crop box. If int, + it is used for both directions + Returns: + PIL Image: Cropped image. + """ + if isinstance(output_size, numbers.Number): + output_size = (int(output_size), int(output_size)) + image_width, image_height = img.size + crop_height, crop_width = output_size + crop_top = int(round((image_height - crop_height) / 2.)) + crop_left = int(round((image_width - crop_width) / 2.)) + return crop(img, crop_top, crop_left, crop_height, crop_width) + +def five_crop(img, size): + """Crop the given PIL Image into four corners and the central crop. + + .. Note:: + This transform returns a tuple of images and there may be a + mismatch in the number of inputs and targets your ``Dataset`` returns. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + + Returns: + tuple: tuple (tl, tr, bl, br, center) + Corresponding top left, top right, bottom left, bottom right and center crop. + """ + if isinstance(size, numbers.Number): + size = (int(size), int(size)) + else: + assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + + image_width, image_height = img.size + crop_height, crop_width = size + if crop_width > image_width or crop_height > image_height: + msg = "Requested crop size {} is bigger than input size {}" + raise ValueError(msg.format(size, (image_height, image_width))) + + tl = img.crop((0, 0, crop_width, crop_height)) + tr = img.crop((image_width - crop_width, 0, image_width, crop_height)) + bl = img.crop((0, image_height - crop_height, crop_width, image_height)) + br = img.crop((image_width - crop_width, image_height - crop_height, + image_width, image_height)) + center = center_crop(img, (crop_height, crop_width)) + return (tl, tr, bl, br, center) + +def ten_crop(img, size, vertical_flip=False): + r"""Crop the given PIL Image into four corners and the central crop plus the + flipped version of these (horizontal flipping is used by default). + + .. Note:: + This transform returns a tuple of images and there may be a + mismatch in the number of inputs and targets your ``Dataset`` returns. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + vertical_flip (bool): Use vertical flipping instead of horizontal + + Returns: + tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip) + Corresponding top left, top right, bottom left, bottom right and center crop + and same for the flipped image. + """ + if isinstance(size, numbers.Number): + size = (int(size), int(size)) + else: + assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + + first_five = five_crop(img, size) + + if vertical_flip: + img = vflip(img) + else: + img = hflip(img) + + second_five = five_crop(img, size) + return first_five + second_five + + +def rotate(img, angle, resample=False, expand=False, center=None, fill=None): + """Rotate the image by angle. + + + Args: + img (PIL Image): PIL Image to be rotated. + angle (float or int): In degrees degrees counter clockwise order. + resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): + An optional resampling filter. See `filters`_ for more information. + If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. + expand (bool, optional): Optional expansion flag. + If true, expands the output image to make it large enough to hold the entire rotated image. + If false or omitted, make the output image the same size as the input image. + Note that the expand flag assumes rotation around the center and no translation. + center (2-tuple, optional): Optional center of rotation. + Origin is the upper left corner. + Default is the center of the image. + fill (n-tuple or int or float): Pixel fill value for area outside the rotated + image. If int or float, the value is used for all bands respectively. + Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. + + .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters + + """ + def parse_fill(fill, num_bands): + if PILLOW_VERSION < "5.2.0": + if fill is None: + return {} + else: + msg = ("The option to fill background area of the rotated image, " + "requires pillow>=5.2.0") + raise RuntimeError(msg) + + if fill is None: + fill = 0 + if isinstance(fill, (int, float)) and num_bands > 1: + fill = tuple([fill] * num_bands) + if not isinstance(fill, (int, float)) and len(fill) != num_bands: + msg = ("The number of elements in 'fill' does not match the number of " + "bands of the image ({} != {})") + raise ValueError(msg.format(len(fill), num_bands)) + + return {"fillcolor": fill} + + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + opts = parse_fill(fill, len(img.getbands())) + + return img.rotate(angle, resample, expand, center, **opts) + + +def _get_inverse_affine_matrix(center, angle, translate, scale, shear): + # Helper method to compute inverse matrix for affine transformation + + # As it is explained in PIL.Image.rotate + # We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1 + # where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1] + # C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1] + # RSS is rotation with scale and shear matrix + # RSS(a, s, (sx, sy)) = + # = R(a) * S(s) * SHy(sy) * SHx(sx) + # = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(x)/cos(y) - sin(a)), 0 ] + # [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(x)/cos(y) + cos(a)), 0 ] + # [ 0 , 0 , 1 ] + # + # where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears: + # SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0] + # [0, 1 ] [-tan(s), 1] + # + # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1 + + if isinstance(shear, numbers.Number): + shear = [shear, 0] + + if not isinstance(shear, (tuple, list)) and len(shear) == 2: + raise ValueError( + "Shear should be a single value or a tuple/list containing " + + "two values. Got {}".format(shear)) + + rot = math.radians(angle) + sx, sy = [math.radians(s) for s in shear] + + cx, cy = center + tx, ty = translate + + # RSS without scaling + a = cos(rot - sy) / cos(sy) + b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot) + c = sin(rot - sy) / cos(sy) + d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot) + + # Inverted rotation matrix with scale and shear + # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 + M = [d, -b, 0, + -c, a, 0] + M = [x / scale for x in M] + + # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 + M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty) + M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty) + + # Apply center translation: C * RSS^-1 * C^-1 * T^-1 + M[2] += cx + M[5] += cy + return M + + +def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None): + """Apply affine transformation on the image keeping image center invariant + + Args: + img (PIL Image): PIL Image to be rotated. + angle (float or int): rotation angle in degrees between -180 and 180, clockwise direction. + translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation) + scale (float): overall scale + shear (float or tuple or list): shear angle value in degrees between -180 to 180, clockwise direction. + If a tuple of list is specified, the first value corresponds to a shear parallel to the x axis, while + the second value corresponds to a shear parallel to the y axis. + resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): + An optional resampling filter. + See `filters`_ for more information. + If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. + fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ + "Argument translate should be a list or tuple of length 2" + + assert scale > 0.0, "Argument scale should be positive" + + output_size = img.size + center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5) + matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) + kwargs = {"fillcolor": fillcolor} if PILLOW_VERSION[0] >= '5' else {} + return img.transform(output_size, Image.AFFINE, matrix, resample, **kwargs) + diff --git a/python/jittor/utils/asm_tuner.py b/python/jittor/utils/asm_tuner.py new file mode 100755 index 00000000..dbece06d --- /dev/null +++ b/python/jittor/utils/asm_tuner.py @@ -0,0 +1,186 @@ +#!/usr/bin/python3 +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guowei Yang <471184555@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import sys +import os +os.environ["log_silent"] = "1" +import re +import jittor_utils as jit_utils +from jittor_utils import LOG +jit_utils.try_import_jit_utils_core(silent=True) + +def my_split(str): + res=[] + last=-1 + for i in range(len(str)): + if str[i]==" " or str[i]=="\t": + if last>>(out0_p); + ''').sync() + +avg_ns = float(rep[1][4]) / n +print("kernel launch overhead(ns):", avg_ns) diff --git a/python/jittor/utils/converter_server.py b/python/jittor/utils/converter_server.py new file mode 100644 index 00000000..7bc5b9ba --- /dev/null +++ b/python/jittor/utils/converter_server.py @@ -0,0 +1,27 @@ +from flask import Flask +from flask import request +from flask import jsonify +app = Flask(__name__) +import json + +from jittor.utils.pytorch_converter import convert + +@app.route('/', methods=["GET", "POST"]) +def hello(): + msg = request + data = msg.data.decode("utf-8") + try: + data = json.loads(data) + src = data["src"] + pjmap = json.loads(data["pjmap"]) + jt_src = convert(src, pjmap) + except Exception as e: + jt_src = str(e) + response = jsonify(jt_src=jt_src) + + # Enable Access-Control-Allow-Origin + response.headers.add("Access-Control-Allow-Origin", "*") + return response + +if __name__ == '__main__': + app.run(host="0.0.0.0") \ No newline at end of file diff --git a/python/jittor/utils/data.gz b/python/jittor/utils/data.gz new file mode 100644 index 00000000..d2974a4e Binary files /dev/null and b/python/jittor/utils/data.gz differ diff --git a/python/jittor/utils/dlink_compiler.py b/python/jittor/utils/dlink_compiler.py new file mode 100644 index 00000000..ba1b9ffb --- /dev/null +++ b/python/jittor/utils/dlink_compiler.py @@ -0,0 +1,26 @@ +import sys +import os +import re +cmds = sys.argv[1:] +def replace(cmds, s, t): + return [ c.replace(s,t) for c in cmds ] +def remove(cmds, ss): + rets = [] + for cmd in cmds: + found = True + for s in ss: + if s in cmd: + found = False + break + if found: + rets.append(cmd) + return rets + +cmds1 = remove(cmds, [".o"]) +cmds1 = replace(cmds1, ".so", ".o") +cmds2 = replace(cmds, "-dc", "") +cmds2 = replace(cmds2, ".cu", ".o") +ret = os.system(" ".join(cmds1).replace("-x cu", "")) +if ret: exit(ret) +ret = os.system(" ".join(cmds2).replace("-x cu", "")) +if ret: exit(ret) \ No newline at end of file diff --git a/python/jittor/utils/dumpdef.py b/python/jittor/utils/dumpdef.py new file mode 100644 index 00000000..2d34c9f5 --- /dev/null +++ b/python/jittor/utils/dumpdef.py @@ -0,0 +1,41 @@ +import os +import sys +import subprocess as sp + +def_path = sys.argv[-1] + +# print(sys.argv) +dumpbin_path = os.environ.get("dumpbin_path", "dumpbin") +export_all = os.environ.get("EXPORT_ALL", "0")=="1" + +syms = {} + +for obj in sys.argv[1:-2]: + cmd = f'"{dumpbin_path}" -SYMBOLS "{obj}"' + ret = sp.getoutput(cmd) + # print(ret) + for l in ret.splitlines(): + if '|' in l: + if "UNDEF" in l: continue + if "External" not in l: continue + sym = l.split('|')[1].strip().split()[0] + if sym[0] in '@.': continue + if sym.startswith("??$get_from_env"): syms[sym] = 1 + # if sym.startswith("??"): continue + if sym.startswith("my"): syms[sym] = 1 + # for cutt + if "custom_cuda" in sym: syms[sym] = 1 + if "cutt" in sym: syms[sym] = 1 + if "_cudaGetErrorEnum" in sym: syms[sym] = 1 + if export_all: syms[sym] = 1 + if "jittor" not in sym: continue + syms[sym] = 1 + # print(ret) +libname = os.path.basename(def_path).rsplit(".", 1)[0] +src = f"LIBRARY {libname}\nEXPORTS\n" +for k in syms: + src += f" {k}\n" +# print(src) + +with open(def_path, "w", encoding="utf8") as f: + f.write(src) diff --git a/python/jittor/utils/gen_pyi.py b/python/jittor/utils/gen_pyi.py new file mode 100644 index 00000000..73ff8869 --- /dev/null +++ b/python/jittor/utils/gen_pyi.py @@ -0,0 +1,263 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Zheng-Ning Liu +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** + +""" This file implements generation of stub files for Jittor C extensions. + +In detail, autocompletion of the following functions are supported. +- functions in __init__.py +- functions in jittor.core.ops +- attributes of jittor.flags +- methods of jittor.Var + +Prerequisite: +- mypy for automatic stub generation, installation: pip install mypy + +Usage: python3 -m jittor.utils.gen_pyi + +""" + +import os +import re +import shutil +import jittor + +def add_indent(s: str, n=1): + for _ in range(n): + s = '\t' + s.replace('\n', '\n\t', s.count('\n')-1) + return s + +def ctype_to_python(type_str): + if type_str == "bool": + return "bool" + if type_str in ["int", "uint", "uint8", "int64", "uint64", "size_t"]: + return "int" + if type_str in ["float32", "float64"]: + return "float" + if type_str in ["string", "string&&", "NanoString", "char*", "const char*"]: + return "str" + if type_str in ["vector"]: + return "List[int]" + if type_str in ["vector&&", "vector&&"]: + return "List[str]" + if type_str == "VarHolder*": + return "Var" + if type_str in ["vector", "vector&&"]: + return "List[Var]" + if type_str in ["vector_to_tuple"]: + return "Tuple[Var]" + if type_str == "NanoVector": + return "Tuple[int]" + if type_str == "vector&&": + return "List[Tuple[int]]" + if type_str in ["FetchFunc", "FetchFunc&&", "NumpyFunc&&"]: + return "Callable" + if type_str == "vector&&": + return "List[Callable]" + if type_str == "PyObject*": + return "float | int | numpy.ndarray | Var" + if type_str == "VarSlices&&": + return "slice" + if type_str in ["ArrayArgs", "ArrayArgs&&", "DataView"]: + return "numpy.ndarray" + if type_str == 'ItemData': + return "float | int | bool" + if type_str == "void": + return "" + print(f"[warning] Unknown ctype: {type_str}, do not write type hinting") + return "" + +def cval_to_python(val_str: str): + if val_str == "false": + return "False" + if val_str == "true": + return "True" + if val_str.startswith("ns_"): + return f'"{val_str[3:]}"' + if val_str == "NanoVector()": + return "()" + return val_str + + +def run_stubgen(jittor_path, cache_path): + + # for __init__.py functions + stubpath = os.path.join(cache_path, 'stubs') + stubfile = os.path.join(stubpath, "jittor", "__init__.pyi") + os.system(f"stubgen -m jittor -o {stubpath} -q") + with open(stubfile) as f: + mypy_content = f.read() + + f = open(stubfile, "w") + # Remove the follow type redirection + unused_content = ["ori_int = int\n", + "ori_float = float\n", + "ori_bool = bool\n", + "int = int32\n", + "float = float32\n", + "double = float64\n", + "\nflags: Any\n"] + for unused in unused_content: + mypy_content = mypy_content.replace(unused, "") + f.write(mypy_content) + + shutil.move(stubfile, os.path.join(jittor_path, "__init__.pyi")) + shutil.rmtree(stubpath) + shutil.rmtree(os.path.expanduser(".mypy_cache")) + +def gen_ops_stub(jittor_path): + f = open(os.path.join(jittor_path, "__init__.pyi"), "a") + f.write("from typing import List, Tuple, Callable, overload\n") + f.write("import numpy\n") + + var_hint = "class Var:\n\t'''Variable that stores multi-dimensional data.'''\n" + var_methods = set() + + def decl_to_param_hints(decl): + param_decl = re.findall(r".+ [a-zA-Z_0-9]+\((.*)\)", decl)[0] + if not param_decl.strip(): + return [] + param_hints = [] + for param_str in param_decl.split(','): + if "=" in param_str: + template = r"\s*(.+)\s+([a-zA-Z_0-9]+)\s*=\s*(.+)" + param_type, param_name, param_val = re.findall(template, param_str)[0] + param_type = ctype_to_python(param_type) + param_val = cval_to_python(param_val) + else: + param_type, param_name = param_str.strip().rsplit(' ', maxsplit=1) + param_type = ctype_to_python(param_type) + param_val = "" + + hint = param_name + if param_type: + hint += ": " + param_type + if param_val: + hint += "=" + param_val + param_hints.append(hint) + return param_hints + + def generate_var_hint(decorators, return_type, param_hints, docstring): + hint = add_indent(decorators) if decorators else "" + hint += f"\tdef {func_name}(" + hint += ", ".join(['self'] + param_hints) + ")" + hint += f"-> {return_type}" if return_type else "" + hint += ":" + if docstring: + hint += add_indent(f"\n'''{docstring}'''\n", 2) + "\t\t...\n" + else: + hint += f" ...\n" + return hint + + for func_name, func in jittor.ops.__dict__.items(): + if func_name.startswith("__"): + continue + # Exclude a function that overrides the builtin bool: + # def bool(x: Var) -> Var: ... + # It will confuse the IDE. So we ignore this function in pyi. + if func_name == "bool": + continue + + docstrings = [] + declarations = [] + for i, doc in enumerate(re.split(r"Declaration:\n(.+)\n", func.__doc__)): + if i % 2 == 0: + if not doc.strip() and docstrings: + # if the current docstring is empty, use the last docstring + docstrings.append(docstrings[-1]) + else: + docstrings.append(doc.replace("'''", '"""').strip()) + else: + declarations.append(doc) + + for i in range(len(declarations)): + decl = declarations[i] + docstring = docstrings[i] + + decorators = "@overload\n" if len(declarations) > 1 else "" + return_type = ctype_to_python(decl.split(' ', maxsplit=1)[0]) + param_hints = decl_to_param_hints(decl) + + func_text = decorators + func_text += f"def {func_name}" + func_text += "(" + ", ".join(param_hints) + ")" + func_text += f"-> {return_type}" if return_type else "" + func_text += ":\n" + if docstring: + func_text += add_indent(f"'''{docstring}'''\n") + "\t...\n" + else: + func_text += f" ...\n" + + f.write(func_text) + + if not "Var" in param_hints[0]: + continue + var_methods.add(func_name) + var_hint += generate_var_hint(decorators, return_type, param_hints[1:], docstring) + + for func_name, func in jittor.Var.__dict__.items(): + if func_name.startswith("__") or func_name in var_methods: + continue + if func_name in ["int", "float", "double", "bool", "long"]: + continue + if func.__doc__ is None: + continue + docstring = func.__doc__[:func.__doc__.find("Declaration:")] + docstring = docstring.replace("'''", '"""').strip() + declarations = re.findall(r"Declaration:\n(.+)\n", func.__doc__) + + for decl in declarations: + decl = decl.replace("inline ", "") + decorators = "@overload\n" if len(declarations) > 1 else "" + return_type = re.findall(r"(.+) [a-zA-Z_0-9]+\(.*\)", decl)[0].split()[-1] + return_type = ctype_to_python(return_type) + param_hints = decl_to_param_hints(decl) + + var_hint += generate_var_hint(decorators, return_type, param_hints, docstring) + + f.write(var_hint) + f.close() + +def gen_flags_stub(jittor_path): + f = open(os.path.join(jittor_path, "__init__.pyi"), "a") + f.write("class Flags:\n") + f.write("\t'''A set of flags to configure jittor running behaviors'''\n") + + for attr_name, attr in jittor.Flags.__dict__.items(): + if attr_name.startswith("__"): + continue + docstring = attr.__doc__ + docstring = attr.__doc__[:attr.__doc__.find("Declaration:")] + docbody = re.findall("\(type.+default.+\):(.+)", docstring)[0].strip() + docbody += "." if not docbody.endswith('.') else "" + attr_type, attr_val = re.findall(r"\(type:(.+), default:(.+)\)", docstring)[0] + attr_type = ctype_to_python(attr_type) + attr_type = attr_type if attr_type else "Any" + f.write(f"\t{attr_name}: {attr_type}\n") + f.write(f"\t'''{docbody} Default: {attr_val}'''\n") + + f.write("flags: Flags\n") + f.write("'''Jittor running time flags instance'''\n") + f.close() + +def get_pyi(jittor_path=None, cache_path=None): + if jittor_path is None: + jittor_path = jittor.flags.jittor_path + if cache_path is None: + import jittor_utils + cache_path = jittor_utils.cache_path + + run_stubgen(jittor_path, cache_path) + gen_ops_stub(jittor_path) + gen_flags_stub(jittor_path) + + print(f"Generated stubfile: {os.path.join(jittor_path, '__init__.pyi')}") + + +if __name__ == "__main__": + get_pyi() \ No newline at end of file diff --git a/python/jittor/utils/jtune.py b/python/jittor/utils/jtune.py new file mode 100755 index 00000000..add09922 --- /dev/null +++ b/python/jittor/utils/jtune.py @@ -0,0 +1,61 @@ +#!/usr/bin/python3 +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +from ctypes import cdll +import sys + +lib_path = sys.argv[1] +cmd = sys.argv[2] +if not lib_path.endswith(".so"): + i = -1 + while lib_path[i] != '.': + i -= 1 + if i > -10: lib_path = lib_path[:i] + lib_path += ".so" + +if cmd == "run_so": + lib = cdll.LoadLibrary(lib_path) + lib.fake_main() + exit(0) + +with open(lib_path+".key") as f: + cpcmd = f.read().splitlines()[0] + +def run_cmd(cmd): + print("Run cmd:", cmd) + assert os.system(cmd) == 0, "Run cmd failed: "+cmd + +import os +if cmd == "cc_to_so": + run_cmd(cpcmd) + # remove hash info, force re-compile + with open(lib_path+'.key', 'w') as f: + f.write(cpcmd) +elif cmd == "cc_to_s": + asm_cmd = cpcmd.replace("_op.so", "_op.s") \ + .replace("-g", "") \ + .replace("-lstdc++", "") \ + .replace("-ldl", "") \ + .replace("-shared", "-S") + run_cmd(asm_cmd) +elif cmd == "s_to_so": + asm_cmd = cpcmd.replace("_op.cc", "_op.s") \ + .replace(" -g", "") + run_cmd(asm_cmd) + # remove hash info, force re-compile + with open(lib_path+'.key', 'w') as f: + f.write(cpcmd) +elif cmd == "perf_so": + perf_cmd = "perf record "+__file__+" "+lib_path+" run_so && perf annotate" + run_cmd(perf_cmd) +elif cmd == "vtune_so": + if os.path.isdir("./__res"): + run_cmd("rm -r ./__res") + vtune_cmd = "amplxe-cl -collect uarch-exploration -r ./__res "+__file__+" "+lib_path+" run_so" + run_cmd(vtune_cmd) +else: + assert 0, "unknown cmd: {cmd}".format(cmd) diff --git a/python/jittor/utils/local_doc_builder.py b/python/jittor/utils/local_doc_builder.py new file mode 100644 index 00000000..cbe0049d --- /dev/null +++ b/python/jittor/utils/local_doc_builder.py @@ -0,0 +1,72 @@ +# how to run: +# docker run -v "${HOME}/Documents/jittor-blog":/srv/jittor-blog -v /home/jittor/Documents/site:/mnt/jittor-blog -e LC_ALL=C.UTF-8 --rm jittor-blog-compiler bash -c "jekyll build --baseurl=JITTOR_BASEURL -d /mnt/jittor-blog/ && chmod -R 777 /mnt/jittor-blog" +# python /home/jittor/Documents/jittor-blog/local_doc_builder.py + +import os + +os.chdir("/home/jittor/Documents/site") + +def check(dirname, fname): + with open(os.path.join(dirname, fname), 'r') as f: + src = f.read() + ac = "JITTOR_BASEURL" + rep = ( + ("href=\"//", "href=\"http://"), + ("src=\"//", "src=\"http://"), + ('https://cg.cs.tsinghua.edu.cn/jittor', ac) + ) + found = False + for a,b in rep: + if a in src: + src = src.replace(a, b) + found = True + if ac not in src and not found: return + n = len(dirname.split(os.path.sep))-1 + s = '.' + '/..' * n + new_src = "" + i = -1 + print("="*20) + print(dirname, fname) + while True: + i += 1 + if i >= len(src): + break + if src[i] != 'J': + new_src += src[i] + continue + if src[i:i+len(ac)] != ac: + new_src += src[i] + continue + j = i + while j xx/xx/index.html + if y.endswith('/'): + y += 'index.html' + else: + z = y.split('/')[-1] + # replace xx/xx --> xx/xx/index.html + if '.' not in z: + y += '/index.html' + y += l + print("found", x, '-->', y) + new_src += y + i = j-1 + with open(os.path.join(dirname, fname), 'w') as f: + f.write(new_src) + +for r, _, f in os.walk('.'): + for fname in f: + ext = fname.split('.')[-1] + if ext not in ['html', 'css', 'js']: + continue + # print(r, fname) + check(r, fname) + diff --git a/python/jittor/utils/nvtx.py b/python/jittor/utils/nvtx.py new file mode 100644 index 00000000..77867ceb --- /dev/null +++ b/python/jittor/utils/nvtx.py @@ -0,0 +1,45 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. +# All Rights Reserved. +# Maintainers: +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +import os +import ctypes + +nvtx_lib_path = os.path.join(jt.compiler.cuda_lib, "libnvToolsExt.so") +nvtx_lib = ctypes.CDLL(nvtx_lib_path, jt.compiler.dlopen_flags) + +nvtxRangePushA = nvtx_lib.nvtxRangePushA +nvtxRangePop = nvtx_lib.nvtxRangePop + +class nvtx_scope: + ''' + Add a mark in nvprof timeline + + Example:: + + from jittor.utils.nvtx import nvtx_scope + with nvtx_scope("model"): + ... + + ''' + def __init__(self, name): + self.name = bytes(name, 'utf8') + + def __enter__(self): + nvtxRangePushA(self.name) + + def __exit__(self, *exc): + nvtxRangePop() + + def __call__(self, func): + def inner(*args, **kw): + with self: + ret = func(*args, **kw) + return ret + return inner diff --git a/python/jittor/utils/polish.py b/python/jittor/utils/polish.py new file mode 100644 index 00000000..301eacf5 --- /dev/null +++ b/python/jittor/utils/polish.py @@ -0,0 +1,131 @@ +#!/usr/bin/python3 +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** + +# Polish steps: +# 1. create jittor-polish repo +# 2. copy jittor src into it +# 3. remove files +# 4. commit jittor-polish(check modify and break) +# 5. compile to build/$git_version/$cc_type/$use_cuda/a.obj +# 6. rsync to build-server +# 7. push to github +# 8. push to pip + +import os +import jittor as jt +from jittor import LOG +from jittor.compiler import run_cmd +from jittor_utils import translator +from jittor.utils.polish_centos import run_in_centos +import sys +import platform + +jittor_path = jt.flags.jittor_path +root_path = os.path.realpath(os.path.join(jt.flags.jittor_path, "..", "..")) +data_path = os.path.join(jittor_path, "src", "__data__") +build_path = os.path.join(data_path, "build") +if not os.path.isdir(build_path): + os.mkdir(build_path) +status = run_cmd("git status", data_path) +print(status) +if "working tree clean" not in status: + LOG.f("__data__ has untracked files") + +git_version = run_cmd("git rev-parse HEAD", data_path) +LOG.i("git_version", git_version) + +run_cmd(f"git rev-parse HEAD > {jittor_path}/version", data_path) + +# remove files +files = jt.compiler.files +data_files = [ name for name in files + if "__data__" in name +] +LOG.i("data_files", data_files) + +# compile data files +import jittor_utils as jit_utils +home = jit_utils.home() +# for cc_type in ["g++", "clang"]: +# for device in ["cpu", "cuda"]: + +os_name_system_dict = { + 'ubuntu': 'Linux', + 'centos': 'Linux', + 'macos': 'Darwin', +} + +if len(sys.argv) > 1 and sys.argv[1] == "native": + os_name_system_dict = {'ubuntu': 'Linux'} + +for os_name, os_type in os_name_system_dict.items(): + if platform.system() != os_type: + continue + os_arch = platform.machine() if os_type == 'Darwin' else '' + + for cc_type in ["g++"]: + for device in ["cpu"]: + key = f"{git_version}-{cc_type}-{device}" + env = f"cache_name=build/{cc_type}/{device} cc_path=" + cname = "g++" if cc_type=="g++" else "clang-8" + env += cname + # use core2 arch, avoid using avx instructions + # TODO: support more archs, such as arm, or use ir(GIMPLE or LLVM) + if platform.machine() in ["x86_64", "AMD64"]: + env += " cc_flags='-march=core2' " + if device == "cpu": + env += " nvcc_path='' " + elif jt.flags.nvcc_path == "": + env = "unset nvcc_path && " + env + cmd = f"{env} {sys.executable} -c 'import jittor'" + if key != 'ubuntu': key += '-' + os_name + if os_arch : key += '-' + os_arch + if platform.machine() == "sw_64": + key += '-sw_64' + if os_name == 'centos': + run_in_centos(env) + obj_path = home + f"/.cache/centos/build/{cc_type}/{device}/{cname}/obj_files" + else: + LOG.i("run cmd:", cmd) + os.system(cmd) + LOG.i("run cmd:", cmd) + os.system(cmd) + obj_path = home + f"/.cache/jittor/build/{cc_type}/{device}/{cname}/obj_files" + + obj_files = [] + for name in data_files: + name = os.path.basename(name) + fname = f"{obj_path}/{name}.o" + assert os.path.isfile(fname), fname + obj_files.append(fname) + ld_cmd = f"ld -r {' '.join(obj_files)} -o {build_path}/{key}.o" + print("RUN CMD:", ld_cmd) + run_cmd(ld_cmd) + +if len(sys.argv) > 1 and sys.argv[1] == "native": + exit(0) + +# compress source +# tar -cvzf build/jittor.tgz . --exclude build --exclude .git --exclude .ipynb_checkpoints --exclude __pycache__ +# mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor +assert os.system(f"cd {root_path} && tar --exclude=build --exclude=.git --exclude=.ipynb_checkpoints --exclude=__pycache__ --exclude=__data__ --exclude=my --exclude=dist --exclude=.vscode --exclude=.github -cvzf {build_path}/jittor.tgz * ")==0 + +# rsync to build-server +jittor_web_base_dir = "Documents/jittor-blog/assets/" +jittor_web_build_dir = jittor_web_base_dir +# copy to jittor-web:Documents/jittor-blog/assets/build/ +assert os.system(f"rsync -avPu {build_path} jittor-web:{jittor_web_build_dir}")==0 +assert os.system(f"ssh jittor-web Documents/jittor-blog.git/hooks/post-update")==0 + + +# sys.exit(0) + +# push to github +# assert os.system(f"cd {polish_path} && git push -f origin master")==0 + +# push to pip \ No newline at end of file diff --git a/python/jittor/utils/polish_centos.py b/python/jittor/utils/polish_centos.py new file mode 100644 index 00000000..db322065 --- /dev/null +++ b/python/jittor/utils/polish_centos.py @@ -0,0 +1,60 @@ +#!/usr/bin/python3 +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt +import os +import jittor_utils as jit_utils +home_path = jit_utils.home() + +def run_cmd(cmd): + print("RUN CMD:", cmd) + assert os.system(cmd) == 0 + +# ubuntu +# debian9 +# centos +# redhat8: docker run -it --name test registry.access.redhat.com/ubi8/ubi:8.1 bash +# fedora: docker run -it fedora +# archlinux: ??? +# pacman fix: https://github.com/qutebrowser/qutebrowser/commit/478e4de7bd1f26bebdcdc166d5369b2b5142c3e2 +# manjaro: ??? + +def run_in_centos(env): + dockerfile_src = r""" + FROM centos:7 + + WORKDIR /root + + # install python + RUN yum install gcc openssl-devel bzip2-devel libffi-devel zlib-devel wget -y + RUN wget https://www.python.org/ftp/python/3.8.3/Python-3.8.3.tgz + RUN tar xzf Python-3.8.3.tgz + RUN yum install make -y + RUN cd Python-3.8.3 && ./configure --enable-optimizations && make altinstall -j8 + + # install g++-7 + # or yum install gcc-g++ + RUN yum install centos-release-scl -y + RUN yum install devtoolset-7-gcc-c++ -y + RUN yum install which -y + RUN scl enable devtoolset-7 'g++ --version' + RUN python3.8 -m pip install numpy tqdm pillow astunparse + """ + + with open("/tmp/centos_build_env", 'w') as f: + f.write(dockerfile_src) + + + centos_path = os.path.join(home_path, ".cache", "centos") + os.makedirs(centos_path+"/src/jittor", exist_ok=True) + os.makedirs(centos_path+"/src/jittor_utils", exist_ok=True) + os.system(f"sudo cp -rL {jt.flags.jittor_path} {centos_path+'/src/'}") + os.system(f"sudo cp -rL {jt.flags.jittor_path}/../jittor_utils {centos_path+'/src/'}") + + run_cmd(f"sudo docker build --tag centos_build_env -f /tmp/centos_build_env .") + run_cmd(f"sudo docker run --rm -v {centos_path}:/root/.cache/jittor centos_build_env scl enable devtoolset-7 'PYTHONPATH=/root/.cache/jittor/src {env} python3.8 -m jittor.test.test_core'") + run_cmd(f"sudo docker run --rm -v {centos_path}:/root/.cache/jittor centos_build_env scl enable devtoolset-7 'PYTHONPATH=/root/.cache/jittor/src {env} python3.8 -m jittor.test.test_core'") \ No newline at end of file diff --git a/python/jittor/utils/publish.py b/python/jittor/utils/publish.py new file mode 100644 index 00000000..6dc24356 --- /dev/null +++ b/python/jittor/utils/publish.py @@ -0,0 +1,52 @@ +#!/usr/bin/python3 +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** + +# Publish steps: +# 1. build,push,upload docker image[jittor/jittor] +# 2. build,push,upload docker image[jittor/jittor-cuda] +# upload to pip: +# rm -rf dist && python3.7 ./setup.py sdist && python3.7 -m twine upload dist/* +import os + +def run_cmd(cmd): + print("[run cmd]", cmd) + assert os.system(cmd) == 0 + +def upload_file(path): + run_cmd(f"rsync -avPu {path} jittor-web:Documents/jittor-blog/assets/build/") + +def docker_task(name, build_cmd): + run_cmd(build_cmd) + run_cmd(f"sudo docker push {name}") + bname = os.path.basename(name) + run_cmd(f"sudo docker save {name}:latest -o /tmp/{bname}.tgz && sudo chmod 666 /tmp/{bname}.tgz") + upload_file(f"/tmp/{bname}.tgz") + +docker_task( + "jittor/jittor-cuda-11-1", + "sudo docker build --tag jittor/jittor-cuda-11-1:latest -f script/Dockerfile_cuda11 . --network host" +) + +docker_task( + "jittor/jittor", + "sudo docker build --tag jittor/jittor:latest . --network host" +) + +docker_task( + "jittor/jittor-cuda", + "sudo docker build --tag jittor/jittor-cuda:latest --build-arg FROM_IMAGE='nvidia/cuda:10.2-cudnn7-devel-ubuntu18.04' . --network host" +) + +docker_task( + "jittor/jittor-cuda-10-1", + "sudo docker build --tag jittor/jittor-cuda-10-1:latest --build-arg FROM_IMAGE='nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04' . --network host" +) + +run_cmd("ssh jittor-web Documents/jittor-blog.git/hooks/post-update") \ No newline at end of file diff --git a/python/jittor/utils/pytorch_converter.py b/python/jittor/utils/pytorch_converter.py new file mode 100644 index 00000000..2d14d57a --- /dev/null +++ b/python/jittor/utils/pytorch_converter.py @@ -0,0 +1,718 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Wenyang Zhou <576825820@qq.com> +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import ast, astunparse +import numpy as np + +pjmap = { + # *************************************************************** + # Module + # *************************************************************** + 'Conv2d': { + 'pytorch': { + 'args': "in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'" + }, + 'jittor': { + 'module': 'nn', + 'name': 'Conv', + 'args': 'in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True' + }, + 'links': {}, + 'extras': {}, + }, + 'ConvTranspose2d': { + 'pytorch': { + 'args': "in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros'" + }, + 'jittor': { + 'module': 'nn', + 'name': 'ConvTranspose', + 'args': 'in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1' + }, + 'links': {}, + 'extras': {}, + }, + 'MaxPool2d': { + 'pytorch': { + 'args': 'kernel_size, stride=None, padding=0, dilation=1, return_indices=False', + }, + 'jittor': { + 'module': 'nn', + 'name': 'Pool', + 'args': 'kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, op="maximum"' + }, + 'links': {}, + 'extras': { + "op": "'maximum'", + }, + }, + 'AvgPool2d': { + 'pytorch': { + 'args': 'kernel_size, stride=None, padding=0, dilation=1, return_indices=False', + }, + 'jittor': { + 'module': 'nn', + 'name': 'Pool', + 'args': 'kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, op="maximum"' + }, + 'links': {}, + 'extras': { + "op": "'mean'", + }, + }, + 'ReLU': { + 'pytorch': { + 'args': 'inplace=False', + }, + 'jittor': { + 'module': 'nn', + 'name': 'ReLU', + 'args': '' + }, + 'links': {}, + 'extras': {}, + 'delete': ['inplace'], + }, + 'relu': { + 'pytorch': { + 'args': 'input', + }, + 'jittor': { + 'module': 'nn', + 'name': 'relu', + 'args': 'x' + }, + 'links': {'input': 'x'}, + 'extras': {}, + 'delete': [], + }, + 'binary_cross_entropy_with_logits': { + 'pytorch': { + 'args': 'input, target, weight, size_average=True', + }, + 'jittor': { + 'module': 'nn', + 'name': 'binary_cross_entropy_with_logits', + 'args': 'input, target, weight, size_average=True' + }, + 'links': {}, + 'extras': {}, + 'delete': [], + }, + 'ReLU6': { + 'pytorch': { + 'args': 'inplace=False', + }, + 'jittor': { + 'module': 'nn', + 'name': 'ReLU6', + 'args': '' + }, + 'links': {}, + 'extras': {}, + 'delete': ['inplace'], + }, + 'PReLU': { + 'pytorch': { + 'args': 'num_parameters=1, init=0.25', + }, + 'jittor': { + 'module': 'nn', + 'name': 'PReLU', + 'args': 'num_parameters=1, init_=0.25' + }, + 'links': {'init': 'init_'}, + 'extras': {}, + }, + 'LeakyReLU': { + 'pytorch': { + 'args': 'negative_slope=0.01, inplace=False', + }, + 'jittor': { + 'module': 'nn', + 'name': 'LeakyReLU', + 'args': 'scale=0.01' + }, + 'links': {'negative_slope': 'scale'}, + 'extras': {}, + 'delete': ['inplace'], + }, + 'BatchNorm2d': { + 'pytorch': { + 'args': 'num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True', + }, + 'jittor': { + 'module': 'nn', + 'name': 'BatchNorm', + 'args': 'num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True' + }, + 'links': {}, + 'extras': {}, + }, + 'BatchNorm1d': { + 'pytorch': { + 'args': "num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True" + }, + 'jittor': { + 'module': 'nn', + 'name': 'BatchNorm1d', + 'args': 'num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True', + }, + 'links': {}, + 'extras': {'affine': 'None'}, + 'delete': ['track_running_stats'], + }, + 'GroupNorm': { + 'pytorch': { + 'args': "num_groups, num_channels, eps=1e-05, momentum=0.1, affine=True" + }, + 'jittor': { + 'module': 'nn', + 'name': 'GroupNorm', + 'args': 'num_groups, num_channels, eps=1e-05, affine=None, is_train=True', + }, + 'links': {}, + 'extras': {'affine': 'None'}, + }, + 'Parameter':{ + 'pytorch': { + 'args': "data,require_grad=True" + }, + 'jittor': { + 'module': 'jt', + 'name': 'array', + 'args': 'data,dtype=None', + }, + 'links': {}, + 'extras': {}, + }, + 'Dropout2d': { + 'pytorch': { + 'args': 'p=0.5, inplace=False', + }, + 'jittor': { + 'module': 'nn', + 'name': 'Dropout', + 'args': 'p=0.5, is_train=False' + }, + 'links': {}, + 'extras': {}, + 'delete': ['inplace'], + }, + 'Upsample': { + 'pytorch': { + 'args': "size=None, scale_factor=None, mode='nearest', align_corners=None", + }, + 'jittor': { + 'module': 'nn', + 'name': 'Upsample', + 'args': "scale_factor=None, mode='nearest'" + }, + 'links': {}, + 'extras': {}, + }, + 'constant_': { + 'pytorch': { + 'args': "tensor, val", + }, + 'jittor': { + 'module': 'init', + 'name': 'constant_', + 'args': 'var, value=0.0' + }, + 'links': {'tensor': 'var', 'val': 'value'}, + 'extras': {}, + }, + 'normal_': { + 'pytorch': { + 'args': "tensor, mean=0.0, std=1.0", + }, + 'jittor': { + 'module': 'init', + 'name': 'gauss_', + 'args': 'var, mean=0.0, std=1.0' + }, + 'links': {'tensor': 'var'}, + 'extras': {}, + }, + 'uniform_': { + 'pytorch': { + 'args': "tensor, a=0.0, b=1.0", + }, + 'jittor': { + 'module': 'init', + 'name': 'uniform_', + 'args': 'var, low, high' + }, + 'links': {'tensor': 'var', 'a': 'low', 'b': 'high'}, + 'extras': {}, + }, + 'cat': { + 'pytorch': { + 'args': "tensors, dim=0, out=None", + }, + 'jittor': { + 'module': 'jt.contrib', + 'name': 'concat', + 'args': 'vars, dim=0' + }, + 'links': {'tensors': 'vars'}, + 'extras': {}, + }, + # *************************************************************** + # Convert format for function which can be writen as either torch.Tensor.xxx(...) or torch.xxx(torch.Tensor, ...) + # Example: x.reshape([2,3]) and torch.reshape(x, [2,3]) + # *************************************************************** + 'flatten': { + 'pytorch': { + 'prefix': ['torch'], + 'args_prefix': 'input, start_dim=0, end_dim=-1', + 'args': 'start_dim=0, end_dim=-1', + }, + 'jittor': { + 'prefix': 'jt', + 'module': '', + 'name': 'flatten', + 'args_prefix': 'input, start_dim=0, end_dim=-1', + 'args': 'start_dim=0, end_dim=-1' + }, + 'links': {}, + 'extras': {}, + }, + 'reshape': { + 'pytorch': { + 'prefix': ['torch'], + 'args_prefix': 'input, shape', + 'args': 'shape', + }, + 'jittor': { + 'prefix': 'jt', + 'module': '', + 'name': 'reshape', + 'args_prefix': 'input, shape', + 'args': 'shape' + }, + 'links': {}, + 'extras': {}, + }, + 'clamp': { + 'pytorch': { + 'prefix': ['torch'], + 'args_prefix': 'input, min, max, out=None', + 'args': 'min, max, out=None', + }, + 'jittor': { + 'prefix': 'jt', + 'module': '', + 'name': 'clamp', + 'args_prefix': 'x, min_v, max_v', + 'args': 'min_v, max_v' + }, + 'links': {'min': 'min_v', 'max': 'max_v'}, + 'extras': {}, + 'delete': ['out'], + }, + 'permute': { + 'pytorch': { + 'prefix': [], + 'args_prefix': '', + 'args': '*dim', + }, + 'jittor': { + 'prefix': '', + 'module': '', + 'name': 'permute', + 'args_prefix': '', + 'args': '*dim' + }, + 'links': {}, + 'extras': {}, + }, + 'view': { + 'pytorch': { + 'prefix': [], + 'args_prefix': '', + 'args': '*shape', + }, + 'jittor': { + 'prefix': '', + 'module': '', + 'name': 'view', + 'args_prefix': '', + 'args': '*shape' + }, + 'links': {}, + 'extras': {}, + } +} + +unsupport_ops = [ + # *************************************************************** + # torch.nn + # *************************************************************** + 'ModuleDict', 'ParameterList', 'ParameterDict', + 'Conv1d', 'Conv3d', 'ConvTranspose1d', 'ConvTranspose3d', 'Unfold', 'Fold', + 'MaxPool1d', 'MaxUnpool1d', 'MaxUnpool2d', 'AvgPool1d', + 'FractionalMaxPool2d', 'LPPool1d', 'LPPool2d', 'AdaptiveMaxPool1d', + 'AdaptiveAvgPool1d', + 'ReflectionPad1d', 'ReplicationPad1d', 'ReplicationPad3d', 'ConstantPad1d', 'ConstantPad3d', + 'ELU', 'Hardshrink', 'Hardtanh', 'LogSigmoid', 'MultiheadAttention', + 'RReLU', 'SELU', 'CELU', 'Softshrink', 'Softsign', 'Tanhshrink', + 'Threshold', 'Softmin', 'Softmax2d', 'LogSoftmax', 'AdaptiveLogSoftmaxWithLoss', + 'BatchNorm3d', 'SyncBatchNorm', 'InstanceNorm1d', 'InstanceNorm3d', 'LocalResponseNorm', + # 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCell', 'LSTMCell', 'GRUCell', 'Transformer', 'TransformerEncoder', + 'TransformerDecoder', 'TransformerEncoderLayer', 'TransformerDecoderLayer', # 'Identity', 'Bilinear', + 'Dropout3d', 'AlphaDropout', 'EmbeddingBag', 'CosineSimilarity', 'PairwiseDistance', 'CTCLoss', 'NLLLoss', 'PoissonNLLLoss', 'KLDivLoss', 'BCEWithLogitsLoss', + 'MarginRankingLoss', 'HingeEmbeddingLoss', 'MultiLabelMarginLoss', 'SmoothL1Loss', 'SoftMarginLoss', + 'MultiLabelSoftMarginLoss', 'CosineEmbeddingLoss', 'MultiMarginLoss', 'TripletMarginLoss', # 'DataParallel', 'DistributedDataParallel', + 'clip_grad_norm_', 'clip_grad_value_', + 'parameters_to_vector', 'vector_to_parameters', 'BasePruningMethod', 'PruningContainer', + 'RandomUnstructured', 'L1Unstructured', 'RandomStructured', 'LnStructured', 'CustomFromMask', + 'random_unstructured', 'l1_unstructured', 'random_structured', 'ln_structured', 'global_unstructured', + 'custom_from_mask', 'remove', 'is_pruned', 'weight_norm', 'remove_weight_norm', 'spectral_norm', + 'remove_spectral_norm', 'PackedSequence', 'pack_padded_sequence', 'pad_packed_sequence', 'pad_sequence', 'pack_sequence' +] + +def pjmap_append(pytorch_func_name, pytorch_args, jittor_func_module, jittor_func_name, jittor_args, extras=None, links=None, delete=None): + ''' adding map to pjmap for converting new function, example: convert AvgPool2d to Pool + args: + * `pytorch_func_name`: Pytorch function name + * `pytorch_args`: Pytorch parameter list + * `jittor_func_module`: to which module the Jittor function belongs + * `jittor_func_name`: Jittor function name + * `jittor_args`: Jittor parameter list + * `extras`: parameter assignment + * `links`: connection parameters + * `delete`: delete parameters + + example: + from jittor.utils.pytorch_converter import pjmap_append + pjmap_append(pytorch_func_name='AvgPool2d', + pytorch_args='kernel_size, stride=None, padding=0, dilation=1, return_indices=False', + jittor_func_module='nn', + jittor_func_name='Pool', + jittor_args='kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False, op="maximum"', + extras={"op": "'mean'"}) + ''' + if links == None: links = {} + if extras == None: extras = {} + if delete == None: delete = [] + assert isinstance(links, dict) + assert isinstance(extras, dict) + assert isinstance(delete, list) + pjmap[pytorch_func_name] = { + 'pytorch': { + 'args': pytorch_args, + }, + 'jittor': { + 'module': jittor_func_module, + 'name': jittor_func_name, + 'args': jittor_args, + }, + 'links': links, + 'extras': extras, + 'delete': delete, + } + + +def raise_unsupport(name, ori_src): + ret = f"raise RuntimeError('''original source: <{ori_src.strip()}>, {name} is not supported in Jittor yet. We will appreciate it if you provide an implementation of {name} and make pull request at https://github.com/Jittor/jittor.''')" + print(ret+'\n') + ret = ast.parse(ret).body[0] + return ret + +class Converter: + def __init__(self, ex_pjmap): + import copy + self.pjmap = copy.deepcopy(pjmap) + if ex_pjmap: + self.pjmap.update(ex_pjmap) + self.unsupport_ops = set(unsupport_ops) + support_ops = {} + for key in self.pjmap.keys(): + module = self.pjmap[key]['jittor']['module'] + name = self.pjmap[key]['jittor']['name'] + if module == 'nn': + support_ops[key] = name + if key in self.unsupport_ops: + self.unsupport_ops.remove(key) + self.support_ops = support_ops + self.import_flag = [] + + def replace(self, a): + if hasattr(a, "attr") and a.attr in self.unsupport_ops: + ori_src = astunparse.unparse(a) + return raise_unsupport(a.attr, ori_src) + + if hasattr(a, "id") and a.id in self.unsupport_ops: + ori_src = astunparse.unparse(a) + return raise_unsupport(a.id, ori_src) + + if hasattr(a, "attr"): + if a.attr in self.support_ops.keys(): a.attr = self.support_ops[a.attr] + + if hasattr(a, "id"): + if a.id in self.support_ops.keys(): a.id = self.support_ops[a.id] + + return None + + def convert_(self, prefix, func_name, ags, kws, ori_src): + info = self.pjmap[func_name] + p_prefix = info['pytorch']['prefix'] if 'prefix' in info['pytorch'].keys() else None + if p_prefix is not None and prefix in p_prefix: + p_ags = info['pytorch']['args_prefix'] + j_ags = info['jittor']['args_prefix'] + else: + p_ags = info['pytorch']['args'] + j_ags = info['jittor']['args'] + if 'delete' in info.keys(): + delete = info['delete'] + else: + delete = None + j_prefix = info['jittor']['prefix'] if 'prefix' in info['jittor'].keys() else None + j_module = info['jittor']['module'] + j_name = info['jittor']['name'] + links = info['links'] + extras = info['extras'] + jj_ags = [] + jj_kws = {} + pp_ags = [] + pp_kws = {} + if j_ags == '' and p_ags == '': + # no args in Pytorch and Jittor. + if p_prefix is None: + return f"{j_module}.{j_name}()" + else: + if prefix in p_prefix: + return f"{j_prefix}.{j_name}()" + else: + return f"{prefix}.{j_name}()" + else: + j_ags = j_ags.replace(' ','').split(',') + for j_ag in j_ags: + if '=' in j_ag: + k,v = j_ag.split('=') + jj_kws[k] = v + else: + jj_ags.append(j_ag) + p_ags = p_ags.replace(' ','').split(',') + for p_ag in p_ags: + if '=' in p_ag: + k,v = p_ag.split('=') + pp_kws[k] = v + else: + pp_ags.append(p_ag) + if len(jj_ags) == 0 and len(pp_ags) != 0: + return f"raise AttributeError('''origin source: <{ori_src.strip()}>, {func_name} in Jittor has no Attribute {pp_ags[0]}''')" + # raise AttributeError(f"{func_name} in Jittor has no Attribute {pp_ags[0]}") + if delete is not None: + for d in delete: + if d in pp_ags: + jj_ags.append(d) + if d in pp_kws.keys(): + jj_kws[d] = None + if len(pp_ags) > len(ags) + len(kws): + return f"raise RuntimeError('''origin source: <{ori_src.strip()}>, There are needed {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you only provide {len(ags) + len(kws)}''')" + # raise RuntimeError(f'There are needed {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you only provide {len(ags) + len(kws)}') + ags_ = [] + for i in range(len(pp_ags)): + if i < len(ags): + if '*' in pp_ags[i]: + ags_.append('(' + ', '.join(ags[i:]) + ')') + ags = ags_ + break + else: + ags_.append(ags[i]) + else: + break + if len(pp_ags) + len(list(pp_kws.keys())) < len(ags) + len(kws): + return f"raise RuntimeError('''origin source: <{ori_src.strip()}>,There are only {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you provide {len(ags) + len(kws)}''')" + # raise RuntimeError(f'There are only {len(pp_ags) + len(list(pp_kws.keys()))} args in Pytorch {func_name} function, but you provide {len(ags) + len(kws)}') + j_ags_flag = np.zeros(len(jj_ags)) + j_ags_values = {} + j_kws_values = {} + for i,ag in enumerate(ags): + if len(pp_ags) == 0: + ag_name = list(pp_kws.keys())[i] + elif i < len(pp_ags): + ag_name = pp_ags[i] + elif i >= len(pp_ags) and (i-len(pp_ags)) <= len(list(pp_kws.keys())): + ag_name = list(pp_kws.keys())[i-len(pp_ags)] + else: + return f"raise RuntimeError('''origin source: <{ori_src.strip()}>,The args number is not matc{func_name} in Jittor has no Attribute {ag_name}''')" + # raise RuntimeError(f'The args number is not matc{func_name} in Jittor has no Attribute {ag_name}') + if ag_name in links.keys(): + ag_name = links[ag_name] + if ag_name in jj_ags: + j_ags_flag[jj_ags.index(ag_name)] = 1 + j_ags_values[str(jj_ags.index(ag_name))] = ag + elif ag_name in jj_kws.keys(): + j_kws_values[ag_name] = ag + else: + return f"raise AttributeError('''origin source: <{ori_src.strip()}>, {func_name} in Jittor has no Attribute {ag_name}''')" + # raise AttributeError(f'{func_name} in Jittor has no Attribute {ag_name}') + for i,kw in enumerate(kws): + kw_name, kw_value = kw.split('=') + if kw_name in links.keys(): + kw_name = links[kw_name] + if kw_name in jj_ags: + j_ags_flag[jj_ags.index(kw_name)] = 1 + j_ags_values[str(jj_ags.index(kw_name))] = kw_value + elif kw_name in jj_kws.keys(): + j_kws_values[kw_name] = kw_value + else: + return f"raise AttributeError('''origin source: <{ori_src.strip()}>, {func_name} in Jittor has no Attribute {kw_name}''')" + # raise AttributeError(f'{func_name} in Jittor has no Attribute {kw_name}') + len_jj_ags = len(jj_ags) if len(jj_ags) == 0 or jj_ags[0] != '' else 0 + if j_ags_flag.sum() < len_jj_ags: + missing_args = [] + for i in range(len(jj_ags)): + if j_ags_flag[i] == 0: + missing_args.append(jj_ags[i]) + return f"raise AttributeError('''origin source: <{ori_src.strip()}>, the needed args of {func_name} in Jittor is {', '.join(jj_ags)}, so you need to give value of {', '.join(missing_args)}.''')" + # raise AttributeError(f"the needed args of {func_name} in Jittor is {', '.join(jj_ags)}, so you need to give value of {', '.join(missing_args)}.") + if extras: + for k in extras.keys(): + if k in jj_ags: + j_ags_values[str(jj_ags.index(k))] = extras[k] + elif k in jj_kws.keys(): + j_kws_values[k] = extras[k] + else: + return f"raise AttributeError('''origin source: <{ori_src.strip()}>, there is not attribute named {k} in Jittor {func_name}, you should delete it in {func_name} extras.''')" + # raise AttributeError(f"there is not attribute named {k} in Jittor {func_name}, you should delete it in {func_name} extras.") + if delete is not None: + for d in delete: + if d in j_ags_values: + del j_ags_values[d] + if d in j_kws_values.keys(): + j_kws_values.pop(d) + j_ags_ = [j_ags_values[str(i)] for i in range(len(list(j_ags_values.keys())))] + j_kws_ = [key + "=" + j_kws_values[key] for key in j_kws_values.keys()] + j_func = f"{j_module}.{j_name}({', '.join(j_ags_+j_kws_)})" + if p_prefix is None: + return f"{j_module}.{j_name}({', '.join(j_ags_+j_kws_)})" + else: + if prefix in p_prefix: + return f"{j_prefix}.{j_name}({', '.join(j_ags_+j_kws_)})" + else: + return f"{prefix}.{j_name}({', '.join(j_ags_+j_kws_)})" + return j_func + + def dfs(self, a): + if isinstance(a, ast.Import): + if 'torch' in astunparse.unparse(a) and 'init' in astunparse.unparse(a): + self.import_flag.append('init') + return ast.parse('from jittor import init').body[0] + if 'torch' in astunparse.unparse(a) and a.names[0].asname == 'nn': + self.import_flag.append('nn') + return ast.parse('from jittor import nn').body[0] + if 'torch' in a.names[0].name: + return 'delete' + elif isinstance(a, ast.ImportFrom): + if 'torch' in a.module: + return 'delete' + elif isinstance(a, ast.Call): + for idx, ag in enumerate(a.args): + ret = self.dfs(ag) + if ret is not None: + a.args[idx] = ret + for idx, kw in enumerate(a.keywords): + ret = self.dfs(kw) + if ret is not None: + a.keywords[idx] = ret + ori_src = astunparse.unparse(a) + func = astunparse.unparse(a.func).strip('\n').split('.') + prefix = '.'.join(func[0:-1]) + func_name = func[-1] + if func_name in self.unsupport_ops: + ret = raise_unsupport(func_name, ori_src) + return ret + if func_name in self.pjmap: + ags = [astunparse.unparse(ag).strip('\n') for ag in a.args] + kws = [astunparse.unparse(kw).strip('\n') for kw in a.keywords] + ret = self.convert_(prefix, func_name, ags, kws, ori_src) + ret_tmp = ret + ret = ast.parse(ret).body[0] + if hasattr(ret,'value'): + return ret.value + else: + print(ret_tmp+'\n') + return ret + if ".load_state_dict" in astunparse.unparse(a.func): + a.func.attr = 'load_parameters' + if astunparse.unparse(a.func).strip('\n').endswith(".size"): + ags = [astunparse.unparse(ag).strip('\n') for ag in a.args] + if len(ags) != 0: + con = astunparse.unparse(a.func).split('.size')[0] + '.shape[' + ','.join(ags) + ']' + else: + con = astunparse.unparse(a.func).replace('size', 'shape') + return ast.parse(con).body[0].value + elif isinstance(a, ast.Expr): pass + elif isinstance(a, ast.Attribute) or isinstance(a, ast.Name): + ret = self.replace(a) + if ret is not None: + print(ret) + return ret + elif isinstance(a, ast.FunctionDef): + if a.name == 'forward': a.name = 'execute' + if hasattr(a, '__dict__'): + for k in a.__dict__.keys(): + if isinstance(a.__dict__[k], list): + delete_flag = [] + for i,a_ in enumerate(a.__dict__[k]): + ret = self.dfs(a_) + if ret == 'delete': + delete_flag.append(True) + continue + if ret is not None: + a.__dict__[k][i] = ret + delete_flag.append(False) + tmp = [a_ for i,a_ in enumerate(a.__dict__[k]) if delete_flag[i] == False] + a.__dict__[k] = tmp + else: + ret = self.dfs(a.__dict__[k]) + if ret is not None: + a.__dict__[k] = ret + + +def convert(code, ex_pjmaps=None): + ''' Model code converter, example: + + from jittor.utils.pytorch_converter import convert + pytorch_code = """ + class Model(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 10, 3) + self.conv2 = nn.Conv2d(10, 32, 3) + self.fc = nn.Linear(1200, 100) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + """ + jittor_code = convert(pytorch_code) + print("## Generate Jittor code:", jittor_code) + exec(jittor_code) + model = Model() + print("## Jittor model:", model) + ''' + + a = ast.parse(code) + converter = Converter(ex_pjmaps) + converter.dfs(a) + a.body.insert(0, ast.parse('import jittor as jt').body[0]) + if 'init' not in converter.import_flag: + a.body.insert(1, ast.parse('from jittor import init').body[0]) + if 'nn' not in converter.import_flag: + a.body.insert(2, ast.parse('from jittor import nn').body[0]) + return astunparse.unparse(a) diff --git a/python/jittor/utils/tracer.py b/python/jittor/utils/tracer.py new file mode 100644 index 00000000..0ca4d134 --- /dev/null +++ b/python/jittor/utils/tracer.py @@ -0,0 +1,24 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Dun Liang . +# +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import jittor as jt + +def fill_module_name(m, name): + ps = [] + stack = [] + def callback(parents, k, v, n): + stack.append(str(k)) + for k2, p in v.__dict__.items(): + if isinstance(p, jt.Var): + ps.append(p) + p.name(".".join(stack[1:]+[str(k2)])) + v._trace_name = str(k) + def callback_leave(parents, k, v, n): + stack.pop() + m.dfs([], name, callback, callback_leave) diff --git a/python/jittor/vcompiler/__init__.py b/python/jittor/vcompiler/__init__.py new file mode 100644 index 00000000..7458533b --- /dev/null +++ b/python/jittor/vcompiler/__init__.py @@ -0,0 +1 @@ +from .vcompiler import * \ No newline at end of file diff --git a/python/jittor/vcompiler/vcompiler.cc b/python/jittor/vcompiler/vcompiler.cc new file mode 100644 index 00000000..0746925a --- /dev/null +++ b/python/jittor/vcompiler/vcompiler.cc @@ -0,0 +1,1048 @@ +// *************************************************************** +// Copyright (c) 2023 Jittor. All Rights Reserved. +// Maintainers: +// Dun Liang . +// Guoye Yang <498731903@qq.com> +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#include +#ifdef HAS_CUDA +#include +#include "helper_cuda.h" +#include "mem/allocator/cuda_dual_allocator.h" +#include "event_queue.h" +#endif +#include "misc/cuda_flags.h" +#include "executor.h" +#include "var.h" +#include "op.h" +#include "mem/allocator.h" +#include "graph.h" +#include "fused_op.h" +#include "fuser.h" +#include "profiler/profiler_guard.h" +#include "parallel_compiler.h" +#include "memory_profiler.h" +#include "misc/nan_checker.h" +#include "memory_profiler.h" +#include "utils/seh.h" +#include "utils/cache_compile.h" +#include "var_holder.h" +#include "mem/swap.h" +#include "mem/mem_info.h" + +#include +#include "var_holder.h" +#include "vcompiler.h" + +namespace jittor { + +EXTERN_LIB MemoryProfiler memory_profiler; +DECLARE_FLAG(int, profile_memory_enable); +DECLARE_FLAG(int, gopt_disable); +DECLARE_FLAG(int, use_threading); + +// from cuda_managed_allocator +#ifdef HAS_CUDA +DECLARE_FLAG(int, use_cuda_managed_allocator); +#endif + +void load_fused_op(FusedOp& fused_op, vector& fuse_ops, vector& ops, int ll, int rr, int64 tt) { + fused_op.ops.clear(); + fused_op.edges.clear(); + auto ntt = ++tflag_count; + for (int i=ll; icustom_data = fid1; + op->tflag = ntt; + fused_op.ops.push_back(op); + } + LOGvvv << "Prepare fused_op" << fused_op.ops; + fused_op.update_ops(); + for (Op* op : fused_op.ops) { + uint fid1 = op->custom_data; + int iid = 0; + for (auto ve : op->_inputs) { + // this is a control dependency edge, dont used + if (ve.back->index<0) continue; + auto v = ve.node->var(); + iid++; + int iop_id; + int iv_id; + if (v->_inputs.size() && v->input()->tflag == ntt) { + auto e = v->_inputs.front(); + iop_id = e.node->custom_data; + iv_id = e.back->index; + } else { + iv_id = v->custom_data >> 2; + // add iv_id, prevent iv_id jit key overflow + iop_id = fused_op.ops.size() + iv_id; + } + fused_op.edges.emplace_back(iop_id, iv_id, fid1, iid-1); + } + // TODO: can we remove this? + // uint oid = 0; + // for (Var* v : op->outputs()) { + // oid++; + // if (v->tflag != tt) { + // // this var node not belong to current execution + // // this will happend in multiple outputs fuseable op + // // v->custom_data = 0 represents this var cannot be fused + // v->custom_data = 0; + // continue; + // } + // // for (auto o : v->outputs_with_index()) { + // // Op* op2 = o.op; + // // uint iid = o.index; + // // if (op2->tflag != ntt) continue; + // // uint fid2 = op2->custom_data; + // // fused_op.edges.emplace_back(fid1, oid-1, fid2, iid); + // // } + // } + } +} + +static inline void propergate_needed_flags(FusedOp& fused_op) { + auto& ops = fused_op.ops; + for (int i=ops.size()-1; i>=0; i--) { + bool has_need = 0; + auto op = ops[i]; + for (auto o : op->outputs()) + if (o->flags.get(NodeFlags::_needed_by_backward) && + !(o->custom_data&1)) { + has_need = 1; + } + if (has_need) + for (auto i : op->inputs()) { + i->flags.set(NodeFlags::_needed_by_backward); + } + } +} + +void check_op_async_error(Op* op, bool is_fused_op, const std::exception& e, jittor::Log& logf) { + vector stack; + if (is_fused_op) { + FusedOp& fused_op = *((FusedOp*)op); + logf >> "[OP TYPE]:" << "fused_op:("; + for (auto& op : fused_op.ops) + logf << op->name_ex() >> ","; + logf >> ")\n"; + logf >> "[Input]:"; + for (auto& vi : fused_op.vars) + if (vi.type == 0) logf << vi.var->dtype() >> vi.var->shape >> vi.var->name >> ","; + logf << "\n[Output]:"; + Var* ov = nullptr; + for (auto& vi : fused_op.vars) + if (vi.type == 2) { + logf << vi.var->dtype() >> vi.var->shape >> vi.var->name >> ","; + ov = vi.var; + } + if (ov) + stack = get_node_trace(ov); + } else { + logf >> "[OP TYPE]:" << op->name_ex(); + logf << "\n[Input]:"; + for (auto v : op->inputs()) + logf << v->dtype() >> v->shape >> v->name >> ","; + logf << "\n[Output]:"; + Var* ov = nullptr; + for (auto v : op->outputs()) { + logf << v->dtype() >> v->shape >> v->name >> ","; + ov = v; + } + if (ov) + stack = get_node_trace(ov); + } + logf << "\n[Async Backtrace]:"; + if (stack.size()) { + logf << "---"; + for (auto& s : stack) { + logf << "\n " << s.file_path >> ":" >> s.lineno; + if (s.module_type.size()) logf << '<' >> s.module_type >> '>'; + if (s.module_name.size() && s.module_name.find(":") == string::npos) + logf << '[' >> s.module_name >> ']'; + } + } else + logf << "not found, please set env JT_SYNC=1, trace_py_var=3"; + logf << "\n[Reason]:" << e.what(); + jittor::LogFatalVoidify() && logf; +} + +static void top_weak_sync(vector& vars) { + auto t = ++tflag_count; + int64 max_id=0; + for (auto v : vars) { + if (v->is_finished()) continue; + max_id = std::max(v->id, max_id); + v->tflag = t; + } + while (true) { + if (sync_ptr == hold_vars.begin()) + break; + auto next_ptr = std::prev(sync_ptr); + auto v = (*next_ptr)->var; + if (v->id > max_id) break; + sync_ptr = next_ptr; + if (v->tflag == t) continue; + if (v->_outputs.size()) continue; + if (v->is_finished()) continue; + vars.push_back(v); + } +} + +extern void free_var_mem(Var* v); + +VarHolder* get_output(Var* x) { + ASSERT(x->mem_ptr) << x; + VarPtr vp(x->shape, x->dtype()); + vp->mem_ptr = x->mem_ptr; + vp->allocation = x->allocation; + vp->allocator = x->allocator; + vp->finish_pending_liveness(); + x->mem_ptr = nullptr; + x->allocator = nullptr; + x->allocation = 0; + return new VarHolder(std::move(vp)); +} + +} // jittor + +#include +#include "common.h" +#include "ops/array_op.h" +#include "ops/code_op.h" +#include "ops/getitem_op.h" + +namespace jittor { + + +inline static bool fast_strcmp(const char* a, const char* b) { + return ((const uint32*)a)[0] == ((const uint32*)b)[0]; +} + +inline static void get_shape_value(vector& nodes, ShapeValue& k) { + auto add_shape = [&](NanoVector shape) { + k.values.push_back(shape.data); + k.values.push_back(shape.offset); + }; + for (auto* node : nodes) { + if (node->is_var()) { + Var* v = (Var*)node; + add_shape(v->shape); + k.values.push_back(v->num); + k.values.push_back(v->size); + continue; + } + auto* op = node->op(); + auto* name = op->name(); + if (fast_strcmp(name, "array")) { + auto* op_ = (ArrayOp*)op; + if (op_->output->flags.get(NodeFlags::_force_fuse)) + k.values.push_back(op_->ptr()[0]); + } else + if (fast_strcmp(name, "code")) { + auto* op_ = (CodeOp*)op; + for (auto& kv : op_->data) { + double v = kv.second; + // bitwise copy + k.values.push_back(*(uint64*)&v); + } + } else + if (fast_strcmp(name, "getitem") || + fast_strcmp(name, "setitem")) { + auto* op_ = (GetitemOp*)op; + for (int i=0; ivs.n; i++) { + auto& vs = op_->vs.slices[i]; + if (vs.is_int() || vs.is_slice()) { + k.values.push_back(vs.slice.start); + k.values.push_back(vs.slice.stop); + k.values.push_back(vs.slice.step); + k.values.push_back(vs.slice.mask); + } + } + add_shape(op_->o_shape); + } + } +} + +inline static void restore_shape_value(vector& nodes, ShapeValue& k) { + int iter = 0; + auto pop_number = [&]() { + ASSERT(iter < k.values.size()); + return k.values[iter++]; + }; + auto pop_shape = [&]() { + ASSERT(iter < k.values.size()); + NanoVector nv; + nv.data = k.values[iter++]; + nv.offset = k.values[iter++]; + return nv; + }; + + for (auto* node : nodes) { + if (node->is_var()) { + Var* v = (Var*)node; + v->shape = pop_shape(); + v->num = pop_number(); + v->size = pop_number(); + continue; + } + auto* op = node->op(); + auto* name = op->name(); + if (fast_strcmp(name, "array")) { + auto* op_ = (ArrayOp*)op; + if (op_->output->flags.get(NodeFlags::_force_fuse)) + op_->ptr()[0] = pop_number(); + } else + if (fast_strcmp(name, "code")) { + auto* op_ = (CodeOp*)op; + for (auto& kv : op_->data) { + double& v = kv.second; + // bitwise copy + *(uint64*)&v = pop_number(); + } + } else + if (fast_strcmp(name, "getitem") || + fast_strcmp(name, "setitem")) { + auto* op_ = (GetitemOp*)op; + for (int i=0; ivs.n; i++) { + auto& vs = op_->vs.slices[i]; + if (vs.is_int() || vs.is_slice()) { + vs.slice.start = pop_number(); + vs.slice.stop = pop_number(); + vs.slice.step = pop_number(); + vs.slice.mask = pop_number(); + } + } + op_->o_shape = pop_shape(); + op->graph_optimize(); + } + } +} + +SGraphPtr build_sgraph(const vector& outputs, const vector& inputs) { + vector vars; + vars.reserve(outputs.size()); + for (auto* vh : outputs) + vars.push_back(vh->var); + bool weak_sync = false; + + if (weak_sync && !use_threading) + top_weak_sync(vars); + auto allocator = get_allocator(); + auto temp_allocator = get_allocator(true); + exe.allocator = allocator; + exe.temp_allocator = temp_allocator; + auto& last_is_cuda = exe.last_is_cuda; + // bfs find all ops need to run + int op_num = 0; + vector bfs_q; + bfs_q.reserve(vars.size()); + int start_var_num = 0; + while (1) { + op_num = 0; + start_var_num = 0; + bfs_q.clear(); + // get all nodes need to be executed + int need_opt = 0; + auto t = ++tflag_count; + int64 max_id = 0; + for (Var* v : vars) + if (!v->is_finished() && v->tflag != t) { + v->tflag = t; + start_var_num++; + bfs_q.push_back(v); + max_id = std::max(max_id, v->id); + } + for (int i=0; iis_var(); + for (auto i : node->_inputs) + if (i.node->tflag != t && !i.node->is_finished()) { + i.node->tflag = t; + need_opt += i.node->flags.get(NodeFlags::_has_gopt); + bfs_q.push_back(i.node); + } + // this var has been fetched + if (weak_sync || node->flags.get(NodeFlags::_fetch)) { + for (auto& n : node->_outputs) { + // if not in queue and is fetch op + if (n.node->tflag != t && + n.node->pending_liveness && + !n.node->is_finished() && + (n.node->id <= max_id || + n.node->flags.get(NodeFlags::_fetch))) { + n.node->tflag = t; + need_opt += n.node->flags.get(NodeFlags::_has_gopt); + bfs_q.push_back(n.node); + } + } + } + } + if (!need_opt || gopt_disable) break; + for (Node* n : bfs_q) { + if (n->flags.get(NodeFlags::_has_gopt)) { + n->op()->graph_optimize(); + n->flags.set(NodeFlags::_has_gopt, 0); + } + } + } + auto tt = tflag_count; + vector ops; + vector all_vars; + ops.reserve(op_num); + all_vars.reserve(bfs_q.size() - op_num); + for (Node* node : bfs_q) + if (!node->is_var()) { + node->custom_data = ops.size(); + ops.push_back(node->op()); + } else { + // set can't fuse flag to false + node->custom_data = all_vars.size(); + all_vars.push_back(node->var()); + } + int var_num = all_vars.size(); + + // father: father of union-find set + vector father(op_num); + for (int i=0; i int { + int j=i; + while (father[j] != j) j = father[j]; + while (i != j) { + int tmp = father[i]; + father[i] = j; + i = tmp; + } + return j; + }; + vector var_fused(var_num); + + if (V_ON(100)) { + for (uint i=0; itype()==OpType::reduce) st="reduce"; + if (op->type()==OpType::broadcast) st="broadcast"; + if (op->type()==OpType::element) st="element"; + + LOGvvv << "id:" << ops[i]->custom_data << " type:" << + st << " addr:" << op; + for (Var* v : op->inputs()) { + Op* next_op = v->input(); + // continue if is boundary + if (!next_op || next_op->tflag != tt) { + LOGvvv << "input:" << v; + continue; + } + LOGvvv << "input:" << next_op->custom_data << " addr:" << next_op; + } + LOGvvv << ""; + } + } + + count_fuse(tt, start_var_num, ops, all_vars, father, var_fused); + // var_fused represents: + // 0: can fused + // 1: cannot fused + // 2: weak shared(may turn into 1 or 3 by shared operator cutting) + // 3: strong shared(force shared) + vector roots, next(op_num, -1); + vector deps(op_num, 0); + roots.reserve(op_num); + for (int i=0; i queue; + queue.reserve(roots.size()); + + // ** toplogical_sort external ** + // output: + // queue: toplogical order of fused op + { + // queue.clear(); + #ifndef JT_bfs_executor + std::priority_queue> p_queue; + #endif + for (int root : roots) { + for (int i=root; i>=0; i=next[i]) { + Op* op = ops[i]; + for (Var* v : op->inputs()) { + if (v->tflag != tt) continue; + Op* opi = v->input(); + // if those two ops are not fused + if (father[opi->custom_data] != root) { + deps[root]++; + } + } + } + #ifdef JT_bfs_executor + if (deps[root] == 0) + queue.push_back(root); + #else + if (deps[root] == 0) + p_queue.emplace(-ops[root]->order(), root); + #endif + } + #ifdef JT_bfs_executor + for (uint s=0; s=0; i=next[i]) { + Op* op = ops[i]; + for (Var* v : op->outputs()) + { + if (v->tflag == tt) + for (Op* op2 : v->outputs()) + { + if (op2->tflag != tt) continue; + int op2_id = father[op2->custom_data]; + // continue if those two ops are fused + if (op2_id == op_id) continue; + deps[op2_id]--; + #ifdef JT_bfs_executor + if (deps[op2_id] == 0) + queue.push_back(op2_id); + #else + if (deps[op2_id] == 0) + p_queue.emplace(-op2->order(), op2_id); + #endif + } + } + } + } + ASSERTop(queue.size(),==,roots.size()); + } + + // ** toplogical_sort internal ** + // output: + // fuse_ops: fused op id [000|1111|22|3333] + // range: split index ^ ^ ^ ^ ^ + vector fuse_ops; + fuse_ops.reserve(op_num*2); + vector range(queue.size()); + { + vector subgraph; + subgraph.reserve(16); + vector sharegraph; + sharegraph.reserve(16); + vector sharegraph_q; + sharegraph_q.reserve(16); + vector shared_id(op_num, -1); + + // for fused op in reversed order + for (uint rid=0; rid=0; i=next[i], total++) { + Op* op = ops[i]; + for (Var* v : op->inputs()) { + if (v->tflag != tt) continue; + Op* opi = v->input(); + // if those two ops are fused + int opid = opi->custom_data; + auto fopid = father[opid]; + if (fopid == root) + deps[i]++; + else if (shared_id[opid] != root) { + auto& vf = var_fused[v->custom_data]; + // var_fused = 1 cannot share input op + // TODO: check this input op's output var all can be shared + if (vf == 1) + continue; + // if weak share, turn into strong share + if (vf == 2) vf = 3; + // new shared op + deps[opid] = 0; + shared_id[opid] = root; + sharegraph.push_back(opid); + } + } + if (deps[i] == 0) + queue.push_back(i); + } + // find all share graph + uint sn = sharegraph.size(); + for (uint i=0; iinputs()) { + if (v->tflag != tt) continue; + int vi = v->custom_data; + if (var_fused[vi] == 1) + continue; + // if weak share, cut off + if (var_fused[vi] == 2) { + if (sharegraph.size() - sn < 32) + var_fused[vi] = 3; + else { + var_fused[vi] = 1; + continue; + } + } + Op* opi = v->input(); + int opid = opi->custom_data; + int& dep = deps[opid]; + if (shared_id[opid] != root) { + shared_id[opid] = root; + dep = 1; + sharegraph.push_back(opid); + } else + dep ++; + } + } + sharegraph_q.clear(); + for (uint i=0; iinputs()) { + if (v->tflag != tt) continue; + int vi = v->custom_data; + if (var_fused[vi] == 1) + continue; + Op* opi = v->input(); + int opid = opi->custom_data; + int& dep = deps[opid]; + dep --; + if (dep == 0) + sharegraph_q.push_back(opid); + } + } + LOGvvvv << "sharegraph_q" << sharegraph_q; + ASSERTop(sharegraph.size(),==,sharegraph_q.size()); + // topsort fused op internal + for (uint s=0; soutputs()) + if (v->tflag == tt) + for (Op* op2 : v->outputs()) { + if (op2->tflag != tt) continue; + int op2_id = op2->custom_data; + // continue if those two ops are not fused + if (father[op2_id] != root) continue; + deps[op2_id]--; + if (deps[op2_id] == 0) + queue.push_back(op2_id); + } + } + ASSERTop(queue.size(),==,(uint)total); + LOGvvvv << "topsort internal" << queue; + for (int i=(int)sharegraph_q.size()-1; i>=0; i--) + fuse_ops.push_back(sharegraph_q[i]); + for (uint i=0; icustom_data = var_fused[i]==1; + } + FusedOp fused_op; + + // compile all ops, prevent compiling during running + parallel_compile_all_ops(queue, range, fused_op, fuse_ops, ops, tt, true); + + // flags + std::sort(bfs_q.begin(), bfs_q.end(), [&](Node* x, Node* y) { return x->idid; }); + unordered_map> share_map; + auto min_id = bfs_q.front()->id; + auto max_id = bfs_q.back()->id; + vector flags(max_id-min_id+1); + constexpr int is_output = 0; + constexpr int is_new_var = 1; + constexpr int is_share = 2; + + auto lived = [&](Node* n) { return n->id>=min_id && n->id<=max_id; }; + auto get_flags = [&](Node* n, int f) -> int { + if (!lived(n)) return 0; + return (flags[n->id-min_id]>>f)&1; + }; + auto set_flags = [&](Node* n, int f) { + if (!lived(n)) return; + flags[n->id-min_id] |= (1<allocator) { + share_map[v] = std::make_pair((Var*)v->allocator, v->allocation); + set_flags(v, is_share); + } + } + + // build fused ops + vector fused_ops(queue.size()); + vector rid_ops(queue.size()); + vector v_last_rid(max_id-min_id+1, -1); + vector jit_entries(queue.size()); + + auto& jkl = get_jk(); + for (uint rid=0; ridtype() != OpType::other) { + auto& fused_op = fused_ops[rid]; + op = &fused_op; + is_fused_op = true; + int ll = (riddo_prepare(jkl); + jit_entries[rid] = (jit_op_entry_t)&FusedOp::do_run; + } else { + op->do_prepare(jkl); + if (!jkl.empty()) { + const char* jit_key = jkl.to_cstring(); + auto iter = jit_ops.find(jit_key); + ASSERT(iter != jit_ops.end()) << jit_key << op << rid; + jit_entries[rid] = iter->second; + } else { + jit_entries[rid] = (jit_op_entry_t)&Op::run; + } + } + rid_ops[rid] = op; + for (auto v : op->inputs()) + if (get_flags(v, is_new_var)) + v_last_rid[v->id-min_id] = rid; + } + + SGraphPtr sgraph_ptr; + sgraph_ptr.ptr = std::make_unique(); + auto& g = *sgraph_ptr.ptr; + + g.outputs.reserve(outputs.size()); + for (auto v : outputs) { + g.outputs.push_back(v->var); + } + + g.inputs.reserve(inputs.size()); + for (auto v : inputs) { + g.inputs.push_back(v->var); + } + + g.bfs_q = std::move(bfs_q); + g.share_map = std::move(share_map); + g.flags = std::move(flags); + g.fused_ops = std::move(fused_ops); + g.rid_ops = std::move(rid_ops); + g.v_last_rid = std::move(v_last_rid); + + ShapeKey key; + key.shapes.reserve(inputs.size()); + for (auto v : inputs) { + key.shapes.push_back(v->var->shape); + } + + ShapeValue& value = g.shape_values[key]; + get_shape_value(g.bfs_q, value); + auto prev_size = value.values.size(); + value.values.resize(value.values.size() + jit_entries.size()); + memcpy(&value.values[prev_size], &jit_entries[0], jit_entries.size()*sizeof(jit_op_entry_t)); + g.shape_value_len = value.values.size(); + + return sgraph_ptr; +} + + +bool prob_sgraph(SGraphPtr* sgraph, const vector& inputs) { + // return true; + ShapeKey key; + key.shapes.reserve(inputs.size()); + for (auto v : inputs) { + key.shapes.push_back(v->var->shape); + } + auto& g = *sgraph->ptr; + auto it = g.shape_values.find(key); + if (it == g.shape_values.end()) return false; + return true; +} + +void merge_sgraph(SGraphPtr* sgraph, SGraphPtr* sgraph2) { + auto& g1 = *sgraph->ptr; + auto& g2 = *sgraph2->ptr; + ASSERT(g1.outputs.size() == g2.outputs.size()); + ASSERT(g1.inputs.size() == g2.inputs.size()); + ASSERTop(g1.bfs_q.size(),==,g2.bfs_q.size()); + ASSERT(g1.share_map.size() == g2.share_map.size()); + ASSERT(g1.flags.size() == g2.flags.size()); + ASSERT(g1.fused_ops.size() == g2.fused_ops.size()); + ASSERT(g1.rid_ops.size() == g2.rid_ops.size()); + ASSERT(g1.v_last_rid.size() == g2.v_last_rid.size()); + ASSERT(g1.shape_value_len == g2.shape_value_len); + + for (int i=0; iis_var() == n2->is_var()); + if (n1->is_var()) { + ASSERT(n1->var()->shape.size() == n2->var()->shape.size()); + ASSERT(n1->var()->dtype() == n2->var()->dtype()); + } else { + ASSERT(fast_strcmp(n1->op()->name(), n2->op()->name()) == 1); + } + } + for (auto& kv : g2.shape_values) { + g1.shape_values[kv.first] = kv.second; + } +} + +vector exec_sgraph(SGraphPtr* sgraph, const vector& inputs) { + ShapeKey key; + key.shapes.reserve(inputs.size()); + for (auto v : inputs) { + key.shapes.push_back(v->var->shape); + } + auto& g = *sgraph->ptr; + auto it = g.shape_values.find(key); + ASSERT(it != g.shape_values.end()); + auto& value = it->second; + restore_shape_value(g.bfs_q, value); + + vector jit_entries(g.rid_ops.size()); + memcpy(&jit_entries[0], &value.values[value.values.size() - jit_entries.size()], jit_entries.size()*sizeof(jit_op_entry_t)); + + ASSERT(inputs.size() == g.inputs.size()); + for (int i=0; ivar; + auto* v = g.inputs[i]; + if (v != v2) { + if (v->mem_ptr) { + free_var_mem(v); + } + ASSERT(v2->mem_ptr); + v->mem_ptr = v2->mem_ptr; + v->allocator = v2->allocator; + v->allocation = v2->allocation; + v->shape = v2->shape; + v->num = v2->num; + v->size = v2->size; + v->allocator->share_with(v->size, v->allocation); + } + } + + + auto allocator = get_allocator(); + auto temp_allocator = get_allocator(true); + exe.allocator = allocator; + exe.temp_allocator = temp_allocator; + auto& last_is_cuda = exe.last_is_cuda; + + vector& vars = g.outputs; + vector& bfs_q = g.bfs_q; + unordered_map>& share_map = g.share_map; + vector& flags = g.flags; + + vector& fused_ops = g.fused_ops; + vector& rid_ops = g.rid_ops; + vector& v_last_rid = g.v_last_rid; + + constexpr int is_output = 0; + constexpr int is_new_var = 1; + constexpr int is_share = 2; + auto min_id = bfs_q.front()->id; + auto max_id = bfs_q.back()->id; + + auto lived = [&](Node* n) { return n->id>=min_id && n->id<=max_id; }; + auto get_flags = [&](Node* n, int f) -> int { + if (!lived(n)) return 0; + return (flags[n->id-min_id]>>f)&1; + }; + auto set_flags = [&](Node* n, int f) { + if (!lived(n)) return; + flags[n->id-min_id] |= (1<type() != OpType::other; + try { + for (auto* var : op->outputs()) + var->alloc(allocator); + if (PREDICT_BRANCH_NOT_TAKEN(profile_memory_enable)) + memory_profiler.check(); + LOGvvv << "Run" << op << "inputs:" << op->inputs() << "outputs:" << op->outputs(); + // op->do_prepare(jkl); + bool is_cuda = op->flags.get(NodeFlags::_cuda); + #ifdef HAS_CUDA + if (!is_cuda) { + if (last_is_cuda) { + // if prev op in gpu and this op in cpu + // cuda sync + checkCudaErrors(cudaDeviceSynchronize()); + sync_times++; + } + for (Var* v : op->inputs()) { + if (v->allocator->is_cuda()) + migrate_to_cpu(v, allocator); + } + if (!use_cuda_managed_allocator) { + for (auto* var : op->outputs()) + if (var->allocator->is_cuda()) + migrate_to_cpu(var, allocator); + } + } else { + for (Var* v : op->inputs()) { + if (!v->allocator->is_cuda()) + migrate_to_gpu(v, allocator); + } + for (Var* v : op->outputs()) { + if (!v->allocator->is_cuda()) + migrate_to_gpu(v, allocator); + } + } + #endif + last_is_cuda = is_cuda; + // _JT_SEH_START2; + if (profiler_enable) + op->do_run(); + else { + jit_op_entry_t& jit_entry = jit_entries[rid]; + jit_entry(op); + } + // _JT_SEH_END2; + #ifdef HAS_CUDA + // migrate to gpu + if (PREDICT_BRANCH_NOT_TAKEN((!is_cuda && use_cuda && !use_cuda_managed_allocator))) { + for (Var* v : op->outputs()) { + migrate_to_gpu(v, allocator); + } + } + #endif + #ifdef JT_CHECK_NAN + for (Var* var : op->outputs()) + check_nan(var, op); + #endif + #ifdef JT_SYNC + #ifdef HAS_CUDA + checkCudaErrors(cudaGetLastError()); + checkCudaErrors(cudaDeviceSynchronize()); + #endif + #endif + LOGvvv << "Finished Op(" >> op->name() << rid >> + "/" >> rid_ops.size() >> ") output:" << op->outputs(); + for (Var* v : op->inputs()) + if (get_flags(v, is_new_var) && !get_flags(v, is_output) && v_last_rid[v->id-min_id] == rid) { + if (v->mem_ptr) + free_var_mem(v); + if (get_flags(v, is_share)) { + // recover share var + auto kv = share_map.find(v)->second; + v->allocator = (Allocator*)kv.first; + v->allocation = kv.second; + } + } + for (Var* v : op->outputs()) { + if (!get_flags(v, is_new_var) && !get_flags(v, is_output) && v->mem_ptr) { + // this output is not used in this graph, so we free it directly + free_var_mem(v); + } + } + } catch (const std::exception& e) { + // log memory info + display_memory_info(__FILELINE__, false, true); + // log jit_key and file location + op->do_prepare(jkl); + string jit_src_path = Op::get_filename_from_jit_key(jkl.to_cstring(), ".cc"); + jittor::Log logf(__FILELINE__, 'f', 0); + logf << "\nExecute fused operator(" >> rid >> '/' >> rid_ops.size() >> ")" + << "failed."; + if (jit_compiler::file_exist(jit_src_path)) + logf << "\n[JIT Source]:" << jit_src_path << "\n"; + check_op_async_error(op, is_fused_op, e, logf); + } + } + for (Var* v : vars) ASSERT(v->mem_ptr || v->flags.get(NodeFlags::_is_swapped) || !v->backward_liveness) << v; + // clean fetcher free buffer + // fetcher_to_free.clear(); + #ifdef HAS_CUDA + event_queue.flush(); + #endif + vector ret; + ret.reserve(vars.size()); + for (Var* v : vars) { + ASSERT(get_flags(v, is_new_var)); + ret.push_back(get_output(v)); + if (get_flags(v, is_share)) { + // recover share var + auto kv = share_map.find(v)->second; + v->allocator = (Allocator*)kv.first; + v->allocation = kv.second; + } + } + return ret; +} + +vector delay_fetch(const vector& inputs) { + static vector prev_vars; + static cudaEvent_t event; + static bool init = false; + if (!init) { + init = true; + checkCudaErrors(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); + } + + sync(inputs); + vector ret; + ret.reserve(prev_vars.size()); + for (auto& v : prev_vars) { + ret.push_back(new VarHolder(move(v))); + } + prev_vars.clear(); + prev_vars.reserve(inputs.size()); + for (auto& v : inputs) { + VarPtr vp(v->var->shape, v->var->dtype()); + vp->alloc(cpu_allocator); + vp->finish_pending_liveness(); + cudaMemcpyAsync(vp->mem_ptr, v->var->mem_ptr, v->var->size, cudaMemcpyDeviceToHost, 0); + prev_vars.emplace_back(move(vp)); + } + cudaEventSynchronize(event); + cudaEventRecord(event, 0); + return ret; +} + +} diff --git a/python/jittor/vcompiler/vcompiler.h b/python/jittor/vcompiler/vcompiler.h new file mode 100644 index 00000000..a637ec51 --- /dev/null +++ b/python/jittor/vcompiler/vcompiler.h @@ -0,0 +1,97 @@ +#include "common.h" +#include "var_holder.h" +#include "pyjt/py_converter.h" +#include "fused_op.h" + +namespace jittor { + +struct ShapeKey { + vector shapes; +}; + +struct ShapeKeyHash { + std::size_t operator()(const ShapeKey& key) const { + std::size_t h = 0; + for (int i=0; i values; +}; + +// unordered_map shape_values; + +struct SGraph { + vector outputs; + vector inputs; + vector bfs_q; + unordered_map> share_map; + vector flags; + + vector fused_ops; + vector rid_ops; + vector v_last_rid; + + std::unordered_map shape_values; + int shape_value_len; +}; + +// @pyjt(SGraphPtr) +struct SGraphPtr { + std::unique_ptr ptr; +}; + +// SGraphPtr +struct SGraphPtr; +EXTERN_LIB PyTypeObject PyjtSGraphPtr; +DEF_IS(SGraphPtr, bool) is_type(PyObject* obj) { + return Py_TYPE(obj) == &PyjtSGraphPtr; +} +DEF_IS(SGraphPtr*, bool) is_type(PyObject* obj) { + return Py_TYPE(obj) == &PyjtSGraphPtr; +} + + +DEF_IS(SGraphPtr, PyObject*) to_py_object(T&& a) { + PyObjHolder obj(_PyObject_New(&PyjtSGraphPtr)); + auto ptr = GET_RAW_PTR(T, obj.obj); + new (ptr) T(); + ptr->ptr = std::move(a.ptr); + return obj.release(); +} + +DEF_IS(SGraphPtr, const T&) from_py_object(PyObject* obj) { + return GET_RAW_PTR(T, obj); +} + +DEF_IS(SGraphPtr*, T) from_py_object(PyObject* obj) { + return GET_RAW_PTR(typename std::remove_pointer::type, obj); +} + +// @pyjt(build_sgraph) +SGraphPtr build_sgraph(const vector& outputs, const vector& inputs); + +// @pyjt(prob_sgraph) +bool prob_sgraph(SGraphPtr* sgraph, const vector& inputs); + +// @pyjt(merge_sgraph) +void merge_sgraph(SGraphPtr* sgraph, SGraphPtr* sgraph2); + +// @pyjt(exec_sgraph) +vector exec_sgraph(SGraphPtr* sgraph, const vector& inputs); + +// @pyjt(delay_fetch) +vector delay_fetch(const vector& inputs); + +} \ No newline at end of file diff --git a/python/jittor/vcompiler/vcompiler.py b/python/jittor/vcompiler/vcompiler.py new file mode 100644 index 00000000..5802bc88 --- /dev/null +++ b/python/jittor/vcompiler/vcompiler.py @@ -0,0 +1,154 @@ +import jittor as jt +import os +import jittor_utils +from jittor_utils import lock +import jittor.compiler as compiler +import numpy as np + +dirname = os.path.dirname(__file__) +cc_files = [ dirname + "/vcompiler.cc"] +with open(dirname + "/vcompiler.h") as f: + h_src = f.read() + +with lock.lock_scope(): + mod = jittor_utils.compile_module(h_src, compiler.cc_flags + "-I" + dirname + " " + " ".join(cc_files)) + +for k, v in mod.__dict__.items(): + if k.startswith("_"): + continue + globals()[k] = v + +def dfs(obj, path=""): + if isinstance(obj, jt.Var): + return [((path, len(obj.shape), str(obj.dtype)), obj)] + if isinstance(obj, (list, tuple)): + ret = [] + for i, v in enumerate(obj): + ret += dfs(v, path + "[%d]" % i) + return ret + if isinstance(obj, dict): + ret = [] + for k, v in obj.items(): + ret += dfs(v, path + "[%r]" % k) + return ret + return [] + +def dfs_config(obj): + if isinstance(obj, jt.Var): + return "Var" + if isinstance(obj, (int, float, bool, str, type(None))): + return obj + if isinstance(obj, (list, tuple)): + return [ dfs_config(v) for v in obj ] + if isinstance(obj, dict): + return { k:dfs_config(v) for k, v in obj.items() } + raise ValueError(f"Unknown type {type(obj)}") + +def dfs_clone_var(obj): + if isinstance(obj, jt.Var): + return obj.clone() + if isinstance(obj, (int, float, bool, str, type(None))): + return obj + if isinstance(obj, (list, tuple)): + return [ dfs_clone_var(v) for v in obj ] + if isinstance(obj, dict): + return { k:dfs_clone_var(v) for k, v in obj.items() } + raise ValueError(f"Unknown type {type(obj)}") + +def dfs_fill(obj, vars): + i = 0 + def dfs_fill_var(obj): + nonlocal i + if isinstance(obj, jt.Var): + v = vars[i] + i += 1 + return v + if isinstance(obj, (int, float, bool, str, type(None))): + return obj + if isinstance(obj, (list, tuple)): + return [ dfs_fill_var(v) for v in obj ] + if isinstance(obj, dict): + ret = { k:dfs_fill_var(v) for k, v in obj.items() } + return obj.__class__(ret) + raise ValueError(f"Unknown type {type(obj)}") + return dfs_fill_var(obj) + + +class CachedGraph: + def __init__(self, func, args, kw): + args = dfs_clone_var(args) + kw = dfs_clone_var(kw) + self.func = func + self.inputs = (args, kw) + jt.sync_all() + exec_called = jt.flags.exec_called + self.outputs = func(*args, **kw) + import gc; gc.collect() + assert exec_called == jt.flags.exec_called, (exec_called, jt.flags.exec_called) + self.outputs_parsed = dfs(self.outputs) + self.outputs_var = [ v for _, v in self.outputs_parsed ] + self.inputs_parsed = dfs(self.inputs) + self.inputs_var = [ v for _, v in self.inputs_parsed ] + self.inputs_key = str([ key for key, _ in self.inputs_parsed ]) + for v in self.outputs_var: + v.release_from_holders() + for v in self.inputs_var: + v.release_from_holders() + self.sgraph = mod.build_sgraph(self.outputs_var, self.inputs_var) + +# a function decorator +# build new graph: +# 1. shape dim changed +# 2. dtype changed +# 3. var path changed +# graph key: +# (args, kw), [ (var_path, shape dim, dtype), var ] +def build(func, debug=False, fallback_func=None): + cache = {} + def func_wrapper(*args, **kw): + if fallback_func and fallback_func(*args, **kw): + return func(*args, **kw) + inputs = (args, kw) + config_key = str(dfs_config(inputs)) + inputs_parsed = dfs(inputs) + inputs_key = str([ key for key, _ in inputs_parsed ]) + inputs_var = [ v for _, v in inputs_parsed ] + jt.sync(inputs_var) + all_key = config_key + inputs_key + if all_key not in cache: + # print(f"create graph with key '{all_key[:30]}'...") + cache[all_key] = CachedGraph(func, args, kw) + graph = cache[all_key] + if not mod.prob_sgraph(graph.sgraph, inputs_var): + # print(f"merge graph with key '{all_key[:30]}'...") + graph2 = CachedGraph(func, args, kw) + mod.merge_sgraph(graph.sgraph, graph2.sgraph) + outputs = mod.exec_sgraph(graph.sgraph, inputs_var) + if debug: + graph2 = CachedGraph(func, args, kw) + outputs2 = mod.exec_sgraph(graph2.sgraph, inputs_var) + for v1, v2 in zip(outputs, outputs2): + np.testing.assert_allclose(v1.data, v2.data, rtol=0.01, atol=0.05) + return dfs_fill(graph.outputs, outputs) + return func_wrapper + +# c interface +# build_sgraph -> sgraph +# merge_sgraph +# exec_sgraph +# prob_sgraph: check sgraph can exec +# if prob_sgraph failed +# build_sgraph +# merge_sgraph +# exec_sgraph + + + +# overall code: +# 1. get input_key from (args, kw) +# 2. if input_key not in cache +# graph.outputs = func(*args, **kw) +# graph.outputs_var = var_parser(graph.outputs) +# graph.sgraph = build_sgraph(outputs_var) +# graph.inputs = (args, kw) +# graph.inputs_var = var_parser(args, kw) diff --git a/python/jittor/version b/python/jittor/version new file mode 100644 index 00000000..98d3c70f --- /dev/null +++ b/python/jittor/version @@ -0,0 +1 @@ +939b29514b2e5cc591053aab614efd569772585d diff --git a/python/jittor/weightnorm.py b/python/jittor/weightnorm.py new file mode 100644 index 00000000..12ee178d --- /dev/null +++ b/python/jittor/weightnorm.py @@ -0,0 +1,86 @@ +import jittor as jt +from jittor import nn + +def _weight_norm(v, g, dim): + return v * (g / jt.norm(v, 2, dim, keepdim=True)) + +class WeightNorm(object): + def __init__(self, name: str, dim: int) -> None: + if dim is None: + dim = -1 + self.name = name + self.dim = dim + + # TODO Make return type more specific + def compute_weight(self, module: nn.Module): + g = getattr(module, self.name + '_g') + v = getattr(module, self.name + '_v') + return _weight_norm(v, g, self.dim) + + @staticmethod + def apply(module, name: str, dim: int): + if hasattr(module, '__fhook2__') and isinstance(module.__fhook2__, WeightNorm): + raise RuntimeError("Cannot register two weight_norm hooks on " + "the same parameter {}".format(name)) + + if dim is None: + dim = -1 + + fn = WeightNorm(name, dim) + + weight = getattr(module, name) + # todo: add check + # remove w from parameter list + # del module._parameters[name] + delattr(module, name) + + # add g and v as new parameters and express w as g/||v|| * v + module.__setattr__(name + '_g', jt.norm(weight, 2, dim, keepdim=True).detach()) + module.__setattr__(name + '_v', weight.detach()) + setattr(module, name, fn.compute_weight(module)) + + # recompute weight before every forward() + # todo: support multiple hook in a module + module.register_pre_forward_hook(fn) + return fn + + def remove(self, module: nn.Module) -> None: + weight = self.compute_weight(module) + delattr(module, self.name) + delattr(module, self.name + '_g') + delattr(module, self.name + '_v') + setattr(module, self.name, weight.detach()) + + def __call__(self, module: nn.Module, inputs) -> None: + setattr(module, self.name, self.compute_weight(module)) + +def weight_norm(module, name, dim): + ''' Add a module weight normalization. + + :param module: input model. + :param name: name of the assigned parameter. + :param dim: which dim to carry out weightnorm. + + Example:: + + class jt_module(jt.nn.Module): + def __init__(self, weight): + super().__init__() + self.linear = jt.array(weight) + + def execute(self, x): + return jt.matmul(self.linear, x) + + jm = jt_module(weight) + weight_norm(jm, 'linear', -1) + + ''' + WeightNorm.apply(module, name, dim) + return module + +def remove_weight_norm(module, name: str = 'weight'): + if hasattr(module, "__fhook2__") and isinstance(module.__fhook2__, WeightNorm): + delattr(module, "__fhook2__") + return module + raise ValueError("weight_norm of '{}' not found in {}" + .format(name, module)) \ No newline at end of file diff --git a/python/jittor_utils/__init__.py b/python/jittor_utils/__init__.py new file mode 100644 index 00000000..9ac774b4 --- /dev/null +++ b/python/jittor_utils/__init__.py @@ -0,0 +1,757 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +from multiprocessing import Pool +import multiprocessing as mp +import subprocess as sp +import os +import re +import sys +import inspect +import datetime +import contextlib +import platform +import threading +import time +from ctypes import cdll +import shutil +import urllib.request +import ctypes + +if platform.system() == 'Darwin': + mp.set_start_method('fork') + +from pathlib import Path +import json + + +_jittor_home = None +def home(): + global _jittor_home + if _jittor_home is not None: + return _jittor_home + + src_path = os.path.join(str(Path.home()),".cache","jittor") + os.makedirs(src_path,exist_ok=True) + src_path_file = os.path.join(src_path,"config.json") + data = {} + if os.path.exists(src_path_file): + with open(src_path_file,"r") as f: + data = json.load(f) + + default_path = data.get("JITTOR_HOME", str(Path.home())) + + _home_path = os.environ.get("JITTOR_HOME", default_path) + + if not os.path.exists(_home_path): + os.makedirs(_home_path, exist_ok=True) + _home_path = os.path.abspath(_home_path) + + # LOG.i(f"Use {_home_path} as Jittor Home") + if default_path != _home_path: + with open(src_path_file,"w") as f: + data['JITTOR_HOME'] = _home_path + json.dump(data,f) + + _jittor_home = _home_path + return _home_path + +class Logwrapper: + def __init__(self): + self.log_silent = int(os.environ.get("log_silent", "0")) + self.log_v = int(os.environ.get("log_v", "0")) + + def log_capture_start(self): + cc.log_capture_start() + + def log_capture_stop(self): + cc.log_capture_stop() + + def log_capture_read(self): + return cc.log_capture_read() + + def _log(self, level, verbose, *msg): + if self.log_silent or verbose > self.log_v: + return + ss = "" + for m in msg: + if callable(m): + m = m() + ss += str(m) + msg = ss + f = inspect.currentframe() + fileline = inspect.getframeinfo(f.f_back.f_back) + fileline = f"{os.path.basename(fileline.filename)}:{fileline.lineno}" + if cc and hasattr(cc, "log"): + cc.log(fileline, level, verbose, msg) + else: + time = datetime.datetime.now().strftime("%m%d %H:%M:%S.%f") + tid = threading.get_ident()%100 + v = f" v{verbose}" if verbose else "" + print(f"[{level} {time} {tid:02}{v} {fileline}] {msg}") + + def V(self, verbose, *msg): self._log('i', verbose, *msg) + def v(self, *msg): self._log('i', 1, *msg) + def vv(self, *msg): self._log('i', 10, *msg) + def vvv(self, *msg): self._log('i', 100, *msg) + def vvvv(self, *msg): self._log('i', 1000, *msg) + def i(self, *msg): self._log('i', 0, *msg) + def w(self, *msg): self._log('w', 0, *msg) + def e(self, *msg): self._log('e', 0, *msg) + def f(self, *msg): self._log('f', 0, *msg) + +class DelayProgress: + def __init__(self, msg, n): + self.msg = msg + self.n = n + self.time = time.time() + + def update(self, i): + if LOG.log_silent: + return + used = time.time() - self.time + if used > 2: + eta = used / (i+1) * (self.n-i-1) + print(f"{self.msg}({i+1}/{self.n}) used: {used:.3f}s eta: {eta:.3f}s", end='\r') + if i==self.n-1: print() + +# check is in jupyter notebook +def in_ipynb(): + try: + cfg = get_ipython().config + if 'IPKernelApp' in cfg: + return True + else: + return False + except: + return False + +@contextlib.contextmanager +def simple_timer(name): + print("Timer start", name) + now = time.time() + yield + print("Time stop", name, time.time()-now) + +@contextlib.contextmanager +def import_scope(flags): + if os.name != 'nt': + prev = sys.getdlopenflags() + sys.setdlopenflags(flags) + yield + if os.name != 'nt': + sys.setdlopenflags(prev) + +def try_import_jit_utils_core(silent=None): + global cc + if cc: return + if not (silent is None): + prev = os.environ.get("log_silent", "0") + os.environ["log_silent"] = str(int(silent)) + try: + # if is in notebook, must log sync, and we redirect the log + if is_in_ipynb: os.environ["log_sync"] = "1" + import jit_utils_core as cc + if is_in_ipynb: + if os.name != 'nt': + # windows jupyter has import error + # disable ostream redirect + # TODO: find a better way + cc.ostream_redirect(True, True) + except Exception as _: + if int(os.environ.get("log_v", "0")) > 0: + print(_) + pass + if not (silent is None): + os.environ["log_silent"] = prev + +def run_cmd(cmd, cwd=None, err_msg=None, print_error=True): + LOG.v(f"Run cmd: {cmd}") + if cwd: + r = sp.run(cmd, cwd=cwd, shell=True, stdout=sp.PIPE, stderr=sp.STDOUT) + else: + r = sp.run(cmd, shell=True, stdout=sp.PIPE, stderr=sp.STDOUT) + try: + s = r.stdout.decode('utf8') + except: + s = r.stdout.decode('gbk') + if r.returncode != 0: + if print_error: + sys.stderr.write(s) + if err_msg is None: + err_msg = f"Run cmd failed: {cmd}" + if not print_error: + err_msg += "\n"+s + raise Exception(err_msg) + if len(s) and s[-1] == '\n': s = s[:-1] + return s + + +def do_compile(args): + cmd, cache_path, jittor_path = args + try_import_jit_utils_core(True) + if cc: + return cc.cache_compile(cmd, cache_path, jittor_path) + else: + run_cmd(cmd) + return True + +pool_size = 0 + +def pool_cleanup(): + global p + p.__exit__(None, None, None) + del p + +def pool_initializer(): + if os.name == 'nt': + os.environ['log_silent'] = '1' + os.environ['gdb_path'] = "" + if cc is None: + try_import_jit_utils_core() + if cc: + cc.init_subprocess() + +def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"): + global pool_size, p + bk = mp.current_process()._config.get('daemon') + mp.current_process()._config['daemon'] = False + if pool_size == 0: + try: + mem_bytes = get_total_mem() + mem_gib = mem_bytes/(1024.**3) + pool_size = min(16,max(int(mem_gib // 3), 1)) + LOG.i(f"Total mem: {mem_gib:.2f}GB, using {pool_size} procs for compiling.") + except ValueError: + # On macOS, python with version lower than 3.9 do not support SC_PHYS_PAGES. + # Use hard coded pool size instead. + pool_size = 4 + LOG.i(f"using {pool_size} procs for compiling.") + if os.name == 'nt': + # a hack way to by pass windows + # multiprocess spawn init_main_from_path. + # check spawn.py:get_preparation_data + spec_bk = sys.modules['__main__'].__spec__ + tmp = lambda x:x + tmp.name = '__main__' + sys.modules['__main__'].__spec__ = tmp + p = Pool(pool_size, initializer=pool_initializer) + p.__enter__() + if os.name == 'nt': + sys.modules['__main__'].__spec__ = spec_bk + import atexit + atexit.register(pool_cleanup) + cmds = [ [cmd, cache_path, jittor_path] for cmd in cmds ] + try: + n = len(cmds) + dp = DelayProgress(msg, n) + for i,_ in enumerate(p.imap_unordered(do_compile, cmds)): + dp.update(i) + finally: + mp.current_process()._config['daemon'] = bk + +if os.name=='nt' and getattr(mp.current_process(), '_inheriting', False): + # when windows spawn multiprocess, disable sub-subprocess + os.environ["DISABLE_MULTIPROCESSING"] = '1' + os.environ["log_silent"] = '1' + +if os.environ.get("DISABLE_MULTIPROCESSING", '0') == '1': + os.environ["use_parallel_op_compiler"] = '0' + def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"): + cmds = [ [cmd, cache_path, jittor_path] for cmd in cmds ] + n = len(cmds) + dp = DelayProgress(msg, n) + for i,cmd in enumerate(cmds): + dp.update(i) + do_compile(cmd) + + +def download(url, filename): + if os.path.isfile(filename): + if os.path.getsize(filename) > 100: + return + LOG.v("Downloading", url) + urllib.request.urlretrieve(url, filename) + LOG.v("Download finished") + +def get_jittor_version(): + path = os.path.dirname(__file__) + with open(os.path.join(path, "../jittor/__init__.py"), "r", encoding='utf8') as fh: + for line in fh: + if line.startswith('__version__'): + version = line.split("'")[1] + break + else: + raise RuntimeError("Unable to find version string.") + return version + +def get_str_hash(s): + import hashlib + md5 = hashlib.md5() + md5.update(s.encode()) + return md5.hexdigest() + +def get_cpu_version(): + v = platform.processor() + try: + if os.name == 'nt': + import winreg + key_name = r"Hardware\Description\System\CentralProcessor\0" + field_name = "ProcessorNameString" + key = winreg.OpenKey(winreg.HKEY_LOCAL_MACHINE, key_name) + value = winreg.QueryValueEx(key, field_name)[0] + winreg.CloseKey(key) + v = value + elif platform.system() == "Darwin": + r, s = sp.getstatusoutput("sysctl -a sysctl machdep.cpu.brand_string") + if r==0: + v = s.split(":")[-1].strip() + else: + with open("/proc/cpuinfo", 'r') as f: + for l in f: + if l.startswith("model name"): + v = l.split(':')[-1].strip() + break + except: + pass + return v + +def short(s): + ss = "" + for c in s: + if str.isidentifier(c) or str.isnumeric(c) \ + or str.isalpha(c) or c in '.-+': + ss += c + if len(ss)>14: + return ss[:14]+'x'+get_str_hash(ss)[:2] + return ss + +def find_cache_path(): + path = home() + # jittor version key + jtv = "jt"+get_jittor_version().rsplit('.', 1)[0] + # cc version key + ccv = cc_type+get_version(cc_path)[1:-1] \ + if cc_type != "cl" else cc_type + # os version key + osv = platform.platform() + platform.node() + if len(osv)>14: + osv = osv[:14] + 'x'+get_str_hash(osv)[:2] + # py version + pyv = "py"+platform.python_version() + # cpu version + cpuv = get_cpu_version() + jittor_path_key = get_str_hash(__file__)[:4] + dirs = [".cache", "jittor", jtv, ccv, pyv, osv, cpuv, jittor_path_key] + dirs = list(map(short, dirs)) + cache_name = "default" + try: + if "cache_name" in os.environ: + cache_name = os.environ["cache_name"] + else: + # try to get branch name from git + r = sp.run(["git","branch"], cwd=os.path.dirname(__file__), stdout=sp.PIPE, + stderr=sp.PIPE) + assert r.returncode == 0 + bs = r.stdout.decode().splitlines() + for b in bs: + if b.startswith("* "): break + + cache_name = b[2:] + for c in " (){}": cache_name = cache_name.replace(c, "_") + except: + pass + if os.environ.get("debug")=="1": + dirs[-1] += "_debug" + for name in os.path.normpath(cache_name).split(os.path.sep): + dirs.append(name) + os.environ["cache_name"] = cache_name + LOG.v("cache_name: ", cache_name) + path = os.path.join(path, *dirs) + os.makedirs(path, exist_ok=True) + if path not in sys.path: + sys.path.append(path) + return path + +def get_version(output): + if output.endswith("mpicc"): + version = run_cmd(f"\"{output}\" --showme:version") + elif os.name == 'nt' and ( + output.endswith("cl") or output.endswith("cl.exe")): + version = run_cmd(output) + else: + version = run_cmd(f"\"{output}\" --version") + v = re.findall("[0-9]+\\.[0-9]+\\.[0-9]+", version) + if len(v) == 0: + v = re.findall("[0-9]+\\.[0-9]+", version) + assert len(v) != 0, f"Can not find version number from: {version}" + if 'clang' in version and platform.system() == 'Darwin': + version = "("+v[-3]+")" + else: + version = "("+v[-1]+")" + return version + +def get_int_version(output): + ver = get_version(output) + ver = ver[1:-1].split('.') + ver = tuple(( int(v) for v in ver )) + return ver + +def find_exe(name, check_version=True, silent=False): + output = shutil.which(name) + if not output: + raise RuntimeError(f"{name} not found") + if check_version: + version = get_version(name) + else: + version = "" + if not silent: + LOG.i(f"Found {name}{version} at {output}.") + return output + +def env_or_find(name, bname, silent=False): + if name in os.environ: + path = os.environ[name] + if path != "": + version = get_version(path) + if not silent: + LOG.i(f"Found {bname}{version} at {path}") + return path + return find_exe(bname, silent=silent) + +def env_or_try_find(name, bname): + if name in os.environ: + path = os.environ[name] + if path != "": + version = get_version(path) + LOG.i(f"Found {bname}{version} at {path}") + return path + return try_find_exe(bname) + +def try_find_exe(*args): + try: + return find_exe(*args) + except: + LOG.v(f"{args[0]} not found.") + return "" + +def get_cc_type(cc_path): + bname = os.path.basename(cc_path) + if "clang" in bname: return "clang" + if "icc" in bname or "icpc" in bname: return "icc" + if "g++" in bname: return "g++" + if "cl" in bname: return "cl" + LOG.f(f"Unknown cc type: {bname}") + +def get_py3_link_path(): + py3_link_path = os.path.join( + os.path.dirname(sys.executable), + "libs", + ) + if not os.path.exists(py3_link_path): + candidate = [os.path.dirname(sys.executable)] + sys.path + for p in candidate: + p = os.path.join(p, "libs") + if os.path.exists(p): + py3_link_path = p + break + return py3_link_path + +def get_py3_config_path(): + global _py3_config_path + if _py3_config_path: + return _py3_config_path + + if os.name == 'nt': + return None + else: + # Search python3.x-config + # Note: + # This may be called via c++ console. In that case, sys.executable will + # be a path to the executable file, rather than python. So, we cannot infer + # python-config path only from sys.executable. + # To address this issue, we add predefined paths to search, + # - Linux: /usr/bin/python3.x-config + # - macOS: + # - shiped with macOS 13: /Library/Developer/CommandLineTools/Library/Frameworks/ + # Python3.framework/Versions/3.x/lib/python3.x/config-3.x-darwin/python-config.py + # - installed via homebrew: /usr/local/bin/python3.x-config + # There may be issues under other cases, e.g., installed via conda. + py3_config_paths = [ + os.path.dirname(sys.executable) + f"/python3.{sys.version_info.minor}-config", + sys.executable + "-config", + f"/usr/bin/python3.{sys.version_info.minor}-config", + f"/usr/local/bin/python3.{sys.version_info.minor}-config", + os.path.dirname(sys.executable) + "/python3-config", + ] + if platform.system() == "Darwin": + if "homebrew" in sys.executable: + py3_config_paths.append(f'/opt/homebrew/bin/python3.{sys.version_info.minor}-config') + else: + py3_config_paths.append(f'/Library/Developer/CommandLineTools/Library/Frameworks/Python3.framework/'\ + f'Versions/3.{sys.version_info.minor}/lib/python3.{sys.version_info.minor}/'\ + f'config-3.{sys.version_info.minor}-darwin/python-config.py') + + if "python_config_path" in os.environ: + py3_config_paths.insert(0, os.environ["python_config_path"]) + + for py3_config_path in py3_config_paths: + if os.path.isfile(py3_config_path): + break + else: + raise RuntimeError(f"python3.{sys.version_info.minor}-config " + f"not found in {py3_config_paths}, please specify " + f"enviroment variable 'python_config_path'," + f" or install python3.{sys.version_info.minor}-dev") + _py3_config_path = py3_config_path + return py3_config_path + +def get_py3_include_path(): + global _py3_include_path + if _py3_include_path: + return _py3_include_path + + if os.name == 'nt': + # Windows + sys.executable = sys.executable.lower() + candidate = [os.path.dirname(sys.executable)] + sys.path + for p in candidate: + include_path = os.path.join(p, "include") + if os.path.exists(include_path): + break + else: + raise RuntimeError("Python include path not found. please report this bug to us.") + _py3_include_path = '-I"' + include_path + '"' + else: + _py3_include_path = run_cmd(get_py3_config_path()+" --includes") + + # macOS (>=13) is shiped with a fake python3-config which outputs wrong include paths + # check the include paths and fix them + if platform.system() == "Darwin": + is_real_path = False + for include_path in _py3_include_path.strip().split(): + if os.path.exists(include_path[2:]): + is_real_path = True + if not is_real_path: + _py3_include_path = f"-I/Library/Developer/CommandLineTools/Library/Frameworks/"\ + f"Python3.framework/Versions/3.{sys.version_info.minor}/Headers" + return _py3_include_path + + +def get_py3_extension_suffix(): + global _py3_extension_suffix + if _py3_extension_suffix: + return _py3_extension_suffix + + if os.name == 'nt': + # Windows + _py3_extension_suffix = f".cp3{sys.version_info.minor}-win_amd64.pyd" + else: + _py3_extension_suffix = run_cmd(get_py3_config_path()+" --extension-suffix") + return _py3_extension_suffix + +def get_total_mem(): + if os.name == 'nt': + from ctypes import Structure, c_int32, c_uint64, sizeof, byref, windll + class MemoryStatusEx(Structure): + _fields_ = [ + ('length', c_int32), + ('memoryLoad', c_int32), + ('totalPhys', c_uint64), + ('availPhys', c_uint64), + ('totalPageFile', c_uint64), + ('availPageFile', c_uint64), + ('totalVirtual', c_uint64), + ('availVirtual', c_uint64), + ('availExtendedVirtual', c_uint64)] + def __init__(self): + self.length = sizeof(self) + m = MemoryStatusEx() + assert windll.kernel32.GlobalMemoryStatusEx(byref(m)) + return m.totalPhys + else: + return os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES') + +def dirty_fix_pytorch_runtime_error(): + ''' This funtion should be called before pytorch. + + Example:: + + import jittor as jt + jt.dirty_fix_pytorch_runtime_error() + import torch + ''' + import os, platform + + if platform.system() == 'Linux': + os.RTLD_GLOBAL = os.RTLD_GLOBAL | os.RTLD_DEEPBIND + import jittor_utils + with jittor_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW): + import torch + +is_in_ipynb = in_ipynb() +cc = None +LOG = Logwrapper() + +check_msvc_install = False +msvc_path = "" +if os.name == 'nt' and os.environ.get("cc_path", "")=="": + msvc_path = os.path.join(home(), ".cache", "jittor", "msvc") + cc_path = os.path.join(msvc_path, "VC", r"_\_\_\_\_\bin", "cl.exe") + check_msvc_install = True +elif platform.system() == "Darwin": + # macOS has a fake "g++" which is actually clang++, so we search clang. + cc_path = env_or_find('cc_path', 'clang++', silent=True) +else: + cc_path = env_or_find('cc_path', 'g++', silent=True) +os.environ["cc_path"] = cc_path +cc_type = get_cc_type(cc_path) +cache_path = find_cache_path() + +_py3_config_path = None +_py3_include_path = None +_py3_extension_suffix = None +try: + import ssl + ssl._create_default_https_context = ssl._create_unverified_context +except: + pass + +try: + import sys + sys.setrecursionlimit(10**6) + if os.name != 'nt': + import resource + resource.setrlimit(resource.RLIMIT_STACK, (2**29,-1)) +except: + pass + +if os.name == 'nt': + if check_msvc_install: + if not os.path.isfile(cc_path): + from jittor_utils import install_msvc + install_msvc.install(msvc_path) + mpath = os.path.join(home(), ".cache", "jittor", "msvc") + if cc_path.startswith(mpath): + msvc_path = mpath + os.RTLD_NOW = os.RTLD_GLOBAL = os.RTLD_DEEPBIND = 0 + path = os.path.dirname(cc_path).replace('/', '\\') + if path: + sys.path.insert(0, path) + os.environ["PATH"] = path+';'+os.environ["PATH"] + if hasattr(os, "add_dll_directory"): + os.add_dll_directory(path) + +backends = [] +def add_backend(mod): + backends.append(mod) + +from . import lock +@lock.lock_scope() +def compile_module(source, flags): + """ + quick c extension: + Example: + + import jittor as jt + + import jittor_utils + import jittor.compiler as compiler + + + mod = jittor_utils.compile_module(''' + #include "common.h" + namespace jittor { + // @pyjt(hello) + string hello(const string& src) { + LOGir << "hello" << src; + } + }''', compiler.cc_flags) + + mod.hello("aaa") + + """ + tmp_path = os.path.join(cache_path, "tmp") + os.makedirs(tmp_path, exist_ok=True) + hash = "hash_" + get_str_hash(source) + so = get_py3_extension_suffix() + header_name = os.path.join(tmp_path, hash+".h") + source_name = os.path.join(tmp_path, hash+".cc") + lib_name = hash+so + with open(header_name, "w", encoding="utf8") as f: + f.write(source) + from jittor.pyjt_compiler import compile_single + ok = compile_single(header_name, source_name) + assert ok, "no pyjt interface found" + + entry_src = f''' +static void init_module(PyModuleDef* mdef, PyObject* m) {{ + mdef->m_doc = "generated py jittor_utils.compile_module"; + jittor::pyjt_def_{hash}(m); +}} +PYJT_MODULE_INIT({hash}); + ''' + with open(source_name, "r", encoding="utf8") as f: + src = f.read() + with open(source_name, "w", encoding="utf8") as f: + f.write(src + entry_src) + jittor_path = os.path.join(os.path.dirname(__file__), "..", "jittor") + jittor_path = os.path.abspath(jittor_path) + from jittor.compiler import fix_cl_flags + do_compile([fix_cl_flags(f"\"{cc_path}\" \"{source_name}\" \"{jittor_path}/src/pyjt/py_arg_printer.cc\" {flags} -o \"{cache_path+'/'+lib_name}\" "), + cache_path, jittor_path]) + with lock.unlock_scope(): + try: + with import_scope(os.RTLD_GLOBAL | os.RTLD_NOW): + exec(f"import {hash}") + except Exception as e: + with import_scope(os.RTLD_GLOBAL | os.RTLD_LAZY): + exec(f"import {hash}") + + mod = locals()[hash] + return mod + +def process_jittor_source(device_type, callback): + import jittor.compiler as compiler + import shutil + djittor = device_type + "_jittor" + djittor_path = os.path.join(compiler.cache_path, djittor) + os.makedirs(djittor_path, exist_ok=True) + + for root, dir, files in os.walk(compiler.jittor_path): + root2 = root.replace(compiler.jittor_path, djittor_path) + os.makedirs(root2, exist_ok=True) + for name in files: + fname = os.path.join(root, name) + fname2 = os.path.join(root2, name) + if fname.endswith(".h") or fname.endswith(".cc") or fname.endswith(".cu"): + with open(fname, 'r', encoding="utf8") as f: + src = f.read() + src = callback(src, name, {"fname":fname, "fname2":fname2}) + with open(fname2, 'w', encoding="utf8") as f: + f.write(src) + else: + shutil.copy(fname, fname2) + compiler.cc_flags = compiler.cc_flags.replace(compiler.jittor_path, djittor_path) + f" -I\"{djittor_path}/extern/cuda/inc\" " + compiler.jittor_path = djittor_path + +import time +class time_scope: + def __init__(self, name): + self.name = name + def __enter__(self): + self.start_time = time.time() + def __exit__(self, *exc): + self.end_time = time.time() + self.execution_time = self.end_time - self.start_time + print(f"exec[{self.name}] time: {self.execution_time}s") + def __call__(self, func): + def inner(*args, **kw): + with self: + ret = func(*args, **kw) + return ret + return inner + diff --git a/python/jittor_utils/auto_diff.py b/python/jittor_utils/auto_diff.py new file mode 100644 index 00000000..112366a6 --- /dev/null +++ b/python/jittor_utils/auto_diff.py @@ -0,0 +1,397 @@ +import os +from collections import defaultdict +import pickle +import numpy as np +import jittor_utils +import jittor_utils as jit_utils +from jittor_utils import LOG +import sys + +with jittor_utils.import_scope(os.RTLD_GLOBAL | os.RTLD_NOW): + jittor_utils.try_import_jit_utils_core() + +has_error = 0 + +def convert(data): + if hasattr(data, "numpy"): + if "Var" in data.__class__.__name__: + return data.numpy() + else: + return data.detach().cpu().numpy() + if isinstance(data, tuple): + return tuple( convert(v) for v in data ) + if isinstance(data, list): + return [ convert(v) for v in data ] + if isinstance(data, np.ndarray): + return data + if isinstance(data, dict): + return {k:convert(data[k]) for k in data} + return data + +rand_hooked = False + +def hook_pt_rand(*shape, device=None): + import torch + if isinstance(shape, tuple) and len(shape)==1 and isinstance(shape[0], (torch.Size, tuple, list)): + shape = tuple(shape[0]) + np.random.seed(0) + res = torch.from_numpy(np.random.rand(*tuple(shape)).astype("float32")) + if device is not None: + return res.to(device) + return res + +def hook_pt_randn(*shape, device=None): + import torch + if isinstance(shape, tuple) and len(shape)==1 and isinstance(shape[0], (torch.Size, tuple, list)): + shape = tuple(shape[0]) + np.random.seed(0) + print(shape) + res = torch.from_numpy(np.random.randn(*tuple(shape)).astype("float32")) + if device is not None: + return res.to(device) + return res + +def hook_pt_normal(mean, std): + import torch + if hasattr(mean, 'shape'): + shape = tuple(mean.shape) + elif hasattr(std, 'shape'): + shape = tuple(std.shape) + else: + shape = (1,) + + np.random.seed(0) + return torch.from_numpy(np.random.normal(size=shape).astype("float32")).to(std.device) * std + mean + +def hook_jt_rand(shape, dtype="float32", rtype="uniform"): + import jittor + np.random.seed(0) + if rtype == "normal": + return jittor.array(np.random.normal(size=shape).astype(str(dtype))) + return jittor.array(np.random.rand(*shape).astype(str(dtype))) + +def hook_rand(): + global rand_hooked + if rand_hooked: return + rand_hooked = True + np.random.seed(0) + if "torch" in sys.modules: + LOG.i("Hook torch.rand") + torch = sys.modules["torch"] + torch.rand = hook_pt_rand + torch.normal = hook_pt_normal + torch.randn = hook_pt_randn + torch.manual_seed(0) + if "jittor" in sys.modules: + jittor = sys.modules["jittor"] + LOG.i("Hook jittor.random") + jittor.random = hook_jt_rand + jittor.seed(0) + + +class Hook: + def __init__(self, base_name, rtol=5e-2, atol=1e-3): + if os.environ.get("use_auto_diff", '1') == '0': + return + hook_rand() + self.rid = 0 + self.base_name = base_name + self.base_path = os.path.join(jit_utils.home(), ".cache", "jittor", "auto_diff", base_name) + if not os.path.exists(self.base_path): + os.makedirs(self.base_path, exist_ok=True) + self.mode = 'save' + else: + self.mode = 'check' + + self.record_status = defaultdict(int) + self.rtol = rtol + self.atol = atol + self.param_name_map = {} + self.hooked_models = {} + LOG.i(f"Jittor AutoDiff: [{self.mode}] mode") + LOG.i("Use cache path:", self.base_path) + LOG.i(f"rtol:{rtol} atol:{atol}") + + def registe_param_name(self, p, name): + self.param_name_map[id(p)] = name + + def get_param_name(self, p): + if id(p) not in self.param_name_map: + LOG.w("Param name not found", p.shape, id(p)) + return "noname"+str(list(p.shape)) + return self.param_name_map[id(p)] + + def check_array(self, name, a, b): + rtol = self.rtol + atol = self.atol + global has_error + err = np.abs(a-b) + tol = atol + rtol * np.abs(b) + is_error = np.logical_or( err > tol, (a>=-1e-5)!=(b>=-1e-5)) + index = np.where(is_error) + assert len(index)>0 + if len(index[0]) == 0: + return + + has_error += 1 + LOG.w(f"Ndarray <{name}> not match, shape:{a.shape}") + i = tuple( i[0] for i in index ) + err_rate = is_error.mean() + LOG.w(f"error index at [{i}], a({a[i]}) b({b[i]}) err({err[i]}) > tol({tol[i]}), err_rate:{err_rate*100:.3f}% amean({a.mean()}) bmean({b.mean()}) astd({a.std()}) bstd({b.std()}) ") + if err_rate > 0.01: + LOG.e("!"*10+"Very HIGH err rate"+"!"*10) + + def check(self, name, pre_data, data): + global has_error + if pre_data is None and isinstance(data, np.ndarray): + if (data==0).all(): + LOG.i(f"name {name} is None") + else: + LOG.e(f"name {name} is non-zero") + return + if type(pre_data) != type(data): + LOG.e(f"type not match, {pre_data.__class__.__name__}!={data.__class__.__name__}, name: {name}") + has_error += 1 + return + if isinstance(pre_data, (list, tuple)): + if len(pre_data) != len(data): + has_error += 1 + LOG.e(f"Name <{name}> len not match, {len(pre_data)} != {len(data)}") + n = max(len(pre_data), len(data)) + for i in range(n): + a = pre_data[i] if i not match {pre_data.shape} != {data.shape}") + return + self.check_array(name, pre_data, data) + elif isinstance(pre_data, dict): + if len(pre_data) != len(data): + has_error += 1 + LOG.w(f"Dict Name <{name}> len not match, {len(pre_data)} != {len(data)}") + for k in pre_data: + pv = pre_data[k] + if k not in data: + has_error += 1 + msg = f"Key <{k}> not in data, Name <{name}>" + if isinstance(pv, np.ndarray): + LOG.e(msg) + else: + LOG.w(msg) + continue + self.check(name+f".{k}", pre_data[k], data[k]) + else: + if pre_data != data: + has_error += 1 + LOG.e(f"Type: {type(pre_data).__name__} Name <{name}> not match {pre_data} != {data}") + + def record(self, name, data, ex_name=""): + if os.environ.get("use_auto_diff", '1') == '0': + return + self.record_status[name] += 1 + fpath = os.path.join(self.base_path, f"{name}-{self.record_status[name]}.pkl") + data = convert(data) + self.rid += 1 + + if self.mode == 'check': + if os.path.isfile(fpath): + with open(fpath, 'rb') as f: + pre_name, pre_data = pickle.load(f) + LOG.i(f"check {self.rid}:<{ex_name}{name}> ...") + self.check(ex_name+name, pre_data, data) + else: + global has_error + has_error += 1 + LOG.e(f"No previous result found: {name}") + return + else: + with open(fpath, 'wb') as f: + pickle.dump((name, data), f) + LOG.i(f"save {self.rid}:<{name}> ok") + + def record_params(self, parameters_dict, mod_name=""): + if os.environ.get("use_auto_diff", '1') == '0': + return + global has_error + pps = {} + for k, v in parameters_dict.items(): + if k.endswith("num_batches_tracked"): + continue + pps[k] = v + ps = { name:convert(param) for name, param in pps.items() } + rec_name = f"{mod_name}_params" + rec_name = f"{rec_name}-{self.record_status[rec_name]}" + self.record_status[rec_name] += 1 + fpath = os.path.join(self.base_path, rec_name+".pkl") + + if self.mode == 'check': + with open(fpath, 'rb') as f: + prev_ps = pickle.load(f) + if len(prev_ps) != len(ps): + has_error += 1 + LOG.e(f"Params len not match {len(prev_ps)} != {len(ps)}") + for k in ps: + a = ps[k] + if k not in prev_ps: + has_error += 1 + LOG.e(f"prev param <{k}> not found.") + continue + b = prev_ps[k] + if a.shape != b.shape: + has_error += 1 + LOG.e(f"Params <{k}> shape not match {a.shape} != {b.shape}") + continue + std_a, mean_a = a.std(), a.mean() + std_b, mean_b = b.std(), b.mean() + n = a.size + # law of large number + std_mean_a = (std_a+std_b)/2 / np.sqrt(n) + 1e-6 + std_std_a = (std_a+std_b)/2 / np.sqrt((n-1)/2) + 1e-6 + x = 4 + if np.abs(mean_a - mean_b) > x * std_mean_a: + has_error += 1 + LOG.e(f"param mean not match, mean_a:{mean_a}, mean_b:{mean_b}, acceptable range:({mean_a - x * std_mean_a}, {mean_a + x * std_mean_a}) name:{k} shape:{a.shape}") + elif np.abs(std_a - std_b) > x * std_std_a: + has_error += 1 + LOG.e(f"param std not match, std_a:{std_a}, std_b:{std_b}, acceptable range:({std_a - x * std_std_a}, {std_a + x * std_std_a}) name:{k} shape:{a.shape}") + else: + LOG.i(f"check param ok: <{k}> shape:{a.shape}") + var = pps[k] + if hasattr(var, "copy_"): + import torch + var.data.copy_(torch.from_numpy(b)) + else: + var.assign(b) + else: + with open(fpath, 'wb') as f: + pickle.dump(ps, f) + LOG.i(f"save params ok") + + def hook_function(self, func): + name = func.__name__ + def new_func(*args, **kw): + ret = func(*args, **kw) + self.record(name+".args", args) + self.record(name+".kw", kw) + self.record(name+".ret", ret) + return ret + return new_func + + def hook_module(self, mod, mod_name=""): + if os.environ.get("use_auto_diff", '1') == '0': + return + if mod_name != "": + mod_name = "<" + mod_name + ">" + self.hooked_models[mod_name] = mod + def forward_hook(self2, input, output, kw=None): + ex_name = '[' + self2.__class__.__name__ + ']' + if "relu" not in self2.__class__.__name__.lower(): + # not test relu, because input may be inplaced + self.record(self2.__ad_mod_name__+".input", input, ex_name) + self.record(self2.__ad_mod_name__+".output", output, ex_name) + if kw is not None: + self.record(self2.__ad_mod_name__+".kw", kw, ex_name) + + names = [] + for name, module in mod.named_modules(): + ns = name.split('.') + skip = 0 + for n in ns: + if n.startswith('_'): + skip = 1 + if skip: + LOG.i("skip", name) + continue + name = mod_name + name + module.__ad_mod_name__ = name + names.append(name) + module.register_forward_hook(forward_hook) + mod_class_name = module.__class__.__name__.lower() + # make dropout in eval mod and record dropout.p + if "dropout" in mod_class_name: + self.record(name+'.p', module.p, "["+mod_class_name+"]") + module.eval() + ps = { mod_name+k:v for k, v in mod.state_dict().items() } + self.record_params(ps, mod_name) + self.record("module names", names) + + def hook_optimizer(self, opt, opt_name=""): + ''' + net = Model() + opt = optim.SGD(net.parameters(), 0.1) + hook.hook_optimizer(opt) + ''' + if os.environ.get("use_auto_diff", '1') == '0': + return + origin_step = opt.step + ex_name = '['+opt.__class__.__name__+']' + def step_hook(*args, **kw): + origin_step(*args, **kw) + for mname, mod in self.hooked_models.items(): + for pname, p in mod.named_parameters(): + self.registe_param_name(p, pname) + self.record(opt_name+".default", opt.defaults, ex_name) + gid = 0 + n_params = 0 + for pg in opt.param_groups: + for p in pg["params"]: + if hasattr(p, "is_stop_grad"): + if p.is_stop_grad(): + continue + n_params += 1 + else: + n_params += 1 + + self.record(opt_name+".n_params", n_params, ex_name) + + for pg in opt.param_groups: + for i, p in reversed(list(enumerate(pg["params"]))): + if hasattr(p, "is_stop_grad"): + if p.is_stop_grad(): + continue + grad = pg["grads"][i] + else: + grad = p.grad + pname = self.get_param_name(p) + self.record(pname+".grad", grad, f"<{opt_name}.grads[{gid}]>") + self.record(pname, p, f"<{opt_name}.params[{gid}]>") + gid += 1 + opt.step = step_hook + + def save_input(self, *data): + ''' + for input, label in torch_dataloader: + hook.save_input(data) + ''' + if self.mode == "save": + self.record_status["[input]"] += 1 + fpath = os.path.join(self.base_path, f"__input-{self.record_status['[input]']}.pkl") + with open(fpath, 'wb') as f: + pickle.dump(convert(data), f) + LOG.i(f"save input: ok") + else: + raise RuntimeError("save_input is invalid in [check] mode") + + def load_input(self): + ''' + for fake_input, fake_label in jittor_dataset: + input, label = hook.load_input() + input = jt.array(input) + label = jt.array(label) + ''' + if self.mode == "check": + self.record_status["[input]"] += 1 + fpath = os.path.join(self.base_path, f"__input-{self.record_status['[input]']}.pkl") + with open(fpath, 'rb') as f: + data = pickle.load(f) + LOG.i(f"load input: ok") + return data + else: + raise RuntimeError("load_input is invalid in [save] mode") diff --git a/python/jittor_utils/auto_git_tag.py b/python/jittor_utils/auto_git_tag.py new file mode 100644 index 00000000..821287c8 --- /dev/null +++ b/python/jittor_utils/auto_git_tag.py @@ -0,0 +1,40 @@ +import subprocess as sp +import os + +fdir = os.path.dirname(__file__) +logs = sp.getoutput(f"cd {fdir} && git log -p -- ../jittor/__init__.py ") +# print(logs) + +lines = logs.splitlines() + +prev_commit = -1 +for i in range(len(lines)): + line = lines[i] + if line.startswith("+__version__"): + version = line.split('\'')[1] + commit = None + date = None + msg = [] + for j in range(i,prev_commit,-1): + if lines[j].startswith("Date:"): + msg.append(lines[j+2]) + for j in range(i,prev_commit,-1): + if lines[j].startswith("commit "): + commit = lines[j].split()[1] + prev_commit = j + 3 + date = lines[j+2] + break + assert commit, version + print(version, commit) + msg = msg[::-1] + cnt = len(msg) + msg = "\n".join(msg) + msg = f"Version {version}\n"+date+f"\nTotal {cnt} commits:\n"+msg + print(msg) + cmd = f"git tag {version} {commit} -m \"{msg}\"" + print(cmd) + ret = sp.getoutput(f"cd {fdir} && {cmd}") + print(ret) + ret = sp.getoutput(f"cd {fdir} && bash ./github_release.sh {version} \"version {version}\"""") + print(ret) + # break \ No newline at end of file diff --git a/python/jittor_utils/class/motd b/python/jittor_utils/class/motd new file mode 100644 index 00000000..75bb8568 --- /dev/null +++ b/python/jittor_utils/class/motd @@ -0,0 +1,20 @@ +★★★★★★★★★★★★★★★★★★★★★ +Welcome to use Jittor +Please put the file under /root directory +★★★★★★★★★★★★★★★★★★★★★ +欢迎使用Jittor +请将文件放置在/root目录下 +本docker已经安装好cuda环境 +相关链接: +* [Jittor官网](https://cg.cs.tsinghua.edu.cn/jittor/) +* [Jittor教程](https://cg.cs.tsinghua.edu.cn/jittor/tutorial/) +* [Jittor模型库](https://cg.cs.tsinghua.edu.cn/jittor/resources/) +* [Jittor文档](https://cg.cs.tsinghua.edu.cn/jittor/assets/docs/index.html) +* [Github](https://github.com/jittor/jittor), [Gitee](https://gitee.com/jittor/jittor) +* [Jittor 论坛](https://discuss.jittor.org/) +* 即时通信: QQ Group(761222083) + +欢迎大家star,fork并在QQ群或者论坛向我们提出宝贵的意见和建议。 + +注意:请不要开启无密码保护的jupyter notebook或vscode server +★★★★★★★★★★★★★★★★★★★★★ diff --git a/python/jittor_utils/class/setup.py b/python/jittor_utils/class/setup.py new file mode 100644 index 00000000..5c575f12 --- /dev/null +++ b/python/jittor_utils/class/setup.py @@ -0,0 +1,16 @@ +import sys +import os +command = sys.argv[1] +if (command == 'ssh'): + port = sys.argv[2] + data = open("/etc/ssh/sshd_config", "r").readlines() + data[12] = 'Port ' + port + '\nPermitRootLogin yes\n' + f = open("/etc/ssh/sshd_config", "w") + f.writelines(data) + f.close() + os.system("service ssh restart") +elif (command == 'passwd'): + passwd = sys.argv[2] + os.system("echo root:"+passwd+" | chpasswd") +else: + print('command error') diff --git a/python/jittor_utils/class/setup_env.py b/python/jittor_utils/class/setup_env.py new file mode 100644 index 00000000..f3503e2b --- /dev/null +++ b/python/jittor_utils/class/setup_env.py @@ -0,0 +1,152 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Guoye Yang <498731903@qq.com> +# Dun Liang . +# +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** + +''' +example: + +export class_home=/mnt/disk/cjld/class_nn +mkdir -p $class_home +docker pull jittor/jittor-cuda +python3.7 -m jittor_utils.class.setup_env setup 4 +python3.7 -m jittor_utils.class.setup_env start 4 +python3.7 -m jittor_utils.class.setup_env report +python3.7 -m jittor_utils.class.setup_env restart 4 +python3.7 -m jittor_utils.class.setup_env stop +''' +# export class_home +# setup [n] // setup for n users. including build user paths, user_info.txt and docker imgs. !!!WILL RESET SUDENT_FILES!!! +# start [n_gpu] // run n docker CONTAINERs with n_gpu GPUs. +# stop // stop n docker CONTAINERs +# restart [n_gpu] // restart n docker CONTAINERs with n_gpu GPUs. +import sys +import os +import json as js +import random + +class_home = os.environ["class_home"] +student_files_dir = class_home + "/student_files" +student_files_bk_dir = class_home + "/student_files_bak" +cwd = os.path.dirname(__file__) + +def run_cmd(cmd): + print("[CMD]:", cmd) + ret = os.system(cmd) + if ret: + print("[CMD] return", ret) + return ret + +def generate_random_str(randomlength): + random_str = '' + base_str = 'ABCDEFGHIGKLMNOPQRSTUVWXYZabcdefghigklmnopqrstuvwxyz0123456789' + length = len(base_str) - 1 + for i in range(randomlength): + random_str += base_str[random.randint(0, length)] + return random_str + +def setup(n): + if os.path.exists(student_files_dir): + if os.path.exists(student_files_bk_dir): + run_cmd(f"rm -rf {student_files_bk_dir}") + run_cmd(f"mv {student_files_dir} {student_files_bk_dir}") + os.makedirs(student_files_dir) + user_info = [] + for i in range(n): # 0 for root + port = 20000 + i + passwd = generate_random_str(8) + name = 'stu_'+str(i) + path = os.path.abspath(os.path.join(student_files_dir, name)) + info = {'port': port, + 'passwd': passwd, + 'name': name, + 'path': path} + user_info.append(info) + student_files_src = class_home + "/student_files_src" + if os.path.isdir(student_files_src): + run_cmd(f"cp -r {student_files_src} {path}") + else: + run_cmd('mkdir -p ' + path) + js.dump(user_info, open(student_files_dir+"/user_info.json", "w")) + +def start(n, n_gpu): + assert os.path.exists(student_files_dir+'/user_info.json') + user_info = js.load(open(student_files_dir+'/user_info.json', 'r')) + for i in range(len(user_info)): + id = i % n + ids = '' + for j in range(n_gpu): + if j > 0: + ids+=',' + ids += str((i * n_gpu + j) % n) + u = user_info[i] + print('START', i, '/', len(user_info)) + assert 0 == run_cmd(f'docker run -itd --shm-size=8g --network host --name {u["name"]} -v {u["path"]}:/root --gpus \'"device={ids}"\' jittor/jittor-cuda bash') + # assert 0 == run_cmd(f'docker exec -it {u["name"]} bash -c \'apt update && apt install openssh-server -y\'') + assert 0 == run_cmd(f'docker cp {cwd}/setup.py {u["name"]}:/etc/ssh/setup.py') + assert 0 == run_cmd(f'docker cp {cwd}/motd {u["name"]}:/etc/motd') + assert 0 == run_cmd(f'docker exec -it {u["name"]} python3.7 /etc/ssh/setup.py passwd {u["passwd"]}') + assert 0 == run_cmd(f'docker exec -it {u["name"]} python3.7 /etc/ssh/setup.py ssh {u["port"]}') + assert 0 == run_cmd(f'docker exec -it {u["name"]} python3.7 -m pip install jittor -U') + assert 0 == run_cmd(f'docker exec -it {u["name"]} python3.7 -m jittor.test.test_example') + +def stop(): + assert os.path.exists(student_files_dir+'/user_info.json') + user_info = js.load(open(student_files_dir+'/user_info.json', 'r')) + for i in range(len(user_info)): + u = user_info[i] + print('STOP', i, '/', len(user_info)) + run_cmd(f'docker rm -f {u["name"]}') + +def report(): + assert os.path.exists(student_files_dir+'/user_info.json') + user_info = js.load(open(student_files_dir+'/user_info.json', 'r')) + hostname = open("/etc/hostname", 'r').read().strip() + ".randonl.me" + for i in range(len(user_info)): + u = user_info[i] + print(f"ssh -p {u['port']} root@{hostname} # passwd: {u['passwd']}") + +def restart(n, n_gpu): + stop() + start(n, n_gpu) + +args = sys.argv[1:] +if (args[0] == 'setup'): + assert(len(args) == 2) + assert(type(eval(args[1])) == int) + n = int(args[1]) + assert(n < 999) + setup(n) +elif (args[0] == 'start'): + assert(len(args) in [2,3]) + assert(type(eval(args[1])) == int) + n = int(args[1]) + if len(args) == 3: + assert(type(eval(args[2])) == int) + n_gpu = int(args[2]) + else: + n_gpu=1 + start(n, n_gpu) +elif (args[0] == 'stop'): + stop() +elif (args[0] == 'restart'): + assert(len(args) in [2,3]) + assert(type(eval(args[1])) == int) + n = int(args[1]) + if len(args) == 3: + assert(type(eval(args[2])) == int) + n_gpu = int(args[2]) + else: + n_gpu=1 + restart(n, n_gpu) +elif (args[0] == 'report'): + report() +else: + assert(False) + diff --git a/python/jittor_utils/clean_cache.py b/python/jittor_utils/clean_cache.py new file mode 100644 index 00000000..bbe028e5 --- /dev/null +++ b/python/jittor_utils/clean_cache.py @@ -0,0 +1,58 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: Dun Liang . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import os, sys, shutil +import glob +import jittor_utils as jit_utils + +cache_path = os.path.join(jit_utils.home(), ".cache", "jittor") + +def callback(func, path, exc_info): + print(f"remove \"{path}\" failed.") + +def rmtree(path): + if os.path.isdir(path): + print(f"remove \"{path}\" recursive.") + shutil.rmtree(path, onerror=callback) + +def clean_all(): + rmtree(cache_path) + +def clean_core(): + rmtree(cache_path+"/default") + rmtree(cache_path+"/master") + fs = glob.glob(cache_path+"/jt*") + for f in fs: rmtree(f) + +def clean_cuda(): + rmtree(cache_path+"/jtcuda") + rmtree(cache_path+"/cutt") + rmtree(cache_path+"/cub") + rmtree(cache_path+"/nccl") + +def clean_dataset(): + rmtree(cache_path+"/dataset") + +def clean_swap(): + rmtree(cache_path+"/tmp") + +def print_help(): + msg = "|".join(keys) + print(f"Usage: {sys.executable} -m jittor_utils.clean_cache [{msg}]") + exit() + + +keys = [ k[6:] for k in globals() if k.startswith("clean_") ] + +if __name__ == "__main__": + if len(sys.argv)==1: + print_help() + else: + for k in sys.argv[1:]: + if k not in keys: + print_help() + func = globals()["clean_"+k] + func() \ No newline at end of file diff --git a/python/jittor_utils/config.py b/python/jittor_utils/config.py new file mode 100644 index 00000000..7b13dc37 --- /dev/null +++ b/python/jittor_utils/config.py @@ -0,0 +1,114 @@ +import os +import platform +import sys +import jittor_utils +from jittor_utils import LOG + + +def search_file(dirs, name): + for d in dirs: + fname = os.path.join(d, name) + if os.path.isfile(fname): + return fname + LOG.f(f"file {name} not found in {dirs}") + +if __name__ == "__main__": + help_msg = f"Usage: {sys.executable} -m jittor_utils.config --include-flags|--link-flags|--cxx-flags|--cxx-example|--help" + if len(sys.argv) <= 1: + print(help_msg) + sys.exit(1) + + s = "" + # base should be something like python3.7m python3.8 + base = jittor_utils.get_py3_include_path().split()[0] + base = "python3" + base.split("python3")[-1] + for arg in sys.argv[1:]: + if arg == "--include-flags": + s += jittor_utils.get_py3_include_path() + s += " -I"+os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "jittor", "src")) + s += " " + elif arg == "--libs-flags": + libext = { + 'Linux': 'so', + 'Darwin': 'dylib', + 'Windows': 'DLL', + }[platform.system()] + ldflags = jittor_utils.run_cmd(jittor_utils.get_py3_config_path() + " --ldflags") + libpaths = [l[2:] for l in ldflags.split(' ') if l.startswith("-L")] + for libbase in libpaths: + libpath = os.path.join(libbase, f"lib{base}.{libext}") + if os.path.isfile(libpath): + s += f" -L{libbase} -l{base} -ldl " + break + else: + raise RuntimeError("Python dynamic library not found") + if os.name == 'nt': + s = s.replace('-ldl', '') + elif arg == "--cxx-flags": + s += " --std=c++17 -fPIC " + elif arg == "--cxx-example": + cc_src = ''' +// please compile with: g++ a.cc $(python3 -m jittor_utils.config --include-flags --libs-flags --cxx-flags) -o a.out && ./a.out +#include +#include + +using namespace std; + +int main() { + jittor::Console console; + // run python code in console + console.run("print('hello jt console', flush=True)"); + + // set a python value: a = 1 + console.set("a", 1); + // get a python value + cout << console.get("a") << endl; + + // set a python string + console.set("b", "hello"); + cout << console.get("b") << endl; + + // set a python array + vector x{1,2,3,4}; + console.set("x", x); + auto x2 = console.get>("x"); + for (auto a : x2) cout << a << " "; cout << endl; + + // set and get a jittor array + jittor::array arr2({2,3}, {6,5,4,3,2,1}); + arr2(0,0) = -1; + console.set_array("arr2", arr2); + console.run("print(arr2, flush=True); arr3 = arr2**2;"); + auto arr3 = console.get_array("arr3"); + cout << arr3.shape[0] << ' ' << arr3.shape[1] << endl; + for (int i=0; i input({2, 3, 224, 224}); + memset(input.data.get(), 0, input.nbyte()); + console.set_array("input", input); + console.run(R"( +import jittor as jt +from jittor.models import resnet + +model = resnet.resnet18() +pred = model(input) + )"); + auto pred = console.get_array("pred"); + cout << "pred.shape " << pred.shape[0] << ' ' << pred.shape[1] << endl; + + return 0; +} + ''' + print(cc_src) + elif arg == "--help": + print(help_msg) + sys.exit(0) + else: + print(help_msg) + sys.exit(1) + print(s) diff --git a/python/jittor_utils/github_release.sh b/python/jittor_utils/github_release.sh new file mode 100644 index 00000000..f19e1cef --- /dev/null +++ b/python/jittor_utils/github_release.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +version=$1 +text=$2 +branch=$(git rev-parse --abbrev-ref HEAD) +# repo_full_name=$(git config --get remote.origin.url | sed 's/.*:\/\/github.com\///;s/.git$//') +repo_full_name=$(git config --get remote.origin.url | sed 's/.*github.com://;s/.git$//') +token=$(git config --global github.token) + +generate_post_data() +{ + cat <. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import os +import sys +import subprocess as sp +import jittor_utils as jit_utils +from jittor_utils import LOG +from jittor_utils.misc import download_url_to_local +import pathlib + +def get_cuda_driver_win(): + try: + import ctypes + cuda_driver = ctypes.CDLL(r"nvcuda") + driver_version = ctypes.c_int() + r = cuda_driver.cuDriverGetVersion(ctypes.byref(driver_version)) + if r != 0: return None + v = driver_version.value + return [v//1000, v%1000//10, v%10] + except: + return None + +def get_cuda_driver(): + if os.name == 'nt': + return get_cuda_driver_win() + ret, out = sp.getstatusoutput("nvidia-smi -q -u") + if ret != 0: return None + try: + out = out.lower() + out = out.split('cuda version')[1] \ + .split(':')[1] \ + .splitlines()[0] \ + .strip() + out = [ int(s) for s in out.split('.')] + return out + except: + return None + +def has_installation(): + jtcuda_path = os.path.join(jit_utils.home(), ".cache", "jittor", "jtcuda") + return os.path.isdir(jtcuda_path) + +def check_cuda_env(): + if not has_installation(): + return + if os.name == "nt": + return + def fix_env(key): + env = os.environ.get(key, "") + env = env.replace(";",":").split(":") + new_env = [] + changed = False + for cp in env: + x = cp.lower() + if "cuda" in x and "jtcuda" not in x: + changed = True + continue + if "jtcuda" in x: + new_env.insert(0, x) + else: + new_env.append(x) + os.environ[key] = ":".join(new_env) + return changed + changed = fix_env("PATH") \ + + fix_env("LD_LIBRARY_PATH") \ + + fix_env("CUDA_HOME") + if changed: + try: + # LD_LIBRARY_PATH change must triggle restart + # because dyloader already setup + # with open("/proc/self/maps", "r") as f: + # cudart_loaded = "cudart" in f.read().lower() + # if cudart_loaded: + with open("/proc/self/cmdline", "r") as f: + argv = f.read().split("\x00") + if len(argv[-1]) == 0: del argv[-1] + LOG.i(f"restart {sys.executable} {argv[1:]}") + os.execl(sys.executable, sys.executable, *argv[1:]) + except: + pass + + +def install_cuda(): + if "nvcc_path" in os.environ and os.environ["nvcc_path"] == "": + return None + cuda_driver_version = get_cuda_driver() + if not cuda_driver_version: + return None + LOG.i("cuda_driver_version: ", cuda_driver_version) + if "JTCUDA_VERSION" in os.environ: + cuda_driver_version = list(map(int,os.environ["JTCUDA_VERSION"].split("."))) + LOG.i("JTCUDA_VERSION: ", cuda_driver_version) + + if os.name == 'nt': + # TODO: cuda11.4 has bug fit with + # current msvc, FIXME + # if cuda_driver_version >= [11,4]: + # cuda_tgz = "cuda11.4_cudnn8_win.zip" + # md5 = "06eed370d0d44bb2cc57809343911187" + if cuda_driver_version >= [11,2]: + cuda_tgz = "cuda11.2_cudnn8_win.zip" + md5 = "b5543822c21bc460c1a414af47754556" + elif cuda_driver_version >= [11,]: + cuda_tgz = "cuda11.0_cudnn8_win.zip" + md5 = "7a248df76ee5e79623236b0560f8d1fd" + elif cuda_driver_version >= [10,]: + cuda_tgz = "cuda10.2_cudnn7_win.zip" + md5 = "7dd9963833a91371299a2ba58779dd71" + else: + LOG.w(f"Unsupport cuda driver version: {cuda_driver_version}, at least 10.2") + return None + else: + if cuda_driver_version >= [12,2]: + cuda_tgz = "cuda12.2_cudnn8_linux.tgz" + md5 = "7afda9332a268f29354488f13b489f53" + elif cuda_driver_version >= [11,2]: + cuda_tgz = "cuda11.2_cudnn8_linux.tgz" + md5 = "b93a1a5d19098e93450ee080509e9836" + elif cuda_driver_version >= [11,]: + cuda_tgz = "cuda11.0_cudnn8_linux.tgz" + md5 = "5dbdb43e35b4db8249027997720bf1ca" + elif cuda_driver_version >= [10,2]: + cuda_tgz = "cuda10.2_cudnn7_linux.tgz" + md5 = "40f0563e8eb176f53e55943f6d212ad7" + elif cuda_driver_version >= [10,]: + cuda_tgz = "cuda10.0_cudnn7_linux.tgz" + md5 = "f16d3ff63f081031d21faec3ec8b7dac" + else: + LOG.w(f"Unsupport cuda driver version: {cuda_driver_version}, at least 10.0") + return None + jtcuda_path = os.path.join(jit_utils.home(), ".cache", "jittor", "jtcuda") + nvcc_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "bin", "nvcc") + if os.name=='nt': nvcc_path += '.exe' + nvcc_lib_path = os.path.join(jtcuda_path, cuda_tgz[:-4], "lib64") + sys.path.append(nvcc_lib_path) + new_ld_path = os.environ.get("LD_LIBRARY_PATH", "") + ":" + nvcc_lib_path + os.environ["LD_LIBRARY_PATH"] = new_ld_path + check_cuda_env() + + if os.path.isfile(nvcc_path): + return nvcc_path + + os.makedirs(jtcuda_path, exist_ok=True) + cuda_tgz_path = os.path.join(jtcuda_path, cuda_tgz) + download_url_to_local("https://cg.cs.tsinghua.edu.cn/jittor/assets/"+cuda_tgz, cuda_tgz, jtcuda_path, md5) + + + if cuda_tgz.endswith(".zip"): + import zipfile + zf = zipfile.ZipFile(cuda_tgz_path) + zf.extractall(path=cuda_tgz_path[:-4]) + else: + import tarfile + with tarfile.open(cuda_tgz_path, "r") as tar: + tar.extractall(cuda_tgz_path[:-4]) + + assert os.path.isfile(nvcc_path), nvcc_path + return nvcc_path + + +if __name__ == "__main__": + nvcc_path = install_cuda() + LOG.i("nvcc is installed at ", nvcc_path) diff --git a/python/jittor_utils/install_msvc.py b/python/jittor_utils/install_msvc.py new file mode 100644 index 00000000..8203ca74 --- /dev/null +++ b/python/jittor_utils/install_msvc.py @@ -0,0 +1,16 @@ +import os +import sys +from jittor_utils.misc import download_url_to_local +from jittor_utils import LOG + + +def install(path): + LOG.i("Installing MSVC...") + filename = "msvc.zip" + url = "https://cg.cs.tsinghua.edu.cn/jittor/assets/" + filename + md5sum = "55f0c175fdf1419b124e0fc498b659d2" + download_url_to_local(url, filename, path, md5sum) + fullname = os.path.join(path, filename) + import zipfile + with zipfile.ZipFile(fullname, "r") as f: + f.extractall(path) diff --git a/python/jittor_utils/load_pytorch.py b/python/jittor_utils/load_pytorch.py new file mode 100644 index 00000000..f27045a6 --- /dev/null +++ b/python/jittor_utils/load_pytorch.py @@ -0,0 +1,318 @@ +import pickle +import os +import io +import shutil +from zipfile import ZipFile +import jittor as jt +import numpy as np +from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO, List + +loaded_storages = {} +deserialized_objects = {} + +def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str: + if isinstance(bytes_str, bytes): + return bytes_str.decode('ascii') + return bytes_str + +def load_tensor(contents, dtype, numel, key, location): + if dtype == np.uint16: dtype = "bfloat16" + name = os.path.join(prefix, "data", str(key)) + name = name.replace("\\", "/") + loaded_storages[key] = contents.read_var(name, dtype) + +def get_dtype_size(dtype): + return jt.NanoString(dtype).dsize() + +def persistent_load(saved_id): + global contents + assert isinstance(saved_id, tuple) + typename = _maybe_decode_ascii(saved_id[0]) + data = saved_id[1:] + assert typename == 'storage', \ + f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" + storage_type, key, location, numel = data + dtype = storage_type.dtype + if key not in loaded_storages: + nbytes = numel + load_tensor(contents, dtype, nbytes, key, _maybe_decode_ascii(location)) + return loaded_storages[key] + +def _dtype_to_storage_type_map(): + return { + np.float16: 'HalfStorage', + # just fake np.uint16 as bfloat16 + np.uint16: 'BFloat16Storage', + np.float32: 'FloatStorage', + np.float64: 'DoubleStorage', + np.int64: 'LongStorage', + np.int32: 'IntStorage', + np.int16: 'ShortStorage', + np.int8: 'CharStorage', + np.bool_: 'BoolStorage' + } + +def _storage_type_to_dtype_map(): + dtype_map = { + val: key for key, val in _dtype_to_storage_type_map().items()} + return dtype_map + +def _get_dtype_from_pickle_storage_type(pickle_storage_type: str): + try: + return _storage_type_to_dtype_map()[pickle_storage_type] + except KeyError: + raise KeyError( + f'pickle storage type "{pickle_storage_type}" is not recognized') + +class StorageType(): + def __init__(self, name): + self.dtype = _get_dtype_from_pickle_storage_type(name) + + def __str__(self): + return f'StorageType(dtype={self.dtype})' + +def jittor_rebuild(storage, storage_offset, size, stride, requires_grad, backward_hooks): + if len(size) == 0: + return jt.array(storage) + record_size = np.prod(size) + expect_stride = [1] + for i in range(len(size)-1, 0, -1): + expect_stride.append(expect_stride[-1]*size[i]) + expect_stride = tuple(expect_stride[::-1]) + if stride is not None and stride != expect_stride: + if len(stride) > 1: # reshape the memory layout based on stride + eval_list = [] + for idx in range(len(stride)): + eval_list.append(f"@e0({idx}) * i{idx}") + evals = "+".join(eval_list) + return jt.array(storage[storage_offset:storage_offset+record_size]).reindex(size, [evals], extras=[jt.array(stride)]) + return jt.array(storage[storage_offset:storage_offset+record_size]).reshape(size) + +def jittor_rebuild_var(data, requires_grad, backward_hooks): + v = jt.array(data) + v.requires_grad = requires_grad + return v + +class UnpicklerWrapper(pickle.Unpickler): # type: ignore[name-defined] + def find_class(self, mod_name, name): + if mod_name.startswith("transformers"): + return super().find_class("collections", "OrderedDict") + if type(name) is str and 'Storage' in name: + try: + return StorageType(name) + except KeyError: + pass + if type(name) is str and '_rebuild_tensor_v2' in name: + return super().find_class("jittor_utils.load_pytorch", "jittor_rebuild") + if type(name) is str and '_rebuild_parameter' in name: + return super().find_class("jittor_utils.load_pytorch", "jittor_rebuild_var") + + return super().find_class(mod_name, name) + +class ArrayWrapper: + def __init__(self, storage, stride=None, size=None, requires_grad=None): + self.requires_grad = requires_grad + self.size = size + self.storage = storage + self.stride = stride + + def __str__(self): + return self.storage.__str__() + +def jittor_rebuild_direct(storage, storage_offset, size, stride, requires_grad, backward_hooks): + if len(size) == 0: + return ArrayWrapper(storage, stride=stride, size=size) + storage.reshape(size) + return ArrayWrapper(storage, stride=stride, size=size) + +def jittor_rebuild_var_direct(data, requires_grad, backward_hooks): + v = ArrayWrapper(storage, requires_grad=requires_grad) + return v + +def jittor_rebuild_direct_v0(storage, storage_offset, size, stride): + if len(size) == 0: + return ArrayWrapper(storage, stride=stride, size=size) + storage.reshape(size) + return ArrayWrapper(storage, stride=stride, size=size) + +class DirectUnpicklerWrapper(pickle.Unpickler): # type: ignore[name-defined] + def find_class(self, mod_name, name): + if mod_name.startswith("transformers"): + return super().find_class("collections", "OrderedDict") + + if type(name) is str and 'Storage' in name: + try: + return StorageType(name) + except KeyError: + print("wrong type: ", name) + pass + if type(name) is str and '_rebuild_tensor_v2' in name: + return super().find_class("jittor_utils.load_pytorch", "jittor_rebuild_direct") + elif type(name) is str and '_rebuild_tensor' in name: + return super().find_class("jittor_utils.load_pytorch", "jittor_rebuild_direct_v0") + elif type(name) is str and '_rebuild_parameter' in name: + return super().find_class("jittor_utils.load_pytorch", "jittor_rebuild_var_direct") + return super().find_class(mod_name, name) + +def _check_seekable(f) -> bool: + def raise_err_msg(patterns, e): + for p in patterns: + if p in str(e): + msg = (str(e) + ". You can only load from a file that is seekable." + + " Please pre-load the data into a buffer like io.BytesIO and" + + " try to load from it instead.") + raise type(e)(msg) + raise e + + try: + f.seek(f.tell()) + return True + except (io.UnsupportedOperation, AttributeError) as e: + raise_err_msg(["seek", "tell"], e) + return False + +def extract_zip(input_zip): + input_zip = ZipFile(input_zip) + return {name: input_zip.read(name) for name in input_zip.namelist()} + +def _is_compressed_file(f): + compress_modules = ['gzip'] + try: + return f.__module__ in compress_modules + except AttributeError: + return False + +def _should_read_directly(f): + if _is_compressed_file(f): + return False + try: + return f.fileno() >= 0 + except io.UnsupportedOperation: + return False + except AttributeError: + return False + +def persistent_load_direct(saved_id): + global deserialized_objects + assert isinstance(saved_id, tuple) + typename = _maybe_decode_ascii(saved_id[0]) + data = saved_id[1:] + if typename == 'module': + # Ignore containers that don't have any sources saved + return data[0] + elif typename == 'storage': + data_type, root_key, location, size, view_metadata = data + location = _maybe_decode_ascii(location) + if root_key not in deserialized_objects: + deserialized_objects[root_key] = np.zeros(size, dtype=data_type) + storage = deserialized_objects[root_key] + if view_metadata is not None: + view_key, offset, view_size = view_metadata + if view_key not in deserialized_objects: + deserialized_objects[view_key] = storage[offset:offset + view_size] + return deserialized_objects[view_key] + else: + return storage + else: + raise RuntimeError("Unknown saved id type: %s" % saved_id[0]) + +def clean_globals(): + global contents, deserialized_objects, loaded_storages, prefix + loaded_storages = {} + deserialized_objects = {} + contents = None + prefix = "" + +def load_pytorch(fn_name): + def dfs_results(result): # dfs the result dict in case of nested state dicts. + if not isinstance(result, dict): + return result + for key, params in result.items(): + if isinstance(params, dict): # recursive + result[key] = dfs_results(params) + elif isinstance(params, ArrayWrapper): # process data + requires_grad = params.requires_grad + shape = params.size + result[key] = jt.array(params.storage) + if shape is not None and len(shape) > 0: + if len(params.stride) > 1: # reshape based on stride + eval_list = [] + for idx in range(len(params.stride)): + eval_list.append(f"@e0({idx}) * i{idx}") + evals = "+".join(eval_list) + result[key] = result[key].reindex(params.size, [evals], extras=[jt.array(params.stride)]) + else: # no need to reshape if only one dimension + result[key] = result[key].reshape(shape) + if requires_grad is not None: + result[key].requires_grad = requires_grad + return result + import jittor as jt + global contents, deserialized_objects, loaded_storages, prefix + loaded_storages = {} + deserialized_objects = {} + if not (fn_name.endswith(".pth") or fn_name.endswith(".pt") or fn_name.endswith(".bin")): + print("This function is designed to load pytorch pth format files.") + return None + else: + contents = jt.ZipFile(fn_name) + if contents.valid(): + loaded_storages = {} + deserialized_objects = {} + for name in contents.list(): + if "data.pkl" in name: + prefix = name[:-8] + break + else: + raise RuntimeError(f"zipfile <{fn_name}> format error, data.pkl not found") + + data_file = contents.read_var(prefix+"data.pkl") + #import pdb; pdb.set_trace(); + #print(data_file) + if data_file.dtype == "uint8": + data_file = data_file.numpy().tobytes() + else: + data_file = data_file.data.tobytes() + data_file = io.BytesIO(data_file) + pickle_load_args = {'encoding': 'utf-8'} + unpickler = UnpicklerWrapper(data_file, **pickle_load_args) + unpickler.persistent_load = persistent_load + result = unpickler.load() + result = dfs_results(result) + else: + deserialized_objects = {} + f = open(fn_name, "rb") + f_should_read_directly = _should_read_directly(f) + MAGIC_NUMBER = 0x1950a86a20f9469cfc6c + PROTOCOL_VERSION = 1001 + pickle_load_args = {'encoding': 'utf-8'} + magic_number = pickle.load(f, **pickle_load_args) + if magic_number != MAGIC_NUMBER: + raise RuntimeError("Invalid magic number; corrupt file?") + protocol_version = pickle.load(f, **pickle_load_args) + if PROTOCOL_VERSION != protocol_version: + raise RuntimeError("Invalid protocal version.") + _sys_info = pickle.load(f, **pickle_load_args) + unpickler = DirectUnpicklerWrapper(f, **pickle_load_args) + unpickler.persistent_load = persistent_load_direct + result = unpickler.load() + offset = f.tell() if f_should_read_directly else None + deserialized_storage_keys = pickle.load(f, **pickle_load_args) + f.read(8) + for key in deserialized_storage_keys: + assert key in deserialized_objects + dtype = deserialized_objects[key].dtype + size = deserialized_objects[key].size * get_dtype_size(dtype) + byte_data = f.read(size) + deserialized_objects[key][:] = np.frombuffer(byte_data, dtype).copy() + f.read(8) + if offset is not None: + offset = f.tell() + + result = dfs_results(result) + clean_globals() + return result + +if __name__ == "__main__": + result = load_pytorch("van_base.pth") + for key, val in result.items(): + print(key, val.shape) \ No newline at end of file diff --git a/python/jittor_utils/load_pytorch_old.py b/python/jittor_utils/load_pytorch_old.py new file mode 100644 index 00000000..5c0d5197 --- /dev/null +++ b/python/jittor_utils/load_pytorch_old.py @@ -0,0 +1,271 @@ +import pickle +import os +import io +import shutil +from zipfile import ZipFile +import jittor as jt +import numpy as np +from typing import Any, BinaryIO, cast, Dict, Optional, Type, Tuple, Union, IO, List + +loaded_storages = {} +deserialized_objects = {} + +def _maybe_decode_ascii(bytes_str: Union[bytes, str]) -> str: + if isinstance(bytes_str, bytes): + return bytes_str.decode('ascii') + return bytes_str + +def load_tensor(contents, dtype, numel, key, location): + name = os.path.join(prefix, "data", str(key)) + loaded_storages[key] = contents.read_var(name, dtype) + +def get_dtype_size(dtype): + dtype = dtype.__str__() + if dtype == "float32" or dtype == "int32": + return 4 + if dtype == "float64" or dtype == "int64": + return 8 + if dtype == "float16" or dtype == "int16": + return 2 + return 1 + +def persistent_load(saved_id): + global contents + assert isinstance(saved_id, tuple) + typename = _maybe_decode_ascii(saved_id[0]) + data = saved_id[1:] + assert typename == 'storage', \ + f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" + storage_type, key, location, numel = data + dtype = storage_type.dtype + if key not in loaded_storages: + nbytes = numel + load_tensor(contents, dtype, nbytes, key, _maybe_decode_ascii(location)) + return loaded_storages[key] + +def _dtype_to_storage_type_map(): + return { + np.float16: 'HalfStorage', + np.float32: 'FloatStorage', + np.int64: 'LongStorage', + np.int32: 'IntStorage', + np.int16: 'ShortStorage', + np.int8: 'CharStorage' + } + +def _storage_type_to_dtype_map(): + dtype_map = { + val: key for key, val in _dtype_to_storage_type_map().items()} + return dtype_map + +def _get_dtype_from_pickle_storage_type(pickle_storage_type: str): + try: + return _storage_type_to_dtype_map()[pickle_storage_type] + except KeyError: + raise KeyError( + f'pickle storage type "{pickle_storage_type}" is not recognized') + +class StorageType(): + def __init__(self, name): + self.dtype = _get_dtype_from_pickle_storage_type(name) + + def __str__(self): + return f'StorageType(dtype={self.dtype})' + +def jittor_rebuild(storage, storage_offset, size, stride, requires_grad, backward_hooks): + if len(size) == 0: + return jt.array(storage) + record_size = np.prod(size) + return jt.array(storage[:record_size]).reshape(size) + +def jittor_rebuild_var(data, requires_grad, backward_hooks): + v = jt.array(data) + v.requires_grad = requires_grad + return v + +class UnpicklerWrapper(pickle.Unpickler): # type: ignore[name-defined] + def find_class(self, mod_name, name): + if type(name) is str and 'Storage' in name: + try: + return StorageType(name) + except KeyError: + pass + if type(name) is str and '_rebuild_tensor_v2' in name: + return super().find_class("jittor_utils.load_pytorch", "jittor_rebuild") + if type(name) is str and '_rebuild_parameter' in name: + return super().find_class("jittor_utils.load_pytorch", "jittor_rebuild_var") + + return super().find_class(mod_name, name) + +class ArrayWrapper: + def __init__(self, storage, stride=None, size=None, requires_grad=None): + self.requires_grad = requires_grad + self.size = size + self.storage = storage + self.stride = stride + + def __str__(self): + return self.storage.__str__() + +def jittor_rebuild_direct(storage, storage_offset, size, stride, requires_grad, backward_hooks): + if len(size) == 0: + return ArrayWrapper(storage, stride=stride, size=size) + storage.reshape(size) + return ArrayWrapper(storage, stride=stride, size=size) + +def jittor_rebuild_var_direct(data, requires_grad, backward_hooks): + v = ArrayWrapper(storage, requires_grad=requires_grad) + return v + +class DirectUnpicklerWrapper(pickle.Unpickler): # type: ignore[name-defined] + def find_class(self, mod_name, name): + if type(name) is str and 'Storage' in name: + try: + return StorageType(name) + except KeyError: + pass + if type(name) is str and '_rebuild_tensor_v2' in name: + return super().find_class("jittor_utils.load_pytorch", "jittor_rebuild_direct") + if type(name) is str and '_rebuild_parameter' in name: + return super().find_class("jittor_utils.load_pytorch", "jittor_rebuild_var_direct") + return super().find_class(mod_name, name) + +def _check_seekable(f) -> bool: + def raise_err_msg(patterns, e): + for p in patterns: + if p in str(e): + msg = (str(e) + ". You can only load from a file that is seekable." + + " Please pre-load the data into a buffer like io.BytesIO and" + + " try to load from it instead.") + raise type(e)(msg) + raise e + + try: + f.seek(f.tell()) + return True + except (io.UnsupportedOperation, AttributeError) as e: + raise_err_msg(["seek", "tell"], e) + return False + +def extract_zip(input_zip): + input_zip = ZipFile(input_zip) + return {name: input_zip.read(name) for name in input_zip.namelist()} + +def _is_compressed_file(f): + compress_modules = ['gzip'] + try: + return f.__module__ in compress_modules + except AttributeError: + return False + +def _should_read_directly(f): + if _is_compressed_file(f): + return False + try: + return f.fileno() >= 0 + except io.UnsupportedOperation: + return False + except AttributeError: + return False + +def persistent_load_direct(saved_id): + global deserialized_objects + assert isinstance(saved_id, tuple) + typename = _maybe_decode_ascii(saved_id[0]) + data = saved_id[1:] + if typename == 'module': + # Ignore containers that don't have any sources saved + return data[0] + elif typename == 'storage': + data_type, root_key, location, size, view_metadata = data + location = _maybe_decode_ascii(location) + if root_key not in deserialized_objects: + deserialized_objects[root_key] = np.zeros(size, dtype=data_type) + storage = deserialized_objects[root_key] + if view_metadata is not None: + view_key, offset, view_size = view_metadata + if view_key not in deserialized_objects: + deserialized_objects[view_key] = storage[offset:offset + view_size] + return deserialized_objects[view_key] + else: + return storage + else: + raise RuntimeError("Unknown saved id type: %s" % saved_id[0]) + +def load_pytorch(fn_name): + import jittor as jt + global contents, deserialized_objects, loaded_storages, prefix + loaded_storages = {} + deserialized_objects = {} + if not (fn_name.endswith(".pth") or fn_name.endswith(".pt") or fn_name.endswith(".bin")): + print("This function is designed to load pytorch pth format files.") + return None + else: + contents = jt.ZipFile(fn_name) + if contents.valid(): + loaded_storages = {} + deserialized_objects = {} + for name in contents.list(): + if "data.pkl" in name: + prefix = name[:-8] + break + else: + raise RuntimeError(f"zipfile <{fn_name}> format error, data.pkl not found") + with jt.flag_scope(use_cuda=0): + print("load??", fn_name) + data_file = contents.read_var(prefix+"data.pkl").data.tobytes() + data_file = io.BytesIO(data_file) + pickle_load_args = {'encoding': 'utf-8'} + unpickler = UnpicklerWrapper(data_file, **pickle_load_args) + unpickler.persistent_load = persistent_load + result = unpickler.load() + else: + deserialized_objects = {} + f = open(fn_name, "rb") + f_should_read_directly = _should_read_directly(f) + MAGIC_NUMBER = 0x1950a86a20f9469cfc6c + PROTOCOL_VERSION = 1001 + pickle_load_args = {'encoding': 'utf-8'} + magic_number = pickle.load(f, **pickle_load_args) + if magic_number != MAGIC_NUMBER: + raise RuntimeError("Invalid magic number; corrupt file?") + protocol_version = pickle.load(f, **pickle_load_args) + if PROTOCOL_VERSION != protocol_version: + raise RuntimeError("Invalid protocal version.") + _sys_info = pickle.load(f, **pickle_load_args) + unpickler = DirectUnpicklerWrapper(f, **pickle_load_args) + unpickler.persistent_load = persistent_load_direct + result = unpickler.load() + offset = f.tell() if f_should_read_directly else None + deserialized_storage_keys = pickle.load(f, **pickle_load_args) + f.read(8) + for key in deserialized_storage_keys: + assert key in deserialized_objects + dtype = deserialized_objects[key].dtype + size = deserialized_objects[key].size * get_dtype_size(dtype) + byte_data = f.read(size) + deserialized_objects[key][:] = np.frombuffer(byte_data, dtype).copy() + f.read(8) + if offset is not None: + offset = f.tell() + for key, params in result.items(): + requires_grad = params.requires_grad + shape = params.size + result[key] = jt.array(params.storage) + if shape is not None and len(shape) > 0: + if len(params.stride) > 1: + eval_list = [] + for idx in range(len(params.stride)): + eval_list.append(f"@e0({idx}) * i{idx}") + evals = "+".join(eval_list) + result[key] = result[key].reindex(params.size, [evals], extras=[jt.array(params.stride)]) + else: + result[key] = result[key].reshape(shape) + if requires_grad is not None: + result[key].requires_grad = requires_grad + return result + +if __name__ == "__main__": + result = load_pytorch("van_base.pth") + for key, val in result.items(): + print(key, val.shape) \ No newline at end of file diff --git a/python/jittor_utils/lock.py b/python/jittor_utils/lock.py new file mode 100644 index 00000000..aa0bb625 --- /dev/null +++ b/python/jittor_utils/lock.py @@ -0,0 +1,90 @@ +try: + import fcntl +except ImportError: + fcntl = None + try: + import win32file + import pywintypes + _OVERLAPPED = pywintypes.OVERLAPPED() + except: + raise Exception("""pywin32 package not found, please install it. +>>> python3.x -m pip install pywin32 +If conda is used, please install with command: +>>> conda install pywin32""") + +import os +from jittor_utils import cache_path, LOG + +disable_lock = os.environ.get("disable_lock", "0") == "1" + +class Lock: + def __init__(self, filename): + self.handle = open(filename, 'w') + LOG.v(f'OPEN LOCK path: {filename} PID: {os.getpid()}') + self.is_locked = False + + def lock(self): + if disable_lock: + return + if fcntl: + fcntl.flock(self.handle, fcntl.LOCK_EX) + else: + hfile = win32file._get_osfhandle(self.handle.fileno()) + win32file.LockFileEx(hfile, 2, 0, -0x10000, _OVERLAPPED) + self.is_locked = True + LOG.vv(f'LOCK PID: {os.getpid()}') + + def unlock(self): + if disable_lock: + return + if fcntl: + fcntl.flock(self.handle, fcntl.LOCK_UN) + else: + hfile = win32file._get_osfhandle(self.handle.fileno()) + win32file.UnlockFileEx(hfile, 0, -0x10000, _OVERLAPPED) + self.is_locked = False + LOG.vv(f'UNLOCK PID: {os.getpid()}') + + def __del__(self): + self.handle.close() + + +class _base_scope: + '''base_scope for support @xxx syntax''' + def __enter__(self): pass + def __exit__(self, *exc): pass + def __call__(self, func): + def inner(*args, **kw): + with self: + ret = func(*args, **kw) + return ret + return inner + +class lock_scope(_base_scope): + def __enter__(self): + self.is_locked = jittor_lock.is_locked + if not self.is_locked: + jittor_lock.lock() + + def __exit__(self, *exc): + if not self.is_locked: + jittor_lock.unlock() + +class unlock_scope(_base_scope): + def __enter__(self): + self.is_locked = jittor_lock.is_locked + if self.is_locked: + jittor_lock.unlock() + + def __exit__(self, *exc): + if self.is_locked: + jittor_lock.lock() + +lock_path = os.path.abspath(os.path.join(cache_path, "../jittor.lock")) +if not os.path.exists(lock_path): + LOG.i("Create lock file:", lock_path) + try: + os.mknod(lock_path) + except: + pass +jittor_lock = Lock(lock_path) diff --git a/python/jittor_utils/misc.py b/python/jittor_utils/misc.py new file mode 100644 index 00000000..bf1a6ba8 --- /dev/null +++ b/python/jittor_utils/misc.py @@ -0,0 +1,174 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Meng-Hao Guo +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import os +import hashlib +import urllib.request +from tqdm import tqdm +from jittor_utils import lock, LOG +import gzip +import tarfile +import zipfile +jittor_offline_path = None +try: + import jittor_offline + jittor_offline_path = os.path.dirname(jittor_offline.__file__) +except: + pass + + +def ensure_dir(dir_path): + if not os.path.isdir(dir_path): + os.makedirs(dir_path) + +def _progress(): + pbar = tqdm(total=None, + unit="B", + unit_scale=True, + unit_divisor=1024) + + def bar_update(block_num, block_size, total_size): + """ reporthook + @block_num: the num of downloaded data block + @block_size: the size of data block + @total_size: the total size of remote file + """ + if pbar.total is None and total_size: + pbar.total = total_size + progress_bytes = block_num * block_size + pbar.update(progress_bytes - pbar.n) + + return bar_update + +@lock.lock_scope() +def download_url_to_local(url, filename, root_folder, md5): + ensure_dir(root_folder) + file_path = os.path.join(root_folder, filename) + if check_file_exist(file_path, md5): + return + else: + if jittor_offline_path: + offpath = os.path.join(jittor_offline_path, filename) + if check_file_exist(offpath, md5): + import shutil + print('Using offline jittor', file_path) + shutil.copy(offpath, file_path) + return + print('Downloading ' + url + ' to ' + file_path) + try: + urllib.request.urlretrieve( + url, file_path, + reporthook=_progress() + ) + except Exception as e: + msg = f"{e}\nDownload File failed, url: {url}, path: {file_path}" + print(msg) + if os.path.isfile(file_path): + os.remove(file_path) + raise RuntimeError(msg) + if not check_file_exist(file_path, md5): + raise RuntimeError(f"MD5 mismatch between the server and the downloaded file {file_path}") + + + +def check_file_exist(file_path, md5): + if not os.path.isfile(file_path): + return False + if md5 is None: + return True + return check_md5(file_path, md5) + + +def calculate_md5(file_path, chunk_size=1024 * 1024): + md5 = hashlib.md5() + with open(file_path, 'rb') as f: + for chunk in iter(lambda: f.read(chunk_size), b''): + md5.update(chunk) + md5 = md5.hexdigest() + LOG.v(f"file {file_path} md5: {md5}") + return md5 + + +def check_md5(file_path, md5, **kwargs): + return md5 == calculate_md5(file_path, **kwargs) + + +def check_integrity(fpath, md5=None): + if not os.path.isfile(fpath): + return False + if md5 is None: + return True + return check_md5(fpath, md5) + + +def _is_tarxz(filename): + return filename.endswith(".tar.xz") + + +def _is_tar(filename): + return filename.endswith(".tar") + + +def _is_targz(filename): + return filename.endswith(".tar.gz") + + +def _is_tgz(filename): + return filename.endswith(".tgz") + + +def _is_gzip(filename): + return filename.endswith(".gz") and not filename.endswith(".tar.gz") + + +def _is_zip(filename): + return filename.endswith(".zip") + + +def extract_archive(from_path, to_path=None, remove_finished=False): + if to_path is None: + to_path = os.path.dirname(from_path) + + if _is_tar(from_path): + with tarfile.open(from_path, 'r') as tar: + tar.extractall(path=to_path) + elif _is_targz(from_path) or _is_tgz(from_path): + with tarfile.open(from_path, 'r:gz') as tar: + tar.extractall(path=to_path) + elif _is_tarxz(from_path): + # .tar.xz archive only supported in Python 3.x + with tarfile.open(from_path, 'r:xz') as tar: + tar.extractall(path=to_path) + elif _is_gzip(from_path): + to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0]) + with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f: + out_f.write(zip_f.read()) + elif _is_zip(from_path): + with zipfile.ZipFile(from_path, 'r') as z: + z.extractall(to_path) + else: + raise ValueError("Extraction of {} not supported".format(from_path)) + + if remove_finished: + os.remove(from_path) + + +def download_and_extract_archive(url, download_root, extract_root=None, filename=None, + md5=None, remove_finished=False): + download_root = os.path.expanduser(download_root) + if extract_root is None: + extract_root = download_root + if not filename: + filename = os.path.basename(url) + + download_url_to_local(url, filename, download_root, md5) + + archive = os.path.join(download_root, filename) + print("Extracting {} to {}".format(archive, extract_root)) + extract_archive(archive, extract_root, remove_finished) diff --git a/python/jittor_utils/pack_offline.py b/python/jittor_utils/pack_offline.py new file mode 100644 index 00000000..aca84807 --- /dev/null +++ b/python/jittor_utils/pack_offline.py @@ -0,0 +1,93 @@ +urls = [ + ("https://cg.cs.tsinghua.edu.cn/jittor/assets/dnnl_lnx_2.2.0_cpu_gomp.tgz", "dnnl_lnx_2.2.0_cpu_gomp.tgz"), + ("https://cg.cs.tsinghua.edu.cn/jittor/assets/dnnl_lnx_2.2.0_cpu_gomp_aarch64.tgz", "dnnl_lnx_2.2.0_cpu_gomp_aarch64.tgz"), + ("https://codeload.github.com/NVIDIA/cub/tar.gz/1.11.0", "cub-1.11.0.tgz"), + ("https://codeload.github.com/Jittor/cutt/zip/v1.2", "cutt-1.2.zip"), + ("https://codeload.github.com/NVIDIA/nccl/tar.gz/v2.8.4-1", "nccl.tgz"), + ("https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz", "train-images-idx3-ubyte.gz"), + ("https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz", "train-labels-idx1-ubyte.gz"), + ("https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz", "t10k-images-idx3-ubyte.gz"), + ("https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz", "t10k-labels-idx1-ubyte.gz") +] + +import urllib +from pathlib import Path +import os +import glob +import shutil +import sys + +cpath = os.path.join(str(Path.home()), ".cache", "jittor", "offpack") +os.makedirs(cpath+"/python/jittor_offline", exist_ok=True) + + +for url, file_path in urls: + file_path = os.path.join(cpath, "python/jittor_offline", file_path) + print("download", url, file_path) + urllib.request.urlretrieve( + url, file_path + ) + +with open(os.path.join(cpath, "MANIFEST.in"), "w") as f: + f.write("include python/jittor_offline/*") +with open(os.path.join(cpath, "__init__.py"), "w") as f: + f.write("") +with open(os.path.join(cpath, "setup.py"), "w") as f: + f.write(""" +import setuptools + + +setuptools.setup( + name="jittor_offline", + version="0.0.7", + author="jittor", + author_email="jittor@qq.com", + description="jittor project", + long_description="jittor_offline", + long_description_content_type="text/markdown", + url="https://github.com/jittor/jittor", + project_urls={ + "Bug Tracker": "https://github.com/jittor/jittor/issues", + }, + classifiers=[ + "Programming Language :: Python :: 3", + "Operating System :: OS Independent", + ], + packages=["jittor_offline"], + package_dir={"": "python"}, + package_data={'': ['*', '*/*', '*/*/*','*/*/*/*','*/*/*/*/*','*/*/*/*/*/*']}, + python_requires=">=3.7", + install_requires=[ + "jittor>=1.3.4.16", + ], +) +""") + + +def callback(func, path, exc_info): + print(f"remove \"{path}\" failed.") + +def rmtree(path): + if os.path.isdir(path): + print(f"remove \"{path}\" recursive.") + shutil.rmtree(path, onerror=callback) + +def remove_tmpfile(): + dist_file = home_path+"/dist" + egg_file = glob.glob(home_path+"/**/*egg-info") + rmtree(dist_file) + for e in egg_file: + rmtree(e) + +def run_cmd(cmd): + print("[CMD]", cmd) + assert os.system(cmd)==0 + +home_path = cpath +os.chdir(cpath) +remove_tmpfile() + +run_cmd(f"{sys.executable} ./setup.py sdist") +run_cmd(f"{sys.executable} -m twine upload dist/*") + +remove_tmpfile() \ No newline at end of file diff --git a/python/jittor_utils/pip_publish.py b/python/jittor_utils/pip_publish.py new file mode 100644 index 00000000..80a45dce --- /dev/null +++ b/python/jittor_utils/pip_publish.py @@ -0,0 +1,34 @@ +import os +import glob +import shutil +import sys + +home_path = os.path.join(os.path.dirname(__file__), "..", "..") +home_path = os.path.abspath(home_path) + +def callback(func, path, exc_info): + print(f"remove \"{path}\" failed.") + +def rmtree(path): + if os.path.isdir(path): + print(f"remove \"{path}\" recursive.") + shutil.rmtree(path, onerror=callback) + +def remove_tmpfile(): + dist_file = home_path+"/dist" + egg_file = glob.glob(home_path+"/**/*egg-info") + rmtree(dist_file) + for e in egg_file: + rmtree(e) + +def run_cmd(cmd): + print("[CMD]", cmd) + assert os.system(cmd)==0 + +os.chdir(home_path) +remove_tmpfile() + +run_cmd(f"{sys.executable} ./setup.py sdist") +run_cmd(f"{sys.executable} -m twine upload dist/*") + +remove_tmpfile() \ No newline at end of file diff --git a/python/jittor_utils/query_cuda_cc.py b/python/jittor_utils/query_cuda_cc.py new file mode 100644 index 00000000..75205fe4 --- /dev/null +++ b/python/jittor_utils/query_cuda_cc.py @@ -0,0 +1,25 @@ +import ctypes +import os +if "CUDA_VISIBLE_DEVICES" in os.environ: + del os.environ["CUDA_VISIBLE_DEVICES"] +if os.name == 'nt': + cuda_driver = ctypes.CDLL("nvcuda") +else: + cuda_driver = ctypes.CDLL("libcuda.so") +driver_version = ctypes.c_int() +r = cuda_driver.cuDriverGetVersion(ctypes.byref(driver_version)) +assert r == 0 +v = driver_version.value + +dcount = ctypes.c_int() +cuda_driver.cuInit(0) +r = cuda_driver.cuDeviceGetCount(ctypes.byref(dcount)) + +for i in range(dcount.value): + dev = ctypes.c_void_p() + major = ctypes.c_int() + minor = ctypes.c_int() + assert 0 == cuda_driver.cuDeviceGet(ctypes.byref(dev), i) + assert 0 == cuda_driver.cuDeviceGetAttribute(ctypes.byref(major), 75, dev) + assert 0 == cuda_driver.cuDeviceGetAttribute(ctypes.byref(minor), 76, dev) + print(major.value*10+minor.value) diff --git a/python/jittor_utils/ring_buffer.py b/python/jittor_utils/ring_buffer.py new file mode 100644 index 00000000..57f9f389 --- /dev/null +++ b/python/jittor_utils/ring_buffer.py @@ -0,0 +1,268 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import multiprocessing as mp +import numpy as np +import ctypes +import random +import pickle +import ctypes + +recv_raw_call = 0.0 + +class RingBufferAllocator: + def __init__(self, size): + self.size = size + self.l = mp.Value(ctypes.c_longlong, 0, lock=False) + self.r = mp.Value(ctypes.c_longlong, 0, lock=False) + self.is_full = mp.Value(ctypes.c_bool, False, lock=False) + self.lock = mp.Lock() + self.cv = mp.Condition(self.lock) + + def __repr__(self): + l = self.l.value + r = self.r.value + is_full = self.is_full.value + if is_full: + cap = 0 + else: + cap = (r - l) / self.size + if cap<=0: cap += 1 + return f"Buffer(free={cap*100:.3f}% l={l} r={r} size={self.size})" + + def alloc_with_lock(self, size): + with self.lock: + while True: + location = self.alloc(size) + if location is not None: break + self.cv.wait() + return location + + def free_with_lock(self, size): + with self.lock: + location = self.free(size) + self.cv.notify() + return location + + def clear(self): + with self.lock: + self.l.value = 0 + self.r.value = 0 + self.is_full.value = False + + def alloc(self, size): + if size > self.size: + raise RuntimeError(f"Buffer size too small {self.size}<{size}") + l = self.l.value + r = self.r.value + is_full = self.is_full.value + if is_full: return None + if l == r and l > 0: + self.l.value = self.r.value = l = r = 0 + # [l, r) + if r > l: + freed = r - l + if freed < size: + # |----l......r---| + # |----#########--| + return None + # |----l......r---| + # |----#####------| + location = l + self.l.value = l = l + size + else: + freed = self.size - l + if freed < size: + # |.....r------l...| + # |------------####### + if size > r: + # |.....r------l...| + # |#######----------- + return None + # |.....r------l...| + # |#####----------- + if size == r: + self.is_full.value = is_full= True + location = 0 + self.l.value = l = size + else: + # |.....r------l...| + # |------------##--| + location = l + if freed == size: + self.l.value = l = 0 + else: + self.l.value = l = l + size + if l == r: + self.is_full.value = is_full = True + return location + + def free(self, size): + l = self.l.value + r = self.r.value + is_full = self.is_full.value + if size==0: return r + if is_full: + self.is_full.value = is_full = False + elif l == r: + return None + location = r + self.r.value = r = r + size + if r > self.size: + location = 0 + self.r.value = r = size + elif r == self.size: + self.r.value = r = 0 + return location + +def str_to_char_array(s, array_len): + if len(s) > array_len: s = s[:array_len] + a = np.array(s, dtype='c') + if len(s) < array_len: + a = np.pad(a, (0,array_len-len(s)), constant_values=' ') + return a + +def char_array_to_str(a): + return str(a.tobytes(), 'ascii').strip() + +class RingBuffer: + def __init__(self, buffer): + self.allocator = RingBufferAllocator(len(buffer)) + self.buffer = buffer + + def clear(self): self.allocator.clear() + + def send_int(self, data): + # int: int64[1] + # data + self.send_raw(np.array([data], dtype='int64')) + def recv_int(self): + return int(self.recv_raw(8, (1,), 'int64')[0]) + + def send_float(self, data): + # float: float64[1] + # data + self.send_raw(np.array([data], dtype='float64')) + def recv_float(self): + return float(self.recv_raw(8, (1,), 'float64')[0]) + + def send_str(self, data): + # str: int64[1] char[len] + # len data + data = np.array(data, dtype='c') + self.send_int(data.nbytes) + self.send_raw(data) + def recv_str(self): + nbytes = self.recv_int() + data = self.recv_raw(nbytes, nbytes, 'c') + return str(data.tostring(), 'ascii') + + def send_ndarray(self, data): + # str: int64[1] char[8] int64[1] int64[slen] char[nbytes] + # slen dtype nbytes shape data + shape = data.shape + slen = len(shape) + self.send_int(slen) + self.send_fix_len_str(str(data.dtype)) + self.send_int(data.nbytes) + self.send_raw(np.array(shape, dtype='int64')) + self.send_raw(data) + + def recv_ndarray(self): + slen = self.recv_int() + dtype = self.recv_fix_len_str() + nbytes = self.recv_int() + shape = self.recv_raw(slen*8, slen, 'int64') + data = self.recv_raw(nbytes, shape, dtype) + return data + + def send_tuple(self, data): + # tuple: int64[1] .... + # len + length = len(data) + self.send_int(length) + for a in data: + self.send(a) + def recv_tuple(self): + length = self.recv_int() + return tuple(self.recv() for i in range(length)) + + def send_list(self, data): + # list: int64[1] .... + # len + length = len(data) + self.send_int(length) + for a in data: + self.send(a) + + def recv_list(self): + length = self.recv_int() + return [self.recv() for i in range(length)] + + def send_pickle(self, data): + # pickle: int64[1] char[len] + # len data + data = pickle.dumps(data) + data = np.frombuffer(data, dtype='c') + self.send_int(data.nbytes) + self.send_raw(data) + + def recv_pickle(self): + nbytes = self.recv_int() + data = self.recv_raw(nbytes, nbytes, 'c') + return pickle.loads(data.tostring()) + + def __repr__(self): + return f"{self.allocator}@0x{hex(ctypes.addressof(self.buffer))}" + + def send_raw(self, data): + assert isinstance(data, np.ndarray) # and data.flags.c_contiguous + with self.allocator.lock: + location = self.allocator.alloc(data.nbytes) + while location is None: + self.allocator.cv.wait() + location = self.allocator.alloc(data.nbytes) + window = np.ndarray(shape=data.shape, dtype=data.dtype, + buffer=self.buffer, offset=location) + window[:] = data + self.allocator.cv.notify() + assert window.nbytes == data.nbytes + + def recv_raw(self, nbytes, shape, dtype): + global recv_raw_call + recv_raw_call += 1 + with self.allocator.lock: + location = self.allocator.free(nbytes) + while location is None: + self.allocator.cv.wait() + location = self.allocator.free(nbytes) + data = np.ndarray(shape=shape, dtype=dtype, + buffer=self.buffer, offset=location).copy() + self.allocator.cv.notify() + assert data.nbytes == nbytes + return data + + def send_fix_len_str(self, s, array_len=8): + data = str_to_char_array(s, array_len) + self.send_raw(data) + + def recv_fix_len_str(self, array_len=8): + data = self.recv_raw(8, 8, 'c') + return char_array_to_str(data) + + def send(self, data): + ts = type(data).__name__ + send = getattr(self, "send_"+ts, self.send_pickle) + self.send_fix_len_str(ts) + send(data) + + def recv(self): + ts = self.recv_fix_len_str() + recv = getattr(self, "recv_"+ts, self.recv_pickle) + return recv() + diff --git a/python/jittor_utils/save_pytorch.py b/python/jittor_utils/save_pytorch.py new file mode 100644 index 00000000..ecbb49a6 --- /dev/null +++ b/python/jittor_utils/save_pytorch.py @@ -0,0 +1,151 @@ +import jittor as jt +from jittor import nn +import io +import pickle +import sys +import torch + +class HalfStorage: pass +class BFloat16Storage: pass +class FloatStorage: pass +class LongStorage: pass +class IntStorage: pass +class ShortStorage: pass +class CharStorage: pass +class BoolStorage: pass +HalfStorage.__module__ = "torch" +BFloat16Storage.__module__ = "torch" +FloatStorage.__module__ = "torch" +LongStorage.__module__ = "torch" +IntStorage.__module__ = "torch" +ShortStorage.__module__ = "torch" +CharStorage.__module__ = "torch" +BoolStorage.__module__ = "torch" +def _rebuild_tensor_v2(*args): pass +_rebuild_tensor_v2.__module__ = "torch._utils" + +targets = [HalfStorage, BFloat16Storage, FloatStorage, LongStorage, IntStorage, ShortStorage, CharStorage, BoolStorage, _rebuild_tensor_v2] + +def swap_targets(targets): + original_targets = [] + for target in targets: + original_targets.append(sys.modules[target.__module__].__dict__.get(target.__name__, target)) + sys.modules[target.__module__].__dict__[target.__name__] = target + return original_targets + +class TensorStorage: + def __init__(self, data): + self.data = data + +class TensorWrapper: + def __init__(self, data): + self.data = data + def __reduce__(self): + a = tuple(self.data.shape) + # calc stride + stride = [1] + for i in range(len(a)-1, 0, -1): + stride.append(stride[-1]*a[i]) + stride = stride[::-1] + + return (_rebuild_tensor_v2, ( + TensorStorage(self.data), + 0, + tuple(self.data.shape), + tuple(stride), + False, + {} + )) + +dtype_map = { + "float16": HalfStorage, + "bfloat16": BFloat16Storage, + "float32": FloatStorage, + "int64": LongStorage, + "int32": IntStorage, + "int16": ShortStorage, + "int8": CharStorage, + "bool": BoolStorage, +} + +def save_pytorch(path, obj): + + serialized_storages = [] + # dfs and wrap jt.Var into TensorWrapper + def dfs(obj): + if isinstance(obj, jt.Var): + return TensorWrapper(obj) + elif isinstance(obj, dict): + return {k: dfs(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [dfs(x) for x in obj] + elif isinstance(obj, tuple): + return tuple(dfs(x) for x in obj) + else: + return obj + + def persistent_id(obj): + if isinstance(obj, TensorStorage): + storage_type = dtype_map[str(obj.data.dtype)] + storage_key = len(serialized_storages) + serialized_storages.append(obj.data) + storage_numel = obj.data.numel() + location = 'cpu' + return ('storage', + storage_type, + storage_key, + location, + storage_numel) + return None + + obj = dfs(obj) + data_buf = io.BytesIO() + pickle_protocol = 2 + pickler = pickle.Pickler(data_buf, protocol=pickle_protocol) + pickler.persistent_id = persistent_id + global targets + targets = swap_targets(targets) + pickler.dump(obj) + targets = swap_targets(targets) + data_value = data_buf.getvalue() + + # use previous pytorch code to save data + # from torch.serialization import _open_zipfile_writer + # with _open_zipfile_writer(path) as zip_file: + # print(data_value) + # zip_file.write_record('data.pkl', data_value, len(data_value)) + # zip_file.write_record('byteorder', sys.byteorder, len(sys.byteorder)) + # for i, v in enumerate(serialized_storages): + # b = v.numpy().tobytes() + # zip_file.write_record(f'data/{i}', b, len(b)) + + import os + path_base_name = os.path.basename(path).split(".")[0] + contents = jt.ZipFile(path, "w") + def write(name, data): + if isinstance(data, str): + write(name, data.encode()) + elif isinstance(data, bytes): + import ctypes + pointer = ctypes.cast(data, ctypes.c_void_p).value + contents.write(path_base_name+'/'+name, pointer, len(data)) + elif isinstance(data, jt.Var): + contents.write(path_base_name+'/'+name, data.raw_ptr, data.nbytes) + else: + raise TypeError(f"unsupported type {type(data)}") + write("data.pkl", data_value) + write("byteorder", sys.byteorder) + for i, v in enumerate(serialized_storages): + write(f"data/{i}", v) + write("version", "3") + del contents + + +if __name__ == "__main__": + linear = nn.Linear(3, 3) + save_pytorch("linear.bin", linear.state_dict()) + + import torch + res = torch.load("linear.bin") + print(res) + print(linear.state_dict()) \ No newline at end of file diff --git a/python/jittor_utils/student_queue.py b/python/jittor_utils/student_queue.py new file mode 100644 index 00000000..b55e2c1e --- /dev/null +++ b/python/jittor_utils/student_queue.py @@ -0,0 +1,68 @@ +# *************************************************************** +# Copyright (c) 2023 Jittor. All Rights Reserved. +# Maintainers: +# Dun Liang . +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +from socketserver import ThreadingTCPServer +import socket +import os +import sys +import threading + +key_queue = {} + +def handle_connect(req:socket.socket, c_addr, server): + print("get connect", c_addr, req) + skey = req.recv(1024).decode() + print("get skey", skey) + with lock: + if skey not in key_queue: + key_queue[skey] = [] + queue = key_queue[skey] + queue.append(req) + + req.send(str(len(queue)-1).encode()) + while True: + buf = req.recv(1024).decode() + print(buf) + with lock: + if len(buf) == 0: + for i,r in enumerate(queue): + if r is req: + for j in range(i+1, len(queue)): + queue[j].send(str(j-1).encode()) + del queue[i] + print("queue size", len(queue)) + break + break + + +def wait_queue(): + global s + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.connect(("127.0.0.1", 8900)) + s.sendall(skey.encode()) + while True: + buf = s.recv(1024).decode() + if len(buf) == 0: + print("Cannot connect to queue server, please report this issue to admin.") + sys.exit(1) + if buf == '0': + print("Begin") + os.system(f"sleep {os.environ.get('SWAIT', '60')} && bash -c ' if kill -9 {os.getpid()} 2>/dev/null; then echo Timeout; fi; ' &") + break + else: + print("Pending", buf) + + + +if "SKEY" in os.environ: + skey = os.environ["SKEY"] + wait_queue() +else: + lock = threading.Lock() + server = ThreadingTCPServer(("127.0.0.1", 8900), handle_connect) + server.serve_forever() diff --git a/python/jittor_utils/translator.py b/python/jittor_utils/translator.py new file mode 100644 index 00000000..39a71803 --- /dev/null +++ b/python/jittor_utils/translator.py @@ -0,0 +1,99 @@ +#!python3 +import os, json +from pathlib import Path +dirname = os.path.dirname(__file__) + +jittor_root = os.path.join(dirname, "..", "..") +print(jittor_root) + +all_src_md = [] + +for r, _, f in os.walk(jittor_root): + for fname in f: + if not fname.endswith(".src.md"): continue + all_src_md.append(os.path.realpath(os.path.join(r, fname))) + +def check_is_en(src): + en_cnt = 0 + for c in src: en_cnt += str.isascii(c) + return en_cnt == len(src) + +def check_is_both(src): + if src.startswith("!"): + return True + return len(src) < 2 + +def splite_markdown_blocks(src): + ''' split markdown document into text, code, table blocks + ''' + blocks = [] + block = "" + status = "text" + + def commit_block(): + blocks.append((block, status)) + + for line in src.split('\n'): + line = line + "\n" + if line.startswith("```"): + assert status in ["text", "code"] + if status == "text": + commit_block() + status = "code" + block = line + elif status == "code": + block += line + commit_block() + status = "text" + block = "" + elif line.strip().startswith('|') and line.strip().endswith('|'): + assert status in ["text", "table"] + if status == "text": + commit_block() + status = "table" + block = line + else: + block += line + else: + if status == "table": + commit_block() + status = "text" + block = line + else: + block += line + if status != "code": + commit_block() + return blocks + +for mdname in all_src_md: + print(mdname) + with open(mdname, "r", encoding='utf8') as f: + src = f.read() + + src_blocks = splite_markdown_blocks(src) + + en_src = "" + cn_src = "" + for block, status in src_blocks: + if status == "code" or status == "table": + en_src += block + cn_src += block + else: + en_s = [] + cn_s = [] + for line in block.split('\n'): + if check_is_both(line): + en_s.append(line) + cn_s.append(line) + elif check_is_en(line): + en_s.append(line) + else: + cn_s.append(line) + en_src += "\n".join(en_s) + cn_src += "\n".join(cn_s) + + with open(mdname.replace(".src.md", ".md"), 'w', encoding='utf8') as f: + f.write(en_src) + with open(mdname.replace(".src.md", ".cn.md"), 'w', encoding='utf8') as f: + f.write(cn_src) + \ No newline at end of file diff --git a/resnet.py b/resnet.py new file mode 100644 index 00000000..07b98e27 --- /dev/null +++ b/resnet.py @@ -0,0 +1,16 @@ +import jittor as jt +from jittor import nn +from jittor.models import resnet50 +import time + +jt.flags.use_cuda = 1 + +net = resnet50() +x = jt.ones(2, 3, 224, 224) +y = net(x) +y.sync() +start = time.time() +for i in range(100): + y = net(x) + y.sync() +print(time.time() - start) diff --git a/run.sh b/run.sh new file mode 100644 index 00000000..c1a9a24a --- /dev/null +++ b/run.sh @@ -0,0 +1 @@ +python test.py 1 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..beec50f6 --- /dev/null +++ b/setup.py @@ -0,0 +1,74 @@ +error_msg = """Jittor only supports Linux and macOS currently. +For other OS, use Jittor may be risky. +If you insist on installing, please set the environment variable : export FORCE_INSTALL=1 +We strongly recommend docker installation: + +# CPU only (Linux) +>>> docker run -it --network host jittor/jittor +# CPU and CUDA (Linux) +>>> docker run -it --network host jittor/jittor-cuda +# CPU only (Mac and Windows) +>>> docker run -it -p 8888:8888 jittor/jittor + +Reference: +1. Windows/Mac/Linux install Jittor in Docker: https://cg.cs.tsinghua.edu.cn/jittor/tutorial/2020-5-15-00-00-docker/ +""" +from warnings import warn +import os +import platform + +if not platform.system() in ['Linux', 'Darwin']: + assert os.environ.get("FORCE_INSTALL", '0') != '1', error_msg + +import setuptools +from setuptools import setup, find_packages +import os +import sys + +path = os.path.dirname(__file__) +with open(os.path.join(path, "README.md"), "r", encoding='utf8') as fh: + long_description = fh.read() + +with open(os.path.join(path, "python/jittor/__init__.py"), "r", encoding='utf8') as fh: + for line in fh: + if line.startswith('__version__'): + version = line.split("'")[1] + break + else: + raise RuntimeError("Unable to find version string.") + +version_require = (3,7) +if os.name == 'nt': + version_require = (3,8) +if sys.version_info < version_require: + raise RuntimeError("Python version not match, require %s, current %s" + %(version_require, sys.version_info)) + +setuptools.setup( + name='jittor', + version=version, + # scripts=[], + author="Jittor Group", + author_email="ran.donglang@gmail.com", + description="a Just-in-time(JIT) deep learning framework", + long_description=long_description, + long_description_content_type="text/markdown", + url="http://jittor.org", + # packages=setuptools.find_packages(), + python_requires='>=3.7', + + packages=["jittor", "jittor.test", "jittor.models", "jittor.utils", "jittor_utils"], + package_dir={'': 'python'}, + package_data={'': ['*', '*/*', '*/*/*','*/*/*/*','*/*/*/*/*','*/*/*/*/*/*']}, + # include_package_data=True, + install_requires=[ + "numpy<2.0", + "tqdm", + "pillow", + "astunparse", + 'pywin32 >= 1.0 ; platform_system=="Windows"' + ], + ) + +# upload to pip: +# rm -rf dist && python3.7 ./setup.py sdist && python3.7 -m twine upload dist/*