mirror of https://github.com/Jittor/Jittor
add jt console support
This commit is contained in:
parent
6adf313681
commit
4471aff46e
|
@ -26,6 +26,7 @@
|
|||
jittor.transform
|
||||
jittor.mpi
|
||||
jittor.linalg
|
||||
jittor.console
|
||||
|
||||
|
||||
.. toctree::
|
||||
|
|
|
@ -0,0 +1,237 @@
|
|||
jittor.console
|
||||
=====================
|
||||
|
||||
这里是Jittor的console api文档,console功能主要面向c/c++, 方便c++用户通过console使用jittor,jittor console 优化了
|
||||
c++数组和jittor内核之间的数据传输,减少了python额外开销,是通过c++使用jittor的高性能接口。
|
||||
|
||||
该功能要求 jittor版本大于1.2.2.17, 编译器支持c++17。
|
||||
|
||||
## 简单教程
|
||||
|
||||
我们提供了一个完整的教程,用户可以通过如下几行命令编译运行:
|
||||
|
||||
```bash
|
||||
# 生成c++ example源代码文件
|
||||
python3.7 -m jittor_utils.config --cxx-example > example.cc
|
||||
# 调用g++编译example, 需要g++支持std=c++17
|
||||
g++ example.cc $(python3.7 -m jittor_utils.config --include-flags --libs-flags --cxx-flags) -o example)
|
||||
# 运行example
|
||||
./example
|
||||
```
|
||||
|
||||
运行结果可能如下:
|
||||
```bash
|
||||
hello jt console
|
||||
1
|
||||
hello
|
||||
1 2 3 4
|
||||
jt.Var([[-1 5 4]
|
||||
[ 3 2 1]], dtype=int32)
|
||||
2 3
|
||||
1 25 16
|
||||
9 4 1
|
||||
pred.shape 2 1000
|
||||
```
|
||||
|
||||
用户可以打开 example.cc, 修改成所需的应用,接下来我们会为大家讲解 example.cc 中的细节。
|
||||
|
||||
打开example.cc, 我们可以看到如下代码:
|
||||
|
||||
```cpp
|
||||
#include <pyjt/pyjt_console.h>
|
||||
#include <iostream>
|
||||
|
||||
using namespace std;
|
||||
|
||||
int main() {
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
这里我们导入了使用 console 所需的头文件 `pyjt/pyjt_console.h`
|
||||
|
||||
接下来是jittor console的实例化, 并且使用python的print输出hello jt console:
|
||||
|
||||
```cpp
|
||||
jittor::Console console;
|
||||
// run python code in console
|
||||
console.run("print('hello jt console', flush=True)");
|
||||
```
|
||||
|
||||
输出结果:
|
||||
|
||||
```
|
||||
hello jt console
|
||||
```
|
||||
|
||||
注意到这里我们在 python print的时候使用了flush keyword,这是为了让python的输出流和c++的输出流保持一致,
|
||||
不会错乱。
|
||||
|
||||
接下来我们调用了 `console.set<T>(name, data)` 和 `console.get<T>(name)` 往 console 里面设置了一个int变量a,并且再从console里面取出来。
|
||||
|
||||
```cpp
|
||||
// set a python value: a = 1
|
||||
console.set<int>("a", 1);
|
||||
// get a python value
|
||||
cout << console.get<int>("a") << endl;
|
||||
```
|
||||
|
||||
输出结果:
|
||||
|
||||
```
|
||||
1
|
||||
```
|
||||
|
||||
同样的方法,我们还设置了 `string` 和 `vector<int>`, 如下所示
|
||||
|
||||
```cpp
|
||||
// set a python string
|
||||
console.set<string>("b", "hello");
|
||||
cout << console.get<string>("b") << endl;
|
||||
|
||||
// set a python array
|
||||
vector<int> x{1,2,3,4};
|
||||
console.set("x", x);
|
||||
auto x2 = console.get<std::vector<int>>("x");
|
||||
for (auto a : x2) cout << a << " "; cout << endl;
|
||||
```
|
||||
|
||||
输出结果:
|
||||
|
||||
```
|
||||
hello
|
||||
1 2 3 4
|
||||
```
|
||||
|
||||
我们还可以往console里面设置jittor变量,这里我们使用了下面几个新的接口:
|
||||
|
||||
1. `jittor::array<T, NDIM>(shape, data)`: 这个接口创建了一个jittor的array,类型是`T`, 维度大小为`NDIM`, 形状为 `shape`, 注意shape的长度需要和`NDIM`保持一致,最后是传入的数据,可以是一个vector,也可以是一个指针。
|
||||
2. `console.set_array(name, arr)`: 往console里面设置该jittor array, 名称为`name`。
|
||||
3. `console.get<T, NDIM>(name)`: 从console里取出一个jittor array,类型为`T`,维度大小为`NDIM`,需要注意的是类型和维度大小必须和console中的变量匹配,否则会抛出异常。
|
||||
4. `arr(i,j)`: 对jittor变量取值。
|
||||
5. `arr.shape[i]`: 获取jittor变量的维度大小。
|
||||
|
||||
在这段代码中,我们首先创建了一个2x3的矩阵, 然后修改了矩阵中的值,随即设置到了python console里面,并且取出输出:
|
||||
|
||||
```cpp
|
||||
// set and get a jittor array
|
||||
jittor::array<int, 2> arr2({2,3}, {6,5,4,3,2,1});
|
||||
arr2(0,0) = -1;
|
||||
console.set_array("arr2", arr2);
|
||||
console.run("print(arr2, flush=True); arr3 = arr2**2;");
|
||||
auto arr3 = console.get_array<int, 2>("arr3");
|
||||
cout << arr3.shape[0] << ' ' << arr3.shape[1] << endl;
|
||||
for (int i=0; i<arr3.shape[0]; i++) {
|
||||
for (int j=0; j<arr3.shape[1]; j++)
|
||||
cout << arr3(i,j) << ' ';
|
||||
cout << endl;
|
||||
}
|
||||
```
|
||||
|
||||
输出结果如下:
|
||||
|
||||
```
|
||||
jt.Var([[-1 5 4]
|
||||
[ 3 2 1]], dtype=int32)
|
||||
2 3
|
||||
1 25 16
|
||||
9 4 1
|
||||
```
|
||||
|
||||
最后,我们演示了从`jittor.models`中导入`resnet`并且将结果从console中取出。
|
||||
|
||||
```cpp
|
||||
jittor::array<float, 4> input({2, 3, 224, 224});
|
||||
memset(input.data.get(), 0, input.nbyte());
|
||||
console.set_array("input", input);
|
||||
console.run(R"(
|
||||
import jittor as jt
|
||||
from jittor.models import resnet
|
||||
|
||||
model = resnet.resnet18()
|
||||
pred = model(input)
|
||||
)");
|
||||
auto pred = console.get_array<float, 2>("pred");
|
||||
cout << "pred.shape " << pred.shape[0] << ' ' << pred.shape[1] << endl;
|
||||
```
|
||||
|
||||
我们输出了取出的变量的形状,结果如下:
|
||||
|
||||
```
|
||||
pred.shape 2 1000
|
||||
```
|
||||
|
||||
## jittor array 接口一览
|
||||
|
||||
`jittor::array` 是 c++和jittor console交互的 array类型,他的定义如下:
|
||||
|
||||
```cpp
|
||||
|
||||
// T: 类型, N: 维度数量
|
||||
template<class T, int N>
|
||||
struct array {
|
||||
|
||||
// N维 形状大小
|
||||
int64 shape[N];
|
||||
// 数据指针
|
||||
unique_ptr<T[]> data;
|
||||
|
||||
// 是否为浮点数
|
||||
bool is_float();
|
||||
// 是否为无符号类型
|
||||
bool is_unsigned();
|
||||
// 数组总大小,为shape数组累乘的结果
|
||||
int64 size();
|
||||
// 数组总比特数
|
||||
int64 nbyte();
|
||||
// 数据类型的字符串表示
|
||||
string dtype();
|
||||
// 维度数量, 同 N
|
||||
int ndim();
|
||||
|
||||
// array 构造函数,shape为形状,数据未被初始化
|
||||
array(const vector<int64>& shape);
|
||||
// array 构造函数,shape为形状,数据从data指针拷贝初始化
|
||||
array(const vector<int64>& shape, const T* data);
|
||||
// array 构造函数,shape为形状,数据从data vector拷贝初始化
|
||||
array(const vector<int64>& shape, const vector<T>& data);
|
||||
|
||||
T& operator()(...);
|
||||
|
||||
};
|
||||
```
|
||||
|
||||
## Console 接口一览
|
||||
|
||||
console接口主要用于设置变量,取出变量,运行脚本, 三部分构成。
|
||||
|
||||
```cpp
|
||||
|
||||
struct Console {
|
||||
|
||||
// 运行代码接口
|
||||
void run(const string& src);
|
||||
|
||||
// 设置变量名称为s, 值为data
|
||||
template<class T>
|
||||
void set(const string& s, const T& data);
|
||||
|
||||
// 获取变量名称为s
|
||||
template<class T>
|
||||
T get(const string& s)
|
||||
|
||||
// 设置 array 变量
|
||||
void set_array(const string& s, const array<T,N>& data);
|
||||
|
||||
// 获取一个jittor array,类型为`T`,维度大小为`NDIM`,需要注意的是类型和维度大小必须和console中的变量匹配,否则会抛出异常。
|
||||
void get_array<T,N>(const string& s);
|
||||
|
||||
};
|
||||
```
|
||||
|
||||
其中 `get`,`set` 支持常见的c++类型有:
|
||||
|
||||
1. int, uint, int64, uint64, float, double
|
||||
2. string
|
||||
3. vector
|
||||
4. map, unordered_map
|
|
@ -8,7 +8,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.16'
|
||||
__version__ = '1.2.2.17'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -845,28 +845,15 @@ check_debug_flags()
|
|||
|
||||
sys.path.append(cache_path)
|
||||
LOG.i(f"Jittor({__version__}) src: {jittor_path}")
|
||||
LOG.i(f"{jit_utils.cc_type} at {jit_utils.cc_path}")
|
||||
LOG.i(f"cache_path: {cache_path}")
|
||||
|
||||
with jit_utils.import_scope(import_flags):
|
||||
jit_utils.try_import_jit_utils_core()
|
||||
|
||||
python_path = sys.executable
|
||||
py3_config_paths = [
|
||||
sys.executable + "-config",
|
||||
os.path.dirname(sys.executable) + f"/python3.{sys.version_info.minor}-config",
|
||||
f"/usr/bin/python3.{sys.version_info.minor}-config",
|
||||
os.path.dirname(sys.executable) + "/python3-config",
|
||||
]
|
||||
if "python_config_path" in os.environ:
|
||||
py3_config_paths.insert(0, os.environ["python_config_path"])
|
||||
py3_config_path = jit_utils.py3_config_path
|
||||
|
||||
for py3_config_path in py3_config_paths:
|
||||
if os.path.isfile(py3_config_path):
|
||||
break
|
||||
else:
|
||||
raise RuntimeError(f"python3.{sys.version_info.minor}-config "
|
||||
"not found in {py3_config_paths}, please specify "
|
||||
"enviroment variable 'python_config_path'")
|
||||
nvcc_path = env_or_try_find('nvcc_path', '/usr/local/cuda/bin/nvcc')
|
||||
gdb_path = try_find_exe('gdb')
|
||||
addr2line_path = try_find_exe('addr2line')
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
from jittor_utils import run_cmd
|
||||
import sys
|
||||
|
||||
class TestConsole(unittest.TestCase):
|
||||
def test_console(self):
|
||||
run_cmd(f"{sys.executable} -m jittor_utils.config --cxx-example > tmp.cc", jt.flags.cache_path)
|
||||
s = run_cmd(f"{jt.flags.cc_path} tmp.cc $({sys.executable} -m jittor_utils.config --include-flags --libs-flags --cxx-flags) -o tmp.out && ./tmp.out", jt.flags.cache_path)
|
||||
print(s)
|
||||
assert "jt.Var" in s
|
||||
assert "pred.shape 2 1000" in s
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -241,23 +241,25 @@ def get_version(output):
|
|||
version = "("+v[-1]+")"
|
||||
return version
|
||||
|
||||
def find_exe(name, check_version=True):
|
||||
def find_exe(name, check_version=True, silent=False):
|
||||
output = run_cmd(f'which {name}', err_msg=f'{name} not found')
|
||||
if check_version:
|
||||
version = get_version(name)
|
||||
else:
|
||||
version = ""
|
||||
LOG.i(f"Found {name}{version} at {output}.")
|
||||
if not silent:
|
||||
LOG.i(f"Found {name}{version} at {output}.")
|
||||
return output
|
||||
|
||||
def env_or_find(name, bname):
|
||||
def env_or_find(name, bname, silent=False):
|
||||
if name in os.environ:
|
||||
path = os.environ[name]
|
||||
if path != "":
|
||||
version = get_version(path)
|
||||
LOG.i(f"Found {bname}{version} at {path}")
|
||||
if not silent:
|
||||
LOG.i(f"Found {bname}{version} at {path}")
|
||||
return path
|
||||
return find_exe(bname)
|
||||
return find_exe(bname, silent=silent)
|
||||
|
||||
def get_cc_type(cc_path):
|
||||
bname = os.path.basename(cc_path)
|
||||
|
@ -271,7 +273,25 @@ is_in_ipynb = in_ipynb()
|
|||
cc = None
|
||||
LOG = LogWarper()
|
||||
|
||||
cc_path = env_or_find('cc_path', 'g++')
|
||||
cc_path = env_or_find('cc_path', 'g++', silent=True)
|
||||
os.environ["cc_path"] = cc_path
|
||||
cc_type = get_cc_type(cc_path)
|
||||
cache_path = find_cache_path()
|
||||
|
||||
|
||||
py3_config_paths = [
|
||||
os.path.dirname(sys.executable) + f"/python3.{sys.version_info.minor}-config",
|
||||
sys.executable + "-config",
|
||||
f"/usr/bin/python3.{sys.version_info.minor}-config",
|
||||
os.path.dirname(sys.executable) + "/python3-config",
|
||||
]
|
||||
if "python_config_path" in os.environ:
|
||||
py3_config_paths.insert(0, os.environ["python_config_path"])
|
||||
|
||||
for py3_config_path in py3_config_paths:
|
||||
if os.path.isfile(py3_config_path):
|
||||
break
|
||||
else:
|
||||
raise RuntimeError(f"python3.{sys.version_info.minor}-config "
|
||||
"not found in {py3_config_paths}, please specify "
|
||||
"enviroment variable 'python_config_path'")
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
import os
|
||||
import sys
|
||||
import jittor_utils
|
||||
from jittor_utils import LOG
|
||||
|
||||
|
||||
def search_file(dirs, name):
|
||||
for d in dirs:
|
||||
fname = os.path.join(d, name)
|
||||
if os.path.isfile(fname):
|
||||
return fname
|
||||
LOG.f(f"file {name} not found in {dirs}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
help_msg = f"Usage: {sys.executable} -m jittor_utils.config --include-flags|--link-flags|--cxx-flags|--cxx-example|--help"
|
||||
if len(sys.argv) <= 1:
|
||||
print(help_msg)
|
||||
sys.exit(1)
|
||||
|
||||
s = ""
|
||||
# base should be something like python3.7m python3.8
|
||||
base = jittor_utils.run_cmd(jittor_utils.py3_config_path + " --includes").split()[0]
|
||||
base = "python3" + base.split("python3")[-1]
|
||||
for arg in sys.argv[1:]:
|
||||
if arg == "--include-flags":
|
||||
s += jittor_utils.run_cmd(jittor_utils.py3_config_path + " --includes")
|
||||
s += " -I"+os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "jittor", "src"))
|
||||
s += " "
|
||||
elif arg == "--libs-flags":
|
||||
libbase = "/usr/lib/x86_64-linux-gnu"
|
||||
libpath = libbase + f"/lib{base}.so"
|
||||
assert os.path.isfile(libpath), f"lib not exist: {libpath}"
|
||||
s += f" -L{libbase} -l{base} -ldl "
|
||||
elif arg == "--cxx-flags":
|
||||
s += " --std=c++17 "
|
||||
elif arg == "--cxx-example":
|
||||
cc_src = '''
|
||||
// please compile with: g++ a.cc $(python3 -m jittor_utils.config --include-flags --libs-flags --cxx-flags) -o a.out && ./a.out
|
||||
#include <pyjt/pyjt_console.h>
|
||||
#include <iostream>
|
||||
|
||||
using namespace std;
|
||||
|
||||
int main() {
|
||||
jittor::Console console;
|
||||
// run python code in console
|
||||
console.run("print('hello jt console', flush=True)");
|
||||
|
||||
// set a python value: a = 1
|
||||
console.set<int>("a", 1);
|
||||
// get a python value
|
||||
cout << console.get<int>("a") << endl;
|
||||
|
||||
// set a python string
|
||||
console.set<string>("b", "hello");
|
||||
cout << console.get<string>("b") << endl;
|
||||
|
||||
// set a python array
|
||||
vector<int> x{1,2,3,4};
|
||||
console.set("x", x);
|
||||
auto x2 = console.get<std::vector<int>>("x");
|
||||
for (auto a : x2) cout << a << " "; cout << endl;
|
||||
|
||||
// set and get a jittor array
|
||||
jittor::array<int, 2> arr2({2,3}, {6,5,4,3,2,1});
|
||||
arr2(0,0) = -1;
|
||||
console.set_array("arr2", arr2);
|
||||
console.run("print(arr2, flush=True); arr3 = arr2**2;");
|
||||
auto arr3 = console.get_array<int, 2>("arr3");
|
||||
cout << arr3.shape[0] << ' ' << arr3.shape[1] << endl;
|
||||
for (int i=0; i<arr3.shape[0]; i++) {
|
||||
for (int j=0; j<arr3.shape[1]; j++)
|
||||
cout << arr3(i,j) << ' ';
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
// run resnet18
|
||||
jittor::array<float, 4> input({2, 3, 224, 224});
|
||||
memset(input.data.get(), 0, input.nbyte());
|
||||
console.set_array("input", input);
|
||||
console.run(R"(
|
||||
import jittor as jt
|
||||
from jittor.models import resnet
|
||||
|
||||
model = resnet.resnet18()
|
||||
pred = model(input)
|
||||
)");
|
||||
auto pred = console.get_array<float, 2>("pred");
|
||||
cout << "pred.shape " << pred.shape[0] << ' ' << pred.shape[1] << endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
'''
|
||||
print(cc_src)
|
||||
elif arg == "--help":
|
||||
print(help_msg)
|
||||
sys.exit(0)
|
||||
else:
|
||||
print(help_msg)
|
||||
sys.exit(1)
|
||||
print(s)
|
|
@ -112,7 +112,7 @@ static int not_compile_window = 0;
|
|||
|
||||
void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& fused_op, vector<int>& fuse_ops, vector<Op*>& ops, int64 tt) {
|
||||
// jit_search_kernel require compile at runtime
|
||||
if (jit_search_kernel || !use_parallel_op_compiler || not_compile_window > 1000)
|
||||
if (jit_search_kernel || !use_parallel_op_compiler || not_compile_window > 100000)
|
||||
return;
|
||||
|
||||
// try not use parallel compile if no op needs compile
|
||||
|
|
|
@ -17,9 +17,31 @@
|
|||
#include "pyjt/numpy.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "var.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "var_holder.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
static auto make_array = get_op_info("array")
|
||||
.get_constructor<VarPtr, const void*, NanoVector, NanoString>();
|
||||
|
||||
PyObject* make_pyjt_array(const vector<int64>& shape, const string& dtype, const void* data) {
|
||||
// return nullptr;
|
||||
auto vh = new VarHolder(make_array(data, shape, dtype));
|
||||
return to_py_object<VarHolder*>(vh);
|
||||
}
|
||||
|
||||
void get_pyjt_array(PyObject* obj, vector<int64>& shape, string& dtype, void*& data) {
|
||||
CHECK(Py_TYPE(obj) == &PyjtVarHolder.ht_type) << "Not a jittor array" << Py_TYPE(obj);
|
||||
auto vh = GET_RAW_PTR(VarHolder, obj);
|
||||
if (!vh->var->mem_ptr)
|
||||
vh->sync();
|
||||
ASSERT(vh->var->mem_ptr);
|
||||
shape = vh->shape().to_vector();
|
||||
dtype = vh->dtype().to_cstring();
|
||||
data = vh->var->mem_ptr;
|
||||
}
|
||||
|
||||
ArrayOp::ArrayOp(PyObject* obj) {
|
||||
ArrayArgs args;
|
||||
PyObjHolder holder;
|
||||
|
|
|
@ -14,7 +14,7 @@ struct PyObjHolder {
|
|||
PyObject* obj;
|
||||
inline PyObjHolder() : obj(nullptr) {
|
||||
}
|
||||
void assign(PyObject* obj) {
|
||||
inline void assign(PyObject* obj) {
|
||||
if (!obj) {
|
||||
LOGf << "Python error occur";
|
||||
}
|
||||
|
|
|
@ -0,0 +1,485 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include <Python.h>
|
||||
#include <dlfcn.h>
|
||||
#include <vector>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <memory>
|
||||
#include <iostream>
|
||||
|
||||
namespace jittor {
|
||||
|
||||
typedef int8_t int8;
|
||||
typedef int16_t int16;
|
||||
typedef int int32;
|
||||
typedef int64_t int64;
|
||||
typedef uint8_t uint8;
|
||||
typedef uint16_t uint16;
|
||||
typedef uint32_t uint32;
|
||||
typedef uint64_t uint64;
|
||||
typedef float float32;
|
||||
typedef double float64;
|
||||
typedef uint32_t uint;
|
||||
|
||||
using string = std::string;
|
||||
using std::move;
|
||||
template <class T> using vector = std::vector<T>;
|
||||
template <class T> using list = std::list<T>;
|
||||
template <class T> using set = std::set<T>;
|
||||
template <class T> using shared_ptr = std::shared_ptr<T>;
|
||||
template <class T> using unique_ptr = std::unique_ptr<T>;
|
||||
template <class T> using unordered_set = std::unordered_set<T>;
|
||||
template <class Ta, class Tb> using pair = std::pair<Ta,Tb>;
|
||||
template <class Ta, class Tb> using map = std::map<Ta,Tb>;
|
||||
template <class Ta, class Tb> using unordered_map = std::unordered_map<Ta,Tb>;
|
||||
|
||||
#define JT_CHECK(cond) \
|
||||
if (!(cond)) throw std::runtime_error("JT_CHECK failed: " #cond " ");
|
||||
|
||||
struct PyObjHolder {
|
||||
|
||||
PyObject* obj;
|
||||
inline PyObjHolder() : obj(nullptr) {
|
||||
}
|
||||
inline void assign(PyObject* obj) {
|
||||
if (!obj) {
|
||||
PyErr_Print();
|
||||
throw std::runtime_error("Python Error Occurred.");
|
||||
}
|
||||
this->obj = obj;
|
||||
}
|
||||
inline PyObjHolder(PyObject* obj) : obj(obj) {
|
||||
if (!obj) {
|
||||
PyErr_Print();
|
||||
throw std::runtime_error("Python Error Occurred.");
|
||||
}
|
||||
}
|
||||
inline ~PyObjHolder() {
|
||||
if (obj) Py_DECREF(obj);
|
||||
}
|
||||
inline PyObject* release() {
|
||||
auto tmp = obj;
|
||||
obj = nullptr;
|
||||
return tmp;
|
||||
}
|
||||
|
||||
inline void free() {
|
||||
if (obj) Py_DECREF(obj);
|
||||
obj = nullptr;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, PyObjHolder& objp) {
|
||||
PyObjHolder repr_obj(PyObject_Repr(objp.obj));
|
||||
|
||||
if (PyUnicode_CheckExact(repr_obj.obj)) {
|
||||
return os << Py_TYPE(objp.obj)->tp_name << ' ' <<
|
||||
PyUnicode_AsUTF8(repr_obj.obj);
|
||||
} else {
|
||||
return os << "unknown(" << (void*)objp.obj << ")";
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#define DEF_IS(check_type, return_type) \
|
||||
template<class T> \
|
||||
typename std::enable_if<std::is_same<T, check_type>::value, return_type>::type
|
||||
|
||||
#define GET_PY_NONE(code) ((code), Py_INCREF(Py_None), Py_None)
|
||||
|
||||
// string
|
||||
DEF_IS(string, bool) is_type(PyObject* obj) {
|
||||
return PyUnicode_CheckExact(obj);
|
||||
}
|
||||
|
||||
DEF_IS(string, PyObject*) to_py_object(const string& a) {
|
||||
return PyUnicode_FromStringAndSize(a.c_str(), a.size());
|
||||
}
|
||||
|
||||
DEF_IS(string, string) from_py_object(PyObject* obj) {
|
||||
Py_ssize_t size;
|
||||
const char* s = PyUnicode_AsUTF8AndSize(obj, &size);
|
||||
JT_CHECK(s);
|
||||
return string(s, size);
|
||||
}
|
||||
|
||||
|
||||
// size_t
|
||||
DEF_IS(size_t, bool) is_type(PyObject* obj) {
|
||||
return PyLong_CheckExact(obj);
|
||||
}
|
||||
|
||||
DEF_IS(size_t, PyObject*) to_py_object(const T& a) {
|
||||
return PyLong_FromUnsignedLongLong(a);
|
||||
}
|
||||
|
||||
DEF_IS(size_t, T) from_py_object(PyObject* obj) {
|
||||
return PyLong_AsUnsignedLongLong(obj);
|
||||
}
|
||||
|
||||
DEF_IS(uint32, bool) is_type(PyObject* obj) {
|
||||
return PyLong_CheckExact(obj);
|
||||
}
|
||||
|
||||
DEF_IS(uint32, PyObject*) to_py_object(const T& a) {
|
||||
return PyLong_FromUnsignedLong(a);
|
||||
}
|
||||
|
||||
DEF_IS(uint32, T) from_py_object(PyObject* obj) {
|
||||
return PyLong_AsUnsignedLong(obj);
|
||||
}
|
||||
|
||||
// int64
|
||||
DEF_IS(int64, bool) is_type(PyObject* obj) {
|
||||
return PyLong_CheckExact(obj);
|
||||
}
|
||||
|
||||
DEF_IS(int64, PyObject*) to_py_object(const T& a) {
|
||||
return PyLong_FromLongLong(a);
|
||||
}
|
||||
|
||||
DEF_IS(int64, T) from_py_object(PyObject* obj) {
|
||||
return PyLong_AsLongLong(obj);
|
||||
}
|
||||
DEF_IS(int32, bool) is_type(PyObject* obj) {
|
||||
return PyLong_CheckExact(obj);
|
||||
}
|
||||
|
||||
DEF_IS(int32, PyObject*) to_py_object(const T& a) {
|
||||
return PyLong_FromLong(a);
|
||||
}
|
||||
|
||||
DEF_IS(int32, T) from_py_object(PyObject* obj) {
|
||||
return PyLong_AsLong(obj);
|
||||
}
|
||||
|
||||
// float64
|
||||
DEF_IS(float64, bool) is_type(PyObject* obj) {
|
||||
return PyFloat_CheckExact(obj) || PyLong_CheckExact(obj);
|
||||
}
|
||||
|
||||
DEF_IS(float64, PyObject*) to_py_object(const T& a) {
|
||||
return PyFloat_FromDouble(a);
|
||||
}
|
||||
|
||||
DEF_IS(float64, T) from_py_object(PyObject* obj) {
|
||||
if (PyFloat_CheckExact(obj))
|
||||
return PyFloat_AS_DOUBLE(obj);
|
||||
return PyLong_AsDouble(obj);
|
||||
}
|
||||
DEF_IS(float32, bool) is_type(PyObject* obj) {
|
||||
return PyFloat_CheckExact(obj) || PyLong_CheckExact(obj);
|
||||
}
|
||||
|
||||
DEF_IS(float32, PyObject*) to_py_object(const T& a) {
|
||||
return PyFloat_FromFloat(a);
|
||||
}
|
||||
|
||||
DEF_IS(float32, T) from_py_object(PyObject* obj) {
|
||||
if (PyFloat_CheckExact(obj))
|
||||
return PyFloat_AS_DOUBLE(obj);
|
||||
return PyFloat_AS_DOUBLE(obj);
|
||||
}
|
||||
|
||||
|
||||
#define CHECK_IS_1(check_type) \
|
||||
template<typename T> struct is_##check_type : public std::false_type {}; \
|
||||
template<typename T> \
|
||||
struct is_##check_type<check_type<T>> : public std::true_type {};
|
||||
|
||||
#define DEF_IS_1(check_type, return_type) \
|
||||
template<class T> \
|
||||
typename std::enable_if<is_##check_type<T>::value, return_type>::type
|
||||
|
||||
CHECK_IS_1(vector);
|
||||
|
||||
DEF_IS_1(vector, bool) is_type(PyObject* obj) {
|
||||
if (!(PyList_CheckExact(obj) || PyTuple_CheckExact(obj)))
|
||||
return false;
|
||||
auto size = Py_SIZE(obj);
|
||||
if (!size)
|
||||
return true;
|
||||
auto arr = PySequence_Fast_ITEMS(obj);
|
||||
return is_type<typename T::value_type>(arr[0]);
|
||||
}
|
||||
|
||||
DEF_IS_1(vector, PyObject*) to_py_object(const T& a) {
|
||||
PyObjHolder list(PyList_New(a.size()));
|
||||
for (uint i=0; i<a.size(); i++) {
|
||||
PyObject* o = to_py_object<typename T::value_type>(a[i]);
|
||||
JT_CHECK(o);
|
||||
// PyList_SET_ITEM borrow ownership, we do not hold this
|
||||
PyList_SET_ITEM(list.obj, i, o);
|
||||
}
|
||||
return list.release();
|
||||
}
|
||||
|
||||
DEF_IS_1(vector, PyObject*) to_py_tuple(const T& a) {
|
||||
PyObjHolder list(PyTuple_New(a.size()));
|
||||
for (uint i=0; i<a.size(); i++) {
|
||||
PyObject* o = to_py_object<typename T::value_type>(a[i]);
|
||||
JT_CHECK(o);
|
||||
// PyTuple_SET_ITEM borrow ownership, we do not hold this
|
||||
PyTuple_SET_ITEM(list.obj, i, o);
|
||||
}
|
||||
return list.release();
|
||||
}
|
||||
|
||||
DEF_IS_1(vector, PyObject*) to_py_object(T&& a) {
|
||||
PyObjHolder list(PyList_New(a.size()));
|
||||
for (uint i=0; i<a.size(); i++) {
|
||||
PyObject* o = to_py_object<typename T::value_type>(std::move(a[i]));
|
||||
JT_CHECK(o);
|
||||
// PyList_SET_ITEM borrow ownership, we do not hold this
|
||||
PyList_SET_ITEM(list.obj, i, o);
|
||||
}
|
||||
return list.release();
|
||||
}
|
||||
|
||||
DEF_IS_1(vector, T) from_py_object(PyObject* obj) {
|
||||
auto size = Py_SIZE(obj);
|
||||
T a(size);
|
||||
auto arr = PySequence_Fast_ITEMS(obj);
|
||||
for (int64 i=0; i<size; i++) {
|
||||
auto oi = arr[i];
|
||||
JT_CHECK(is_type<typename T::value_type>(oi));
|
||||
a[i] = from_py_object<typename T::value_type>(oi);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
|
||||
#define CHECK_IS_2(check_type) \
|
||||
template<typename T> struct is_##check_type : public std::false_type {}; \
|
||||
template<typename Ta, typename Tb> \
|
||||
struct is_##check_type<check_type<Ta, Tb>> : public std::true_type {};
|
||||
|
||||
#define DEF_IS_2(check_type, return_type) \
|
||||
template<class T> \
|
||||
typename std::enable_if<is_##check_type<T>::value, return_type>::type
|
||||
|
||||
CHECK_IS_2(unordered_map);
|
||||
|
||||
DEF_IS_2(unordered_map, bool) is_type(PyObject* obj) {
|
||||
return PyDict_CheckExact(obj);
|
||||
}
|
||||
|
||||
DEF_IS_2(unordered_map, PyObject*) to_py_object(const T& a) {
|
||||
PyObjHolder dict(PyDict_New());
|
||||
for (const auto& kv : a) {
|
||||
PyObjHolder key(to_py_object<typename T::key_type>(kv.first));
|
||||
PyObjHolder value(to_py_object<typename T::mapped_type>(kv.second));
|
||||
PyDict_SetItem(dict.obj, key.obj, value.obj);
|
||||
}
|
||||
return dict.release();
|
||||
}
|
||||
|
||||
DEF_IS_2(unordered_map, T) from_py_object(PyObject* obj) {
|
||||
auto size = Py_SIZE(obj);
|
||||
T a;
|
||||
a.reserve(size);
|
||||
PyObject *key, *value;
|
||||
Py_ssize_t pos = 0;
|
||||
while (PyDict_Next(obj, &pos, &key, &value)) {
|
||||
JT_CHECK(is_type<typename T::key_type>(key)
|
||||
&& is_type<typename T::mapped_type>(value));
|
||||
a.emplace(
|
||||
from_py_object<typename T::key_type>(key),
|
||||
from_py_object<typename T::mapped_type>(value)
|
||||
);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
// copy from unordered_map
|
||||
CHECK_IS_2(map);
|
||||
|
||||
DEF_IS_2(map, bool) is_type(PyObject* obj) {
|
||||
return PyDict_CheckExact(obj);
|
||||
}
|
||||
|
||||
DEF_IS_2(map, PyObject*) to_py_object(const T& a) {
|
||||
PyObjHolder dict(PyDict_New());
|
||||
for (const auto& kv : a) {
|
||||
PyObjHolder key(to_py_object<typename T::key_type>(kv.first));
|
||||
PyObjHolder value(to_py_object<typename T::mapped_type>(kv.second));
|
||||
PyDict_SetItem(dict.obj, key.obj, value.obj);
|
||||
}
|
||||
return dict.release();
|
||||
}
|
||||
|
||||
DEF_IS_2(map, T) from_py_object(PyObject* obj) {
|
||||
T a;
|
||||
PyObject *key, *value;
|
||||
Py_ssize_t pos = 0;
|
||||
while (PyDict_Next(obj, &pos, &key, &value)) {
|
||||
JT_CHECK(is_type<typename T::key_type>(key)
|
||||
&& is_type<typename T::mapped_type>(value));
|
||||
a.emplace(
|
||||
from_py_object<typename T::key_type>(key),
|
||||
from_py_object<typename T::mapped_type>(value)
|
||||
);
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
template<class T, int N>
|
||||
struct array {
|
||||
|
||||
typedef T _type;
|
||||
static constexpr int _ndim = N;
|
||||
|
||||
int64 shape[N];
|
||||
unique_ptr<T[]> data;
|
||||
|
||||
inline bool is_float() const { return std::is_floating_point<T>::value; }
|
||||
inline bool is_unsigned() const { return std::is_unsigned<T>::value; }
|
||||
inline int64 size() const {
|
||||
int64 s=1;
|
||||
for (auto x : shape) s *= x;
|
||||
return s;
|
||||
}
|
||||
inline int64 nbyte() const { return size()*sizeof(T); }
|
||||
inline string dtype() const {
|
||||
return DTYPE();
|
||||
}
|
||||
inline int ndim() const { return N; }
|
||||
|
||||
inline static string DTYPE() {
|
||||
string dtype(std::is_floating_point<T>::value ? "float" :
|
||||
std::is_unsigned<T>::value ? "uint" : "int");
|
||||
if (sizeof(T)==1) dtype += "8"; else
|
||||
if (sizeof(T)==2) dtype += "16"; else
|
||||
if (sizeof(T)==4) dtype += "32"; else
|
||||
if (sizeof(T)==8) dtype += "64"; else
|
||||
throw std::runtime_error("Not support type");
|
||||
return dtype;
|
||||
}
|
||||
|
||||
inline array(const vector<int64>& shape) {
|
||||
if (shape.size() != N) throw std::runtime_error("Dim not match");
|
||||
for (int i=0; i<N; i++) this->shape[i] = shape[i];
|
||||
data.reset(new T[size()]);
|
||||
}
|
||||
|
||||
inline array(const vector<int64>& shape, const T* data) : array(shape) {
|
||||
memcpy(this->data.get(), data, nbyte());
|
||||
}
|
||||
|
||||
inline array(const vector<int64>& shape, const vector<T>& data) : array(shape, &data[0]) {
|
||||
}
|
||||
|
||||
template<int I, class Ti, typename... Targs>
|
||||
inline int64 get_offset(int64 offset, Ti i, Targs... Fargs) {
|
||||
if constexpr (I+1==N)
|
||||
return offset*shape[I]+i;
|
||||
else
|
||||
return get_offset<I+1>(offset*shape[I]+i, Fargs...);
|
||||
}
|
||||
|
||||
template<typename... Targs>
|
||||
T& operator()(Targs... Fargs) {
|
||||
return data[get_offset<0>(0, Fargs...)];
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct Console {
|
||||
|
||||
PyObjHolder globals, locals;
|
||||
PyObject* (*make_pyjt_array)(const vector<int64>& shape, const string& dtype, const void* data);
|
||||
void (*get_pyjt_array)(PyObject* obj, vector<int64>& shape, string& dtype, void*& data);
|
||||
|
||||
inline Console() {
|
||||
Py_Initialize();
|
||||
globals.assign(PyDict_New());
|
||||
locals.assign(PyDict_New());
|
||||
|
||||
#if PY_VERSION_HEX < 0x03080000
|
||||
PyObjHolder builtins(PyImport_ImportModule("builtins"));
|
||||
PyDict_SetItemString(globals.obj, "__builtins__", builtins.obj);
|
||||
#endif
|
||||
|
||||
run("import jittor as jt");
|
||||
make_pyjt_array = (PyObject* (*)(const vector<int64>& shape, const string& dtype, const void* data))dlsym(RTLD_DEFAULT, "_ZN6jittor15make_pyjt_arrayERKSt6vectorIlSaIlEERKNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEEPKv");
|
||||
get_pyjt_array = (void (*)(PyObject* obj, vector<int64>& shape, string& dtype, void*& data))dlsym(RTLD_DEFAULT, "_ZN6jittor14get_pyjt_arrayEP7_objectRSt6vectorIlSaIlEERNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEEERPv");
|
||||
}
|
||||
|
||||
inline ~Console() {
|
||||
globals.free();
|
||||
locals.free();
|
||||
Py_FinalizeEx();
|
||||
}
|
||||
|
||||
inline void run(const char* src) {
|
||||
PyObjHolder ret(PyRun_String(src, Py_file_input, globals.obj, locals.obj));
|
||||
}
|
||||
|
||||
inline void run(const string& src) { run(src.c_str()); }
|
||||
|
||||
template<class T>
|
||||
inline void set(const char* s, const T& data) {
|
||||
PyObjHolder py_data(to_py_object<T>(data));
|
||||
PyDict_SetItemString(locals.obj, s, py_data.obj);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
inline void set(const string& s, const T& data) {
|
||||
set(s.c_str(), data);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
inline T get(const char* s) {
|
||||
auto obj = PyDict_GetItemString(locals.obj, s);
|
||||
if (!obj) obj = PyDict_GetItemString(globals.obj, s);
|
||||
if (!obj) throw std::runtime_error(string("KeyError: ")+s);
|
||||
if (!is_type<T>(obj)) throw std::runtime_error(string("TypeError: key<")+s+"> is "+Py_TYPE(obj)->tp_name);
|
||||
return from_py_object<T>(obj);
|
||||
};
|
||||
|
||||
template<class T>
|
||||
inline T get(const string& s) {
|
||||
return get<T>(s.c_str());
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<class T, int N>
|
||||
inline void set_array(const string& s, const array<T,N>& data) {
|
||||
PyObjHolder obj(make_pyjt_array(
|
||||
vector<int64>(data.shape, data.shape+N),
|
||||
data.dtype(),
|
||||
data.data.get()));
|
||||
PyDict_SetItemString(locals.obj, s.c_str(), obj.obj);
|
||||
}
|
||||
|
||||
template<class T, int N>
|
||||
inline array<T,N> get_array(const string& s) {
|
||||
auto obj = PyDict_GetItemString(locals.obj, s.c_str());
|
||||
if (!obj) obj = PyDict_GetItemString(globals.obj, s.c_str());
|
||||
if (!obj) throw std::runtime_error(string("KeyError: ")+s);
|
||||
vector<int64> shape;
|
||||
string dtype;
|
||||
void* data;
|
||||
get_pyjt_array(obj, shape, dtype, data);
|
||||
string dtype2 = array<T,N>::DTYPE();
|
||||
if (dtype2 != dtype)
|
||||
throw new std::runtime_error(string("array dtype not match: ")+dtype+"!="+dtype2);
|
||||
if (shape.size() != N)
|
||||
throw new std::runtime_error(string("array ndim not match: ")+std::to_string(shape.size())+"!="+std::to_string(N));
|
||||
return array<T, N>(shape, (T*)data);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
|
@ -304,8 +304,7 @@ int system_popen(const char* cmd) {
|
|||
FILE *ptr = popen(cmd2.c_str(), "r");
|
||||
if (!ptr) return -1;
|
||||
while (fgets(buf, BUFSIZ, ptr) != NULL) {
|
||||
std::cout << buf;
|
||||
std::cout.flush();
|
||||
puts(buf);
|
||||
}
|
||||
return pclose(ptr);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue