data gz and register_hook and requires_grads_

This commit is contained in:
Dun Liang 2021-07-26 20:58:15 +08:00
parent 0cf77ea11c
commit ee020b60f7
21 changed files with 123 additions and 50 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.79'
__version__ = '1.2.3.80'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@ -827,6 +827,11 @@ class Module:
self.dfs([], "", callback, callback_leave)
return ms
def requires_grad_(self, requires_grad=True):
self._requires_grad = requires_grad
self._place_hooker()
return self
def __hooked_call__(self, *args, **kw):
if hasattr(self, "__fhook2__"):
if len(kw):
@ -837,7 +842,11 @@ class Module:
if len(kw):
LOG.w("backward hook not support kw")
args = grad_hooker(args, self.__bihook__)
ret = self.__hooked_call__(*args, **kw)
if hasattr(self, "_requires_grad") and not self._requires_grad:
with jt.no_grad():
ret = self.__hooked_call__(*args, **kw)
else:
ret = self.__hooked_call__(*args, **kw)
if hasattr(self, "__bohook__"):
if len(kw):
LOG.w("backward hook not support kw")
@ -1204,6 +1213,29 @@ def grad_hooker(args, hook):
hooker = GradHooker(hook)
return hooker(*args)
def register_hook(v, hook):
""" register hook of any jittor Variables, if hook return not None,
the gradient of this variable will be alter, Example::
x = jt.array([0.0, 0.0])
y = x * [1,2]
y.register_hook(lambda g: g*2)
dx = jt.grad(y, x)
print(dx)
# will be [2, 4]
"""
def _hook(grads):
g = hook(grads[0])
if g is not None:
return (g,)
return None
hooker = GradHooker(_hook)
v.swap(hooker(v)[0])
return v
Var.register_hook = register_hook
def make_module(func, exec_n_args=1):
class MakeModule(Module):
def __init__(self, *args, **kw):

View File

