mirror of https://github.com/Jittor/Jittor
var slice check numpy scalar
This commit is contained in:
parent
cf8ba20fe3
commit
5c243bef79
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.5'
|
||||
__version__ = '1.2.2.6'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -18,6 +18,7 @@ unsigned int (*PyArray_GetNDArrayCFeatureVersion)();
|
|||
int (*PyArray_SetBaseObject)(PyObject *arr, PyObject *obj);
|
||||
PyObject* (*PyArray_NewCopy)(PyObject *, int);
|
||||
int (*PyArray_CopyInto)(PyObject *, PyObject *);
|
||||
void (*PyArray_CastScalarToCtype)(PyObject* scalar, void* ctypeptr, PyArrayDescr_Proxy* outcode);
|
||||
|
||||
tmp_data_t tmp_data;
|
||||
|
||||
|
@ -36,6 +37,7 @@ void numpy_init() {
|
|||
fill(PyArray_SetBaseObject, 282);
|
||||
fill(PyArray_NewCopy, 85);
|
||||
fill(PyArray_CopyInto, 82);
|
||||
fill(PyArray_CastScalarToCtype, 63);
|
||||
|
||||
ASSERT(PyArray_GetNDArrayCFeatureVersion()>=7);
|
||||
}
|
||||
|
|
|
@ -100,6 +100,7 @@ extern unsigned int (*PyArray_GetNDArrayCFeatureVersion)();
|
|||
extern int (*PyArray_SetBaseObject)(PyObject *arr, PyObject *obj);
|
||||
extern PyObject* (*PyArray_NewCopy)(PyObject *, int);
|
||||
extern int (*PyArray_CopyInto)(PyObject *, PyObject *);
|
||||
extern void (*PyArray_CastScalarToCtype)(PyObject* scalar, void* ctypeptr, PyArrayDescr_Proxy* outcode);
|
||||
|
||||
#define PyArray_Copy(obj) PyArray_NewCopy(obj, 0)
|
||||
|
||||
|
|
|
@ -734,7 +734,13 @@ void load_var_slice(PyObject* obj, T* var_slice, vector<unique_ptr<VarHolder>>&
|
|||
} else
|
||||
if (obj == Py_None) {
|
||||
var_slice->set_none();
|
||||
}else {
|
||||
} else
|
||||
if (PyObject_TypeCheck(obj, PyNumberArrType_Type)) {
|
||||
PyArrayDescr_Proxy array_descr = {.type_num = 5}; // 5: int32
|
||||
int value;
|
||||
PyArray_CastScalarToCtype(obj, &value, &array_descr);
|
||||
var_slice->set_int(value);
|
||||
} else {
|
||||
holders.emplace_back();
|
||||
auto* vh = from_py_object<VarHolder*>(obj, holders.back());
|
||||
auto vv = (Var**)vh;
|
||||
|
|
Loading…
Reference in New Issue