mirror of https://github.com/Jittor/Jittor
add sphinx doc
This commit is contained in:
parent
e5265505df
commit
126f92a6f3
|
@ -21,4 +21,5 @@ venv/
|
||||||
!README.md
|
!README.md
|
||||||
!README.cn.md
|
!README.cn.md
|
||||||
python/jittor.egg-info
|
python/jittor.egg-info
|
||||||
dist/
|
dist/
|
||||||
|
!doc/source/*
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
# Minimal makefile for Sphinx documentation
|
||||||
|
#
|
||||||
|
|
||||||
|
# You can set these variables from the command line, and also
|
||||||
|
# from the environment for the first two.
|
||||||
|
SPHINXOPTS ?=
|
||||||
|
SPHINXBUILD ?= sphinx-build
|
||||||
|
SOURCEDIR = source
|
||||||
|
BUILDDIR = build
|
||||||
|
|
||||||
|
# Put it first so that "make" without argument is like "make help".
|
||||||
|
help:
|
||||||
|
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||||
|
|
||||||
|
.PHONY: help Makefile
|
||||||
|
|
||||||
|
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||||
|
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||||
|
%: Makefile
|
||||||
|
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
|
@ -0,0 +1,18 @@
|
||||||
|
# sudo python3.7 -m pip install \
|
||||||
|
# recommonmark \
|
||||||
|
# sphinx sphinx-autobuild sphinx_rtd_theme \
|
||||||
|
# sphinx-autobuild \
|
||||||
|
# --timeout 100
|
||||||
|
|
||||||
|
|
||||||
|
bpath=$(readlink -f "${BASH_SOURCE[0]}")
|
||||||
|
bpath=$(dirname "${bpath}")
|
||||||
|
|
||||||
|
jittor_path=$(readlink -f "${bpath}/..")
|
||||||
|
|
||||||
|
echo "[doc path] $bpath"
|
||||||
|
echo "[jittor path] $jittor_path"
|
||||||
|
|
||||||
|
export PYTHONPATH=$jittor_path/python
|
||||||
|
cd $bpath
|
||||||
|
sphinx-autobuild -b html source build
|
|
@ -0,0 +1 @@
|
||||||
|
../../README.cn.md
|
|
@ -0,0 +1,98 @@
|
||||||
|
# Configuration file for the Sphinx documentation builder.
|
||||||
|
#
|
||||||
|
# This file only contains a selection of the most common options. For a full
|
||||||
|
# list see the documentation:
|
||||||
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||||
|
|
||||||
|
# -- Path setup --------------------------------------------------------------
|
||||||
|
|
||||||
|
# If extensions (or modules to document with autodoc) are in another directory,
|
||||||
|
# add these directories to sys.path here. If the directory is relative to the
|
||||||
|
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||||
|
#
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
jittor_path = os.path.abspath('../../python')
|
||||||
|
print(f"[jittor_path] {jittor_path}")
|
||||||
|
sys.path.insert(0, jittor_path)
|
||||||
|
|
||||||
|
import jittor
|
||||||
|
|
||||||
|
|
||||||
|
# -- Project information -----------------------------------------------------
|
||||||
|
|
||||||
|
project = 'Jittor'
|
||||||
|
copyright = '2020, Jittor'
|
||||||
|
author = 'Jittor'
|
||||||
|
|
||||||
|
# The full version, including alpha/beta/rc tags
|
||||||
|
release = '1.1.3.1'
|
||||||
|
|
||||||
|
# The language for content autogenerated by Sphinx. Refer to documentation
|
||||||
|
# for a list of supported languages.
|
||||||
|
#
|
||||||
|
# This is also used if you do content translation via gettext catalogs.
|
||||||
|
# Usually you set "language" from the command line for these cases.
|
||||||
|
language = 'zh_CN'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
|
# Add any Sphinx extension module names here, as strings. They can be
|
||||||
|
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||||
|
# ones.
|
||||||
|
extensions = [
|
||||||
|
'recommonmark',
|
||||||
|
'sphinx.ext.autodoc',
|
||||||
|
# Auto-generate section labels.
|
||||||
|
'sphinx.ext.autosectionlabel',
|
||||||
|
'sphinx.ext.viewcode',
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add any paths that contain templates here, relative to this directory.
|
||||||
|
templates_path = ['_templates']
|
||||||
|
|
||||||
|
# List of patterns, relative to source directory, that match files and
|
||||||
|
# directories to ignore when looking for source files.
|
||||||
|
# This pattern also affects html_static_path and html_extra_path.
|
||||||
|
exclude_patterns = []
|
||||||
|
|
||||||
|
|
||||||
|
# -- Options for HTML output -------------------------------------------------
|
||||||
|
|
||||||
|
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||||
|
# a list of builtin themes.
|
||||||
|
#
|
||||||
|
html_theme = 'alabaster'
|
||||||
|
|
||||||
|
# Add any paths that contain custom static files (such as style sheets) here,
|
||||||
|
# relative to this directory. They are copied after the builtin static files,
|
||||||
|
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||||
|
html_static_path = ['_static']
|
||||||
|
|
||||||
|
import sphinx_rtd_theme
|
||||||
|
html_theme = "sphinx_rtd_theme"
|
||||||
|
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
|
||||||
|
|
||||||
|
source_suffix = {
|
||||||
|
'.rst': 'restructuredtext',
|
||||||
|
'.txt': 'markdown',
|
||||||
|
'.md': 'markdown',
|
||||||
|
}
|
||||||
|
|
||||||
|
import recommonmark
|
||||||
|
from recommonmark.transform import AutoStructify
|
||||||
|
|
||||||
|
# At the bottom of conf.py
|
||||||
|
def setup(app):
|
||||||
|
app.add_config_value('recommonmark_config', {
|
||||||
|
# 'url_resolver': lambda url: github_doc_root + url,
|
||||||
|
'auto_toc_tree_section': 'Contents',
|
||||||
|
}, True)
|
||||||
|
app.add_transform(AutoStructify)
|
||||||
|
|
||||||
|
|
||||||
|
# Prefix document path to section labels, otherwise autogenerated labels would look like 'heading'
|
||||||
|
# rather than 'path/to/file:heading'
|
||||||
|
autosectionlabel_prefix_document = True
|
|
@ -0,0 +1,40 @@
|
||||||
|
.. Jittor documentation master file, created by
|
||||||
|
sphinx-quickstart on Mon May 18 23:05:53 2020.
|
||||||
|
You can adapt this file completely to your liking, but it should at least
|
||||||
|
contain the root `toctree` directive.
|
||||||
|
|
||||||
|
欢迎查阅计图文档
|
||||||
|
==================================
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 2
|
||||||
|
:caption: 内容一览:
|
||||||
|
|
||||||
|
README.cn.md
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 2
|
||||||
|
:caption: 模块API:
|
||||||
|
|
||||||
|
jittor
|
||||||
|
jittor.nn
|
||||||
|
jittor.models
|
||||||
|
jittor.optim
|
||||||
|
jittor.init
|
||||||
|
jittor.contrib
|
||||||
|
jittor.dataset
|
||||||
|
jittor.transform
|
||||||
|
|
||||||
|
|
||||||
|
.. toctree::
|
||||||
|
:maxdepth: 1
|
||||||
|
:caption: 其他:
|
||||||
|
|
||||||
|
todo
|
||||||
|
|
||||||
|
Indices and tables
|
||||||
|
==================
|
||||||
|
|
||||||
|
* :ref:`genindex`
|
||||||
|
* :ref:`modindex`
|
||||||
|
* :ref:`search`
|
|
@ -0,0 +1,10 @@
|
||||||
|
jittor.contrib
|
||||||
|
=====================
|
||||||
|
|
||||||
|
这里是Jittor的贡献代码模块模块的API文档,此模块的代码可能还没有完全成熟,我们将在后续迭代开发中继续完善,您可以通过`from jittor import contrib`来获取该模块。
|
||||||
|
|
||||||
|
```eval_rst
|
||||||
|
.. automodule:: jittor.contrib
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
```
|
|
@ -0,0 +1,11 @@
|
||||||
|
jittor.dataset
|
||||||
|
=====================
|
||||||
|
|
||||||
|
这里是Jittor的数据集模块的API文档,您可以通过`from jittor import dataset`来获取该模块。
|
||||||
|
|
||||||
|
```eval_rst
|
||||||
|
.. automodule:: jittor.dataset
|
||||||
|
:imported-members:
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
```
|
|
@ -0,0 +1,10 @@
|
||||||
|
jittor.init
|
||||||
|
=====================
|
||||||
|
|
||||||
|
这里是Jittor的参数初始化模块的API文档,您可以通过`from jittor import init`来获取该模块。
|
||||||
|
|
||||||
|
```eval_rst
|
||||||
|
.. automodule:: jittor.init
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
```
|
|
@ -0,0 +1,44 @@
|
||||||
|
jittor
|
||||||
|
=====================
|
||||||
|
|
||||||
|
## jittor
|
||||||
|
|
||||||
|
这里是Jittor主模块的API文档,您可以通过`import jittor`来获取该模块。
|
||||||
|
|
||||||
|
```eval_rst
|
||||||
|
.. automodule:: jittor
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
```
|
||||||
|
|
||||||
|
## jittor.core
|
||||||
|
|
||||||
|
以下为Jittor的内核API,内核API可以通过`jittor.core.XXX`或者`jittor.XXX`直接访问。
|
||||||
|
|
||||||
|
|
||||||
|
```eval_rst
|
||||||
|
.. automodule:: jittor_core
|
||||||
|
:imported-members:
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
```
|
||||||
|
|
||||||
|
## jittor.ops
|
||||||
|
|
||||||
|
这里是Jittor的基础算子模块的API文档,该API可以通过`jittor.ops.XXX`或者`jittor.XXX`直接访问。
|
||||||
|
|
||||||
|
```eval_rst
|
||||||
|
.. automodule:: jittor_core.ops
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
```
|
||||||
|
|
||||||
|
## jittor.Var
|
||||||
|
|
||||||
|
这里是Jittor的基础变量类的API文档。该API可以通过`my_jittor_var.XXX`直接访问。
|
||||||
|
|
||||||
|
```eval_rst
|
||||||
|
.. automodule:: jittor_core.Var
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
```
|
|
@ -0,0 +1,13 @@
|
||||||
|
jittor.models
|
||||||
|
=====================
|
||||||
|
|
||||||
|
这里是Jittor的骨干网络模块的API文档,您可以通过`from jittor import models`来获取该模块。
|
||||||
|
|
||||||
|
```eval_rst
|
||||||
|
|
||||||
|
.. automodule:: jittor.models
|
||||||
|
:members:
|
||||||
|
:imported-members:
|
||||||
|
:undoc-members:
|
||||||
|
```
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
jittor.nn
|
||||||
|
=====================
|
||||||
|
|
||||||
|
这里是Jittor的神经网络模块的API文档,您可以通过`from jittor import nn`来获取该模块。
|
||||||
|
|
||||||
|
```eval_rst
|
||||||
|
.. automodule:: jittor.nn
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
|
||||||
|
.. automodule:: jittor.nn
|
||||||
|
:imported-members:
|
||||||
|
:members: Pool, pool, AdaptiveAvgPool2d
|
||||||
|
:undoc-members:
|
||||||
|
|
||||||
|
.. autoclass:: jittor.nn.ReLU
|
||||||
|
:members:
|
||||||
|
.. autoclass:: jittor.nn.ReLU6
|
||||||
|
:members:
|
||||||
|
.. autoclass:: jittor.nn.LeakyReLU
|
||||||
|
:members:
|
||||||
|
.. autoclass:: jittor.nn.Softmax
|
||||||
|
:members:
|
||||||
|
```
|
|
@ -0,0 +1,10 @@
|
||||||
|
jittor.optim
|
||||||
|
=====================
|
||||||
|
|
||||||
|
这里是Jittor的优化器模块的API文档,您可以通过`from jittor import optim`来获取该模块。
|
||||||
|
|
||||||
|
```eval_rst
|
||||||
|
.. automodule:: jittor.optim
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
```
|
|
@ -0,0 +1,10 @@
|
||||||
|
jittor.transform
|
||||||
|
=====================
|
||||||
|
|
||||||
|
这里是Jittor的 数据变换 模块的API文档,您可以通过`from jittor import transform`来获取该模块。
|
||||||
|
|
||||||
|
```eval_rst
|
||||||
|
.. automodule:: jittor.transform
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
|
```
|
|
@ -0,0 +1,12 @@
|
||||||
|
TODO
|
||||||
|
=====================
|
||||||
|
|
||||||
|
## 文档相关
|
||||||
|
|
||||||
|
* 文档语法规范
|
||||||
|
* 文档加上教程链接
|
||||||
|
* MPI接口文档
|
||||||
|
* 文档自动更新
|
||||||
|
* 首页到文档的链接
|
||||||
|
* 模型库的文档(GAN,segmentation,detection...)
|
||||||
|
* 文档补全,重要的类加上使用example
|
|
@ -17,5 +17,6 @@ namespace jittor {
|
||||||
|
|
||||||
extern ncclComm_t comm;
|
extern ncclComm_t comm;
|
||||||
extern ncclUniqueId id;
|
extern ncclUniqueId id;
|
||||||
|
extern int nccl_device_id;
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
||||||
|
|
|
@ -6,6 +6,7 @@
|
||||||
// This file is subject to the terms and conditions defined in
|
// This file is subject to the terms and conditions defined in
|
||||||
// file 'LICENSE.txt', which is part of this source code package.
|
// file 'LICENSE.txt', which is part of this source code package.
|
||||||
// ***************************************************************
|
// ***************************************************************
|
||||||
|
#include "misc/cuda_flags.h"
|
||||||
#include "nccl_warper.h"
|
#include "nccl_warper.h"
|
||||||
#include "event_queue.h"
|
#include "event_queue.h"
|
||||||
|
|
||||||
|
@ -17,25 +18,33 @@ namespace jittor {
|
||||||
|
|
||||||
ncclComm_t comm;
|
ncclComm_t comm;
|
||||||
ncclUniqueId id;
|
ncclUniqueId id;
|
||||||
|
int nccl_device_id = 0;
|
||||||
|
|
||||||
|
|
||||||
struct nccl_initer {
|
struct nccl_initer {
|
||||||
|
|
||||||
nccl_initer() {
|
nccl_initer() {
|
||||||
if (!get_device_count()) return;
|
int device_count = get_device_count();
|
||||||
|
if (!device_count) return;
|
||||||
|
if (!inside_mpi) return;
|
||||||
if (mpi_world_rank == 0)
|
if (mpi_world_rank == 0)
|
||||||
checkCudaErrors(ncclGetUniqueId(&id));
|
checkCudaErrors(ncclGetUniqueId(&id));
|
||||||
MPI_CHECK(MPI_Bcast((void *)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD));
|
MPI_CHECK(MPI_Bcast((void *)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD));
|
||||||
LOGv << "NCCL init in device" << mpi_local_rank;
|
if (mpi_local_rank >= device_count)
|
||||||
checkCudaErrors(cudaSetDevice(mpi_local_rank));
|
LOGf << "mpi_local_rank(">>mpi_local_rank>>") is larger than device_count("
|
||||||
|
>>device_count>>")";
|
||||||
|
nccl_device_id = mpi_local_rank; // % device_count;
|
||||||
|
LOGv << "NCCL init in device" << nccl_device_id << "local_rank" << mpi_local_rank;
|
||||||
|
checkCudaErrors(cudaSetDevice(nccl_device_id));
|
||||||
event_queue.run_sync([]() {
|
event_queue.run_sync([]() {
|
||||||
checkCudaErrors(cudaSetDevice(mpi_local_rank));
|
checkCudaErrors(cudaSetDevice(nccl_device_id));
|
||||||
});
|
});
|
||||||
checkCudaErrors(ncclCommInitRank(&comm, mpi_world_size, id, mpi_world_rank));
|
checkCudaErrors(ncclCommInitRank(&comm, mpi_world_size, id, mpi_world_rank));
|
||||||
}
|
}
|
||||||
|
|
||||||
~nccl_initer() {
|
~nccl_initer() {
|
||||||
if (!get_device_count()) return;
|
if (!get_device_count()) return;
|
||||||
|
if (!inside_mpi) return;
|
||||||
checkCudaErrors(ncclCommDestroy(comm));
|
checkCudaErrors(ncclCommDestroy(comm));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -27,6 +27,7 @@ namespace jittor {
|
||||||
extern int mpi_world_size;
|
extern int mpi_world_size;
|
||||||
extern int mpi_world_rank;
|
extern int mpi_world_rank;
|
||||||
extern int mpi_local_rank;
|
extern int mpi_local_rank;
|
||||||
|
extern bool inside_mpi;
|
||||||
|
|
||||||
// @pyjt(world_size)
|
// @pyjt(world_size)
|
||||||
int _mpi_world_size();
|
int _mpi_world_size();
|
||||||
|
|
|
@ -30,6 +30,7 @@ namespace jittor {
|
||||||
int mpi_world_size = 1;
|
int mpi_world_size = 1;
|
||||||
int mpi_world_rank = 0;
|
int mpi_world_rank = 0;
|
||||||
int mpi_local_rank = 0;
|
int mpi_local_rank = 0;
|
||||||
|
bool inside_mpi = false;
|
||||||
|
|
||||||
int _mpi_world_size() {
|
int _mpi_world_size() {
|
||||||
return mpi_world_size;
|
return mpi_world_size;
|
||||||
|
@ -73,6 +74,8 @@ static void getHostName(char* hostname, int maxlen) {
|
||||||
struct mpi_initer {
|
struct mpi_initer {
|
||||||
|
|
||||||
mpi_initer() {
|
mpi_initer() {
|
||||||
|
inside_mpi = !!getenv("OMPI_COMM_WORLD_SIZE");
|
||||||
|
if (!inside_mpi) return;
|
||||||
LOGvv << "MPI init...";
|
LOGvv << "MPI init...";
|
||||||
MPI_CHECK(MPI_Init(NULL, NULL));
|
MPI_CHECK(MPI_Init(NULL, NULL));
|
||||||
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size));
|
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size));
|
||||||
|
@ -95,6 +98,7 @@ mpi_initer() {
|
||||||
}
|
}
|
||||||
|
|
||||||
~mpi_initer() {
|
~mpi_initer() {
|
||||||
|
if (!inside_mpi) return;
|
||||||
MPI_CHECK(MPI_Finalize());
|
MPI_CHECK(MPI_Finalize());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,8 @@ with lock.lock_scope():
|
||||||
from jittor_core.ops import *
|
from jittor_core.ops import *
|
||||||
from . import compile_extern
|
from . import compile_extern
|
||||||
from .compile_extern import mkl_ops, mpi, mpi_ops
|
from .compile_extern import mkl_ops, mpi, mpi_ops
|
||||||
|
if core.get_device_count() == 0:
|
||||||
|
has_cuda = compile_extern.has_cuda = compiler.has_cuda = False
|
||||||
if has_cuda:
|
if has_cuda:
|
||||||
from .compile_extern import cudnn, curand, cublas
|
from .compile_extern import cudnn, curand, cublas
|
||||||
|
|
||||||
|
@ -29,92 +31,6 @@ import pickle
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
def dfs(scope, vars):
|
|
||||||
for v in scope.children.values():
|
|
||||||
if type(v) == Scope:
|
|
||||||
dfs(v, vars)
|
|
||||||
else:
|
|
||||||
vars.append(v)
|
|
||||||
|
|
||||||
def dfs_records(scope, records):
|
|
||||||
for v in scope.children.values():
|
|
||||||
if type(v) == Scope:
|
|
||||||
dfs_records(v, records)
|
|
||||||
for v in scope.records.values():
|
|
||||||
records.append(v)
|
|
||||||
|
|
||||||
class Scope:
|
|
||||||
def __init__(self, parent=None, name=None):
|
|
||||||
self.children = OrderedDict()
|
|
||||||
self.index = {}
|
|
||||||
self.records = OrderedDict()
|
|
||||||
if name == None:
|
|
||||||
self.name = self.full_name = ""
|
|
||||||
else:
|
|
||||||
self.name = name
|
|
||||||
self.full_name = parent.full_name + name + "/"
|
|
||||||
|
|
||||||
def get_scope(self, name, unique=True):
|
|
||||||
if not unique:
|
|
||||||
index = self.index.get(name, 0)
|
|
||||||
self.index[name] = index+1
|
|
||||||
name = name + f'_{index}'
|
|
||||||
if name not in self.children:
|
|
||||||
sub_scope = Scope(self, name)
|
|
||||||
self.children[name] = sub_scope
|
|
||||||
else:
|
|
||||||
sub_scope = self.children[name]
|
|
||||||
assert type(sub_scope) == Scope, f"Name {name} is a Var: {sub_scope}"
|
|
||||||
return sub_scope
|
|
||||||
|
|
||||||
def make_var(self, shape, dtype, init, name, unique):
|
|
||||||
if not unique:
|
|
||||||
index = self.index.get(name, 0)
|
|
||||||
self.index[name] = index+1
|
|
||||||
name = name + f'_{index}'
|
|
||||||
if name in self.children:
|
|
||||||
var = self.children[name]
|
|
||||||
assert type(var) == core.Var, f"Name {name} exist: {var}"
|
|
||||||
assert (shape is None or var.shape == shape) and var.dtype == dtype, f"Shape or dtype not match {var} != {dtype}{shape}"
|
|
||||||
return var
|
|
||||||
else:
|
|
||||||
full_name = self.full_name + name
|
|
||||||
if type(init) != core.Var:
|
|
||||||
if callable(init):
|
|
||||||
var = init(shape, dtype)
|
|
||||||
if type(var) != core.Var:
|
|
||||||
var = array(var)
|
|
||||||
else:
|
|
||||||
assert init != None
|
|
||||||
var = array(init)
|
|
||||||
else:
|
|
||||||
var = init
|
|
||||||
var.stop_fuse()
|
|
||||||
self.children[name] = var
|
|
||||||
var.name(full_name)
|
|
||||||
return var
|
|
||||||
|
|
||||||
def clean_index(self): self.index.clear()
|
|
||||||
|
|
||||||
def clean(self):
|
|
||||||
self.children.clear()
|
|
||||||
self.records.clear()
|
|
||||||
self.index.clear()
|
|
||||||
|
|
||||||
current_scope = Scope()
|
|
||||||
root_scope = current_scope
|
|
||||||
|
|
||||||
class _call_record_scope:
|
|
||||||
def __enter__(self): pass
|
|
||||||
def __exit__(self, *exc): pass
|
|
||||||
def __call__(self, func):
|
|
||||||
def inner(*args, **kw):
|
|
||||||
with self:
|
|
||||||
ret = func(*args, **kw)
|
|
||||||
record_in_scope(ret, "output")
|
|
||||||
return ret
|
|
||||||
return inner
|
|
||||||
|
|
||||||
class _call_no_record_scope:
|
class _call_no_record_scope:
|
||||||
def __enter__(self): pass
|
def __enter__(self): pass
|
||||||
def __exit__(self, *exc): pass
|
def __exit__(self, *exc): pass
|
||||||
|
@ -143,30 +59,6 @@ class flag_scope(_call_no_record_scope):
|
||||||
for k,v in self.flags_bk.items():
|
for k,v in self.flags_bk.items():
|
||||||
setattr(flags, k, v)
|
setattr(flags, k, v)
|
||||||
|
|
||||||
class var_scope(_call_record_scope):
|
|
||||||
def __init__(self, name="scope", unique=False, **jt_flags):
|
|
||||||
self.fs = flag_scope(**jt_flags)
|
|
||||||
self.name = name
|
|
||||||
self.unique = unique
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
global current_scope
|
|
||||||
self.prev = current_scope
|
|
||||||
try:
|
|
||||||
current_scope = current_scope.get_scope(self.name, self.unique)
|
|
||||||
current_scope.clean_index()
|
|
||||||
self.fs.__enter__()
|
|
||||||
except:
|
|
||||||
current_scope = self.prev
|
|
||||||
del self.prev
|
|
||||||
raise
|
|
||||||
|
|
||||||
def __exit__(self, *exc):
|
|
||||||
self.fs.__exit__(*exc)
|
|
||||||
global current_scope
|
|
||||||
current_scope = self.prev
|
|
||||||
del self.prev
|
|
||||||
|
|
||||||
single_log_capture = None
|
single_log_capture = None
|
||||||
|
|
||||||
class log_capture_scope(_call_no_record_scope):
|
class log_capture_scope(_call_no_record_scope):
|
||||||
|
@ -229,75 +121,7 @@ class profile_scope(_call_no_record_scope):
|
||||||
profiler.stop()
|
profiler.stop()
|
||||||
self.report.extend(profiler.report())
|
self.report.extend(profiler.report())
|
||||||
|
|
||||||
def make_var(shape=None, dtype="float32", init=None, name='var', unique=False):
|
|
||||||
return current_scope.make_var(shape, dtype, init, name, unique)
|
|
||||||
|
|
||||||
def find_vars(path=None):
|
|
||||||
scope = current_scope
|
|
||||||
if path is not None:
|
|
||||||
assert isinstance(path, str)
|
|
||||||
ns = path.split("/")
|
|
||||||
if ns[-1] == "":
|
|
||||||
ns.pop()
|
|
||||||
for n in ns: scope = scope.children[n]
|
|
||||||
if not isinstance(scope, Scope):
|
|
||||||
return [scope]
|
|
||||||
vars = []
|
|
||||||
dfs(scope, vars)
|
|
||||||
return vars
|
|
||||||
|
|
||||||
def find_var(path):
|
|
||||||
scope = current_scope
|
|
||||||
if path is not None:
|
|
||||||
assert isinstance(path, str)
|
|
||||||
ns = path.split("/")
|
|
||||||
for n in ns: scope = scope.children[n]
|
|
||||||
assert not isinstance(scope, Scope)
|
|
||||||
return scope
|
|
||||||
|
|
||||||
def find_records(path=None):
|
|
||||||
scope = current_scope
|
|
||||||
if path is not None:
|
|
||||||
assert isinstance(path, str)
|
|
||||||
ns = path.split("/")
|
|
||||||
if ns[-1] == "":
|
|
||||||
ns.pop()
|
|
||||||
for n in ns: scope = scope.children[n]
|
|
||||||
assert isinstance(scope, Scope)
|
|
||||||
records = []
|
|
||||||
dfs_records(scope, records)
|
|
||||||
return records
|
|
||||||
|
|
||||||
def find_record(path):
|
|
||||||
scope = current_scope
|
|
||||||
assert isinstance(path, str)
|
|
||||||
ns = path.split("/")
|
|
||||||
for n in ns[:-1]: scope = scope.children[n]
|
|
||||||
assert isinstance(scope, Scope)
|
|
||||||
return scope.records[ns[-1]]
|
|
||||||
|
|
||||||
def find_scope(path):
|
|
||||||
scope = current_scope
|
|
||||||
if path is not None:
|
|
||||||
assert isinstance(path, str)
|
|
||||||
ns = path.split("/")
|
|
||||||
if ns[-1] == "":
|
|
||||||
ns.pop()
|
|
||||||
for n in ns: scope = scope.children[n]
|
|
||||||
assert isinstance(scope, Scope)
|
|
||||||
return scope
|
|
||||||
|
|
||||||
def record_in_scope(self, name):
|
|
||||||
current_scope.records[name] = self
|
|
||||||
if isinstance(self, Var):
|
|
||||||
full_name = current_scope.full_name + name
|
|
||||||
self.name(full_name)
|
|
||||||
return self
|
|
||||||
|
|
||||||
Var.record_in_scope = record_in_scope
|
|
||||||
|
|
||||||
def clean():
|
def clean():
|
||||||
current_scope.clean()
|
|
||||||
import gc
|
import gc
|
||||||
# make sure python do a full collection
|
# make sure python do a full collection
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
@ -411,11 +235,8 @@ def squeeze(x, dim):
|
||||||
Var.squeeze = squeeze
|
Var.squeeze = squeeze
|
||||||
|
|
||||||
def clamp(x, min_v, max_v):
|
def clamp(x, min_v, max_v):
|
||||||
# TODO: change to x.maximum(min_v).minimum(max_v)
|
|
||||||
assert min_v <= max_v
|
assert min_v <= max_v
|
||||||
min_b = (x < min_v).int()
|
return x.maximum(min_v).minimum(max_v)
|
||||||
max_b = (x > max_v).int()
|
|
||||||
return x * (1 - min_b - max_b) + min_v * min_b + max_v * max_b
|
|
||||||
Var.clamp = clamp
|
Var.clamp = clamp
|
||||||
|
|
||||||
def type_as(a, b):
|
def type_as(a, b):
|
||||||
|
@ -456,32 +277,6 @@ def display_memory_info():
|
||||||
fileline = f"{os.path.basename(fileline.filename)}:{fileline.lineno}"
|
fileline = f"{os.path.basename(fileline.filename)}:{fileline.lineno}"
|
||||||
core.display_memory_info(fileline)
|
core.display_memory_info(fileline)
|
||||||
|
|
||||||
def import_vars(data):
|
|
||||||
''' Load variables into current scopes
|
|
||||||
example:
|
|
||||||
import_vars({"w":[1.0,2.0,3.0]})
|
|
||||||
jt.get_var([3], "float64", name="w", gen_index=False)
|
|
||||||
'''
|
|
||||||
for k in data:
|
|
||||||
v = data[k]
|
|
||||||
if type(v) != core.Var:
|
|
||||||
v = array(v).stop_fuse()
|
|
||||||
scopes = k.split("/")
|
|
||||||
scope = current_scope
|
|
||||||
for i in range(len(scopes)-1):
|
|
||||||
scope = scope.get_scope(scopes[i])
|
|
||||||
vname = scopes[-1]
|
|
||||||
assert vname not in scope.children, f"Var {k} exists. Please load_vars at the beginning"
|
|
||||||
v.name(k)
|
|
||||||
scope.children[vname] = v
|
|
||||||
|
|
||||||
def export_vars():
|
|
||||||
''' Export all vars into a dictionary
|
|
||||||
return: a dictionary, key is var name, value is numpy array
|
|
||||||
'''
|
|
||||||
data = { v.name():v.fetch_sync() for v in find_vars() }
|
|
||||||
return data
|
|
||||||
|
|
||||||
def load(path):
|
def load(path):
|
||||||
pkl_file = open(path, 'rb')
|
pkl_file = open(path, 'rb')
|
||||||
model_dict = pickle.load(pkl_file)
|
model_dict = pickle.load(pkl_file)
|
||||||
|
@ -489,7 +284,7 @@ def load(path):
|
||||||
|
|
||||||
class Module:
|
class Module:
|
||||||
def __init__(self, *args, **kw):
|
def __init__(self, *args, **kw):
|
||||||
__doc__ == 'doc'
|
pass
|
||||||
def execute(self, *args, **kw):
|
def execute(self, *args, **kw):
|
||||||
pass
|
pass
|
||||||
def __call__(self, *args, **kw):
|
def __call__(self, *args, **kw):
|
||||||
|
@ -498,8 +293,6 @@ class Module:
|
||||||
return self.__str__()
|
return self.__str__()
|
||||||
def _get_name(self):
|
def _get_name(self):
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
def __doc__(self):
|
|
||||||
pass
|
|
||||||
def __name__(self):
|
def __name__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -670,16 +463,12 @@ def make_module(func, exec_n_args=1):
|
||||||
def __init__(self, *args, **kw):
|
def __init__(self, *args, **kw):
|
||||||
self.args = args
|
self.args = args
|
||||||
self.kw = kw
|
self.kw = kw
|
||||||
self.__doc__ == 'doc'
|
|
||||||
def execute(self, *args):
|
def execute(self, *args):
|
||||||
return func(*args, *self.args, **self.kw)
|
return func(*args, *self.args, **self.kw)
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return 'str'
|
return f"{func.__name__}({self.extra_repr()})"
|
||||||
def __repr__(self):
|
|
||||||
return self.__str__()
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return ''
|
return ",".join(map(str, self.args))
|
||||||
|
|
||||||
return MakeModule
|
return MakeModule
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
import os, sys, shutil
|
import os, sys, shutil
|
||||||
from .compiler import *
|
from .compiler import *
|
||||||
from jittor_utils import run_cmd, get_version
|
from jittor_utils import run_cmd, get_version
|
||||||
from jittor.dataset.utils import download_url_to_local
|
from jittor.utils.misc import download_url_to_local
|
||||||
|
|
||||||
def search_file(dirs, name):
|
def search_file(dirs, name):
|
||||||
for d in dirs:
|
for d in dirs:
|
||||||
|
@ -256,6 +256,11 @@ def install_nccl(root_folder):
|
||||||
LOG.i("Downloading nccl...")
|
LOG.i("Downloading nccl...")
|
||||||
download_url_to_local(url, filename, root_folder, true_md5)
|
download_url_to_local(url, filename, root_folder, true_md5)
|
||||||
|
|
||||||
|
if core.get_device_count() == 0:
|
||||||
|
return
|
||||||
|
if not inside_mpi():
|
||||||
|
return
|
||||||
|
|
||||||
import tarfile
|
import tarfile
|
||||||
with tarfile.open(fullname, "r") as tar:
|
with tarfile.open(fullname, "r") as tar:
|
||||||
tar.extractall(root_folder)
|
tar.extractall(root_folder)
|
||||||
|
@ -269,7 +274,7 @@ def setup_nccl():
|
||||||
global nccl_ops, use_nccl
|
global nccl_ops, use_nccl
|
||||||
use_nccl = os.environ.get("use_nccl", "1")=="1"
|
use_nccl = os.environ.get("use_nccl", "1")=="1"
|
||||||
nccl_ops = None
|
nccl_ops = None
|
||||||
if not has_cuda or mpi is None:
|
if not has_cuda or not has_mpi:
|
||||||
use_nccl = False
|
use_nccl = False
|
||||||
return
|
return
|
||||||
if not use_nccl: return
|
if not use_nccl: return
|
||||||
|
@ -284,6 +289,7 @@ def setup_nccl():
|
||||||
|
|
||||||
make_cache_dir(nccl_path)
|
make_cache_dir(nccl_path)
|
||||||
nccl_home = install_nccl(nccl_path)
|
nccl_home = install_nccl(nccl_path)
|
||||||
|
if nccl_home is None: return
|
||||||
nccl_include_path = os.path.join(nccl_home, "build", "include")
|
nccl_include_path = os.path.join(nccl_home, "build", "include")
|
||||||
nccl_lib_path = os.path.join(nccl_home, "build", "lib")
|
nccl_lib_path = os.path.join(nccl_home, "build", "lib")
|
||||||
|
|
||||||
|
@ -343,8 +349,6 @@ def setup_mpi():
|
||||||
else:
|
else:
|
||||||
use_mpi = True
|
use_mpi = True
|
||||||
has_mpi = True
|
has_mpi = True
|
||||||
if not inside_mpi():
|
|
||||||
use_mpi = False
|
|
||||||
if not use_mpi:
|
if not use_mpi:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -388,6 +392,8 @@ def setup_mpi():
|
||||||
setattr(core.Var, k, warper(mpi_ops.__dict__[k]))
|
setattr(core.Var, k, warper(mpi_ops.__dict__[k]))
|
||||||
|
|
||||||
setup_mpi()
|
setup_mpi()
|
||||||
|
if not inside_mpi():
|
||||||
|
mpi = None
|
||||||
setup_nccl()
|
setup_nccl()
|
||||||
|
|
||||||
setup_cutt()
|
setup_cutt()
|
||||||
|
|
|
@ -187,19 +187,3 @@ def setitem(x, slices, value):
|
||||||
|
|
||||||
jt.Var.__getitem__ = jt.Var.slice_var = slice_var
|
jt.Var.__getitem__ = jt.Var.slice_var = slice_var
|
||||||
jt.Var.__setitem__ = setitem
|
jt.Var.__setitem__ = setitem
|
||||||
|
|
||||||
def adam(model, loss, lr=3e-4, betas=[0.9, 0.999], eps=1e-8):
|
|
||||||
ps = jt.find_vars(model)
|
|
||||||
gs = jt.grad(loss, ps)
|
|
||||||
with jt.var_scope('_'.join([model, 'adam']), unique=True):
|
|
||||||
adam_step = jt.make_var([1], init=jt.zeros)
|
|
||||||
adam_step += 1
|
|
||||||
for p,g in zip(ps,gs):
|
|
||||||
m = jt.make_var(p.shape, init=jt.zeros)
|
|
||||||
v = jt.make_var(p.shape, init=jt.zeros)
|
|
||||||
|
|
||||||
m.assign(betas[0] * m + (1-betas[0]) * g)
|
|
||||||
v.assign(betas[1] * v + (1-betas[1]) * g * g)
|
|
||||||
step_size = lr * jt.sqrt(1-betas[1]**adam_step) / (1-betas[0] ** adam_step)
|
|
||||||
p -= m * step_size / (jt.sqrt(v) + eps)
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
|
||||||
|
from .dataset import Dataset, ImageFolder
|
||||||
|
from .mnist import MNIST
|
||||||
|
from .voc import VOC
|
|
@ -23,7 +23,7 @@ import jittor as jt
|
||||||
|
|
||||||
dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset")
|
dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset")
|
||||||
mp_log_v = os.environ.get("mp_log_v", 0)
|
mp_log_v = os.environ.get("mp_log_v", 0)
|
||||||
mpi = jt.compile_extern.mpi
|
mpi = jt.mpi
|
||||||
|
|
||||||
class Worker:
|
class Worker:
|
||||||
def __init__(self, target, args, buffer_size):
|
def __init__(self, target, args, buffer_size):
|
||||||
|
|
|
@ -12,7 +12,7 @@ import gzip
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
# our lib jittor import
|
# our lib jittor import
|
||||||
from jittor.dataset.dataset import Dataset, dataset_root
|
from jittor.dataset.dataset import Dataset, dataset_root
|
||||||
from jittor.dataset.utils import ensure_dir, download_url_to_local
|
from jittor.utils.misc import ensure_dir, download_url_to_local
|
||||||
import jittor as jt
|
import jittor as jt
|
||||||
import jittor.transform as trans
|
import jittor.transform as trans
|
||||||
|
|
||||||
|
|
|
@ -8,74 +8,9 @@
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
|
|
||||||
import jittor as jt
|
import jittor as jt
|
||||||
import os
|
|
||||||
from six.moves import urllib
|
|
||||||
import hashlib
|
|
||||||
from tqdm import tqdm
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from collections.abc import Sequence, Mapping
|
from collections.abc import Sequence, Mapping
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from .. import lock
|
|
||||||
|
|
||||||
def ensure_dir(dir_path):
|
|
||||||
if not os.path.isdir(dir_path):
|
|
||||||
os.makedirs(dir_path)
|
|
||||||
|
|
||||||
def _progress():
|
|
||||||
pbar = tqdm(total=None)
|
|
||||||
|
|
||||||
def bar_update(block_num, block_size, total_size):
|
|
||||||
""" reporthook
|
|
||||||
@block_num: the num of downloaded data block
|
|
||||||
@block_size: the size of data block
|
|
||||||
@total_size: the total size of remote file
|
|
||||||
"""
|
|
||||||
if pbar.total is None and total_size:
|
|
||||||
pbar.total = total_size
|
|
||||||
progress_bytes = block_num * block_size
|
|
||||||
pbar.update(progress_bytes - pbar.n)
|
|
||||||
|
|
||||||
return bar_update
|
|
||||||
|
|
||||||
@lock.lock_scope()
|
|
||||||
def download_url_to_local(url, filename, root_folder, md5):
|
|
||||||
ensure_dir(root_folder)
|
|
||||||
file_path = os.path.join(root_folder, filename)
|
|
||||||
if check_file_exist(file_path, md5):
|
|
||||||
print("Data file has been downloaded and verified")
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
print('Downloading ' + url + ' to ' + file_path)
|
|
||||||
urllib.request.urlretrieve(
|
|
||||||
url, file_path,
|
|
||||||
reporthook=_progress()
|
|
||||||
)
|
|
||||||
except(urllib.error.URLError, IOError) as e:
|
|
||||||
raise e
|
|
||||||
if not check_file_exist(file_path, md5):
|
|
||||||
raise RuntimeError("File downloads failed.")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def check_file_exist(file_path, md5):
|
|
||||||
if not os.path.isfile(file_path):
|
|
||||||
return False
|
|
||||||
if md5 is None:
|
|
||||||
return True
|
|
||||||
return check_md5(file_path, md5)
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_md5(file_path, chunk_size=1024 * 1024):
|
|
||||||
md5 = hashlib.md5()
|
|
||||||
with open(file_path, 'rb') as f:
|
|
||||||
for chunk in iter(lambda: f.read(chunk_size), b''):
|
|
||||||
md5.update(chunk)
|
|
||||||
return md5.hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
def check_md5(file_path, md5, **kwargs):
|
|
||||||
return md5 == calculate_md5(file_path, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def get_random_list(n):
|
def get_random_list(n):
|
||||||
return list(np.random.permutation(range(n)))
|
return list(np.random.permutation(range(n)))
|
||||||
|
|
|
@ -12,7 +12,7 @@ import os
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import cv2
|
import cv2
|
||||||
from dataset import Dataset, dataset_root
|
from .dataset import Dataset, dataset_root
|
||||||
|
|
||||||
class VOC(Dataset):
|
class VOC(Dataset):
|
||||||
NUM_CLASSES = 21
|
NUM_CLASSES = 21
|
||||||
|
|
|
@ -1,52 +0,0 @@
|
||||||
import jittor as jt
|
|
||||||
import numpy as np
|
|
||||||
import sys, os
|
|
||||||
f32 = jt.float32
|
|
||||||
|
|
||||||
@jt.var_scope('linear')
|
|
||||||
def linear(x, n):
|
|
||||||
w = jt.make_var([x.shape[-1], n], init=lambda *a:
|
|
||||||
(jt.random(*a)-f32(0.5)) / f32(x.shape[-1])**f32(0.5))
|
|
||||||
b = jt.make_var([n], init=lambda *a: jt.random(*a)-f32(0.5))
|
|
||||||
return jt.matmul(x, w) + b
|
|
||||||
|
|
||||||
def relu(x): return jt.maximum(x, f32(0))
|
|
||||||
|
|
||||||
@jt.var_scope('model', unique=True)
|
|
||||||
def model(x):
|
|
||||||
x = linear(x, 10)
|
|
||||||
x = relu(x)
|
|
||||||
x = linear(x, 1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
np.random.seed(0)
|
|
||||||
jt.set_seed(3)
|
|
||||||
n = 1000
|
|
||||||
batch_size = 50
|
|
||||||
base_lr = 0.05
|
|
||||||
# we need to stop grad of global value to prevent memory leak
|
|
||||||
lr = f32(base_lr).name("lr").stop_grad()
|
|
||||||
|
|
||||||
def get_data(n):
|
|
||||||
for i in range(n):
|
|
||||||
x = np.random.rand(batch_size, 1)
|
|
||||||
y = x*x
|
|
||||||
yield np.float32(x), np.float32(y)
|
|
||||||
|
|
||||||
for i,(x,y) in enumerate(get_data(n)):
|
|
||||||
pred_y = model(x).name("pred_y")
|
|
||||||
loss = ((pred_y - y)**f32(2)).name("loss")
|
|
||||||
loss_mean = loss.mean()
|
|
||||||
|
|
||||||
ps = jt.find_vars('model')
|
|
||||||
gs = jt.grad(loss_mean, ps)
|
|
||||||
for p,g in zip(ps, gs):
|
|
||||||
p -= g * lr
|
|
||||||
if i>2:
|
|
||||||
assert prev == jt.liveness_info(), f"memory leak {prev} {jt.liveness_info()}"
|
|
||||||
prev = jt.liveness_info()
|
|
||||||
print(f"step {i}, loss = {loss_mean().sum()}")
|
|
||||||
|
|
||||||
# result is 0.0009948202641680837
|
|
||||||
result = 0.0009948202641680837
|
|
||||||
assert abs(loss_mean.data - result) < 1e-6
|
|
|
@ -43,38 +43,6 @@ jt.Var.__imatmul__ = lambda a,b: a.assign(matmul(a,b))
|
||||||
def get_init_var_rand(shape, dtype):
|
def get_init_var_rand(shape, dtype):
|
||||||
return jt.array(np.random.normal(0.0, 1.0, shape).astype(np.float32))
|
return jt.array(np.random.normal(0.0, 1.0, shape).astype(np.float32))
|
||||||
|
|
||||||
@jt.var_scope('conv')
|
|
||||||
def conv(x, in_planes, out_planes, kernel_size, padding, stride = 1, init_method=None):
|
|
||||||
Kw = kernel_size
|
|
||||||
Kh = kernel_size
|
|
||||||
_C = in_planes
|
|
||||||
Kc = out_planes
|
|
||||||
N,C,H,W = x.shape
|
|
||||||
|
|
||||||
assert C==_C
|
|
||||||
if init_method==None:
|
|
||||||
w = jt.make_var([Kc, _C, Kh, Kw], init=lambda *a: init.relu_invariant_gauss(*a, mode="fan_out"))
|
|
||||||
else:
|
|
||||||
w = jt.make_var([Kc, _C, Kh, Kw], init=init_method)
|
|
||||||
xx = x.reindex([N,Kc,C,(H+padding*2-kernel_size)//stride+1,(W+padding*2-kernel_size)//stride+1,Kh,Kw], [
|
|
||||||
'i0', # Nid
|
|
||||||
'i2', # Cid
|
|
||||||
f'i3*{stride}-{padding}+i5', # Hid+Khid
|
|
||||||
f'i4*{stride}-{padding}+i6', # Wid+KWid
|
|
||||||
])
|
|
||||||
ww = w.broadcast(xx.shape, [0,3,4])
|
|
||||||
yy = xx*ww
|
|
||||||
y = yy.sum([2,5,6]) # C, Kh, Kw
|
|
||||||
return y
|
|
||||||
|
|
||||||
@jt.var_scope('linear')
|
|
||||||
def linear(x, n):
|
|
||||||
w = jt.make_var([n, x.shape[-1]], init=lambda *a: init.invariant_uniform(*a))
|
|
||||||
w = w.reindex([w.shape[1], w.shape[0]],["i1","i0"])
|
|
||||||
bound = 1.0/math.sqrt(w.shape[0])
|
|
||||||
b = jt.make_var([n], init=lambda *a: init.uniform(*a,-bound,bound))
|
|
||||||
return jt.matmul(x, w) + b
|
|
||||||
|
|
||||||
def relu(x): return jt.maximum(x, 0)
|
def relu(x): return jt.maximum(x, 0)
|
||||||
def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale)
|
def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale)
|
||||||
def relu6(x): return jt.minimum(jt.maximum(x, 0), 6)
|
def relu6(x): return jt.minimum(jt.maximum(x, 0), 6)
|
||||||
|
|
|
@ -30,6 +30,7 @@ def check(jt_model, torch_model, shape, near_data):
|
||||||
data = list(range(8)) * int((shape[0] * shape[1] * shape[2] * shape[3]) / 8)
|
data = list(range(8)) * int((shape[0] * shape[1] * shape[2] * shape[3]) / 8)
|
||||||
random.shuffle(data)
|
random.shuffle(data)
|
||||||
x = jt.array(data).float32().reshape(shape)
|
x = jt.array(data).float32().reshape(shape)
|
||||||
|
x.data
|
||||||
else:
|
else:
|
||||||
x = jt.random(shape)
|
x = jt.random(shape)
|
||||||
y = jt_model(x)
|
y = jt_model(x)
|
||||||
|
|
|
@ -12,7 +12,7 @@ import numpy as np
|
||||||
from jittor import compile_extern
|
from jittor import compile_extern
|
||||||
from .test_log import find_log_with_re
|
from .test_log import find_log_with_re
|
||||||
import copy
|
import copy
|
||||||
if compile_extern.has_cuda:
|
if jt.has_cuda:
|
||||||
from jittor.compile_extern import cublas_ops, cudnn_ops, cub_ops
|
from jittor.compile_extern import cublas_ops, cudnn_ops, cub_ops
|
||||||
else:
|
else:
|
||||||
cublas_ops = cudnn_ops = cub_ops = None
|
cublas_ops = cudnn_ops = cub_ops = None
|
||||||
|
|
|
@ -11,7 +11,7 @@ import jittor as jt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from jittor import compile_extern
|
from jittor import compile_extern
|
||||||
from .test_log import find_log_with_re
|
from .test_log import find_log_with_re
|
||||||
if compile_extern.has_cuda:
|
if jt.has_cuda:
|
||||||
from jittor.compile_extern import cublas_ops, cudnn_ops, cub_ops
|
from jittor.compile_extern import cublas_ops, cudnn_ops, cub_ops
|
||||||
else:
|
else:
|
||||||
cublas_ops = cudnn_ops = cub_ops = None
|
cublas_ops = cudnn_ops = cub_ops = None
|
||||||
|
|
|
@ -30,7 +30,7 @@ class TestArray(unittest.TestCase):
|
||||||
a.data = jt.array([7,8,9])
|
a.data = jt.array([7,8,9])
|
||||||
assert (a.fetch_sync()==[7,8,9]).all()
|
assert (a.fetch_sync()==[7,8,9]).all()
|
||||||
|
|
||||||
@unittest.skipIf(not compile_extern.has_cuda, "Cuda not found")
|
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||||
@jt.flag_scope(use_cuda=1)
|
@jt.flag_scope(use_cuda=1)
|
||||||
def test_memcopy_overlap(self):
|
def test_memcopy_overlap(self):
|
||||||
import time
|
import time
|
||||||
|
@ -95,13 +95,13 @@ class TestArray(unittest.TestCase):
|
||||||
with jt.flag_scope(use_cuda=1):
|
with jt.flag_scope(use_cuda=1):
|
||||||
assert (jt.array([1,2,3]).reshape((1,3)).data==[1,2,3]).all()
|
assert (jt.array([1,2,3]).reshape((1,3)).data==[1,2,3]).all()
|
||||||
|
|
||||||
@unittest.skipIf(not compile_extern.has_cuda, "Cuda not found")
|
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||||
def test_array_dual(self):
|
def test_array_dual(self):
|
||||||
with jt.flag_scope(use_cuda=1):
|
with jt.flag_scope(use_cuda=1):
|
||||||
a = jt.array(np.float32([1,2,3]))
|
a = jt.array(np.float32([1,2,3]))
|
||||||
assert (a.data==[1,2,3]).all()
|
assert (a.data==[1,2,3]).all()
|
||||||
|
|
||||||
@unittest.skipIf(not compile_extern.has_cuda, "Cuda not found")
|
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||||
def test_array_migrate(self):
|
def test_array_migrate(self):
|
||||||
with jt.flag_scope(use_cuda=1):
|
with jt.flag_scope(use_cuda=1):
|
||||||
a = jt.array(np.float32([1,2,3]))
|
a = jt.array(np.float32([1,2,3]))
|
||||||
|
|
|
@ -18,7 +18,7 @@ class TestCompileOptions(unittest.TestCase):
|
||||||
assert a.compile_options=={"compile_shapes":1}
|
assert a.compile_options=={"compile_shapes":1}
|
||||||
b = a+a
|
b = a+a
|
||||||
assert b.compile_options=={}
|
assert b.compile_options=={}
|
||||||
with jt.var_scope(compile_options={"compile_shapes":1}):
|
with jt.flag_scope(compile_options={"compile_shapes":1}):
|
||||||
c = a+b
|
c = a+b
|
||||||
assert c.compile_options=={"compile_shapes":1}
|
assert c.compile_options=={"compile_shapes":1}
|
||||||
with jt.profile_scope() as report:
|
with jt.profile_scope() as report:
|
||||||
|
|
|
@ -14,7 +14,7 @@ from jittor import compile_extern
|
||||||
# TODO: compare with pytorch
|
# TODO: compare with pytorch
|
||||||
|
|
||||||
from jittor.test.test_log import find_log_with_re
|
from jittor.test.test_log import find_log_with_re
|
||||||
if compile_extern.has_cuda:
|
if jt.has_cuda:
|
||||||
from jittor.compile_extern import cublas_ops, cudnn_ops
|
from jittor.compile_extern import cublas_ops, cudnn_ops
|
||||||
else:
|
else:
|
||||||
cublas_ops = cudnn_ops = None
|
cublas_ops = cudnn_ops = None
|
||||||
|
@ -28,10 +28,7 @@ def conv_nchw(x, in_planes, out_planes, kernel_size, padding, stride = 1, dilati
|
||||||
|
|
||||||
assert C==_C
|
assert C==_C
|
||||||
if w_ is None:
|
if w_ is None:
|
||||||
if init_method==None:
|
assert 0
|
||||||
w = jt.make_var([Kc, _C, Kh, Kw], init=lambda *a: init.relu_invariant_gauss(*a, mode="fan_out"))
|
|
||||||
else:
|
|
||||||
w = jt.make_var([Kc, _C, Kh, Kw], init=init_method)
|
|
||||||
else:
|
else:
|
||||||
w = w_
|
w = w_
|
||||||
oh = (H-Kh*dilation+dilation-1+padding*2)//stride+1
|
oh = (H-Kh*dilation+dilation-1+padding*2)//stride+1
|
||||||
|
@ -56,10 +53,7 @@ def conv_nhwc(x, in_planes, out_planes, kernel_size, padding, stride = 1, dilati
|
||||||
|
|
||||||
assert C==_C
|
assert C==_C
|
||||||
if w_ is None:
|
if w_ is None:
|
||||||
if init_method==None:
|
assert 0
|
||||||
w = jt.make_var([Kc, _C, Kh, Kw], init=lambda *a: init.relu_invariant_gauss(*a, mode="fan_out"))
|
|
||||||
else:
|
|
||||||
w = jt.make_var([Kc, _C, Kh, Kw], init=init_method)
|
|
||||||
else:
|
else:
|
||||||
w = w_
|
w = w_
|
||||||
oh = (H-Kh*dilation+dilation-1+padding*2)//stride+1
|
oh = (H-Kh*dilation+dilation-1+padding*2)//stride+1
|
||||||
|
|
|
@ -7,7 +7,7 @@ import unittest
|
||||||
import jittor as jt
|
import jittor as jt
|
||||||
import os
|
import os
|
||||||
from jittor import compile_extern
|
from jittor import compile_extern
|
||||||
if compile_extern.has_cuda:
|
if jt.has_cuda:
|
||||||
from jittor.compile_extern import cublas_ops, cudnn_ops, cub_ops
|
from jittor.compile_extern import cublas_ops, cudnn_ops, cub_ops
|
||||||
else:
|
else:
|
||||||
cublas_ops = cudnn_ops = cub_ops = None
|
cublas_ops = cudnn_ops = cub_ops = None
|
||||||
|
|
|
@ -10,7 +10,7 @@ import numpy as np
|
||||||
from jittor import compile_extern
|
from jittor import compile_extern
|
||||||
|
|
||||||
from jittor.test.test_log import find_log_with_re
|
from jittor.test.test_log import find_log_with_re
|
||||||
if compile_extern.has_cuda:
|
if jt.has_cuda:
|
||||||
from jittor.compile_extern import cublas_ops, cudnn_ops
|
from jittor.compile_extern import cublas_ops, cudnn_ops
|
||||||
else:
|
else:
|
||||||
cublas_ops = cudnn_ops = None
|
cublas_ops = cudnn_ops = None
|
||||||
|
|
|
@ -12,7 +12,7 @@ import numpy as np
|
||||||
from jittor import compile_extern
|
from jittor import compile_extern
|
||||||
from .test_log import find_log_with_re
|
from .test_log import find_log_with_re
|
||||||
import copy
|
import copy
|
||||||
if compile_extern.has_cuda:
|
if jt.has_cuda:
|
||||||
from jittor.compile_extern import cutt_ops
|
from jittor.compile_extern import cutt_ops
|
||||||
else:
|
else:
|
||||||
cutt_ops = None
|
cutt_ops = None
|
||||||
|
|
|
@ -11,7 +11,7 @@ from .test_grad import ngrad
|
||||||
from itertools import permutations
|
from itertools import permutations
|
||||||
from jittor import compile_extern
|
from jittor import compile_extern
|
||||||
from .test_log import find_log_with_re
|
from .test_log import find_log_with_re
|
||||||
if compile_extern.has_cuda:
|
if jt.has_cuda:
|
||||||
from jittor.compile_extern import cutt_ops
|
from jittor.compile_extern import cutt_ops
|
||||||
else:
|
else:
|
||||||
cutt_ops = None
|
cutt_ops = None
|
||||||
|
|
|
@ -21,7 +21,7 @@ class TestFlags(unittest.TestCase):
|
||||||
|
|
||||||
def test_scope(self):
|
def test_scope(self):
|
||||||
prev = jt.flags.log_v
|
prev = jt.flags.log_v
|
||||||
with jt.var_scope(log_v=1):
|
with jt.flag_scope(log_v=1):
|
||||||
assert jt.flags.log_v == 1
|
assert jt.flags.log_v == 1
|
||||||
assert jt.flags.log_v == prev
|
assert jt.flags.log_v == prev
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ def performance_test_scope(warmup=0, rerun=0, **args):
|
||||||
jt.profiler.start(warmup, rerun)
|
jt.profiler.start(warmup, rerun)
|
||||||
report = []
|
report = []
|
||||||
try:
|
try:
|
||||||
with jt.var_scope(**args):
|
with jt.flag_scope(**args):
|
||||||
yield report
|
yield report
|
||||||
finally:
|
finally:
|
||||||
jt.profiler.stop()
|
jt.profiler.stop()
|
||||||
|
|
|
@ -14,7 +14,7 @@ from jittor import compile_extern
|
||||||
# TODO: compare with pytorch
|
# TODO: compare with pytorch
|
||||||
|
|
||||||
from jittor.test.test_log import find_log_with_re
|
from jittor.test.test_log import find_log_with_re
|
||||||
if compile_extern.has_cuda:
|
if jt.has_cuda:
|
||||||
from jittor.compile_extern import cublas_ops, cudnn_ops
|
from jittor.compile_extern import cublas_ops, cudnn_ops
|
||||||
else:
|
else:
|
||||||
cublas_ops = cudnn_ops = None
|
cublas_ops = cudnn_ops = None
|
||||||
|
@ -34,10 +34,7 @@ def conv_nchw(x, in_planes, out_planes, kernel_size, padding, stride=1, dilation
|
||||||
ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1
|
ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1
|
||||||
|
|
||||||
if w_ is None:
|
if w_ is None:
|
||||||
if init_method==None:
|
assert 0
|
||||||
w = jt.make_var([oc, C // G, Kh, Kw], init=lambda *a: init.relu_invariant_gauss(*a, mode="fan_out"))
|
|
||||||
else:
|
|
||||||
w = jt.make_var([oc, C // G, Kh, Kw], init=init_method)
|
|
||||||
else:
|
else:
|
||||||
w = w_
|
w = w_
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,7 @@ def find_log_with_re(logs, pattern=None, **args):
|
||||||
class TestLog(unittest.TestCase):
|
class TestLog(unittest.TestCase):
|
||||||
def test_log_capture(self):
|
def test_log_capture(self):
|
||||||
LOG.log_capture_start()
|
LOG.log_capture_start()
|
||||||
with jt.var_scope(log_v=1000, log_vprefix=""):
|
with jt.flag_scope(log_v=1000, log_vprefix=""):
|
||||||
LOG.v("1")
|
LOG.v("1")
|
||||||
LOG.vv("2")
|
LOG.vv("2")
|
||||||
LOG.i("3")
|
LOG.i("3")
|
||||||
|
|
|
@ -16,16 +16,6 @@ import numpy as np
|
||||||
def get_init_var(shape, dtype):
|
def get_init_var(shape, dtype):
|
||||||
return jt.random(shape, dtype)
|
return jt.random(shape, dtype)
|
||||||
|
|
||||||
def batch_norm(x):
|
|
||||||
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
|
|
||||||
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
|
|
||||||
norm_x = (x-xmean.broadcast_var(x))/(jt.sqrt(x2mean-xmean*xmean+jt.float32(1e-5)).broadcast_var(x))
|
|
||||||
w = jt.make_var([x.shape[1]], init=get_init_var)
|
|
||||||
b = jt.make_var([x.shape[1]], init=get_init_var)
|
|
||||||
w = w.broadcast([1, w.shape[0],1,1], [0,2,3])
|
|
||||||
b = b.broadcast([1, b.shape[0],1,1], [0,2,3])
|
|
||||||
return norm_x * w + b
|
|
||||||
|
|
||||||
def pool(x, size, op, padding, stride = 1): # TODO: stride, padding
|
def pool(x, size, op, padding, stride = 1): # TODO: stride, padding
|
||||||
N,C,H,W = x.shape
|
N,C,H,W = x.shape
|
||||||
h = (H+padding*2-size)//stride+1
|
h = (H+padding*2-size)//stride+1
|
||||||
|
@ -43,41 +33,25 @@ def pool(x, size, op, padding, stride = 1): # TODO: stride, padding
|
||||||
"i3", # Wid
|
"i3", # Wid
|
||||||
])
|
])
|
||||||
|
|
||||||
def conv(x, in_planes, out_planes, kernel_size, padding, stride = 1):
|
|
||||||
Kw = kernel_size
|
|
||||||
Kh = kernel_size
|
|
||||||
_C = in_planes
|
|
||||||
Kc = out_planes
|
|
||||||
N,C,H,W = x.shape
|
|
||||||
assert C==_C
|
|
||||||
w = jt.make_var([Kc, _C, Kh, Kw], init=get_init_var)
|
|
||||||
xx = x.reindex([N,Kc,C,(H+padding*2-kernel_size)//stride+1,(W+padding*2-kernel_size)//stride+1,Kh,Kw], [
|
|
||||||
'i0', # Nid
|
|
||||||
'i2', # Cid
|
|
||||||
f'i3*{stride}-{padding}+i5', # Hid+Khid
|
|
||||||
f'i4*{stride}-{padding}+i6', # Wid+KWid
|
|
||||||
])
|
|
||||||
ww = w.broadcast(xx.shape, [0,3,4])
|
|
||||||
yy = xx*ww
|
|
||||||
y = yy.sum([2,5,6]) # Kc, Kh, Kw
|
|
||||||
return y
|
|
||||||
|
|
||||||
def relu(x): return jt.maximum(x, jt.float32(0))
|
def relu(x): return jt.maximum(x, jt.float32(0))
|
||||||
|
|
||||||
@jt.var_scope('resnet_fake', unique=True)
|
def resnet_fake():
|
||||||
def resnet_fake(x):
|
from jittor import nn
|
||||||
x = conv(x, 3, 64, 7, 3, 2)
|
net = nn.Sequential(
|
||||||
x = batch_norm(x)
|
nn.Conv(3, 64, 7, 2, 3),
|
||||||
x = relu(x)
|
nn.BatchNorm(64),
|
||||||
x = pool(x, 3, "maximum", 1, 2)
|
nn.ReLU(),
|
||||||
return x
|
nn.Pool(3, 2, 1)
|
||||||
|
)
|
||||||
|
return net
|
||||||
|
|
||||||
class TestLongestDisFuse(unittest.TestCase):
|
class TestLongestDisFuse(unittest.TestCase):
|
||||||
|
|
||||||
def test_longest_dis_fuse(self):
|
def test_longest_dis_fuse(self):
|
||||||
x = jt.array(np.random.rand(1,3,224,224).astype(np.float32))
|
x = jt.array(np.random.rand(1,3,224,224).astype(np.float32))
|
||||||
loss = jt.sum(resnet_fake(x))
|
net = resnet_fake()
|
||||||
ps = jt.find_vars('resnet_fake')
|
loss = jt.sum(net(x))
|
||||||
|
ps = net.parameters()
|
||||||
gs = jt.grad(loss, ps)
|
gs = jt.grad(loss, ps)
|
||||||
jt.sync(gs)
|
jt.sync(gs)
|
||||||
# assert not alloc big tensor
|
# assert not alloc big tensor
|
||||||
|
|
|
@ -12,13 +12,6 @@ import numpy as np
|
||||||
from .test_log import find_log_with_re
|
from .test_log import find_log_with_re
|
||||||
f32 = jt.float32
|
f32 = jt.float32
|
||||||
from jittor import nn, Module
|
from jittor import nn, Module
|
||||||
|
|
||||||
@jt.var_scope('linear')
|
|
||||||
def linear(x, n):
|
|
||||||
w = jt.make_var([x.shape[-1], n], init=lambda *a:
|
|
||||||
(jt.random(*a)-f32(0.5)) / f32(x.shape[-1])**f32(0.5))
|
|
||||||
b = jt.make_var([n], init=lambda *a: jt.random(*a)-f32(0.5))
|
|
||||||
return jt.nn.matmul(x, w) + b
|
|
||||||
|
|
||||||
def relu(x): return jt.maximum(x, f32(0))
|
def relu(x): return jt.maximum(x, f32(0))
|
||||||
|
|
||||||
|
|
|
@ -87,8 +87,7 @@ class TestSingleArray(unittest.TestCase):
|
||||||
jt.set_seed(3)
|
jt.set_seed(3)
|
||||||
|
|
||||||
x = f32(np.random.rand(1, 1))
|
x = f32(np.random.rand(1, 1))
|
||||||
w = jt.make_var([x.shape[-1], 10], init=lambda *a:
|
w = (jt.random([x.shape[-1], 10])-f32(0.5)) / f32(x.shape[-1])**f32(0.5)
|
||||||
(jt.random(*a)-f32(0.5)) / f32(x.shape[-1])**f32(0.5))
|
|
||||||
jt.nn.matmul(x, w).data
|
jt.nn.matmul(x, w).data
|
||||||
|
|
||||||
def test4(self):
|
def test4(self):
|
||||||
|
|
|
@ -59,7 +59,7 @@ class TestMklConvOp(unittest.TestCase):
|
||||||
|
|
||||||
a_jt = jt.array(a)
|
a_jt = jt.array(a)
|
||||||
b_jt = jt.array(b)
|
b_jt = jt.array(b)
|
||||||
with jt.var_scope(enable_tuner=0,compile_options={"test_mkl_conv":1}):
|
with jt.flag_scope(enable_tuner=0,compile_options={"test_mkl_conv":1}):
|
||||||
c_jt = conv(a_jt, b_jt, 3, 2).data
|
c_jt = conv(a_jt, b_jt, 3, 2).data
|
||||||
with jt.log_capture_scope(
|
with jt.log_capture_scope(
|
||||||
enable_tuner=1,
|
enable_tuner=1,
|
||||||
|
@ -84,7 +84,7 @@ class TestMklConvOp(unittest.TestCase):
|
||||||
|
|
||||||
a_jt = jt.array(a)
|
a_jt = jt.array(a)
|
||||||
b_jt = jt.array(b)
|
b_jt = jt.array(b)
|
||||||
with jt.var_scope(enable_tuner=0,
|
with jt.flag_scope(enable_tuner=0,
|
||||||
compile_options={"test_mkl_conv":uid[0]}):
|
compile_options={"test_mkl_conv":uid[0]}):
|
||||||
c_jt = conv_nhwc_hwio(a_jt, b_jt, stride, pad).data
|
c_jt = conv_nhwc_hwio(a_jt, b_jt, stride, pad).data
|
||||||
with jt.log_capture_scope(
|
with jt.log_capture_scope(
|
||||||
|
@ -118,7 +118,7 @@ class TestMklConvOp(unittest.TestCase):
|
||||||
a_jt = jt.array(a)
|
a_jt = jt.array(a)
|
||||||
b_jt = jt.array(b)
|
b_jt = jt.array(b)
|
||||||
|
|
||||||
with jt.var_scope(
|
with jt.flag_scope(
|
||||||
enable_tuner=0,
|
enable_tuner=0,
|
||||||
compile_options={"test_mkl_conv":1}
|
compile_options={"test_mkl_conv":1}
|
||||||
):
|
):
|
||||||
|
@ -164,7 +164,7 @@ class TestMklConvOp(unittest.TestCase):
|
||||||
a_jt = jt.array(a)
|
a_jt = jt.array(a)
|
||||||
b_jt = jt.array(b)
|
b_jt = jt.array(b)
|
||||||
|
|
||||||
with jt.var_scope(
|
with jt.flag_scope(
|
||||||
enable_tuner=0,
|
enable_tuner=0,
|
||||||
compile_options={"test_mkl_conv":1}
|
compile_options={"test_mkl_conv":1}
|
||||||
):
|
):
|
||||||
|
|
|
@ -15,6 +15,8 @@ import numpy as np
|
||||||
from jittor.test.test_mpi import run_mpi_test
|
from jittor.test.test_mpi import run_mpi_test
|
||||||
|
|
||||||
mpi = jt.compile_extern.mpi
|
mpi = jt.compile_extern.mpi
|
||||||
|
if mpi:
|
||||||
|
n = mpi.world_size()
|
||||||
|
|
||||||
class FakeMpiBatchNorm(nn.Module):
|
class FakeMpiBatchNorm(nn.Module):
|
||||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True):
|
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True):
|
||||||
|
@ -57,7 +59,8 @@ class TestMpiBatchnorm(unittest.TestCase):
|
||||||
mpi = jt.compile_extern.mpi
|
mpi = jt.compile_extern.mpi
|
||||||
data = np.random.rand(30,3,10,10).astype("float32")
|
data = np.random.rand(30,3,10,10).astype("float32")
|
||||||
x1 = jt.array(data)
|
x1 = jt.array(data)
|
||||||
x2 = jt.array(data[mpi.world_rank()*10:(mpi.world_rank()+1)*10,...])
|
stride = 30//n
|
||||||
|
x2 = jt.array(data[mpi.world_rank()*stride:(mpi.world_rank()+1)*stride,...])
|
||||||
|
|
||||||
bn1 = nn.BatchNorm(3, sync=False)
|
bn1 = nn.BatchNorm(3, sync=False)
|
||||||
bn2 = nn.BatchNorm(3, sync=True)
|
bn2 = nn.BatchNorm(3, sync=True)
|
||||||
|
@ -75,7 +78,8 @@ class TestMpiBatchnorm(unittest.TestCase):
|
||||||
mpi = jt.compile_extern.mpi
|
mpi = jt.compile_extern.mpi
|
||||||
data = np.random.rand(30,3,10,10).astype("float32")
|
data = np.random.rand(30,3,10,10).astype("float32")
|
||||||
global_x = jt.array(data)
|
global_x = jt.array(data)
|
||||||
x = jt.array(data[mpi.world_rank()*10:(mpi.world_rank()+1)*10,...])
|
stride = 30//n
|
||||||
|
x = jt.array(data[mpi.world_rank()*stride:(mpi.world_rank()+1)*stride,...])
|
||||||
|
|
||||||
bn1 = nn.BatchNorm(3, sync=True)
|
bn1 = nn.BatchNorm(3, sync=True)
|
||||||
bn2 = FakeMpiBatchNorm(3)
|
bn2 = FakeMpiBatchNorm(3)
|
||||||
|
@ -98,7 +102,7 @@ class TestMpiBatchnorm(unittest.TestCase):
|
||||||
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||||
class TestMpiBatchnormEntry(unittest.TestCase):
|
class TestMpiBatchnormEntry(unittest.TestCase):
|
||||||
def test(self):
|
def test(self):
|
||||||
run_mpi_test(3, "test_mpi_batchnorm")
|
run_mpi_test(2, "test_mpi_batchnorm")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -13,6 +13,8 @@ import numpy as np
|
||||||
from jittor.test.test_mpi import run_mpi_test
|
from jittor.test.test_mpi import run_mpi_test
|
||||||
|
|
||||||
mpi = jt.compile_extern.mpi
|
mpi = jt.compile_extern.mpi
|
||||||
|
if mpi:
|
||||||
|
n = mpi.world_size()
|
||||||
|
|
||||||
@unittest.skipIf(mpi is None, "no inside mpirun")
|
@unittest.skipIf(mpi is None, "no inside mpirun")
|
||||||
class TestMpiOps(unittest.TestCase):
|
class TestMpiOps(unittest.TestCase):
|
||||||
|
@ -24,9 +26,9 @@ class TestMpiOps(unittest.TestCase):
|
||||||
def test_all_reduce(self):
|
def test_all_reduce(self):
|
||||||
x = jt.random([5, 5])
|
x = jt.random([5, 5])
|
||||||
y = x.mpi_all_reduce()
|
y = x.mpi_all_reduce()
|
||||||
assert np.allclose(y.data, (x*3).data)
|
assert np.allclose(y.data, (x*n).data)
|
||||||
g = jt.grad(y,x)
|
g = jt.grad(y,x)
|
||||||
assert np.allclose(g.data, np.ones([5,5])*3)
|
assert np.allclose(g.data, np.ones([5,5])*n)
|
||||||
|
|
||||||
def test_all_reduce_mean(self):
|
def test_all_reduce_mean(self):
|
||||||
x = jt.random([5, 5])
|
x = jt.random([5, 5])
|
||||||
|
@ -45,7 +47,7 @@ class TestMpiOps(unittest.TestCase):
|
||||||
assert np.allclose(y.data, data.data)
|
assert np.allclose(y.data, data.data)
|
||||||
g = jt.grad(y,x)
|
g = jt.grad(y,x)
|
||||||
if mpi.world_rank() == 0:
|
if mpi.world_rank() == 0:
|
||||||
assert np.allclose(g.data, np.ones([5,5])*3)
|
assert np.allclose(g.data, np.ones([5,5])*n)
|
||||||
else:
|
else:
|
||||||
assert np.allclose(g.data, np.zeros([5,5]))
|
assert np.allclose(g.data, np.zeros([5,5]))
|
||||||
|
|
||||||
|
@ -54,7 +56,7 @@ class TestMpiOps(unittest.TestCase):
|
||||||
y = x.mpi_reduce(root=0)
|
y = x.mpi_reduce(root=0)
|
||||||
y.sync()
|
y.sync()
|
||||||
if mpi.world_rank() == 0:
|
if mpi.world_rank() == 0:
|
||||||
assert np.allclose(y.data, (x*3).data)
|
assert np.allclose(y.data, (x*n).data)
|
||||||
else:
|
else:
|
||||||
assert np.allclose(y.data, np.zeros([5,5]))
|
assert np.allclose(y.data, np.zeros([5,5]))
|
||||||
g = jt.grad(y,x)
|
g = jt.grad(y,x)
|
||||||
|
@ -64,7 +66,7 @@ class TestMpiOps(unittest.TestCase):
|
||||||
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||||
class TestMpiOpsEntry(unittest.TestCase):
|
class TestMpiOpsEntry(unittest.TestCase):
|
||||||
def test(self):
|
def test(self):
|
||||||
run_mpi_test(3, "test_mpi_op")
|
run_mpi_test(2, "test_mpi_op")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
|
@ -1,105 +0,0 @@
|
||||||
# ***************************************************************
|
|
||||||
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
|
||||||
# This file is subject to the terms and conditions defined in
|
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
|
||||||
# ***************************************************************
|
|
||||||
import unittest
|
|
||||||
import jittor as jt
|
|
||||||
import numpy as np
|
|
||||||
from .test_core import expect_error
|
|
||||||
ops = jt.ops
|
|
||||||
|
|
||||||
@jt.var_scope('linear')
|
|
||||||
def linear(x, n):
|
|
||||||
w = jt.make_var([x.shape[-1], n], init=ops.random)
|
|
||||||
return jt.matmul(x, w)
|
|
||||||
|
|
||||||
@jt.var_scope('model', unique=True)
|
|
||||||
def model(x):
|
|
||||||
x = linear(x, 10)
|
|
||||||
# x = relu(x)
|
|
||||||
x = linear(x, 10)
|
|
||||||
# x = relu(x)
|
|
||||||
x = linear(x, 1)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class TestScope(unittest.TestCase):
|
|
||||||
def test_name(self):
|
|
||||||
jt.clean()
|
|
||||||
@jt.var_scope('model', unique=True)
|
|
||||||
def model():
|
|
||||||
with jt.var_scope('a'):
|
|
||||||
assert jt.current_scope.full_name == "model/a_0/"
|
|
||||||
with jt.var_scope('b'):
|
|
||||||
with jt.var_scope('b'):
|
|
||||||
assert jt.current_scope.full_name == "model/b_0/b_0/"
|
|
||||||
with jt.var_scope('c'):
|
|
||||||
assert jt.current_scope.full_name == "model/c_0/"
|
|
||||||
model()
|
|
||||||
model()
|
|
||||||
model()
|
|
||||||
jt.clean()
|
|
||||||
|
|
||||||
def test_var(self):
|
|
||||||
jt.clean()
|
|
||||||
for i in range(2):
|
|
||||||
x = jt.array([[1]])
|
|
||||||
y = model(x)
|
|
||||||
params = jt.find_vars()
|
|
||||||
assert len(params) == 3
|
|
||||||
names = [ p.name() for p in params ]
|
|
||||||
assert names == [
|
|
||||||
"model/linear_0/var_0",
|
|
||||||
"model/linear_1/var_0",
|
|
||||||
"model/linear_2/var_0",
|
|
||||||
], str(names)
|
|
||||||
jt.find_var("model/linear_0/var_0")
|
|
||||||
expect_error(lambda: jt.find_var("model/linear_0"))
|
|
||||||
expect_error(lambda: jt.find_var("model/linear"))
|
|
||||||
assert len(jt.find_vars("model/linear_0/var_0")) == 1
|
|
||||||
assert len(jt.find_vars("model/linear_0/")) == 1
|
|
||||||
assert len(jt.find_vars("model/")) == 3
|
|
||||||
jt.clean()
|
|
||||||
|
|
||||||
def test_get_var_unique(self):
|
|
||||||
jt.clean()
|
|
||||||
x = jt.make_var([1], init=ops.random)
|
|
||||||
y = jt.make_var([1], init=ops.random)
|
|
||||||
z = jt.make_var([1], init=ops.random)
|
|
||||||
assert x.name() == "var_0"
|
|
||||||
assert y.name() == "var_1", y.name()
|
|
||||||
assert z.name() == "var_2"
|
|
||||||
x = jt.make_var([1], name="x", unique=True, init=ops.random)
|
|
||||||
y = jt.make_var([1], name="y", unique=True, init=ops.random)
|
|
||||||
z = jt.make_var([1], name="z", unique=True, init=ops.random)
|
|
||||||
assert x.name() == "x"
|
|
||||||
assert y.name() == "y"
|
|
||||||
assert z.name() == "z"
|
|
||||||
expect_error(lambda: jt.make_var([2], name="x", unique=True, init=ops.random))
|
|
||||||
jt.clean()
|
|
||||||
|
|
||||||
def test_record_scope(self):
|
|
||||||
jt.clean()
|
|
||||||
@jt.var_scope("func")
|
|
||||||
def func(a):
|
|
||||||
b = a+1
|
|
||||||
jt.record_in_scope(b, "b")
|
|
||||||
c = b*2
|
|
||||||
return c
|
|
||||||
a = jt.array([1,2,3])
|
|
||||||
func(a)
|
|
||||||
assert np.allclose(jt.find_record("func_0/output").data, (a.data+1)*2)
|
|
||||||
assert np.allclose(jt.find_scope("func_0").records["output"].data, (a.data+1)*2)
|
|
||||||
recs = jt.find_records()
|
|
||||||
rec_names = [ r.name() for r in recs ]
|
|
||||||
assert len(recs)==2 and rec_names==["func_0/b","func_0/output"]
|
|
||||||
|
|
||||||
def test_get_var_init(self):
|
|
||||||
jt.clean()
|
|
||||||
assert (jt.make_var(init=[1,2,3]).data == [1,2,3]).all()
|
|
||||||
assert (jt.make_var(shape=[3], init=np.zeros).data == [0,0,0]).all()
|
|
||||||
assert (jt.make_var(init=jt.array([1,2,3]) == [1,2,3]).data).all()
|
|
||||||
jt.clean()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
|
@ -12,7 +12,7 @@ class TestTracer(unittest.TestCase):
|
||||||
|
|
||||||
# force use addr2line
|
# force use addr2line
|
||||||
jt.flags.gdb_path = ""
|
jt.flags.gdb_path = ""
|
||||||
with jt.var_scope(gdb_path=""):
|
with jt.flag_scope(gdb_path=""):
|
||||||
jt.print_trace()
|
jt.print_trace()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
# ***************************************************************
|
||||||
|
# Copyright (c) 2020 Jittor. Authors:
|
||||||
|
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
||||||
|
# Dun Liang <randonlang@gmail.com>.
|
||||||
|
# All Rights Reserved.
|
||||||
|
# This file is subject to the terms and conditions defined in
|
||||||
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
# ***************************************************************
|
||||||
|
import jittor as jt
|
||||||
|
import os
|
||||||
|
from six.moves import urllib
|
||||||
|
import hashlib
|
||||||
|
from tqdm import tqdm
|
||||||
|
from .. import lock
|
||||||
|
|
||||||
|
def ensure_dir(dir_path):
|
||||||
|
if not os.path.isdir(dir_path):
|
||||||
|
os.makedirs(dir_path)
|
||||||
|
|
||||||
|
def _progress():
|
||||||
|
pbar = tqdm(total=None)
|
||||||
|
|
||||||
|
def bar_update(block_num, block_size, total_size):
|
||||||
|
""" reporthook
|
||||||
|
@block_num: the num of downloaded data block
|
||||||
|
@block_size: the size of data block
|
||||||
|
@total_size: the total size of remote file
|
||||||
|
"""
|
||||||
|
if pbar.total is None and total_size:
|
||||||
|
pbar.total = total_size
|
||||||
|
progress_bytes = block_num * block_size
|
||||||
|
pbar.update(progress_bytes - pbar.n)
|
||||||
|
|
||||||
|
return bar_update
|
||||||
|
|
||||||
|
@lock.lock_scope()
|
||||||
|
def download_url_to_local(url, filename, root_folder, md5):
|
||||||
|
ensure_dir(root_folder)
|
||||||
|
file_path = os.path.join(root_folder, filename)
|
||||||
|
if check_file_exist(file_path, md5):
|
||||||
|
print("Data file has been downloaded and verified")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
print('Downloading ' + url + ' to ' + file_path)
|
||||||
|
urllib.request.urlretrieve(
|
||||||
|
url, file_path,
|
||||||
|
reporthook=_progress()
|
||||||
|
)
|
||||||
|
except(urllib.error.URLError, IOError) as e:
|
||||||
|
raise e
|
||||||
|
if not check_file_exist(file_path, md5):
|
||||||
|
raise RuntimeError("File downloads failed.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def check_file_exist(file_path, md5):
|
||||||
|
if not os.path.isfile(file_path):
|
||||||
|
return False
|
||||||
|
if md5 is None:
|
||||||
|
return True
|
||||||
|
return check_md5(file_path, md5)
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_md5(file_path, chunk_size=1024 * 1024):
|
||||||
|
md5 = hashlib.md5()
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
for chunk in iter(lambda: f.read(chunk_size), b''):
|
||||||
|
md5.update(chunk)
|
||||||
|
return md5.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def check_md5(file_path, md5, **kwargs):
|
||||||
|
return md5 == calculate_md5(file_path, **kwargs)
|
|
@ -23,7 +23,7 @@ int current_seed;
|
||||||
|
|
||||||
static void init_cuda_devices() {
|
static void init_cuda_devices() {
|
||||||
#ifdef HAS_CUDA
|
#ifdef HAS_CUDA
|
||||||
int count;
|
int count=0;
|
||||||
cudaGetDeviceCount(&count);
|
cudaGetDeviceCount(&count);
|
||||||
for (int i=0; i<count; i++) {
|
for (int i=0; i<count; i++) {
|
||||||
cudaDeviceProp devProp;
|
cudaDeviceProp devProp;
|
||||||
|
|
|
@ -91,7 +91,14 @@ void Op::do_jit_prepare() {
|
||||||
memcheck_all_exist();
|
memcheck_all_exist();
|
||||||
jk << name();
|
jk << name();
|
||||||
jit_prepare();
|
jit_prepare();
|
||||||
if (!jk.empty()) {
|
if (jk.empty()) {
|
||||||
|
// not a jit op
|
||||||
|
bool has_cuda = flags.get(NodeFlags::_cuda);
|
||||||
|
bool has_cpu = flags.get(NodeFlags::_cpu);
|
||||||
|
CHECK(has_cuda || has_cpu);
|
||||||
|
if (has_cuda && has_cpu && !use_cuda)
|
||||||
|
flags.set(NodeFlags::_cuda, 0);
|
||||||
|
} else {
|
||||||
// check use int64_t as index_t if array is too big
|
// check use int64_t as index_t if array is too big
|
||||||
int in_id=0, out_id=0;
|
int in_id=0, out_id=0;
|
||||||
bool use_int64_t = false;
|
bool use_int64_t = false;
|
||||||
|
|
Loading…
Reference in New Issue