mirror of https://github.com/Jittor/Jittor
185 lines
4.6 KiB
C++
185 lines
4.6 KiB
C++
// ***************************************************************
|
|
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
|
// This file is subject to the terms and conditions defined in
|
|
// file 'LICENSE.txt', which is part of this source code package.
|
|
// ***************************************************************
|
|
#pragma once
|
|
#include "common.h"
|
|
|
|
namespace jittor {
|
|
|
|
|
|
#define FOR_ALL_NS(m) \
|
|
\
|
|
m(void) \
|
|
m(bool) \
|
|
m(int8) \
|
|
m(int16) \
|
|
m(int32) \
|
|
m(int64) \
|
|
m(uint8) \
|
|
m(uint16) \
|
|
m(uint32) \
|
|
m(uint64) \
|
|
m(float32) \
|
|
m(float64) \
|
|
\
|
|
m(pow) \
|
|
m(maximum) \
|
|
m(minimum) \
|
|
m(add) \
|
|
m(subtract) \
|
|
m(multiply) \
|
|
m(divide) \
|
|
m(floor_divide) \
|
|
m(mod) \
|
|
m(less) \
|
|
m(less_equal) \
|
|
m(greater) \
|
|
m(greater_equal) \
|
|
m(equal) \
|
|
m(not_equal) \
|
|
m(left_shift) \
|
|
m(right_shift) \
|
|
m(logical_and) \
|
|
m(logical_or) \
|
|
m(logical_xor) \
|
|
m(bitwise_and) \
|
|
m(bitwise_or) \
|
|
m(bitwise_xor) \
|
|
m(mean) \
|
|
\
|
|
m(abs) \
|
|
m(negative) \
|
|
m(logical_not) \
|
|
m(bitwise_not) \
|
|
m(log) \
|
|
m(exp) \
|
|
m(sqrt) \
|
|
m(round) \
|
|
m(floor) \
|
|
m(ceil) \
|
|
m(cast) \
|
|
\
|
|
m(sin) \
|
|
m(asin) \
|
|
m(sinh) \
|
|
m(asinh) \
|
|
m(tan) \
|
|
m(atan) \
|
|
m(tanh) \
|
|
m(atanh) \
|
|
m(cos) \
|
|
m(acos) \
|
|
m(cosh) \
|
|
m(acosh) \
|
|
m(sigmoid) \
|
|
|
|
struct NanoString;
|
|
#define DECLEAR_NS(T) extern NanoString ns_##T;
|
|
FOR_ALL_NS(DECLEAR_NS);
|
|
|
|
// @pyjt(NanoString)
|
|
struct NanoString {
|
|
typedef uint16 ns_t;
|
|
enum Flags {
|
|
// bit0~7: index
|
|
_index=0, _index_nbits=8,
|
|
_n=_index_nbits,
|
|
|
|
// bit0-1: type
|
|
_type=_n, _type_nbits=2,
|
|
_other=0, _dtype=1, _unary=2, _binary=3,
|
|
// bit2: is bool
|
|
_bool=_n+2,
|
|
// bit3: is int
|
|
_int=_n+3,
|
|
// bit4: is unsigned
|
|
_unsigned=_n+4,
|
|
// bit5: is float
|
|
_float=_n+5,
|
|
// bit6-7: dsize(1,2,4,8 byte)
|
|
_dsize=_n+6, _dsize_nbits=2,
|
|
};
|
|
ns_t data=0;
|
|
|
|
static unordered_map<string, NanoString> __string_to_ns;
|
|
static vector<const char*> __ns_to_string;
|
|
|
|
inline void set(Flags f, ns_t a=1, ns_t nbits=1) {
|
|
ns_t mask = (((1u<<nbits)-1)<<f);
|
|
data = (data & ~mask) | ((a<<f)&mask);
|
|
}
|
|
|
|
inline ns_t get(Flags f, ns_t nbits=1) const {
|
|
return (data>>f) & ((1u<<nbits)-1);
|
|
}
|
|
inline ns_t index() const { return get(_index, _index_nbits); }
|
|
inline ns_t type() const { return get(_type, _type_nbits); }
|
|
inline ns_t is_bool() const { return get(_bool); }
|
|
inline ns_t is_int() const { return get(_int); }
|
|
inline ns_t is_unsigned() const { return get(_unsigned); }
|
|
inline ns_t is_float() const { return get(_float); }
|
|
inline ns_t dsize() const { return 1<<get(_dsize, _dsize_nbits); }
|
|
inline ns_t is_dtype() const { return get(_type, _type_nbits)==_dtype; }
|
|
inline ns_t is_binary() const { return get(_type, _type_nbits)==_binary; }
|
|
inline ns_t is_unary() const { return get(_type, _type_nbits)==_unary; }
|
|
|
|
inline NanoString() {}
|
|
// @pyjt(__init__)
|
|
inline NanoString(const char* s) {
|
|
auto iter = __string_to_ns.find(s);
|
|
ASSERT(iter != __string_to_ns.end()) << s;
|
|
data = iter->second.data;
|
|
}
|
|
// @pyjt(__init__)
|
|
inline NanoString(const NanoString& other) : data(other.data) {}
|
|
inline NanoString(const string& s) : NanoString(s.c_str()) {}
|
|
// @pyjt(__repr__)
|
|
inline const char* to_cstring() const
|
|
{ return __ns_to_string[index()]; }
|
|
};
|
|
|
|
// force_type = 1 for int, 2 for float
|
|
inline
|
|
NanoString dtype_infer(NanoString v1, NanoString v2, int force_type=0) {
|
|
bool is_float = v1.is_float() || v2.is_float();
|
|
int dsize = std::max(v1.dsize(), v2.dsize());
|
|
if (force_type == 1)
|
|
is_float = false;
|
|
else if (force_type == 2)
|
|
is_float = true;
|
|
if (is_float) {
|
|
if (dsize==4) return ns_float32;
|
|
return ns_float64;
|
|
} else {
|
|
if (dsize==8) return ns_int64;
|
|
if (dsize==4) return ns_int32;
|
|
if (dsize==2) return ns_int16;
|
|
return ns_int8;
|
|
}
|
|
}
|
|
|
|
// @pyjt(NanoString.__eq__)
|
|
inline bool eq(const NanoString& a, const NanoString& b) {
|
|
return a.data == b.data;
|
|
}
|
|
|
|
// @pyjt(NanoString.__ne__)
|
|
inline bool ne(const NanoString& a, const NanoString& b) {
|
|
return a.data != b.data;
|
|
}
|
|
|
|
inline bool operator==(const NanoString& a, const NanoString& b) {
|
|
return a.data == b.data;
|
|
}
|
|
inline bool operator!=(const NanoString& a, const NanoString& b) {
|
|
return a.data != b.data;
|
|
}
|
|
|
|
inline std::ostream& operator<<(std::ostream& os, const NanoString& v) {
|
|
return os << v.to_cstring();
|
|
}
|
|
|
|
}
|