numpy_code op with cupy

This commit is contained in:
Gword 2020-07-24 21:19:50 +08:00
parent 5ee01ae6fa
commit b80f03b1ee
7 changed files with 154 additions and 44 deletions

View File

@ -21,6 +21,7 @@ with lock.lock_scope():
has_cuda = compile_extern.has_cuda = compiler.has_cuda = False
if has_cuda:
from .compile_extern import cudnn, curand, cublas
from . import init_cupy
import contextlib
import numpy as np
@ -636,3 +637,4 @@ Var.double = Var.float64
from . import nn
from .nn import matmul
from . import contrib
from . import numpy2cupy

View File

@ -0,0 +1,19 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
import cupy
import os
import ctypes
jt_allocator = ctypes.CDLL(os.path.join(
jt.compiler.cache_path,
"jittor_core"+jt.compiler.extension_suffix),
os.RTLD_NOW | os.RTLD_GLOBAL)
malloc = jt_allocator.get_jittor_cuda_malloc()
free = jt_allocator.get_jittor_cuda_free()

View File

@ -0,0 +1,26 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import numpy as np
import cupy as cp
def cvt(a):
a_pointer, read_only_flag = a.__array_interface__['data']
aptr=cp.cuda.MemoryPointer(cp.cuda.memory.UnownedMemory(a_pointer,a.size*a.itemsize,a,0),0)
a = cp.ndarray(a.shape,a.dtype,aptr)
return a
def numpy2cupy(snp, data):
for key in data:
if isinstance(data[key], list):
for i in range(len(data[key])):
data[key][i]=cvt(data[key][i])
elif isinstance(data[key], int):
pass
else:
data[key]=cvt(data[key])

View File

@ -8,13 +8,20 @@
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
import numpy
import cupy
import ctypes
import sys
class TestCodeOp(unittest.TestCase):
def test(self):
def forward_code(np, data):
a = data["inputs"][0]
b = data["outputs"][0]
if (jt.flags.use_cuda==0):
assert isinstance(a,numpy.ndarray)
else:
assert isinstance(a,cupy.core.core.ndarray)
np.add(a,a,out=b)
def backward_code(np, data):
@ -22,18 +29,24 @@ class TestCodeOp(unittest.TestCase):
out = data["outputs"][0]
np.copyto(out, dout*2.0)
a = jt.random((5,1))
b = jt.numpy_code(
a.shape,
a.dtype,
[a],
forward_code,
[backward_code],
)
assert np.allclose(b.data,(a+a).data)
da = jt.grad(b,a)
one=np.ones(a.shape)
assert np.allclose(da.data,one*2.0)
def check():
a = jt.random((5,1))
b = jt.numpy_code(
a.shape,
a.dtype,
[a],
forward_code,
[backward_code],
)
assert numpy.allclose(b.data,(a+a).data)
da = jt.grad(b,a)
one=numpy.ones(a.shape)
assert numpy.allclose(da.data,one*2.0)
jt.flags.use_cuda = 0
check()
jt.flags.use_cuda = 1
check()
def test_multi_input(self):
def forward_code(np, data):
@ -56,25 +69,31 @@ class TestCodeOp(unittest.TestCase):
else:
np.negative(dout, out)
a = jt.random((5,1))
b = jt.random((5,1))
c, d = jt.numpy_code(
[a.shape, a.shape],
[a.dtype, a.dtype],
[a, b],
forward_code,
[backward_code1,backward_code2],
)
assert np.allclose(c.data,(a+b).data)
assert np.allclose(d.data,(a-b).data)
dca, dcb = jt.grad(c,[a,b])
dda, ddb = jt.grad(d,[a,b])
one=np.ones(a.shape)
mone=one*-1.0
assert np.allclose(dca.data,one)
assert np.allclose(dcb.data,one)
assert np.allclose(dda.data,one)
assert np.allclose(ddb.data,mone)
def check():
a = jt.random((5,1))
b = jt.random((5,1))
c, d = jt.numpy_code(
[a.shape, a.shape],
[a.dtype, a.dtype],
[a, b],
forward_code,
[backward_code1,backward_code2],
)
assert numpy.allclose(c.data,(a+b).data)
assert numpy.allclose(d.data,(a-b).data)
dca, dcb = jt.grad(c,[a,b])
dda, ddb = jt.grad(d,[a,b])
one=numpy.ones(a.shape)
mone=one*-1.0
assert numpy.allclose(dca.data,one)
assert numpy.allclose(dcb.data,one)
assert numpy.allclose(dda.data,one)
assert numpy.allclose(ddb.data,mone)
jt.flags.use_cuda = 0
check()
jt.flags.use_cuda = 1
check()
@unittest.skipIf(True, "Memory leak testing is not in progress, Skip")
def test_memory_leak(self):
@ -108,16 +127,16 @@ class TestCodeOp(unittest.TestCase):
forward_code,
[backward_code1,backward_code2],
)
assert np.allclose(c.data,(a+b).data)
assert np.allclose(d.data,(a-b).data)
assert numpy.allclose(c.data,(a+b).data)
assert numpy.allclose(d.data,(a-b).data)
dca, dcb = jt.grad(c,[a,b])
dda, ddb = jt.grad(d,[a,b])
one=np.ones(a.shape)
one=numpy.ones(a.shape)
mone=one*-1.0
assert np.allclose(dca.data,one)
assert np.allclose(dcb.data,one)
assert np.allclose(dda.data,one)
assert np.allclose(ddb.data,mone)
assert numpy.allclose(dca.data,one)
assert numpy.allclose(dcb.data,one)
assert numpy.allclose(dda.data,one)
assert numpy.allclose(ddb.data,mone)
if __name__ == "__main__":
unittest.main()

