mirror of https://github.com/Jittor/Jittor
110 lines
2.8 KiB
Python
110 lines
2.8 KiB
Python
# ***************************************************************
|
|
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
|
# This file is subject to the terms and conditions defined in
|
|
# file 'LICENSE.txt', which is part of this source code package.
|
|
# ***************************************************************
|
|
import unittest
|
|
import os
|
|
import jittor as jt
|
|
from .test_core import expect_error
|
|
|
|
header ="""
|
|
#pragma once
|
|
#include "op.h"
|
|
|
|
namespace jittor {
|
|
|
|
struct CustomOp : Op {
|
|
Var* output;
|
|
CustomOp(NanoVector shape, NanoString dtype=ns_float);
|
|
|
|
const char* name() const override { return "custom"; }
|
|
DECLARE_jit_run;
|
|
};
|
|
|
|
} // jittor
|
|
"""
|
|
|
|
src = """
|
|
#include "var.h"
|
|
#include "custom_op.h"
|
|
|
|
namespace jittor {
|
|
#ifndef JIT
|
|
CustomOp::CustomOp(NanoVector shape, NanoString dtype) {
|
|
output = create_output(shape, dtype);
|
|
}
|
|
|
|
void CustomOp::jit_prepare() {
|
|
add_jit_define("T", output->dtype());
|
|
}
|
|
|
|
#else // JIT
|
|
#ifdef JIT_cpu
|
|
void CustomOp::jit_run() {
|
|
index_t num = output->num;
|
|
auto* __restrict__ x = output->ptr<T>();
|
|
for (index_t i=0; i<num; i++)
|
|
x[i] = (T)i;
|
|
}
|
|
#else
|
|
void CustomOp::jit_run() {
|
|
}
|
|
#endif // JIT_cpu
|
|
#endif // JIT
|
|
|
|
} // jittor
|
|
"""
|
|
|
|
class TestCustomOp(unittest.TestCase):
|
|
def test_compile_custom_ops(self):
|
|
tmp_path = jt.flags.cache_path
|
|
hname = tmp_path+"/custom_op.h"
|
|
ccname = tmp_path+"/custom_op.cc"
|
|
with open(hname, "w") as f:
|
|
f.write(header)
|
|
with open(ccname, "w") as f:
|
|
f.write(src)
|
|
cops = jt.compile_custom_ops([hname, ccname])
|
|
a = cops.custom([3,4,5], 'float')
|
|
na = a.data
|
|
assert a.shape == [3,4,5] and a.dtype == 'float'
|
|
assert (na.flatten() == range(3*4*5)).all(), na
|
|
|
|
def test_compile_custom_op(self):
|
|
my_op = jt.compile_custom_op("""
|
|
struct MyOp : Op {
|
|
Var* output;
|
|
MyOp(NanoVector shape, NanoString dtype=ns_float);
|
|
|
|
const char* name() const override { return "my"; }
|
|
DECLARE_jit_run;
|
|
};
|
|
""", """
|
|
#ifndef JIT
|
|
MyOp::MyOp(NanoVector shape, NanoString dtype) {
|
|
output = create_output(shape, dtype);
|
|
}
|
|
|
|
void MyOp::jit_prepare() {
|
|
add_jit_define("T", output->dtype());
|
|
}
|
|
|
|
#else // JIT
|
|
void MyOp::jit_run() {
|
|
index_t num = output->num;
|
|
auto* __restrict__ x = output->ptr<T>();
|
|
for (index_t i=0; i<num; i++)
|
|
x[i] = (T)-i;
|
|
}
|
|
#endif // JIT
|
|
""",
|
|
"my")
|
|
a = my_op([3,4,5], 'float')
|
|
na = a.data
|
|
assert a.shape == [3,4,5] and a.dtype == 'float'
|
|
assert (-na.flatten() == range(3*4*5)).all(), na
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|