mirror of https://github.com/Jittor/Jittor
data gz and register_hook and requires_grads_
This commit is contained in:
parent
0cf77ea11c
commit
ee020b60f7
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.79'
|
||||
__version__ = '1.2.3.80'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -827,6 +827,11 @@ class Module:
|
|||
self.dfs([], "", callback, callback_leave)
|
||||
return ms
|
||||
|
||||
def requires_grad_(self, requires_grad=True):
|
||||
self._requires_grad = requires_grad
|
||||
self._place_hooker()
|
||||
return self
|
||||
|
||||
def __hooked_call__(self, *args, **kw):
|
||||
if hasattr(self, "__fhook2__"):
|
||||
if len(kw):
|
||||
|
@ -837,7 +842,11 @@ class Module:
|
|||
if len(kw):
|
||||
LOG.w("backward hook not support kw")
|
||||
args = grad_hooker(args, self.__bihook__)
|
||||
ret = self.__hooked_call__(*args, **kw)
|
||||
if hasattr(self, "_requires_grad") and not self._requires_grad:
|
||||
with jt.no_grad():
|
||||
ret = self.__hooked_call__(*args, **kw)
|
||||
else:
|
||||
ret = self.__hooked_call__(*args, **kw)
|
||||
if hasattr(self, "__bohook__"):
|
||||
if len(kw):
|
||||
LOG.w("backward hook not support kw")
|
||||
|
@ -1204,6 +1213,29 @@ def grad_hooker(args, hook):
|
|||
hooker = GradHooker(hook)
|
||||
return hooker(*args)
|
||||
|
||||
def register_hook(v, hook):
|
||||
""" register hook of any jittor Variables, if hook return not None,
|
||||
the gradient of this variable will be alter, Example::
|
||||
|
||||
x = jt.array([0.0, 0.0])
|
||||
y = x * [1,2]
|
||||
y.register_hook(lambda g: g*2)
|
||||
dx = jt.grad(y, x)
|
||||
print(dx)
|
||||
# will be [2, 4]
|
||||
|
||||
"""
|
||||
def _hook(grads):
|
||||
g = hook(grads[0])
|
||||
if g is not None:
|
||||
return (g,)
|
||||
return None
|
||||
hooker = GradHooker(_hook)
|
||||
v.swap(hooker(v)[0])
|
||||
return v
|
||||
|
||||
Var.register_hook = register_hook
|
||||
|
||||
def make_module(func, exec_n_args=1):
|
||||
class MakeModule(Module):
|
||||
def __init__(self, *args, **kw):
|
||||
|
|
|
@ -23,6 +23,7 @@ from . import pyjt_compiler
|
|||
from jittor_utils import lock
|
||||
from jittor_utils import install_cuda
|
||||
from jittor import __version__
|
||||
import hashlib
|
||||
|
||||
def find_jittor_path():
|
||||
return os.path.dirname(__file__)
|
||||
|
@ -641,7 +642,6 @@ def compile_custom_ops(
|
|||
if gen_name_ != "":
|
||||
gen_name = gen_name_
|
||||
if len(gen_name) > 100:
|
||||
import hashlib
|
||||
gen_name = gen_name[:80] + "___hash" + hashlib.md5(gen_name.encode()).hexdigest()
|
||||
|
||||
includes = sorted(list(set(includes)))
|
||||
|
@ -1078,30 +1078,38 @@ os_type = {
|
|||
"macos": "macos",
|
||||
}
|
||||
|
||||
version_file = os.path.join(jittor_path, "version")
|
||||
if os.path.isfile(version_file) and not os.path.isdir(os.path.join(jittor_path, "src", "__data__")):
|
||||
with open(version_file, 'r') as f:
|
||||
version = f.read().strip()
|
||||
# key = f"{version}-{cc_type}-{'cuda' if has_cuda else 'cpu'}.o"
|
||||
key = f"{version}-g++-cpu"
|
||||
os_id = os_release["ID"]
|
||||
os_key = os_type.get(os_id, "ubuntu")
|
||||
os_key += '-' + os_arch if os_arch else ''
|
||||
if platform.machine()=='aarch64':
|
||||
os_key += '-aarch64'
|
||||
if platform.machine()=='sw_64':
|
||||
os_key += '-sw_64'
|
||||
import ssl
|
||||
ssl._create_default_https_context = ssl._create_unverified_context
|
||||
if "os_key" in os.environ:
|
||||
os_key = os.environ['os_key']
|
||||
LOG.i("OS type:", os_id, " OS key:", os_key)
|
||||
key += '-' + os_key + '.o'
|
||||
# TODO: open the website
|
||||
extra_obj = os.path.join(cache_path, key)
|
||||
url = os.path.join("https://cg.cs.tsinghua.edu.cn/jittor/assets/build/"+key)
|
||||
jit_utils.download(url, extra_obj)
|
||||
files.append(extra_obj)
|
||||
if platform.machine()=='sw_64':
|
||||
import ssl
|
||||
ssl._create_default_https_context = ssl._create_unverified_context
|
||||
|
||||
data_gz_path = os.path.join(jittor_path, "utils", "data.gz")
|
||||
use_data_gz = os.path.isfile(data_gz_path)
|
||||
if os.environ.get("use_data_gz", "1") == "0":
|
||||
use_data_gz = False
|
||||
if use_data_gz:
|
||||
import gzip
|
||||
with gzip.open(data_gz_path, 'rb') as f:
|
||||
data = f.read()
|
||||
md5 = hashlib.md5(data).hexdigest()
|
||||
target_md5 = None
|
||||
data_gz_md5_path = os.path.join(cache_path, "data.md5")
|
||||
if os.path.isfile(data_gz_md5_path):
|
||||
with open(data_gz_md5_path, 'r') as f:
|
||||
target_md5 = f.read()
|
||||
data_o_path = os.path.join(cache_path, "data.o")
|
||||
if target_md5 != md5:
|
||||
data_s_path = os.path.join(cache_path, "data.cc")
|
||||
with open(data_s_path, "w") as f:
|
||||
f.write(data.decode("utf8"))
|
||||
dflags = (cc_flags+opt_flags)\
|
||||
.replace("-Wall", "") \
|
||||
.replace("-Werror", "")
|
||||
run_cmd(f"{cc_path} {dflags} -D_P\\(...\\)= {data_s_path} -c -o {data_o_path}")
|
||||
os.remove(data_s_path)
|
||||
with open(data_gz_md5_path, 'w') as f:
|
||||
f.write(md5)
|
||||
files.append(data_o_path)
|
||||
files = [f for f in files if "__data__" not in f]
|
||||
|
||||
compile(cc_path, cc_flags+opt_flags, files, 'jittor_core'+extension_suffix)
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
|
||||
from .dataset import Dataset, ImageFolder, dataset_root, TensorDataset
|
||||
from .dataset import Dataset, ImageFolder, dataset_root, TensorDataset, VarDataset
|
||||
from .mnist import MNIST
|
||||
from .cifar import CIFAR10, CIFAR100
|
||||
from .voc import VOC
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
#include <cuda_runtime.h>
|
||||
#include <cublas_v2.h>
|
||||
|
||||
#include "utils/log.h"
|
||||
#include "helper_cuda.h"
|
||||
#include "fp16_emu.h"
|
||||
#include "common.h"
|
||||
|
|
|
@ -58,6 +58,7 @@
|
|||
|
||||
// CUDA and CUBLAS functions
|
||||
#include <helper_functions.h>
|
||||
#include "utils/log.h"
|
||||
#include "helper_cuda.h"
|
||||
|
||||
#ifndef min
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
#include <cuda_runtime.h>
|
||||
#include <cudnn.h>
|
||||
|
||||
#include "utils/log.h"
|
||||
#include "helper_cuda.h"
|
||||
#include "fp16_emu.h"
|
||||
#include "common.h"
|
||||
|
|
|
@ -65,6 +65,7 @@
|
|||
#include <assert.h>
|
||||
|
||||
#include <cudnn.h>
|
||||
#include "utils/log.h"
|
||||
#include "helper_cuda.h"
|
||||
#include "fp16_dev.h"
|
||||
#include "fp16_emu.h"
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
#include <cudnn.h>
|
||||
#include "utils/log.h"
|
||||
#include "helper_cuda.h"
|
||||
|
||||
const char *_cudaGetErrorEnum(cudnnStatus_t error) {
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cublas.h>
|
||||
#include "utils/log.h"
|
||||
#include "helper_cuda.h"
|
||||
#include <curand.h>
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
|
||||
#include <cuda_runtime.h>
|
||||
#include <nccl.h>
|
||||
#include "utils/log.h"
|
||||
#include "helper_cuda.h"
|
||||
|
||||
namespace jittor {
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
*
|
||||
*/
|
||||
|
||||
#include "utils/log.h"
|
||||
#include "helper_cuda.h"
|
||||
#include "fp16_dev.h"
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
// These are CUDA Helper functions for initialization and error checking
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include "utils/log.h"
|
||||
#include "helper_cuda.h"
|
||||
|
||||
#ifdef _CUFFT_H_
|
||||
|
|
|
@ -778,7 +778,7 @@ void simple_net(int times = 100) {
|
|||
s.wait();
|
||||
}
|
||||
|
||||
// extern "C" int mkl_test_entry();
|
||||
// extern int mkl_test_entry();
|
||||
|
||||
int mkl_test_entry() {
|
||||
try {
|
||||
|
|
|
@ -15,25 +15,25 @@ def eye(shape, dtype):
|
|||
return jt.array(np.identity(shape[0])).unary(dtype)
|
||||
|
||||
def eye_(var):
|
||||
var.assign(eye(var.shape, var.dtype))
|
||||
return var.assign(eye(var.shape, var.dtype))
|
||||
|
||||
def constant(shape, dtype, value=0.0):
|
||||
return jt.array(value).unary(dtype).broadcast(shape)
|
||||
|
||||
def constant_(var, value=0.0):
|
||||
var.assign(constant(var.shape, var.dtype, value))
|
||||
return var.assign(constant(var.shape, var.dtype, value))
|
||||
|
||||
def uniform(shape, dtype, low, high):
|
||||
return jt.random(shape, dtype) * (low - high) + high
|
||||
|
||||
def uniform_(var, low, high):
|
||||
var.assign(uniform(var.shape, var.dtype, low, high))
|
||||
return var.assign(uniform(var.shape, var.dtype, low, high))
|
||||
|
||||
def gauss(shape, dtype, mean=0.0, std=1.0):
|
||||
return jt.random(shape, dtype, "normal") * std + mean
|
||||
|
||||
def gauss_(var, mean=0.0, std=1.0):
|
||||
var.assign(gauss(var.shape, var.dtype, mean, std))
|
||||
return var.assign(gauss(var.shape, var.dtype, mean, std))
|
||||
|
||||
def invariant_uniform(shape, dtype, mode="fan_in"):
|
||||
assert len(shape)>1
|
||||
|
@ -61,7 +61,7 @@ def relu_invariant_gauss(shape, dtype, mode="fan_in"):
|
|||
return gauss(shape, dtype, 0, std)
|
||||
|
||||
def relu_invariant_gauss_(var, mode="fan_in"):
|
||||
var.assign(relu_invariant_gauss(tuple(var.shape), var.dtype, mode))
|
||||
return var.assign(relu_invariant_gauss(tuple(var.shape), var.dtype, mode))
|
||||
|
||||
def calculate_std(var,mode,nonlinearity,param=0.01):
|
||||
mode = mode.lower()
|
||||
|
@ -112,7 +112,7 @@ def xavier_uniform(shape, dtype, gain=1.0):
|
|||
return uniform(shape, dtype, -bound, bound)
|
||||
|
||||
def xavier_uniform_(var, gain=1.0):
|
||||
var.assign(xavier_uniform(tuple(var.shape), var.dtype, gain))
|
||||
return var.assign(xavier_uniform(tuple(var.shape), var.dtype, gain))
|
||||
|
||||
def xavier_gauss(shape, dtype, gain=1.0):
|
||||
assert len(shape)>1
|
||||
|
@ -125,4 +125,4 @@ def xavier_gauss(shape, dtype, gain=1.0):
|
|||
return gauss(shape, dtype, 0, std)
|
||||
|
||||
def xavier_gauss_(var, gain=1.0):
|
||||
var.assign(xavier_gauss(tuple(var.shape), var.dtype, gain))
|
||||
return var.assign(xavier_gauss(tuple(var.shape), var.dtype, gain))
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
#include <functional>
|
||||
#include "utils/log.h"
|
||||
|
||||
#define JIT_TEST(name) extern "C" void jit_test_ ## name ()
|
||||
#define JIT_TEST(name) extern void jit_test_ ## name ()
|
||||
void expect_error(std::function<void()> func);
|
||||
|
||||
#define VAR_MEMBER_NAME_AND_OFFSET(name, op) { #name , offsetof(struct op, name) }
|
||||
|
|
|
@ -33,7 +33,7 @@ int get_seed();
|
|||
|
||||
void add_set_seed_callback(set_seed_callback callback);
|
||||
|
||||
extern "C"
|
||||
extern
|
||||
std::default_random_engine* get_random_engine();
|
||||
|
||||
// things need to be clean before python exit
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
|
||||
namespace jittor {
|
||||
|
||||
DECLARE_FLAG(int, para_opt_level);
|
||||
DEFINE_FLAG(int, para_opt_level, 3, "para_opt_level");
|
||||
|
||||
void LoopVarAnalyzePass::run() {
|
||||
// loop_vars: opi_xx->shape[j]
|
||||
|
|
|
@ -34,9 +34,9 @@ constexpr int32_t basename_index(const char * const path, const int32_t index =
|
|||
|
||||
#define PREDICT_BRANCH_NOT_TAKEN(x) (__builtin_expect(x, 0))
|
||||
|
||||
extern "C" uint32_t get_tid();
|
||||
extern "C" bool g_supports_color;
|
||||
extern "C" void print_prefix(std::ostream* out);
|
||||
extern uint32_t get_tid();
|
||||
extern bool g_supports_color;
|
||||
extern void print_prefix(std::ostream* out);
|
||||
|
||||
constexpr char green[] = "\033[38;5;2m";
|
||||
constexpr char red[] = "\033[38;5;1m";
|
||||
|
@ -58,10 +58,10 @@ static void get_color(char level, int verbose, const char*& color_begin, const c
|
|||
color_end = "\033[m";
|
||||
}
|
||||
|
||||
extern "C" void send_log(std::ostringstream&& out);
|
||||
extern "C" void flush_log();
|
||||
extern "C" void log_capture_start();
|
||||
extern "C" void log_capture_stop();
|
||||
extern void send_log(std::ostringstream&& out);
|
||||
extern void flush_log();
|
||||
extern void log_capture_start();
|
||||
extern void log_capture_stop();
|
||||
extern std::vector<std::map<string,string>> log_capture_read();
|
||||
extern string thread_local thread_name;
|
||||
|
||||
|
@ -145,9 +145,10 @@ template<class T> T get_from_env(const char* name,const T& _default) {
|
|||
template<> std::string get_from_env(const char* name, const std::string& _default);
|
||||
|
||||
#define DECLARE_FLAG(type, name) \
|
||||
extern "C" type name; \
|
||||
extern "C" std::string doc_ ## name; \
|
||||
extern "C" void set_ ## name (const type&);
|
||||
extern type name; \
|
||||
extern std::string doc_ ## name; \
|
||||
extern void set_ ## name (const type&);
|
||||
|
||||
|
||||
#ifdef JIT
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ class TestAtomicTunerClass(unittest.TestCase):
|
|||
x=jt.random([100,64,128,128])
|
||||
with jt.log_capture_scope(
|
||||
# log_silent=1,
|
||||
log_v=0, log_vprefix="atomic_tuner_pass.cc=100",
|
||||
log_v=0, log_vprefix="atomic=100,data=100",
|
||||
) as logs:
|
||||
y=model(x).numpy()
|
||||
with jt.log_capture_scope(
|
||||
|
|
|
@ -46,7 +46,30 @@ class TestHook(unittest.TestCase):
|
|||
assert hooked
|
||||
np.testing.assert_allclose(dx.numpy(), [-1.0, -2.0])
|
||||
|
||||
def test_register_hook(self):
|
||||
x = jt.array([0.0, 0.0])
|
||||
y = x * [1,2]
|
||||
y.register_hook(lambda g: g*2)
|
||||
dx = jt.grad(y, x)
|
||||
np.testing.assert_allclose(dx.data, [2,4])
|
||||
|
||||
def test_requires_grads_(self):
|
||||
class Mod(jt.nn.Module):
|
||||
def execute(self, x):
|
||||
return x*2
|
||||
x = jt.random((100,))
|
||||
mod = Mod()
|
||||
mod.requires_grad_(True)
|
||||
y = mod(x)
|
||||
y = y*10
|
||||
dx = jt.grad(y, x)
|
||||
np.testing.assert_allclose(dx.data, 20)
|
||||
|
||||
mod.requires_grad_(False)
|
||||
y = mod(x)
|
||||
y = y*10
|
||||
dx = jt.grad(y, x)
|
||||
np.testing.assert_allclose(dx.data, 0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Binary file not shown.
Loading…
Reference in New Issue