View File

@ -445,5 +445,28 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
LOGvv << "cudaDeviceSynchronize times:" << sync_times << "/" <<queue.size();
#endif
}
unordered_map<void*, size_t> allocation_map;
unordered_map<void*, size_t> size_map;
extern "C" void* jittor_cuda_malloc(void*, size_t size, int device_id) {
size_t allocation;
void* ptr=exe.allocator->alloc(size, allocation);
allocation_map[ptr]=allocation;
size_map[ptr]=size;
return ptr;
}
extern "C" void jittor_cuda_free(void*, void* ptr, int device_id) {
exe.allocator->free(ptr, size_map[ptr], allocation_map[ptr]);
}
extern "C" void* get_jittor_cuda_malloc() {
return (void*)jittor_cuda_malloc;
}
extern "C" void* get_jittor_cuda_free() {
return (void*)jittor_cuda_free;
}
} // jittor

View File

@ -21,6 +21,8 @@ static auto make_numpy_code = get_op_info("numpy_code")
NumpyCodeOp::NumpyCodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs, NumpyFunc&& forward, vector<NumpyFunc>&& sbackward)
: _inputs(inputs), forward(move(forward))
{
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
_outputs.push_back(create_output(shape, dtype));
CHECKop(_inputs.size(),<=,10);
ASSERT(_outputs[0]->num >= 0);
@ -32,6 +34,8 @@ NumpyCodeOp::NumpyCodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inpu
NumpyCodeOp::NumpyCodeOp(vector<NanoVector>&& shapes, vector<NanoString>&& dtypes, vector<Var*>&& inputs, NumpyFunc&& forward, vector<NumpyFunc>&& sbackward)
: _inputs(inputs), forward(move(forward))
{
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
CHECKop(shapes.size(),==,dtypes.size()) << "Number of outputs' shapes and dtypes should be the same";
_outputs.resize(shapes.size());
CHECKop(_inputs.size(),<=,10);
@ -49,6 +53,8 @@ NumpyCodeOp::NumpyCodeOp(vector<NanoVector>&& shapes, vector<NanoString>&& dtype
NumpyCodeOp::NumpyCodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs, NumpyFunc forward, NumpyResult&& results)
: _inputs(inputs), forward(forward), _results(move(results))
{
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
_outputs.push_back(create_output(shape, dtype));
CHECKop(_inputs.size(),<=,10);
ASSERT(_outputs[0]->num >= 0);

View File

@ -1,5 +1,8 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// Copyright (c) 2020 Jittor. Authors:
// Dun Liang <randonlang@gmail.com>.
// Guowei Yang <471184555@qq.com>
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
@ -10,6 +13,7 @@
#include "misc/hash.h"
#include "misc/nano_string.h"
#include "misc/fast_shared_ptr.h"
#include "misc/cuda_flags.h"
namespace jittor {
@ -559,8 +563,6 @@ DEF_IS_1(fast_shared_ptr, T) from_py_object(PyObject* obj) {
return from_py_object<typename T::value_type>(obj);
}
DEF_IS(NumpyFunc, T) from_py_object(PyObject* obj) {
// PyObject_Call
Py_INCREF(obj);
@ -568,7 +570,12 @@ DEF_IS(NumpyFunc, T) from_py_object(PyObject* obj) {
// callback
[obj](typename T::R* result) {
// import numpy
PyObjHolder np(PyImport_ImportModule("numpy"));
string npstr="numpy";
#ifdef HAS_CUDA
if (use_cuda) npstr="cupy";
#endif
PyObjHolder np(PyImport_ImportModule(npstr.data()));
// data = {}
PyObjHolder data(to_py_object(result->varrays));
PyObjHolder data2(to_py_object(result->ints));
@ -580,7 +587,15 @@ DEF_IS(NumpyFunc, T) from_py_object(PyObject* obj) {
PyObjHolder args(PyTuple_New(2));
PyTuple_SET_ITEM(args.obj, 0, np.release());
PyTuple_SET_ITEM(args.obj, 1, data.release());
PyObjHolder ret(PyObject_Call(obj, args.obj, nullptr));
if (npstr=="cupy") {
PyObject *jt = PyImport_ImportModule("jittor");
PyObject *np2cp = PyObject_GetAttrString(jt,"numpy2cupy");
PyObject *pFunc = PyObject_GetAttrString(np2cp, "numpy2cupy");
PyObjHolder ret1(PyObject_Call(pFunc, args.obj, nullptr));
}
PyObjHolder ret2(PyObject_Call(obj, args.obj, nullptr));
},
// deleter
[obj]() { Py_DECREF(obj); },