worker seed and auto 64 to 32

This commit is contained in:
Dun Liang 2021-05-09 18:02:17 +08:00
parent 5447f23c42
commit c47db5d189
7 changed files with 92 additions and 3 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
__version__ = '1.2.2.69' __version__ = '1.2.2.70'
from . import lock from . import lock
with lock.lock_scope(): with lock.lock_scope():
ori_int = int ori_int = int
@ -280,6 +280,10 @@ def array(data, dtype=None):
return ops.array(np.array(data, dtype)) return ops.array(np.array(data, dtype))
return ops.array(data) return ops.array(data)
def array64(data, dtype=None):
with jt.flag_scope(auto_convert_64_to_32=0):
return array(data, dtype)
def grad(loss, targets): def grad(loss, targets):
if type(targets) == core.Var: if type(targets) == core.Var:
return core.grad(loss, [targets])[0] return core.grad(loss, [targets])[0]

View File

@ -164,6 +164,9 @@ class Dataset(object):
import jittor_utils import jittor_utils
jittor_utils.cc.init_subprocess() jittor_utils.cc.init_subprocess()
jt.jt_init_subprocess() jt.jt_init_subprocess()
seed = jt.get_seed()
wseed = (seed ^ worker_id) ^ 1234
jt.set_seed(wseed)
# parallel_op_compiler still problematic, # parallel_op_compiler still problematic,
# it is not work on ubuntu 16.04. but worked on ubuntu 20.04 # it is not work on ubuntu 16.04. but worked on ubuntu 20.04
# it seems like the static value of parallel compiler # it seems like the static value of parallel compiler

View File

@ -160,6 +160,20 @@ class TestArray(unittest.TestCase):
expect_error(lambda : jt.array("asdasd")) expect_error(lambda : jt.array("asdasd"))
expect_error(lambda : jt.array(jt)) expect_error(lambda : jt.array(jt))
def test_64_bit(self):
a = np.random.rand(10)
b = jt.array(a)
assert b.dtype == "float32"
with jt.flag_scope(auto_convert_64_to_32=0):
a = np.random.rand(10)
b = jt.array(a)
assert b.dtype == "float64"
a = np.random.rand(10)
b = jt.array64(a)
assert b.dtype == "float64"
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -106,5 +106,62 @@ class TestDataset2(unittest.TestCase):
def test_dataset_use_jittor_cuda(self): def test_dataset_use_jittor_cuda(self):
self.test_dataset_use_jittor() self.test_dataset_use_jittor()
class TestDatasetSeed(unittest.TestCase):
def test_np(self):
class YourDataset(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=16)
def __getitem__(self, k):
return np.random.rand(2)
dataset = YourDataset().set_attrs(batch_size=1, shuffle=True, num_workers=4)
dd = []
for d in dataset:
dd.append(d.numpy())
for i in range(len(d)):
for j in range(i+1, len(d)):
assert not np.allclose(dd[i], dd[j])
def test_py_native(self):
import random
class YourDataset(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=16)
def __getitem__(self, k):
return random.randint(0,1000)
jt.set_global_seed(0)
dataset = YourDataset().set_attrs(batch_size=1, shuffle=True, num_workers=4)
dd = []
for d in dataset:
dd.append(d.numpy())
for i in range(len(d)):
for j in range(i+1, len(d)):
assert not np.allclose(dd[i], dd[j])
def test_jtrand(self):
import random
class YourDataset(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=160)
def __getitem__(self, k):
return jt.rand(2)
jt.set_global_seed(0)
dataset = YourDataset().set_attrs(batch_size=1, shuffle=True, num_workers=4)
dd = []
for d in dataset:
dd.append(d.numpy())
for i in range(len(d)):
for j in range(i+1, len(d)):
assert not np.allclose(dd[i], dd[j])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -73,6 +73,10 @@ void set_seed(int seed) {
cb(seed); cb(seed);
} }
int get_seed() {
return current_seed;
}
void add_set_seed_callback(set_seed_callback callback) { void add_set_seed_callback(set_seed_callback callback) {
callbacks.push_back(callback); callbacks.push_back(callback);
callback(current_seed); callback(current_seed);
@ -90,6 +94,7 @@ void jt_init_subprocess() {
exe.last_is_cuda = false; exe.last_is_cuda = false;
no_cuda_error_when_free = 1; no_cuda_error_when_free = 1;
#endif #endif
callbacks.clear();
} }
} }

View File

@ -15,6 +15,8 @@ typedef void (*set_seed_callback)(int);
void init(); void init();
// @pyjt(set_seed, seed) // @pyjt(set_seed, seed)
void set_seed(int seed); void set_seed(int seed);
// @pyjt(get_seed)
int get_seed();
void add_set_seed_callback(set_seed_callback callback); void add_set_seed_callback(set_seed_callback callback);

View File

@ -22,6 +22,9 @@
namespace jittor { namespace jittor {
DEFINE_FLAG(int, auto_convert_64_to_32, 1, "auto convert 64bit numpy array into 32bit jittor array");
static auto make_array = get_op_info("array") static auto make_array = get_op_info("array")
.get_constructor<VarPtr, const void*, NanoVector, NanoString>(); .get_constructor<VarPtr, const void*, NanoVector, NanoString>();
@ -81,7 +84,8 @@ ArrayOp::ArrayOp(PyObject* obj) {
args.ptr = arr->data; args.ptr = arr->data;
// use 32-bit by default // use 32-bit by default
if (holder.obj && args.dtype.dsize() == 8) { if ((auto_convert_64_to_32 || holder.obj)
&& args.dtype.dsize() == 8 && args.ptr) {
auto num = PyArray_Size(arr)/8; auto num = PyArray_Size(arr)/8;
if (args.dtype.is_int()) { if (args.dtype.is_int()) {
auto* __restrict__ i64 = (int64*)args.ptr; auto* __restrict__ i64 = (int64*)args.ptr;
@ -96,7 +100,6 @@ ArrayOp::ArrayOp(PyObject* obj) {
f32[i] = (float32)f64[i]; f32[i] = (float32)f64[i];
args.dtype = ns_float32; args.dtype = ns_float32;
} }
} }
} else { } else {
LOGf << "type <" >> Py_TYPE(obj)->tp_name >> "> not support for jittor array"; LOGf << "type <" >> Py_TYPE(obj)->tp_name >> "> not support for jittor array";
@ -147,6 +150,7 @@ ArrayOp::ArrayOp(PyObject* obj) {
NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_WRITEABLE, // flags NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_WRITEABLE, // flags
NULL // obj NULL // obj
)); ));
// TODO: fix not c style auto convert
ASSERT(0==PyArray_CopyInto(holder.obj, obj)); ASSERT(0==PyArray_CopyInto(holder.obj, obj));
} }
} }