mirror of https://github.com/Jittor/Jittor
131 lines
3.8 KiB
C++
131 lines
3.8 KiB
C++
// ***************************************************************
|
|
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
|
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
|
// This file is subject to the terms and conditions defined in
|
|
// file 'LICENSE.txt', which is part of this source code package.
|
|
// ***************************************************************
|
|
#pragma once
|
|
#include "pyjt/py_obj_holder.h"
|
|
#include "common.h"
|
|
#include "misc/nano_string.h"
|
|
#include "ops/array_op.h"
|
|
|
|
namespace jittor {
|
|
|
|
struct PyArrayDescr_Proxy {
|
|
PyObject_HEAD
|
|
PyObject* typeobj;
|
|
char kind;
|
|
char type;
|
|
char byteorder;
|
|
char flags;
|
|
int type_num;
|
|
int elsize;
|
|
int alignment;
|
|
char* subarray;
|
|
PyObject *fields;
|
|
PyObject *names;
|
|
};
|
|
|
|
struct PyArray_Proxy {
|
|
PyObject_HEAD
|
|
char* data;
|
|
int nd;
|
|
ssize_t* dimensions;
|
|
ssize_t* strides;
|
|
PyObject *base;
|
|
PyArrayDescr_Proxy *descr;
|
|
int flags;
|
|
};
|
|
|
|
enum NPY_TYPES {
|
|
NPY_BOOL=0,
|
|
NPY_BYTE, NPY_UBYTE,
|
|
NPY_SHORT, NPY_USHORT,
|
|
NPY_INT, NPY_UINT,
|
|
NPY_LONG, NPY_ULONG,
|
|
NPY_LONGLONG, NPY_ULONGLONG,
|
|
NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE,
|
|
NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE,
|
|
NPY_OBJECT=17,
|
|
NPY_HALF=23,
|
|
NPY_END=24,
|
|
};
|
|
|
|
EXTERN_LIB NanoString npy2ns[];
|
|
EXTERN_LIB NPY_TYPES ns2npy[];
|
|
|
|
#define NPY_ARRAY_C_CONTIGUOUS 0x0001
|
|
#define NPY_ARRAY_ALIGNED 0x0100
|
|
#define NPY_ARRAY_WRITEABLE 0x0400
|
|
// NPY_ARRAY_C_CONTIGUOUS=1
|
|
inline bool is_c_style(PyArray_Proxy* obj) { return obj->flags & 1; }
|
|
inline NanoString get_type_str(PyArray_Proxy* obj) {
|
|
NanoString type = ns_void;
|
|
if (obj->descr->type_num < NPY_END)
|
|
type = npy2ns[obj->descr->type_num];
|
|
CHECK(type != ns_void) << "Numpy type not support, type_num:"
|
|
<< obj->descr->type_num
|
|
<< "type_char:" << obj->descr->type << NPY_END << npy2ns[obj->descr->type_num];
|
|
return type;
|
|
}
|
|
|
|
inline int get_typenum(NanoString ns) {
|
|
return ns2npy[ns.index()];
|
|
}
|
|
|
|
typedef Py_intptr_t npy_intp;
|
|
|
|
EXTERN_LIB unordered_map<string, int> np_typenum_map;
|
|
|
|
EXTERN_LIB void** PyArray_API;
|
|
EXTERN_LIB PyTypeObject *PyArray_Type;
|
|
EXTERN_LIB PyTypeObject *PyNumberArrType_Type;
|
|
EXTERN_LIB PyTypeObject *PyArrayDescr_Type;
|
|
EXTERN_LIB PyObject* (*PyArray_New)(PyTypeObject *, int, npy_intp const *, int, npy_intp const *, void *, int, int, PyObject *);
|
|
EXTERN_LIB PyObject* (*PyArray_FromAny)(PyObject *, PyArrayDescr_Proxy *, int, int, int, PyObject *);
|
|
EXTERN_LIB unsigned int (*PyArray_GetNDArrayCFeatureVersion)();
|
|
EXTERN_LIB int (*PyArray_SetBaseObject)(PyObject *arr, PyObject *obj);
|
|
EXTERN_LIB PyObject* (*PyArray_NewCopy)(PyObject *, int);
|
|
EXTERN_LIB int (*PyArray_CopyInto)(PyObject *, PyObject *);
|
|
EXTERN_LIB void (*PyArray_CastScalarToCtype)(PyObject* scalar, void* ctypeptr, PyArrayDescr_Proxy* outcode);
|
|
|
|
#define PyArray_Copy(obj) PyArray_NewCopy(obj, 0)
|
|
|
|
#define NPY_ARRAY_ALIGNED 0x0100
|
|
#define NPY_ARRAY_WRITEABLE 0x0400
|
|
#define NPY_ARRAY_BEHAVED (NPY_ARRAY_ALIGNED | \
|
|
NPY_ARRAY_WRITEABLE)
|
|
|
|
#define NPY_ARRAY_CARRAY (NPY_ARRAY_C_CONTIGUOUS | \
|
|
NPY_ARRAY_BEHAVED)
|
|
|
|
#define PyArray_SimpleNew(nd, dims, typenum) \
|
|
PyArray_New(PyArray_Type, nd, dims, typenum, NULL, NULL, 0, 0, NULL)
|
|
|
|
#define PyArray_SimpleNewFromData(nd, dims, typenum, data) \
|
|
PyArray_New(&PyArray_Type, nd, dims, typenum, NULL, \
|
|
data, 0, NPY_ARRAY_CARRAY, NULL)
|
|
|
|
#define PyArray_FROM_O(m) PyArray_FromAny(m, NULL, 0, 0, 0, NULL)
|
|
|
|
inline int64 PyArray_Size(PyArray_Proxy* arr) {
|
|
int64 size = 1;
|
|
for (int i=0; i<arr->nd; i++)
|
|
size *= arr->dimensions[i];
|
|
size *= arr->descr->elsize;
|
|
return size;
|
|
}
|
|
|
|
union tmp_data_t {
|
|
int32 i32;
|
|
float32 f32;
|
|
int8 i8;
|
|
};
|
|
|
|
EXTERN_LIB tmp_data_t tmp_data;
|
|
|
|
void numpy_init();
|
|
|
|
} // jittor
|