add sphinx doc

This commit is contained in:
Dun Liang 2020-05-21 20:32:01 +08:00
parent e5265505df
commit 126f92a6f3
54 changed files with 498 additions and 588 deletions

1
.gitignore vendored
View File

@ -22,3 +22,4 @@ venv/
!README.cn.md
python/jittor.egg-info
dist/
!doc/source/*

20
doc/Makefile Normal file
View File

@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

18
doc/build_doc.sh Normal file
View File

@ -0,0 +1,18 @@
# sudo python3.7 -m pip install \
# recommonmark \
# sphinx sphinx-autobuild sphinx_rtd_theme \
# sphinx-autobuild \
# --timeout 100
bpath=$(readlink -f "${BASH_SOURCE[0]}")
bpath=$(dirname "${bpath}")
jittor_path=$(readlink -f "${bpath}/..")
echo "[doc path] $bpath"
echo "[jittor path] $jittor_path"
export PYTHONPATH=$jittor_path/python
cd $bpath
sphinx-autobuild -b html source build

1
doc/source/README.cn.md Symbolic link
View File

@ -0,0 +1 @@
../../README.cn.md

98
doc/source/conf.py Normal file
View File

@ -0,0 +1,98 @@
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
jittor_path = os.path.abspath('../../python')
print(f"[jittor_path] {jittor_path}")
sys.path.insert(0, jittor_path)
import jittor
# -- Project information -----------------------------------------------------
project = 'Jittor'
copyright = '2020, Jittor'
author = 'Jittor'
# The full version, including alpha/beta/rc tags
release = '1.1.3.1'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = 'zh_CN'
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'recommonmark',
'sphinx.ext.autodoc',
# Auto-generate section labels.
'sphinx.ext.autosectionlabel',
'sphinx.ext.viewcode',
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'alabaster'
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
import sphinx_rtd_theme
html_theme = "sphinx_rtd_theme"
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
source_suffix = {
'.rst': 'restructuredtext',
'.txt': 'markdown',
'.md': 'markdown',
}
import recommonmark
from recommonmark.transform import AutoStructify
# At the bottom of conf.py
def setup(app):
app.add_config_value('recommonmark_config', {
# 'url_resolver': lambda url: github_doc_root + url,
'auto_toc_tree_section': 'Contents',
}, True)
app.add_transform(AutoStructify)
# Prefix document path to section labels, otherwise autogenerated labels would look like 'heading'
# rather than 'path/to/file:heading'
autosectionlabel_prefix_document = True

40
doc/source/index.rst Normal file
View File

@ -0,0 +1,40 @@
.. Jittor documentation master file, created by
sphinx-quickstart on Mon May 18 23:05:53 2020.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
欢迎查阅计图文档
==================================
.. toctree::
:maxdepth: 2
:caption: 内容一览:
README.cn.md
.. toctree::
:maxdepth: 2
:caption: 模块API:
jittor
jittor.nn
jittor.models
jittor.optim
jittor.init
jittor.contrib
jittor.dataset
jittor.transform
.. toctree::
:maxdepth: 1
:caption: 其他:
todo
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`

View File

@ -0,0 +1,10 @@
jittor.contrib
=====================
这里是Jittor的贡献代码模块模块的API文档此模块的代码可能还没有完全成熟我们将在后续迭代开发中继续完善您可以通过`from jittor import contrib`来获取该模块。
```eval_rst
.. automodule:: jittor.contrib
:members:
:undoc-members:
```

View File

@ -0,0 +1,11 @@
jittor.dataset
=====================
这里是Jittor的数据集模块的API文档您可以通过`from jittor import dataset`来获取该模块。
```eval_rst
.. automodule:: jittor.dataset
:imported-members:
:members:
:undoc-members:
```

10
doc/source/jittor.init.md Normal file
View File

@ -0,0 +1,10 @@
jittor.init
=====================
这里是Jittor的参数初始化模块的API文档您可以通过`from jittor import init`来获取该模块。
```eval_rst
.. automodule:: jittor.init
:members:
:undoc-members:
```

44
doc/source/jittor.md Normal file
View File

@ -0,0 +1,44 @@
jittor
=====================
## jittor
这里是Jittor主模块的API文档您可以通过`import jittor`来获取该模块。
```eval_rst
.. automodule:: jittor
:members:
:undoc-members:
```
## jittor.core
以下为Jittor的内核API内核API可以通过`jittor.core.XXX`或者`jittor.XXX`直接访问。
```eval_rst
.. automodule:: jittor_core
:imported-members:
:members:
:undoc-members:
```
## jittor.ops
这里是Jittor的基础算子模块的API文档该API可以通过`jittor.ops.XXX`或者`jittor.XXX`直接访问。
```eval_rst
.. automodule:: jittor_core.ops
:members:
:undoc-members:
```
## jittor.Var
这里是Jittor的基础变量类的API文档。该API可以通过`my_jittor_var.XXX`直接访问。
```eval_rst
.. automodule:: jittor_core.Var
:members:
:undoc-members:
```

View File

@ -0,0 +1,13 @@
jittor.models
=====================
这里是Jittor的骨干网络模块的API文档您可以通过`from jittor import models`来获取该模块。
```eval_rst
.. automodule:: jittor.models
:members:
:imported-members:
:undoc-members:
```

24
doc/source/jittor.nn.md Normal file
View File

@ -0,0 +1,24 @@
jittor.nn
=====================
这里是Jittor的神经网络模块的API文档您可以通过`from jittor import nn`来获取该模块。
```eval_rst
.. automodule:: jittor.nn
:members:
:undoc-members:
.. automodule:: jittor.nn
:imported-members:
:members: Pool, pool, AdaptiveAvgPool2d
:undoc-members:
.. autoclass:: jittor.nn.ReLU
:members:
.. autoclass:: jittor.nn.ReLU6
:members:
.. autoclass:: jittor.nn.LeakyReLU
:members:
.. autoclass:: jittor.nn.Softmax
:members:
```

View File

@ -0,0 +1,10 @@
jittor.optim
=====================
这里是Jittor的优化器模块的API文档您可以通过`from jittor import optim`来获取该模块。
```eval_rst
.. automodule:: jittor.optim
:members:
:undoc-members:
```

View File

@ -0,0 +1,10 @@
jittor.transform
=====================
这里是Jittor的 数据变换 模块的API文档您可以通过`from jittor import transform`来获取该模块。
```eval_rst
.. automodule:: jittor.transform
:members:
:undoc-members:
```

12
doc/source/todo.md Normal file
View File

@ -0,0 +1,12 @@
TODO
=====================
## 文档相关
* 文档语法规范
* 文档加上教程链接
* MPI接口文档
* 文档自动更新
* 首页到文档的链接
* 模型库的文档GANsegmentationdetection...
* 文档补全重要的类加上使用example

View File

@ -17,5 +17,6 @@ namespace jittor {
extern ncclComm_t comm;
extern ncclUniqueId id;
extern int nccl_device_id;
} // jittor

View File

@ -6,6 +6,7 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "misc/cuda_flags.h"
#include "nccl_warper.h"
#include "event_queue.h"
@ -17,25 +18,33 @@ namespace jittor {
ncclComm_t comm;
ncclUniqueId id;
int nccl_device_id = 0;
struct nccl_initer {
nccl_initer() {
if (!get_device_count()) return;
int device_count = get_device_count();
if (!device_count) return;
if (!inside_mpi) return;
if (mpi_world_rank == 0)
checkCudaErrors(ncclGetUniqueId(&id));
MPI_CHECK(MPI_Bcast((void *)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD));
LOGv << "NCCL init in device" << mpi_local_rank;
checkCudaErrors(cudaSetDevice(mpi_local_rank));
if (mpi_local_rank >= device_count)
LOGf << "mpi_local_rank(">>mpi_local_rank>>") is larger than device_count("
>>device_count>>")";
nccl_device_id = mpi_local_rank; // % device_count;
LOGv << "NCCL init in device" << nccl_device_id << "local_rank" << mpi_local_rank;
checkCudaErrors(cudaSetDevice(nccl_device_id));
event_queue.run_sync([]() {
checkCudaErrors(cudaSetDevice(mpi_local_rank));
checkCudaErrors(cudaSetDevice(nccl_device_id));
});
checkCudaErrors(ncclCommInitRank(&comm, mpi_world_size, id, mpi_world_rank));
}
~nccl_initer() {
if (!get_device_count()) return;
if (!inside_mpi) return;
checkCudaErrors(ncclCommDestroy(comm));
}

View File

@ -27,6 +27,7 @@ namespace jittor {
extern int mpi_world_size;
extern int mpi_world_rank;
extern int mpi_local_rank;
extern bool inside_mpi;
// @pyjt(world_size)
int _mpi_world_size();

View File

@ -30,6 +30,7 @@ namespace jittor {
int mpi_world_size = 1;
int mpi_world_rank = 0;
int mpi_local_rank = 0;
bool inside_mpi = false;
int _mpi_world_size() {
return mpi_world_size;
@ -73,6 +74,8 @@ static void getHostName(char* hostname, int maxlen) {
struct mpi_initer {
mpi_initer() {
inside_mpi = !!getenv("OMPI_COMM_WORLD_SIZE");
if (!inside_mpi) return;
LOGvv << "MPI init...";
MPI_CHECK(MPI_Init(NULL, NULL));
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size));
@ -95,6 +98,7 @@ mpi_initer() {
}
~mpi_initer() {
if (!inside_mpi) return;
MPI_CHECK(MPI_Finalize());
}

View File

@ -17,6 +17,8 @@ with lock.lock_scope():
from jittor_core.ops import *
from . import compile_extern
from .compile_extern import mkl_ops, mpi, mpi_ops
if core.get_device_count() == 0:
has_cuda = compile_extern.has_cuda = compiler.has_cuda = False
if has_cuda:
from .compile_extern import cudnn, curand, cublas
@ -29,92 +31,6 @@ import pickle
import sys
import traceback
def dfs(scope, vars):
for v in scope.children.values():
if type(v) == Scope:
dfs(v, vars)
else:
vars.append(v)
def dfs_records(scope, records):
for v in scope.children.values():
if type(v) == Scope:
dfs_records(v, records)
for v in scope.records.values():
records.append(v)
class Scope:
def __init__(self, parent=None, name=None):
self.children = OrderedDict()
self.index = {}
self.records = OrderedDict()
if name == None:
self.name = self.full_name = ""
else:
self.name = name
self.full_name = parent.full_name + name + "/"
def get_scope(self, name, unique=True):
if not unique:
index = self.index.get(name, 0)
self.index[name] = index+1
name = name + f'_{index}'
if name not in self.children:
sub_scope = Scope(self, name)
self.children[name] = sub_scope
else:
sub_scope = self.children[name]
assert type(sub_scope) == Scope, f"Name {name} is a Var: {sub_scope}"
return sub_scope
def make_var(self, shape, dtype, init, name, unique):
if not unique:
index = self.index.get(name, 0)
self.index[name] = index+1
name = name + f'_{index}'
if name in self.children:
var = self.children[name]
assert type(var) == core.Var, f"Name {name} exist: {var}"
assert (shape is None or var.shape == shape) and var.dtype == dtype, f"Shape or dtype not match {var} != {dtype}{shape}"
return var
else:
full_name = self.full_name + name
if type(init) != core.Var:
if callable(init):
var = init(shape, dtype)
if type(var) != core.Var:
var = array(var)
else:
assert init != None
var = array(init)
else:
var = init
var.stop_fuse()
self.children[name] = var
var.name(full_name)
return var
def clean_index(self): self.index.clear()
def clean(self):
self.children.clear()
self.records.clear()
self.index.clear()
current_scope = Scope()
root_scope = current_scope
class _call_record_scope:
def __enter__(self): pass
def __exit__(self, *exc): pass
def __call__(self, func):
def inner(*args, **kw):
with self:
ret = func(*args, **kw)
record_in_scope(ret, "output")
return ret
return inner
class _call_no_record_scope:
def __enter__(self): pass
def __exit__(self, *exc): pass
@ -143,30 +59,6 @@ class flag_scope(_call_no_record_scope):
for k,v in self.flags_bk.items():
setattr(flags, k, v)
class var_scope(_call_record_scope):
def __init__(self, name="scope", unique=False, **jt_flags):
self.fs = flag_scope(**jt_flags)
self.name = name
self.unique = unique
def __enter__(self):
global current_scope
self.prev = current_scope
try:
current_scope = current_scope.get_scope(self.name, self.unique)
current_scope.clean_index()
self.fs.__enter__()
except:
current_scope = self.prev
del self.prev
raise
def __exit__(self, *exc):
self.fs.__exit__(*exc)
global current_scope
current_scope = self.prev
del self.prev
single_log_capture = None
class log_capture_scope(_call_no_record_scope):
@ -229,75 +121,7 @@ class profile_scope(_call_no_record_scope):
profiler.stop()
self.report.extend(profiler.report())
def make_var(shape=None, dtype="float32", init=None, name='var', unique=False):
return current_scope.make_var(shape, dtype, init, name, unique)
def find_vars(path=None):
scope = current_scope
if path is not None:
assert isinstance(path, str)
ns = path.split("/")
if ns[-1] == "":
ns.pop()
for n in ns: scope = scope.children[n]
if not isinstance(scope, Scope):
return [scope]
vars = []
dfs(scope, vars)
return vars
def find_var(path):
scope = current_scope
if path is not None:
assert isinstance(path, str)
ns = path.split("/")
for n in ns: scope = scope.children[n]
assert not isinstance(scope, Scope)
return scope
def find_records(path=None):
scope = current_scope
if path is not None:
assert isinstance(path, str)
ns = path.split("/")
if ns[-1] == "":
ns.pop()
for n in ns: scope = scope.children[n]
assert isinstance(scope, Scope)
records = []
dfs_records(scope, records)
return records
def find_record(path):
scope = current_scope
assert isinstance(path, str)
ns = path.split("/")
for n in ns[:-1]: scope = scope.children[n]
assert isinstance(scope, Scope)
return scope.records[ns[-1]]
def find_scope(path):
scope = current_scope
if path is not None:
assert isinstance(path, str)
ns = path.split("/")
if ns[-1] == "":
ns.pop()
for n in ns: scope = scope.children[n]
assert isinstance(scope, Scope)
return scope
def record_in_scope(self, name):
current_scope.records[name] = self
if isinstance(self, Var):
full_name = current_scope.full_name + name
self.name(full_name)
return self
Var.record_in_scope = record_in_scope
def clean():
current_scope.clean()
import gc
# make sure python do a full collection
gc.collect()
@ -411,11 +235,8 @@ def squeeze(x, dim):
Var.squeeze = squeeze
def clamp(x, min_v, max_v):
# TODO: change to x.maximum(min_v).minimum(max_v)
assert min_v <= max_v
min_b = (x < min_v).int()
max_b = (x > max_v).int()
return x * (1 - min_b - max_b) + min_v * min_b + max_v * max_b
return x.maximum(min_v).minimum(max_v)
Var.clamp = clamp
def type_as(a, b):
@ -456,32 +277,6 @@ def display_memory_info():
fileline = f"{os.path.basename(fileline.filename)}:{fileline.lineno}"
core.display_memory_info(fileline)
def import_vars(data):
''' Load variables into current scopes
example:
import_vars({"w":[1.0,2.0,3.0]})
jt.get_var([3], "float64", name="w", gen_index=False)
'''
for k in data:
v = data[k]
if type(v) != core.Var:
v = array(v).stop_fuse()
scopes = k.split("/")
scope = current_scope
for i in range(len(scopes)-1):
scope = scope.get_scope(scopes[i])
vname = scopes[-1]
assert vname not in scope.children, f"Var {k} exists. Please load_vars at the beginning"
v.name(k)
scope.children[vname] = v
def export_vars():
''' Export all vars into a dictionary
return: a dictionary, key is var name, value is numpy array
'''
data = { v.name():v.fetch_sync() for v in find_vars() }
return data
def load(path):
pkl_file = open(path, 'rb')
model_dict = pickle.load(pkl_file)
@ -489,7 +284,7 @@ def load(path):
class Module:
def __init__(self, *args, **kw):
__doc__ == 'doc'
pass
def execute(self, *args, **kw):
pass
def __call__(self, *args, **kw):
@ -498,8 +293,6 @@ class Module:
return self.__str__()
def _get_name(self):
return self.__class__.__name__
def __doc__(self):
pass
def __name__(self):
pass
@ -670,16 +463,12 @@ def make_module(func, exec_n_args=1):
def __init__(self, *args, **kw):
self.args = args
self.kw = kw
self.__doc__ == 'doc'
def execute(self, *args):
return func(*args, *self.args, **self.kw)
def __str__(self):
return 'str'
def __repr__(self):
return self.__str__()
return f"{func.__name__}({self.extra_repr()})"
def extra_repr(self):
return ''
return ",".join(map(str, self.args))
return MakeModule

View File

@ -6,7 +6,7 @@
import os, sys, shutil
from .compiler import *
from jittor_utils import run_cmd, get_version
from jittor.dataset.utils import download_url_to_local
from jittor.utils.misc import download_url_to_local
def search_file(dirs, name):
for d in dirs:
@ -256,6 +256,11 @@ def install_nccl(root_folder):
LOG.i("Downloading nccl...")
download_url_to_local(url, filename, root_folder, true_md5)
if core.get_device_count() == 0:
return
if not inside_mpi():
return
import tarfile
with tarfile.open(fullname, "r") as tar:
tar.extractall(root_folder)
@ -269,7 +274,7 @@ def setup_nccl():
global nccl_ops, use_nccl
use_nccl = os.environ.get("use_nccl", "1")=="1"
nccl_ops = None
if not has_cuda or mpi is None:
if not has_cuda or not has_mpi:
use_nccl = False
return
if not use_nccl: return
@ -284,6 +289,7 @@ def setup_nccl():
make_cache_dir(nccl_path)
nccl_home = install_nccl(nccl_path)
if nccl_home is None: return
nccl_include_path = os.path.join(nccl_home, "build", "include")
nccl_lib_path = os.path.join(nccl_home, "build", "lib")
@ -343,8 +349,6 @@ def setup_mpi():
else:
use_mpi = True
has_mpi = True
if not inside_mpi():
use_mpi = False
if not use_mpi:
return
@ -388,6 +392,8 @@ def setup_mpi():
setattr(core.Var, k, warper(mpi_ops.__dict__[k]))
setup_mpi()
if not inside_mpi():
mpi = None
setup_nccl()
setup_cutt()

View File

@ -187,19 +187,3 @@ def setitem(x, slices, value):
jt.Var.__getitem__ = jt.Var.slice_var = slice_var
jt.Var.__setitem__ = setitem
def adam(model, loss, lr=3e-4, betas=[0.9, 0.999], eps=1e-8):
ps = jt.find_vars(model)
gs = jt.grad(loss, ps)
with jt.var_scope('_'.join([model, 'adam']), unique=True):
adam_step = jt.make_var([1], init=jt.zeros)
adam_step += 1
for p,g in zip(ps,gs):
m = jt.make_var(p.shape, init=jt.zeros)
v = jt.make_var(p.shape, init=jt.zeros)
m.assign(betas[0] * m + (1-betas[0]) * g)
v.assign(betas[1] * v + (1-betas[1]) * g * g)
step_size = lr * jt.sqrt(1-betas[1]**adam_step) / (1-betas[0] ** adam_step)
p -= m * step_size / (jt.sqrt(v) + eps)

View File

@ -0,0 +1,4 @@
from .dataset import Dataset, ImageFolder
from .mnist import MNIST
from .voc import VOC

View File

@ -23,7 +23,7 @@ import jittor as jt
dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset")
mp_log_v = os.environ.get("mp_log_v", 0)
mpi = jt.compile_extern.mpi
mpi = jt.mpi
class Worker:
def __init__(self, target, args, buffer_size):

View File

@ -12,7 +12,7 @@ import gzip
from PIL import Image
# our lib jittor import
from jittor.dataset.dataset import Dataset, dataset_root
from jittor.dataset.utils import ensure_dir, download_url_to_local
from jittor.utils.misc import ensure_dir, download_url_to_local
import jittor as jt
import jittor.transform as trans

View File

@ -8,74 +8,9 @@
# ***************************************************************
import jittor as jt
import os
from six.moves import urllib
import hashlib
from tqdm import tqdm
import numpy as np
from collections.abc import Sequence, Mapping
from PIL import Image
from .. import lock
def ensure_dir(dir_path):
if not os.path.isdir(dir_path):
os.makedirs(dir_path)
def _progress():
pbar = tqdm(total=None)
def bar_update(block_num, block_size, total_size):
""" reporthook
@block_num: the num of downloaded data block
@block_size: the size of data block
@total_size: the total size of remote file
"""
if pbar.total is None and total_size:
pbar.total = total_size
progress_bytes = block_num * block_size
pbar.update(progress_bytes - pbar.n)
return bar_update
@lock.lock_scope()
def download_url_to_local(url, filename, root_folder, md5):
ensure_dir(root_folder)
file_path = os.path.join(root_folder, filename)
if check_file_exist(file_path, md5):
print("Data file has been downloaded and verified")
else:
try:
print('Downloading ' + url + ' to ' + file_path)
urllib.request.urlretrieve(
url, file_path,
reporthook=_progress()
)
except(urllib.error.URLError, IOError) as e:
raise e
if not check_file_exist(file_path, md5):
raise RuntimeError("File downloads failed.")
def check_file_exist(file_path, md5):
if not os.path.isfile(file_path):
return False
if md5 is None:
return True
return check_md5(file_path, md5)
def calculate_md5(file_path, chunk_size=1024 * 1024):
md5 = hashlib.md5()
with open(file_path, 'rb') as f:
for chunk in iter(lambda: f.read(chunk_size), b''):
md5.update(chunk)
return md5.hexdigest()
def check_md5(file_path, md5, **kwargs):
return md5 == calculate_md5(file_path, **kwargs)
def get_random_list(n):
return list(np.random.permutation(range(n)))

View File

@ -12,7 +12,7 @@ import os
from PIL import Image
import matplotlib.pyplot as plt
import cv2
from dataset import Dataset, dataset_root
from .dataset import Dataset, dataset_root
class VOC(Dataset):
NUM_CLASSES = 21

View File

@ -1,52 +0,0 @@
import jittor as jt
import numpy as np
import sys, os
f32 = jt.float32
@jt.var_scope('linear')
def linear(x, n):
w = jt.make_var([x.shape[-1], n], init=lambda *a:
(jt.random(*a)-f32(0.5)) / f32(x.shape[-1])**f32(0.5))
b = jt.make_var([n], init=lambda *a: jt.random(*a)-f32(0.5))
return jt.matmul(x, w) + b
def relu(x): return jt.maximum(x, f32(0))
@jt.var_scope('model', unique=True)
def model(x):
x = linear(x, 10)
x = relu(x)
x = linear(x, 1)
return x
np.random.seed(0)
jt.set_seed(3)
n = 1000
batch_size = 50
base_lr = 0.05
# we need to stop grad of global value to prevent memory leak
lr = f32(base_lr).name("lr").stop_grad()
def get_data(n):
for i in range(n):
x = np.random.rand(batch_size, 1)
y = x*x
yield np.float32(x), np.float32(y)
for i,(x,y) in enumerate(get_data(n)):
pred_y = model(x).name("pred_y")
loss = ((pred_y - y)**f32(2)).name("loss")
loss_mean = loss.mean()
ps = jt.find_vars('model')
gs = jt.grad(loss_mean, ps)
for p,g in zip(ps, gs):
p -= g * lr
if i>2:
assert prev == jt.liveness_info(), f"memory leak {prev} {jt.liveness_info()}"
prev = jt.liveness_info()
print(f"step {i}, loss = {loss_mean().sum()}")
# result is 0.0009948202641680837
result = 0.0009948202641680837
assert abs(loss_mean.data - result) < 1e-6

View File

@ -43,38 +43,6 @@ jt.Var.__imatmul__ = lambda a,b: a.assign(matmul(a,b))
def get_init_var_rand(shape, dtype):
return jt.array(np.random.normal(0.0, 1.0, shape).astype(np.float32))
@jt.var_scope('conv')
def conv(x, in_planes, out_planes, kernel_size, padding, stride = 1, init_method=None):
Kw = kernel_size
Kh = kernel_size
_C = in_planes
Kc = out_planes
N,C,H,W = x.shape
assert C==_C
if init_method==None:
w = jt.make_var([Kc, _C, Kh, Kw], init=lambda *a: init.relu_invariant_gauss(*a, mode="fan_out"))
else:
w = jt.make_var([Kc, _C, Kh, Kw], init=init_method)
xx = x.reindex([N,Kc,C,(H+padding*2-kernel_size)//stride+1,(W+padding*2-kernel_size)//stride+1,Kh,Kw], [
'i0', # Nid
'i2', # Cid
f'i3*{stride}-{padding}+i5', # Hid+Khid
f'i4*{stride}-{padding}+i6', # Wid+KWid
])
ww = w.broadcast(xx.shape, [0,3,4])
yy = xx*ww
y = yy.sum([2,5,6]) # C, Kh, Kw
return y
@jt.var_scope('linear')
def linear(x, n):
w = jt.make_var([n, x.shape[-1]], init=lambda *a: init.invariant_uniform(*a))
w = w.reindex([w.shape[1], w.shape[0]],["i1","i0"])
bound = 1.0/math.sqrt(w.shape[0])
b = jt.make_var([n], init=lambda *a: init.uniform(*a,-bound,bound))
return jt.matmul(x, w) + b
def relu(x): return jt.maximum(x, 0)
def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale)
def relu6(x): return jt.minimum(jt.maximum(x, 0), 6)

View File

@ -30,6 +30,7 @@ def check(jt_model, torch_model, shape, near_data):
data = list(range(8)) * int((shape[0] * shape[1] * shape[2] * shape[3]) / 8)
random.shuffle(data)
x = jt.array(data).float32().reshape(shape)
x.data
else:
x = jt.random(shape)
y = jt_model(x)

View File

@ -12,7 +12,7 @@ import numpy as np
from jittor import compile_extern
from .test_log import find_log_with_re
import copy
if compile_extern.has_cuda:
if jt.has_cuda:
from jittor.compile_extern import cublas_ops, cudnn_ops, cub_ops
else:
cublas_ops = cudnn_ops = cub_ops = None

View File

@ -11,7 +11,7 @@ import jittor as jt
import numpy as np
from jittor import compile_extern
from .test_log import find_log_with_re
if compile_extern.has_cuda:
if jt.has_cuda:
from jittor.compile_extern import cublas_ops, cudnn_ops, cub_ops
else:
cublas_ops = cudnn_ops = cub_ops = None

View File

@ -30,7 +30,7 @@ class TestArray(unittest.TestCase):
a.data = jt.array([7,8,9])
assert (a.fetch_sync()==[7,8,9]).all()
@unittest.skipIf(not compile_extern.has_cuda, "Cuda not found")
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
@jt.flag_scope(use_cuda=1)
def test_memcopy_overlap(self):
import time
@ -95,13 +95,13 @@ class TestArray(unittest.TestCase):
with jt.flag_scope(use_cuda=1):
assert (jt.array([1,2,3]).reshape((1,3)).data==[1,2,3]).all()
@unittest.skipIf(not compile_extern.has_cuda, "Cuda not found")
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
def test_array_dual(self):
with jt.flag_scope(use_cuda=1):
a = jt.array(np.float32([1,2,3]))
assert (a.data==[1,2,3]).all()
@unittest.skipIf(not compile_extern.has_cuda, "Cuda not found")
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
def test_array_migrate(self):
with jt.flag_scope(use_cuda=1):
a = jt.array(np.float32([1,2,3]))

View File

@ -18,7 +18,7 @@ class TestCompileOptions(unittest.TestCase):
assert a.compile_options=={"compile_shapes":1}
b = a+a
assert b.compile_options=={}
with jt.var_scope(compile_options={"compile_shapes":1}):
with jt.flag_scope(compile_options={"compile_shapes":1}):
c = a+b
assert c.compile_options=={"compile_shapes":1}
with jt.profile_scope() as report:

View File

@ -14,7 +14,7 @@ from jittor import compile_extern
# TODO: compare with pytorch
from jittor.test.test_log import find_log_with_re
if compile_extern.has_cuda:
if jt.has_cuda:
from jittor.compile_extern import cublas_ops, cudnn_ops
else:
cublas_ops = cudnn_ops = None
@ -28,10 +28,7 @@ def conv_nchw(x, in_planes, out_planes, kernel_size, padding, stride = 1, dilati
assert C==_C
if w_ is None:
if init_method==None:
w = jt.make_var([Kc, _C, Kh, Kw], init=lambda *a: init.relu_invariant_gauss(*a, mode="fan_out"))
else:
w = jt.make_var([Kc, _C, Kh, Kw], init=init_method)
assert 0
else:
w = w_
oh = (H-Kh*dilation+dilation-1+padding*2)//stride+1
@ -56,10 +53,7 @@ def conv_nhwc(x, in_planes, out_planes, kernel_size, padding, stride = 1, dilati
assert C==_C
if w_ is None:
if init_method==None:
w = jt.make_var([Kc, _C, Kh, Kw], init=lambda *a: init.relu_invariant_gauss(*a, mode="fan_out"))
else:
w = jt.make_var([Kc, _C, Kh, Kw], init=init_method)
assert 0
else:
w = w_
oh = (H-Kh*dilation+dilation-1+padding*2)//stride+1

View File

@ -7,7 +7,7 @@ import unittest
import jittor as jt
import os
from jittor import compile_extern
if compile_extern.has_cuda:
if jt.has_cuda:
from jittor.compile_extern import cublas_ops, cudnn_ops, cub_ops
else:
cublas_ops = cudnn_ops = cub_ops = None

View File

@ -10,7 +10,7 @@ import numpy as np
from jittor import compile_extern
from jittor.test.test_log import find_log_with_re
if compile_extern.has_cuda:
if jt.has_cuda:
from jittor.compile_extern import cublas_ops, cudnn_ops
else:
cublas_ops = cudnn_ops = None

View File

@ -12,7 +12,7 @@ import numpy as np
from jittor import compile_extern
from .test_log import find_log_with_re
import copy
if compile_extern.has_cuda:
if jt.has_cuda:
from jittor.compile_extern import cutt_ops
else:
cutt_ops = None

View File

@ -11,7 +11,7 @@ from .test_grad import ngrad
from itertools import permutations
from jittor import compile_extern
from .test_log import find_log_with_re
if compile_extern.has_cuda:
if jt.has_cuda:
from jittor.compile_extern import cutt_ops
else:
cutt_ops = None

View File

@ -21,7 +21,7 @@ class TestFlags(unittest.TestCase):
def test_scope(self):
prev = jt.flags.log_v
with jt.var_scope(log_v=1):
with jt.flag_scope(log_v=1):
assert jt.flags.log_v == 1
assert jt.flags.log_v == prev

View File

@ -32,7 +32,7 @@ def performance_test_scope(warmup=0, rerun=0, **args):
jt.profiler.start(warmup, rerun)
report = []
try:
with jt.var_scope(**args):
with jt.flag_scope(**args):
yield report
finally:
jt.profiler.stop()

View File

@ -14,7 +14,7 @@ from jittor import compile_extern
# TODO: compare with pytorch
from jittor.test.test_log import find_log_with_re
if compile_extern.has_cuda:
if jt.has_cuda:
from jittor.compile_extern import cublas_ops, cudnn_ops
else:
cublas_ops = cudnn_ops = None
@ -34,10 +34,7 @@ def conv_nchw(x, in_planes, out_planes, kernel_size, padding, stride=1, dilation
ow = (W+padding[1]*2-Kw*dilation[1]+dilation[1]-1)//stride[1]+1
if w_ is None:
if init_method==None:
w = jt.make_var([oc, C // G, Kh, Kw], init=lambda *a: init.relu_invariant_gauss(*a, mode="fan_out"))
else:
w = jt.make_var([oc, C // G, Kh, Kw], init=init_method)
assert 0
else:
w = w_

View File

@ -29,7 +29,7 @@ def find_log_with_re(logs, pattern=None, **args):
class TestLog(unittest.TestCase):
def test_log_capture(self):
LOG.log_capture_start()
with jt.var_scope(log_v=1000, log_vprefix=""):
with jt.flag_scope(log_v=1000, log_vprefix=""):
LOG.v("1")
LOG.vv("2")
LOG.i("3")

View File

@ -16,16 +16,6 @@ import numpy as np
def get_init_var(shape, dtype):
return jt.random(shape, dtype)
def batch_norm(x):
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
norm_x = (x-xmean.broadcast_var(x))/(jt.sqrt(x2mean-xmean*xmean+jt.float32(1e-5)).broadcast_var(x))
w = jt.make_var([x.shape[1]], init=get_init_var)
b = jt.make_var([x.shape[1]], init=get_init_var)
w = w.broadcast([1, w.shape[0],1,1], [0,2,3])
b = b.broadcast([1, b.shape[0],1,1], [0,2,3])
return norm_x * w + b
def pool(x, size, op, padding, stride = 1): # TODO: stride, padding
N,C,H,W = x.shape
h = (H+padding*2-size)//stride+1
@ -43,41 +33,25 @@ def pool(x, size, op, padding, stride = 1): # TODO: stride, padding
"i3", # Wid
])
def conv(x, in_planes, out_planes, kernel_size, padding, stride = 1):
Kw = kernel_size
Kh = kernel_size
_C = in_planes
Kc = out_planes
N,C,H,W = x.shape
assert C==_C
w = jt.make_var([Kc, _C, Kh, Kw], init=get_init_var)
xx = x.reindex([N,Kc,C,(H+padding*2-kernel_size)//stride+1,(W+padding*2-kernel_size)//stride+1,Kh,Kw], [
'i0', # Nid
'i2', # Cid
f'i3*{stride}-{padding}+i5', # Hid+Khid
f'i4*{stride}-{padding}+i6', # Wid+KWid
])
ww = w.broadcast(xx.shape, [0,3,4])
yy = xx*ww
y = yy.sum([2,5,6]) # Kc, Kh, Kw
return y
def relu(x): return jt.maximum(x, jt.float32(0))
@jt.var_scope('resnet_fake', unique=True)
def resnet_fake(x):
x = conv(x, 3, 64, 7, 3, 2)
x = batch_norm(x)
x = relu(x)
x = pool(x, 3, "maximum", 1, 2)
return x
def resnet_fake():
from jittor import nn
net = nn.Sequential(
nn.Conv(3, 64, 7, 2, 3),
nn.BatchNorm(64),
nn.ReLU(),
nn.Pool(3, 2, 1)
)
return net
class TestLongestDisFuse(unittest.TestCase):
def test_longest_dis_fuse(self):
x = jt.array(np.random.rand(1,3,224,224).astype(np.float32))
loss = jt.sum(resnet_fake(x))
ps = jt.find_vars('resnet_fake')
net = resnet_fake()
loss = jt.sum(net(x))
ps = net.parameters()
gs = jt.grad(loss, ps)
jt.sync(gs)
# assert not alloc big tensor

View File

@ -13,13 +13,6 @@ from .test_log import find_log_with_re
f32 = jt.float32
from jittor import nn, Module
@jt.var_scope('linear')
def linear(x, n):
w = jt.make_var([x.shape[-1], n], init=lambda *a:
(jt.random(*a)-f32(0.5)) / f32(x.shape[-1])**f32(0.5))
b = jt.make_var([n], init=lambda *a: jt.random(*a)-f32(0.5))
return jt.nn.matmul(x, w) + b
def relu(x): return jt.maximum(x, f32(0))
class Model(Module):

View File

@ -87,8 +87,7 @@ class TestSingleArray(unittest.TestCase):
jt.set_seed(3)
x = f32(np.random.rand(1, 1))
w = jt.make_var([x.shape[-1], 10], init=lambda *a:
(jt.random(*a)-f32(0.5)) / f32(x.shape[-1])**f32(0.5))
w = (jt.random([x.shape[-1], 10])-f32(0.5)) / f32(x.shape[-1])**f32(0.5)
jt.nn.matmul(x, w).data
def test4(self):

View File

@ -59,7 +59,7 @@ class TestMklConvOp(unittest.TestCase):
a_jt = jt.array(a)
b_jt = jt.array(b)
with jt.var_scope(enable_tuner=0,compile_options={"test_mkl_conv":1}):
with jt.flag_scope(enable_tuner=0,compile_options={"test_mkl_conv":1}):
c_jt = conv(a_jt, b_jt, 3, 2).data
with jt.log_capture_scope(
enable_tuner=1,
@ -84,7 +84,7 @@ class TestMklConvOp(unittest.TestCase):
a_jt = jt.array(a)
b_jt = jt.array(b)
with jt.var_scope(enable_tuner=0,
with jt.flag_scope(enable_tuner=0,
compile_options={"test_mkl_conv":uid[0]}):
c_jt = conv_nhwc_hwio(a_jt, b_jt, stride, pad).data
with jt.log_capture_scope(
@ -118,7 +118,7 @@ class TestMklConvOp(unittest.TestCase):
a_jt = jt.array(a)
b_jt = jt.array(b)
with jt.var_scope(
with jt.flag_scope(
enable_tuner=0,
compile_options={"test_mkl_conv":1}
):
@ -164,7 +164,7 @@ class TestMklConvOp(unittest.TestCase):
a_jt = jt.array(a)
b_jt = jt.array(b)
with jt.var_scope(
with jt.flag_scope(
enable_tuner=0,
compile_options={"test_mkl_conv":1}
):

View File

@ -15,6 +15,8 @@ import numpy as np
from jittor.test.test_mpi import run_mpi_test
mpi = jt.compile_extern.mpi
if mpi:
n = mpi.world_size()
class FakeMpiBatchNorm(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True):
@ -57,7 +59,8 @@ class TestMpiBatchnorm(unittest.TestCase):
mpi = jt.compile_extern.mpi
data = np.random.rand(30,3,10,10).astype("float32")
x1 = jt.array(data)
x2 = jt.array(data[mpi.world_rank()*10:(mpi.world_rank()+1)*10,...])
stride = 30//n
x2 = jt.array(data[mpi.world_rank()*stride:(mpi.world_rank()+1)*stride,...])
bn1 = nn.BatchNorm(3, sync=False)
bn2 = nn.BatchNorm(3, sync=True)
@ -75,7 +78,8 @@ class TestMpiBatchnorm(unittest.TestCase):
mpi = jt.compile_extern.mpi
data = np.random.rand(30,3,10,10).astype("float32")
global_x = jt.array(data)
x = jt.array(data[mpi.world_rank()*10:(mpi.world_rank()+1)*10,...])
stride = 30//n
x = jt.array(data[mpi.world_rank()*stride:(mpi.world_rank()+1)*stride,...])
bn1 = nn.BatchNorm(3, sync=True)
bn2 = FakeMpiBatchNorm(3)
@ -98,7 +102,7 @@ class TestMpiBatchnorm(unittest.TestCase):
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
class TestMpiBatchnormEntry(unittest.TestCase):
def test(self):
run_mpi_test(3, "test_mpi_batchnorm")
run_mpi_test(2, "test_mpi_batchnorm")
if __name__ == "__main__":
unittest.main()

View File

@ -13,6 +13,8 @@ import numpy as np
from jittor.test.test_mpi import run_mpi_test
mpi = jt.compile_extern.mpi
if mpi:
n = mpi.world_size()
@unittest.skipIf(mpi is None, "no inside mpirun")
class TestMpiOps(unittest.TestCase):
@ -24,9 +26,9 @@ class TestMpiOps(unittest.TestCase):
def test_all_reduce(self):
x = jt.random([5, 5])
y = x.mpi_all_reduce()
assert np.allclose(y.data, (x*3).data)
assert np.allclose(y.data, (x*n).data)
g = jt.grad(y,x)
assert np.allclose(g.data, np.ones([5,5])*3)
assert np.allclose(g.data, np.ones([5,5])*n)
def test_all_reduce_mean(self):
x = jt.random([5, 5])
@ -45,7 +47,7 @@ class TestMpiOps(unittest.TestCase):
assert np.allclose(y.data, data.data)
g = jt.grad(y,x)
if mpi.world_rank() == 0:
assert np.allclose(g.data, np.ones([5,5])*3)
assert np.allclose(g.data, np.ones([5,5])*n)
else:
assert np.allclose(g.data, np.zeros([5,5]))
@ -54,7 +56,7 @@ class TestMpiOps(unittest.TestCase):
y = x.mpi_reduce(root=0)
y.sync()
if mpi.world_rank() == 0:
assert np.allclose(y.data, (x*3).data)
assert np.allclose(y.data, (x*n).data)
else:
assert np.allclose(y.data, np.zeros([5,5]))
g = jt.grad(y,x)
@ -64,7 +66,7 @@ class TestMpiOps(unittest.TestCase):
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
class TestMpiOpsEntry(unittest.TestCase):
def test(self):
run_mpi_test(3, "test_mpi_op")
run_mpi_test(2, "test_mpi_op")
if __name__ == "__main__":
unittest.main()

View File

@ -1,105 +0,0 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
from .test_core import expect_error
ops = jt.ops
@jt.var_scope('linear')
def linear(x, n):
w = jt.make_var([x.shape[-1], n], init=ops.random)
return jt.matmul(x, w)
@jt.var_scope('model', unique=True)
def model(x):
x = linear(x, 10)
# x = relu(x)
x = linear(x, 10)
# x = relu(x)
x = linear(x, 1)
return x
class TestScope(unittest.TestCase):
def test_name(self):
jt.clean()
@jt.var_scope('model', unique=True)
def model():
with jt.var_scope('a'):
assert jt.current_scope.full_name == "model/a_0/"
with jt.var_scope('b'):
with jt.var_scope('b'):
assert jt.current_scope.full_name == "model/b_0/b_0/"
with jt.var_scope('c'):
assert jt.current_scope.full_name == "model/c_0/"
model()
model()
model()
jt.clean()
def test_var(self):
jt.clean()
for i in range(2):
x = jt.array([[1]])
y = model(x)
params = jt.find_vars()
assert len(params) == 3
names = [ p.name() for p in params ]
assert names == [
"model/linear_0/var_0",
"model/linear_1/var_0",
"model/linear_2/var_0",
], str(names)
jt.find_var("model/linear_0/var_0")
expect_error(lambda: jt.find_var("model/linear_0"))
expect_error(lambda: jt.find_var("model/linear"))
assert len(jt.find_vars("model/linear_0/var_0")) == 1
assert len(jt.find_vars("model/linear_0/")) == 1
assert len(jt.find_vars("model/")) == 3
jt.clean()
def test_get_var_unique(self):
jt.clean()
x = jt.make_var([1], init=ops.random)
y = jt.make_var([1], init=ops.random)
z = jt.make_var([1], init=ops.random)
assert x.name() == "var_0"
assert y.name() == "var_1", y.name()
assert z.name() == "var_2"
x = jt.make_var([1], name="x", unique=True, init=ops.random)
y = jt.make_var([1], name="y", unique=True, init=ops.random)
z = jt.make_var([1], name="z", unique=True, init=ops.random)
assert x.name() == "x"
assert y.name() == "y"
assert z.name() == "z"
expect_error(lambda: jt.make_var([2], name="x", unique=True, init=ops.random))
jt.clean()
def test_record_scope(self):
jt.clean()
@jt.var_scope("func")
def func(a):
b = a+1
jt.record_in_scope(b, "b")
c = b*2
return c
a = jt.array([1,2,3])
func(a)
assert np.allclose(jt.find_record("func_0/output").data, (a.data+1)*2)
assert np.allclose(jt.find_scope("func_0").records["output"].data, (a.data+1)*2)
recs = jt.find_records()
rec_names = [ r.name() for r in recs ]
assert len(recs)==2 and rec_names==["func_0/b","func_0/output"]
def test_get_var_init(self):
jt.clean()
assert (jt.make_var(init=[1,2,3]).data == [1,2,3]).all()
assert (jt.make_var(shape=[3], init=np.zeros).data == [0,0,0]).all()
assert (jt.make_var(init=jt.array([1,2,3]) == [1,2,3]).data).all()
jt.clean()
if __name__ == "__main__":
unittest.main()

View File

@ -12,7 +12,7 @@ class TestTracer(unittest.TestCase):
# force use addr2line
jt.flags.gdb_path = ""
with jt.var_scope(gdb_path=""):
with jt.flag_scope(gdb_path=""):
jt.print_trace()

View File

@ -0,0 +1,73 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Meng-Hao Guo <guomenghao1997@gmail.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
import os
from six.moves import urllib
import hashlib
from tqdm import tqdm
from .. import lock
def ensure_dir(dir_path):
if not os.path.isdir(dir_path):
os.makedirs(dir_path)
def _progress():
pbar = tqdm(total=None)
def bar_update(block_num, block_size, total_size):
""" reporthook
@block_num: the num of downloaded data block
@block_size: the size of data block
@total_size: the total size of remote file
"""
if pbar.total is None and total_size:
pbar.total = total_size
progress_bytes = block_num * block_size
pbar.update(progress_bytes - pbar.n)
return bar_update
@lock.lock_scope()
def download_url_to_local(url, filename, root_folder, md5):
ensure_dir(root_folder)
file_path = os.path.join(root_folder, filename)
if check_file_exist(file_path, md5):
print("Data file has been downloaded and verified")
else:
try:
print('Downloading ' + url + ' to ' + file_path)
urllib.request.urlretrieve(
url, file_path,
reporthook=_progress()
)
except(urllib.error.URLError, IOError) as e:
raise e
if not check_file_exist(file_path, md5):
raise RuntimeError("File downloads failed.")
def check_file_exist(file_path, md5):
if not os.path.isfile(file_path):
return False
if md5 is None:
return True
return check_md5(file_path, md5)
def calculate_md5(file_path, chunk_size=1024 * 1024):
md5 = hashlib.md5()
with open(file_path, 'rb') as f:
for chunk in iter(lambda: f.read(chunk_size), b''):
md5.update(chunk)
return md5.hexdigest()
def check_md5(file_path, md5, **kwargs):
return md5 == calculate_md5(file_path, **kwargs)

View File

@ -23,7 +23,7 @@ int current_seed;
static void init_cuda_devices() {
#ifdef HAS_CUDA
int count;
int count=0;
cudaGetDeviceCount(&count);
for (int i=0; i<count; i++) {
cudaDeviceProp devProp;

View File

@ -91,7 +91,14 @@ void Op::do_jit_prepare() {
memcheck_all_exist();
jk << name();
jit_prepare();
if (!jk.empty()) {
if (jk.empty()) {
// not a jit op
bool has_cuda = flags.get(NodeFlags::_cuda);
bool has_cpu = flags.get(NodeFlags::_cpu);
CHECK(has_cuda || has_cpu);
if (has_cuda && has_cpu && !use_cuda)
flags.set(NodeFlags::_cuda, 0);
} else {
// check use int64_t as index_t if array is too big
int in_id=0, out_id=0;
bool use_int64_t = false;