mirror of https://github.com/Jittor/Jittor
fix cupy bug
This commit is contained in:
parent
c00b0f9865
commit
d5d5fc8ea9
|
@ -22,7 +22,7 @@ with lock.lock_scope():
|
|||
has_cuda = compile_extern.has_cuda = compiler.has_cuda = False
|
||||
if has_cuda:
|
||||
from .compile_extern import cudnn, curand, cublas
|
||||
from . import init_cupy
|
||||
from .init_cupy import numpy2cupy
|
||||
|
||||
import contextlib
|
||||
import numpy as np
|
||||
|
|
|
@ -6,14 +6,40 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import jittor as jt
|
||||
import cupy
|
||||
import os
|
||||
import ctypes
|
||||
|
||||
jt_allocator = ctypes.CDLL(os.path.join(
|
||||
jt.compiler.cache_path,
|
||||
"jittor_core"+jt.compiler.extension_suffix),
|
||||
os.RTLD_NOW | os.RTLD_GLOBAL)
|
||||
malloc = jt_allocator.get_jittor_cuda_malloc()
|
||||
free = jt_allocator.get_jittor_cuda_free()
|
||||
has_cupy = 0
|
||||
try:
|
||||
import cupy
|
||||
has_cupy = 1
|
||||
except:
|
||||
pass
|
||||
if has_cupy:
|
||||
import jittor as jt
|
||||
import os
|
||||
import ctypes
|
||||
|
||||
def cvt(a):
|
||||
a_pointer, read_only_flag = a.__array_interface__['data']
|
||||
aptr=cp.cuda.MemoryPointer(cp.cuda.memory.UnownedMemory(a_pointer,a.size*a.itemsize,a,0),0)
|
||||
a = cp.ndarray(a.shape,a.dtype,aptr)
|
||||
return a
|
||||
|
||||
def numpy2cupy(snp, data):
|
||||
for key in data:
|
||||
if isinstance(data[key], list):
|
||||
for i in range(len(data[key])):
|
||||
data[key][i]=cvt(data[key][i])
|
||||
elif isinstance(data[key], int):
|
||||
pass
|
||||
else:
|
||||
data[key]=cvt(data[key])
|
||||
|
||||
jt_allocator = ctypes.CDLL(os.path.join(
|
||||
jt.compiler.cache_path,
|
||||
"jittor_core"+jt.compiler.extension_suffix),
|
||||
os.RTLD_NOW | os.RTLD_GLOBAL)
|
||||
malloc = jt_allocator.get_jittor_cuda_malloc()
|
||||
free = jt_allocator.get_jittor_cuda_free()
|
||||
else:
|
||||
def numpy2cupy(snp, data):
|
||||
pass
|
|
@ -1,26 +0,0 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import numpy as np
|
||||
import cupy as cp
|
||||
|
||||
def cvt(a):
|
||||
a_pointer, read_only_flag = a.__array_interface__['data']
|
||||
aptr=cp.cuda.MemoryPointer(cp.cuda.memory.UnownedMemory(a_pointer,a.size*a.itemsize,a,0),0)
|
||||
a = cp.ndarray(a.shape,a.dtype,aptr)
|
||||
return a
|
||||
|
||||
def numpy2cupy(snp, data):
|
||||
for key in data:
|
||||
if isinstance(data[key], list):
|
||||
for i in range(len(data[key])):
|
||||
data[key][i]=cvt(data[key][i])
|
||||
elif isinstance(data[key], int):
|
||||
pass
|
||||
else:
|
||||
data[key]=cvt(data[key])
|
|
@ -13,7 +13,9 @@
|
|||
#include "misc/hash.h"
|
||||
#include "misc/nano_string.h"
|
||||
#include "misc/fast_shared_ptr.h"
|
||||
#ifdef HAS_CUDA
|
||||
#include "misc/cuda_flags.h"
|
||||
#endif
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -610,12 +612,13 @@ DEF_IS(NumpyFunc, T) from_py_object(PyObject* obj) {
|
|||
PyTuple_SET_ITEM(args.obj, 0, np.release());
|
||||
PyTuple_SET_ITEM(args.obj, 1, data.release());
|
||||
|
||||
#ifdef HAS_CUDA
|
||||
if (npstr=="cupy") {
|
||||
PyObject *jt = PyImport_ImportModule("jittor");
|
||||
PyObject *np2cp = PyObject_GetAttrString(jt,"numpy2cupy");
|
||||
PyObject *pFunc = PyObject_GetAttrString(np2cp, "numpy2cupy");
|
||||
PyObjHolder ret1(PyObject_Call(pFunc, args.obj, nullptr));
|
||||
PyObjHolder jt(PyImport_ImportModule("jittor"));
|
||||
PyObjHolder pFunc(PyObject_GetAttrString(jt.obj,"numpy2cupy"));
|
||||
PyObjHolder ret1(PyObject_Call(pFunc.obj, args.obj, nullptr));
|
||||
}
|
||||
#endif
|
||||
|
||||
PyObjHolder ret2(PyObject_Call(obj, args.obj, nullptr));
|
||||
},
|
||||
|
|
Loading…
Reference in New Issue