fix cupy bug

This commit is contained in:
Dun Liang 2020-07-31 22:08:05 +08:00
parent c00b0f9865
commit d5d5fc8ea9
4 changed files with 44 additions and 41 deletions

View File

@ -22,7 +22,7 @@ with lock.lock_scope():
has_cuda = compile_extern.has_cuda = compiler.has_cuda = False
if has_cuda:
from .compile_extern import cudnn, curand, cublas
from . import init_cupy
from .init_cupy import numpy2cupy
import contextlib
import numpy as np

View File

@ -6,14 +6,40 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
import cupy
import os
import ctypes
jt_allocator = ctypes.CDLL(os.path.join(
jt.compiler.cache_path,
"jittor_core"+jt.compiler.extension_suffix),
os.RTLD_NOW | os.RTLD_GLOBAL)
malloc = jt_allocator.get_jittor_cuda_malloc()
free = jt_allocator.get_jittor_cuda_free()
has_cupy = 0
try:
import cupy
has_cupy = 1
except:
pass
if has_cupy:
import jittor as jt
import os
import ctypes
def cvt(a):
a_pointer, read_only_flag = a.__array_interface__['data']
aptr=cp.cuda.MemoryPointer(cp.cuda.memory.UnownedMemory(a_pointer,a.size*a.itemsize,a,0),0)
a = cp.ndarray(a.shape,a.dtype,aptr)
return a
def numpy2cupy(snp, data):
for key in data:
if isinstance(data[key], list):
for i in range(len(data[key])):
data[key][i]=cvt(data[key][i])
elif isinstance(data[key], int):
pass
else:
data[key]=cvt(data[key])
jt_allocator = ctypes.CDLL(os.path.join(
jt.compiler.cache_path,
"jittor_core"+jt.compiler.extension_suffix),
os.RTLD_NOW | os.RTLD_GLOBAL)
malloc = jt_allocator.get_jittor_cuda_malloc()
free = jt_allocator.get_jittor_cuda_free()
else:
def numpy2cupy(snp, data):
pass

View File

@ -1,26 +0,0 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import numpy as np
import cupy as cp
def cvt(a):
a_pointer, read_only_flag = a.__array_interface__['data']
aptr=cp.cuda.MemoryPointer(cp.cuda.memory.UnownedMemory(a_pointer,a.size*a.itemsize,a,0),0)
a = cp.ndarray(a.shape,a.dtype,aptr)
return a
def numpy2cupy(snp, data):
for key in data:
if isinstance(data[key], list):
for i in range(len(data[key])):
data[key][i]=cvt(data[key][i])
elif isinstance(data[key], int):
pass
else:
data[key]=cvt(data[key])

View File

@ -13,7 +13,9 @@
#include "misc/hash.h"
#include "misc/nano_string.h"
#include "misc/fast_shared_ptr.h"
#ifdef HAS_CUDA
#include "misc/cuda_flags.h"
#endif
namespace jittor {
@ -610,12 +612,13 @@ DEF_IS(NumpyFunc, T) from_py_object(PyObject* obj) {
PyTuple_SET_ITEM(args.obj, 0, np.release());
PyTuple_SET_ITEM(args.obj, 1, data.release());
#ifdef HAS_CUDA
if (npstr=="cupy") {
PyObject *jt = PyImport_ImportModule("jittor");
PyObject *np2cp = PyObject_GetAttrString(jt,"numpy2cupy");
PyObject *pFunc = PyObject_GetAttrString(np2cp, "numpy2cupy");
PyObjHolder ret1(PyObject_Call(pFunc, args.obj, nullptr));
PyObjHolder jt(PyImport_ImportModule("jittor"));
PyObjHolder pFunc(PyObject_GetAttrString(jt.obj,"numpy2cupy"));
PyObjHolder ret1(PyObject_Call(pFunc.obj, args.obj, nullptr));
}
#endif
PyObjHolder ret2(PyObject_Call(obj, args.obj, nullptr));
},