@ -23,6 +23,7 @@ from . import pyjt_compiler
from jittor_utils import lock
from jittor_utils import install_cuda
from jittor import __version__
import hashlib
def find_jittor_path():
return os.path.dirname(__file__)
@ -641,7 +642,6 @@ def compile_custom_ops(
if gen_name_ != "":
gen_name = gen_name_
if len(gen_name) > 100:
import hashlib
gen_name = gen_name[:80] + "___hash" + hashlib.md5(gen_name.encode()).hexdigest()
includes = sorted(list(set(includes)))
@ -1078,30 +1078,38 @@ os_type = {
"macos": "macos",
}
version_file = os.path.join(jittor_path, "version")
if os.path.isfile(version_file) and not os.path.isdir(os.path.join(jittor_path, "src", "__data__")):
with open(version_file, 'r') as f:
version = f.read().strip()
# key = f"{version}-{cc_type}-{'cuda' if has_cuda else 'cpu'}.o"
key = f"{version}-g++-cpu"
os_id = os_release["ID"]
os_key = os_type.get(os_id, "ubuntu")
os_key += '-' + os_arch if os_arch else ''
if platform.machine()=='aarch64':
os_key += '-aarch64'
if platform.machine()=='sw_64':
os_key += '-sw_64'
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
if "os_key" in os.environ:
os_key = os.environ['os_key']
LOG.i("OS type:", os_id, " OS key:", os_key)
key += '-' + os_key + '.o'
# TODO: open the website
extra_obj = os.path.join(cache_path, key)
url = os.path.join("https://cg.cs.tsinghua.edu.cn/jittor/assets/build/"+key)
jit_utils.download(url, extra_obj)
files.append(extra_obj)
if platform.machine()=='sw_64':
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
data_gz_path = os.path.join(jittor_path, "utils", "data.gz")
use_data_gz = os.path.isfile(data_gz_path)
if os.environ.get("use_data_gz", "1") == "0":
use_data_gz = False
if use_data_gz:
import gzip
with gzip.open(data_gz_path, 'rb') as f:
data = f.read()
md5 = hashlib.md5(data).hexdigest()
target_md5 = None
data_gz_md5_path = os.path.join(cache_path, "data.md5")
if os.path.isfile(data_gz_md5_path):
with open(data_gz_md5_path, 'r') as f:
target_md5 = f.read()
data_o_path = os.path.join(cache_path, "data.o")
if target_md5 != md5:
data_s_path = os.path.join(cache_path, "data.cc")
with open(data_s_path, "w") as f:
f.write(data.decode("utf8"))
dflags = (cc_flags+opt_flags)\
.replace("-Wall", "") \
.replace("-Werror", "")
run_cmd(f"{cc_path} {dflags} -D_P\\(...\\)= {data_s_path} -c -o {data_o_path}")
os.remove(data_s_path)
with open(data_gz_md5_path, 'w') as f:
f.write(md5)
files.append(data_o_path)
files = [f for f in files if "__data__" not in f]
compile(cc_path, cc_flags+opt_flags, files, 'jittor_core'+extension_suffix)

View File

@ -1,5 +1,5 @@
from .dataset import Dataset, ImageFolder, dataset_root, TensorDataset
from .dataset import Dataset, ImageFolder, dataset_root, TensorDataset, VarDataset
from .mnist import MNIST
from .cifar import CIFAR10, CIFAR100
from .voc import VOC

View File

@ -11,6 +11,7 @@
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include "utils/log.h"
#include "helper_cuda.h"
#include "fp16_emu.h"
#include "common.h"

View File

@ -58,6 +58,7 @@
// CUDA and CUBLAS functions
#include <helper_functions.h>
#include "utils/log.h"
#include "helper_cuda.h"
#ifndef min

View File

@ -8,6 +8,7 @@
#include <cuda_runtime.h>
#include <cudnn.h>
#include "utils/log.h"
#include "helper_cuda.h"
#include "fp16_emu.h"
#include "common.h"

View File

@ -65,6 +65,7 @@
#include <assert.h>
#include <cudnn.h>
#include "utils/log.h"
#include "helper_cuda.h"
#include "fp16_dev.h"
#include "fp16_emu.h"

View File

@ -1,4 +1,5 @@
#include <cudnn.h>
#include "utils/log.h"
#include "helper_cuda.h"
const char *_cudaGetErrorEnum(cudnnStatus_t error) {

View File

@ -11,6 +11,7 @@
#include <cuda_runtime.h>
#include <cublas.h>
#include "utils/log.h"
#include "helper_cuda.h"
#include <curand.h>

View File

@ -12,6 +12,7 @@
#include <cuda_runtime.h>
#include <nccl.h>
#include "utils/log.h"
#include "helper_cuda.h"
namespace jittor {

View File

@ -9,6 +9,7 @@
*
*/
#include "utils/log.h"
#include "helper_cuda.h"
#include "fp16_dev.h"

View File

@ -13,6 +13,7 @@
// These are CUDA Helper functions for initialization and error checking
#include <cuda_runtime.h>
#include "utils/log.h"
#include "helper_cuda.h"
#ifdef _CUFFT_H_

View File

@ -778,7 +778,7 @@ void simple_net(int times = 100) {
s.wait();
}
// extern "C" int mkl_test_entry();
// extern int mkl_test_entry();
int mkl_test_entry() {
try {

View File

@ -15,25 +15,25 @@ def eye(shape, dtype):
return jt.array(np.identity(shape[0])).unary(dtype)
def eye_(var):
var.assign(eye(var.shape, var.dtype))
return var.assign(eye(var.shape, var.dtype))
def constant(shape, dtype, value=0.0):
return jt.array(value).unary(dtype).broadcast(shape)
def constant_(var, value=0.0):
var.assign(constant(var.shape, var.dtype, value))
return var.assign(constant(var.shape, var.dtype, value))
def uniform(shape, dtype, low, high):
return jt.random(shape, dtype) * (low - high) + high
def uniform_(var, low, high):
var.assign(uniform(var.shape, var.dtype, low, high))
return var.assign(uniform(var.shape, var.dtype, low, high))
def gauss(shape, dtype, mean=0.0, std=1.0):
return jt.random(shape, dtype, "normal") * std + mean
def gauss_(var, mean=0.0, std=1.0):
var.assign(gauss(var.shape, var.dtype, mean, std))
return var.assign(gauss(var.shape, var.dtype, mean, std))
def invariant_uniform(shape, dtype, mode="fan_in"):
assert len(shape)>1
@ -61,7 +61,7 @@ def relu_invariant_gauss(shape, dtype, mode="fan_in"):
return gauss(shape, dtype, 0, std)
def relu_invariant_gauss_(var, mode="fan_in"):
var.assign(relu_invariant_gauss(tuple(var.shape), var.dtype, mode))
return var.assign(relu_invariant_gauss(tuple(var.shape), var.dtype, mode))
def calculate_std(var,mode,nonlinearity,param=0.01):
mode = mode.lower()
@ -112,7 +112,7 @@ def xavier_uniform(shape, dtype, gain=1.0):
return uniform(shape, dtype, -bound, bound)
def xavier_uniform_(var, gain=1.0):
var.assign(xavier_uniform(tuple(var.shape), var.dtype, gain))
return var.assign(xavier_uniform(tuple(var.shape), var.dtype, gain))
def xavier_gauss(shape, dtype, gain=1.0):
assert len(shape)>1
@ -125,4 +125,4 @@ def xavier_gauss(shape, dtype, gain=1.0):
return gauss(shape, dtype, 0, std)
def xavier_gauss_(var, gain=1.0):
var.assign(xavier_gauss(tuple(var.shape), var.dtype, gain))
return var.assign(xavier_gauss(tuple(var.shape), var.dtype, gain))

View File

@ -10,7 +10,7 @@
#include <functional>
#include "utils/log.h"
#define JIT_TEST(name) extern "C" void jit_test_ ## name ()
#define JIT_TEST(name) extern void jit_test_ ## name ()
void expect_error(std::function<void()> func);
#define VAR_MEMBER_NAME_AND_OFFSET(name, op) { #name , offsetof(struct op, name) }

View File

@ -33,7 +33,7 @@ int get_seed();
void add_set_seed_callback(set_seed_callback callback);
extern "C"
extern
std::default_random_engine* get_random_engine();
// things need to be clean before python exit

View File

@ -13,7 +13,7 @@
namespace jittor {
DECLARE_FLAG(int, para_opt_level);
DEFINE_FLAG(int, para_opt_level, 3, "para_opt_level");
void LoopVarAnalyzePass::run() {
// loop_vars: opi_xx->shape[j]

View File

@ -34,9 +34,9 @@ constexpr int32_t basename_index(const char * const path, const int32_t index =
#define PREDICT_BRANCH_NOT_TAKEN(x) (__builtin_expect(x, 0))
extern "C" uint32_t get_tid();
extern "C" bool g_supports_color;
extern "C" void print_prefix(std::ostream* out);
extern uint32_t get_tid();
extern bool g_supports_color;
extern void print_prefix(std::ostream* out);
constexpr char green[] = "\033[38;5;2m";
constexpr char red[] = "\033[38;5;1m";
@ -58,10 +58,10 @@ static void get_color(char level, int verbose, const char*& color_begin, const c
color_end = "\033[m";
}
extern "C" void send_log(std::ostringstream&& out);
extern "C" void flush_log();
extern "C" void log_capture_start();
extern "C" void log_capture_stop();
extern void send_log(std::ostringstream&& out);
extern void flush_log();
extern void log_capture_start();
extern void log_capture_stop();
extern std::vector<std::map<string,string>> log_capture_read();
extern string thread_local thread_name;
@ -145,9 +145,10 @@ template<class T> T get_from_env(const char* name,const T& _default) {
template<> std::string get_from_env(const char* name, const std::string& _default);
#define DECLARE_FLAG(type, name) \
extern "C" type name; \
extern "C" std::string doc_ ## name; \
extern "C" void set_ ## name (const type&);
extern type name; \
extern std::string doc_ ## name; \
extern void set_ ## name (const type&);
#ifdef JIT

View File

@ -42,7 +42,7 @@ class TestAtomicTunerClass(unittest.TestCase):
x=jt.random([100,64,128,128])
with jt.log_capture_scope(
# log_silent=1,
log_v=0, log_vprefix="atomic_tuner_pass.cc=100",
log_v=0, log_vprefix="atomic=100,data=100",
) as logs:
y=model(x).numpy()
with jt.log_capture_scope(

View File

@ -46,7 +46,30 @@ class TestHook(unittest.TestCase):
assert hooked
np.testing.assert_allclose(dx.numpy(), [-1.0, -2.0])
def test_register_hook(self):
x = jt.array([0.0, 0.0])
y = x * [1,2]
y.register_hook(lambda g: g*2)
dx = jt.grad(y, x)
np.testing.assert_allclose(dx.data, [2,4])
def test_requires_grads_(self):
class Mod(jt.nn.Module):
def execute(self, x):
return x*2
x = jt.random((100,))
mod = Mod()
mod.requires_grad_(True)
y = mod(x)
y = y*10
dx = jt.grad(y, x)
np.testing.assert_allclose(dx.data, 20)
mod.requires_grad_(False)
y = mod(x)
y = y*10
dx = jt.grad(y, x)
np.testing.assert_allclose(dx.data, 0)
if __name__ == "__main__":
unittest.main()

BIN
python/jittor/utils/data.gz Normal file

Binary file not shown.