madd bf16

This commit is contained in:
Dun Liang 2023-09-08 22:43:31 +08:00 committed by lzhengning
parent 2457c5ce4a
commit c485fdc07b
36 changed files with 704 additions and 88 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.8.7'
__version__ = '1.3.9.0'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@ -429,9 +429,9 @@ def random(shape, dtype="float32", type="uniform"):
[0.46107793 0.62798643 0.03457401]], dtype=float32)
'''
# TODO: move those code to core
if dtype == "float16":
if dtype in ["float16", "bfloat16"]:
# TODO: make curand support fp16
ret = ops.random(shape, "float32", type).float16()
ret = ops.random(shape, "float32", type).cast(dtype)
else:
ret = ops.random(shape, dtype, type)
amp_reg = jt.flags.amp_reg
@ -1765,7 +1765,7 @@ Arguments of hook are defined as::
return self
def float32(self):
'''convert all parameters to float16'''
'''convert all parameters to float32'''
self._amp_level = 0
for p in self.parameters():
if p.dtype.is_float():
@ -1785,6 +1785,19 @@ Arguments of hook are defined as::
p.assign(p.float16())
return self
def bfloat16(self):
'''convert all parameters to bfloat16'''
# self._amp_level = 3 if flags.th_mode else 4
# amp level better set globally
self._amp_level = -1
if self._amp_level >= 0:
cls = self.__class__
cls.__call__ = cls.__half_call__
for p in self.parameters():
if p.dtype.is_float():
p.assign(p.bfloat16())
return self
def __half_call__(self, *args, **kw):
amp_level = getattr(self, "_amp_level", -1)
if amp_level >= 0:
@ -2031,7 +2044,7 @@ def jittor_exit():
atexit.register(jittor_exit)
def vtos(v):
data_str = f"jt.Var({v.data}, dtype={v.dtype})"
data_str = f"jt.Var({v.numpy()}, dtype={v.dtype})"
data_str = data_str.replace("\n", "\n ")
return data_str

View File

@ -216,19 +216,9 @@ jt.Var.__setitem__ = setitem
def _merge_dtypes(dtypes):
s = -1
e = -1
names = ["bool","uint","int","float"]
dbytes = ["8","16","32","64"]
for d in dtypes:
for name in names:
if d.startswith(name):
s = max(s,names.index(name))
for db in dbytes:
if d.endswith(db):
e = max(e,dbytes.index(db))
assert s>=0 and s<4 and e<4
dtype = names[s]+("" if e ==-1 else dbytes[e])
dtype = dtypes[0]
for i in range(1, len(dtypes)):
dtype = jt.binary_dtype_infer("add", dtype, dtypes[i])
return dtype
@jt.flag_scope(amp_reg=4) # _custom_flag

View File

@ -254,7 +254,7 @@ class JittorBackend(AbstractBackend):
return self.jittor.unsqueeze(x, new_position)
def is_float_type(self, x):
return x.dtype in ["float16", "float32", "float64"]
return x.dtype in ["float16", "bfloat16", "float32", "float64"]
def layers(self):
from jittor.einops.layers import jittor

View File

@ -25,6 +25,7 @@ static inline cudaDataType get_dtype(NanoString dtype) {
if (dtype == ns_float32) return CUDA_R_32F;
if (dtype == ns_float64) return CUDA_R_64F;
if (dtype == ns_float16) return CUDA_R_16F;
if (dtype == ns_bfloat16) return CUDA_R_16BF;
LOGf << "not support type" << dtype;
return CUDA_R_32F;
}

View File

@ -80,6 +80,10 @@ void CublasAccMatmulOp::jit_run() {
if ('@Trans_b'=='T') {
k = bs[0];
}
bool has_fp16_or_bf16 = a->dtype() == ns_float16
|| b->dtype() == ns_float16 || c->dtype() == ns_float16
|| a->dtype() == ns_bfloat16
|| b->dtype() == ns_bfloat16 || c->dtype() == ns_bfloat16;
// a: [n,m], b: [m,k], c: [n,k]
#if CUDART_VERSION >= 11000
@ -92,8 +96,7 @@ void CublasAccMatmulOp::jit_run() {
} else if (use_tensorcore==1) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
if (a->dtype() == ns_float16
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
if (has_fp16_or_bf16) {
computeType = CUBLAS_COMPUTE_16F;
}
#else
@ -102,8 +105,7 @@ void CublasAccMatmulOp::jit_run() {
if (use_tensorcore) {
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
if (a->dtype() == ns_float16
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
if (has_fp16_or_bf16) {
computeType = CUDA_R_16F;
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}

View File

@ -123,6 +123,10 @@ void CublasBatchedMatmulOp::jit_run() {
if ('@Trans_b'=='T') {
k = bs[adim-2];
}
bool has_fp16_or_bf16 = a->dtype() == ns_float16
|| b->dtype() == ns_float16 || c->dtype() == ns_float16
|| a->dtype() == ns_bfloat16
|| b->dtype() == ns_bfloat16 || c->dtype() == ns_bfloat16;
// a: [b,n,m], b: [b,m,k], c: [b,n,k]
#if CUDART_VERSION >= 11000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
@ -134,8 +138,7 @@ void CublasBatchedMatmulOp::jit_run() {
} else if (use_tensorcore==1) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
if (a->dtype() == ns_float16
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
if (has_fp16_or_bf16) {
computeType = use_tensorcore ? CUBLAS_COMPUTE_16F : CUBLAS_COMPUTE_32F;
algo = use_tensorcore ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
@ -149,8 +152,7 @@ void CublasBatchedMatmulOp::jit_run() {
if (use_tensorcore) {
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
if (a->dtype() == ns_float16
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
if (has_fp16_or_bf16) {
computeType = CUDA_R_16F;
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}

View File

@ -80,6 +80,10 @@ void CublasMatmulOp::jit_run() {
if ('@Trans_b'=='T') {
k = bs[0];
}
bool has_fp16_or_bf16 = a->dtype() == ns_float16
|| b->dtype() == ns_float16 || c->dtype() == ns_float16
|| a->dtype() == ns_bfloat16
|| b->dtype() == ns_bfloat16 || c->dtype() == ns_bfloat16;
// a: [n,m], b: [m,k], c: [n,k]
#if CUDART_VERSION >= 11000
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;
@ -91,8 +95,7 @@ void CublasMatmulOp::jit_run() {
} else if (use_tensorcore==1) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
if (a->dtype() == ns_float16
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
if (has_fp16_or_bf16) {
computeType = use_tensorcore ? CUBLAS_COMPUTE_16F : CUBLAS_COMPUTE_32F;
algo = use_tensorcore ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
@ -106,8 +109,7 @@ void CublasMatmulOp::jit_run() {
if (use_tensorcore) {
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
if (a->dtype() == ns_float16
|| b->dtype() == ns_float16 || c->dtype() == ns_float16) {
if (has_fp16_or_bf16) {
computeType = CUDA_R_16F;
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}

View File

@ -7,6 +7,8 @@
#pragma once
#include <cuda_runtime.h>
#include <cudnn.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "utils/log.h"
#include "helper_cuda.h"
@ -25,4 +27,10 @@ void set_algorithm_cache_size(int size);
// @pyjt(set_max_workspace_ratio)
void set_max_workspace_ratio(float64 ratio);
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
template <> __inline__ cudnnDataType_t getDataType<__nv_bfloat16>() { return CUDNN_DATA_BFLOAT16; }
} // jittor

View File

@ -112,10 +112,6 @@ unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
EXTERN_LIB unordered_map<string, cudnnConvolutionBwdFilterAlgo_t> bwdw_algo_cache;
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
void CudnnConvBackwardWOp::jit_run() {
auto w = dw;
auto y = dy;

View File

@ -103,10 +103,6 @@ unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
EXTERN_LIB unordered_map<string, cudnnConvolutionBwdDataAlgo_t> bwdx_algo_cache;
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
void CudnnConvBackwardXOp::jit_run() {
auto x = dx;
auto y = dy;

View File

@ -107,10 +107,6 @@ unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
EXTERN_LIB unordered_map<string, cudnnConvolutionFwdAlgo_t> fwd_algo_cache;
template <typename T_ELEM> __inline__ cudnnDataType_t getDataType();
template <> __inline__ cudnnDataType_t getDataType<half1>() { return CUDNN_DATA_HALF; }
template <> __inline__ cudnnDataType_t getDataType<float>() { return CUDNN_DATA_FLOAT; }
void CudnnConvOp::jit_run() {
cudnnHandle_t& handle_ = cudnn_handle;
@ -178,9 +174,12 @@ void CudnnConvOp::jit_run() {
if(use_tensorcore){
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) );
}
bool has_fp16_or_bf16 = x->dtype() == ns_float16
|| y->dtype() == ns_float16 || w->dtype() == ns_float16
|| x->dtype() == ns_bfloat16
|| y->dtype() == ns_bfloat16 || w->dtype() == ns_bfloat16;
if (x->dtype() == ns_float16
|| y->dtype() == ns_float16 || w->dtype() == ns_float16) {
if (has_fp16_or_bf16) {
checkCudaErrors( cudnnSetConvolutionMathType(cudnnConvDesc, CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION) );
}

View File

@ -56,6 +56,7 @@ void NcclAllGatherOp::jit_run() {
@if(@strcmp(@Tx,int64)==0, ncclInt64)
@if(@strcmp(@Tx,uint8)==0, ncclUint8)
@if(@strcmp(@Tx,float16)==0, ncclHalf)
@if(@strcmp(@Tx,bfloat16)==0, ncclBfloat16)
)
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();

View File

@ -49,6 +49,7 @@ void NcclBroadcastOp::jit_run() {
@if(@strcmp(@Tx,int64)==0, ncclInt64)
@if(@strcmp(@Tx,uint8)==0, ncclUint8)
@if(@strcmp(@Tx,float16)==0, ncclHalf)
@if(@strcmp(@Tx,bfloat16)==0, ncclBfloat16)
)
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();

View File

@ -49,6 +49,7 @@ void NcclReduceOp::jit_run() {
@if(@strcmp(@Tx,int64)==0, ncclInt64)
@if(@strcmp(@Tx,uint8)==0, ncclUint8)
@if(@strcmp(@Tx,float16)==0, ncclHalf)
@if(@strcmp(@Tx,bfloat16)==0, ncclBfloat16)
)
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();

View File

@ -2200,7 +2200,15 @@ def peek_s(x):
def peek(x):
print(peek_s(x))
class Finfo:
pass
bfloat16_finfo = Finfo()
bfloat16_finfo.min = -1e38
bfloat16_finfo.max = 1e38
def finfo(dtype):
if dtype == "bfloat16":
return bfloat16_finfo
return np.finfo(str(dtype).split('.')[-1])
def iinfo(dtype):

View File

@ -727,8 +727,10 @@ def fp32_guard(func):
return func(*args, **kw)
new_args = []
need_cast = False
dtype = None
for a in args:
if isinstance(a, jt.Var) and a.dtype == "float16":
if isinstance(a, jt.Var) and (a.dtype == "float16" or a.dtype == "bfloat16"):
dtype = a.dtype
new_args.append(a.float32())
need_cast = True
else:
@ -736,7 +738,7 @@ def fp32_guard(func):
with jt.flag_scope(amp_level=0):
a = func(*new_args, **kw)
if need_cast and isinstance(a, jt.Var) and a.dtype == "float32":
a = a.float16()
a = a.cast(dtype)
return a
return wrapper

View File

@ -53,7 +53,7 @@ __global__ void kernel(in0_type* x, out0_type* y, int len) {{
{for_loop}
#pragma unroll
for (int j=0; j<{ILP}; j++) {{
v1 = max(v1, float(v[i][j]));
v1 = ::max(v1, float(v[i][j]));
}}
__shared__ float vmax;
auto tmp = BlockReduce(temp_storage).Reduce(v1, cub::Max());

View File

@ -6,6 +6,7 @@
// ***************************************************************
#pragma once
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "common.h"
namespace jittor {
@ -104,6 +105,14 @@ template <> struct int_mapper<__half> {
inline static __device__ src from_int(target a) { return __ushort_as_half(a); }
};
template <> struct int_mapper<__nv_bfloat16> {
typedef __nv_bfloat16 src;
typedef unsigned short target;
inline static __device__ target to_int(src a) { return __bfloat16_as_ushort(a); }
inline static __device__ target* to_intp(src* a) { return (target*)a; }
inline static __device__ src from_int(target a) { return __ushort_as_bfloat16(a); }
};
template <> struct int_mapper<double> {
typedef double src;
typedef long long target;
@ -157,6 +166,37 @@ __half cuda_atomic_min(__half* a, __half b) {
return old_f;
}
template<> __device__
__nv_bfloat16 cuda_atomic_max(__nv_bfloat16* a, __nv_bfloat16 b) {
auto old_f = *a;
auto old = int_mapper<__nv_bfloat16>::to_int(old_f);
auto a_i = int_mapper<__nv_bfloat16>::to_intp(a);
while (1) {
auto assume = old;
if (old_f>=b) break;
old = atomicCAS(a_i, assume, int_mapper<__nv_bfloat16>::to_int(b));
old_f = int_mapper<__nv_bfloat16>::from_int(old);
if (assume==old) break;
}
return old_f;
}
template<> __device__
__nv_bfloat16 cuda_atomic_min(__nv_bfloat16* a, __nv_bfloat16 b) {
auto old_f = *a;
auto old = int_mapper<__nv_bfloat16>::to_int(old_f);
auto a_i = int_mapper<__nv_bfloat16>::to_intp(a);
while (1) {
auto assume = old;
if (old_f<=b) break;
old = atomicCAS(a_i, assume, int_mapper<__nv_bfloat16>::to_int(b));
old_f = int_mapper<__nv_bfloat16>::from_int(old);
if (assume==old) break;
}
return old_f;
}
template<typename T>
__device__ inline T shared_reduce_add(T a, T b) {
return a + b;

View File

@ -10,6 +10,7 @@
#include "misc/cuda_flags.h"
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "helper_cuda.h"
#endif
#include "mem/allocator.h"
@ -20,10 +21,25 @@ namespace jittor {
#ifdef IS_CUDA
EXTERN_LIB vector<int> check_nan_float16(__half* ptr, int64 num);
EXTERN_LIB vector<int> check_nan_bfloat16(__nv_bfloat16* ptr, int64 num);
EXTERN_LIB vector<int> check_nan_float32(float32* ptr, int64 num);
EXTERN_LIB vector<int> check_nan_float64(float64* ptr, int64 num);
#endif
void dump_var(Var* v, string name) {
std::stringstream ss;
ss << name << v->id << v->dtype() << v->shape << ".bin";
name = ss.str();
LOGe << "dump" << v << "to" << name;
char* buffer = new char[v->size];
cudaMemcpy(buffer, v->mem_ptr, v->size, cudaMemcpyDefault);
std::fstream file(name, std::ios::out | std::ios::binary);
file.write(buffer, v->size);
file.close();
delete[] buffer;
}
bool check_nan(Var* v, Op* op) {
if (!v->dtype().is_float() || v->num == 0) return true;
if (v->input() && (
@ -36,6 +52,9 @@ bool check_nan(Var* v, Op* op) {
if (v->dtype() == ns_float16) {
nan_index = check_nan_float16((__half*)v->mem_ptr, v->num);
}
if (v->dtype() == ns_bfloat16) {
nan_index = check_nan_bfloat16((__nv_bfloat16*)v->mem_ptr, v->num);
}
if (v->dtype() == ns_float32) {
nan_index = check_nan_float32((float32*)v->mem_ptr, v->num);
} else
@ -44,6 +63,32 @@ bool check_nan(Var* v, Op* op) {
}
if (nan_index[0]) {
LOGe << "detect nan count:" << nan_index[0];
/* dump nan var for analysis
python code for parse dump file:
import numpy as np
def load_var(filename):
dtype = "float16"
shape = filename.split('[')[1].split(']')[0]
shape = tuple(int(s) for s in shape.split(',')[:-1])
with open(filename, 'rb') as f:
array = np.fromfile(f, dtype=dtype)
return array.reshape(shape)
in0 = load_var("/tmp/input13736float16[4096,11008,].bin")
in1 = load_var("/tmp/input26930float16[32768,11008,].bin")
out0 = load_var("/tmp/output26938float16[32768,4096,].bin")
*/
if (getenv("DUMP_NAN_INPUT") && getenv("DUMP_NAN_INPUT") == string("1")) {
for (Var* v : op->inputs())
dump_var(v, "/tmp/input");
for (Var* v : op->outputs())
dump_var(v, "/tmp/output");
}
for (int i=0; i<std::min(10, (int)nan_index.size()-1); i++) {
int index = nan_index[i+1];
int icnt = 0;
@ -56,6 +101,12 @@ bool check_nan(Var* v, Op* op) {
cudaMemcpy(&value, ptr+index, sizeof(__half), cudaMemcpyDeviceToHost);
LOGe << "input" << icnt << "dtype" << input->dtype() << "index" << index << "value" << (float)value;
} else
if (input->dtype() == ns_bfloat16) {
auto* ptr = input->ptr<__nv_bfloat16>();
__nv_bfloat16 value;
cudaMemcpy(&value, ptr+index, sizeof(__nv_bfloat16), cudaMemcpyDeviceToHost);
LOGe << "input" << icnt << "dtype" << input->dtype() << "index" << index << "value" << (float)value;
} else
if (input->dtype() == ns_float32) {
auto* ptr = input->ptr<float32>();
float32 value;

View File

@ -7,6 +7,7 @@
#include "misc/cuda_flags.h"
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "helper_cuda.h"
#include <cassert>
@ -37,6 +38,20 @@ __global__ void _check_nan_float16(__half* __restrict__ ptr, int64 num, int* cnt
}
}
__global__ void _check_nan_bfloat16(__nv_bfloat16* __restrict__ ptr, int64 num, int* cnt) {
int64 i = threadIdx.x + blockIdx.x * (int64)blockDim.x;
if (i<num) {
#if JT_CHECK_NAN == 2
if (isnan(float(ptr[i])))
#else
if (isnan(float(ptr[i])) || isinf(float(ptr[i]))
// || abs(__half2float(ptr[i])) > 60000.f
)
#endif
print_nan(float(ptr[i]), i, cnt);
}
}
__global__ void _check_nan_float32(float32* __restrict__ ptr, int64 num, int* cnt) {
int64 i = threadIdx.x + blockIdx.x * (int64)blockDim.x;
if (i<num) {
@ -99,6 +114,13 @@ vector<int> check_nan_float16(__half* ptr, int64 num) {
return report_nan();
}
vector<int> check_nan_bfloat16(__nv_bfloat16* ptr, int64 num) {
int block_num = std::max((int64)1, (num-1)/1024+1);
int thread_num = std::min((int64)1024, num);
_check_nan_bfloat16<<<block_num, thread_num>>>(ptr, num, check_nan_get_device_ptr());
return report_nan();
}
#endif
}

View File

@ -182,6 +182,9 @@ static void init_ns() {
dsize_map["float16"] = 1;
is_float_map["float16"] = 1;
is_unsigned["float16"] = 0;
dsize_map["bfloat16"] = 1;
is_float_map["bfloat16"] = 1;
is_unsigned["bfloat16"] = 0;
NanoString::ns_t i=0;
auto func = [&](const char* name, NanoString& ns) {
ns.set(NanoString::_index, i++, NanoString::_index_nbits);

View File

@ -145,7 +145,8 @@ struct NanoString {
// @pyjt(is_float)
inline bool is_float() const { return get(_float); }
inline ns_t is_white() const { return get(_white_list); }
inline ns_t dsize() const { return 1<<get(_dsize, _dsize_nbits); }
// @pyjt(dsize)
inline int dsize() const { return 1<<get(_dsize, _dsize_nbits); }
inline ns_t dsize_() const { return get(_dsize, _dsize_nbits); }
inline ns_t is_dtype() const { return get(_type, _type_nbits)==_dtype; }
inline ns_t is_binary() const { return get(_type, _type_nbits)==_binary; }
@ -197,13 +198,16 @@ constexpr int amp_keep_reduce = 4;
constexpr int amp_keep_white = 8;
constexpr int amp_array_prefer = 16;
inline NanoString float_dtype(int dsize_, bool has_scalar=false) {
inline NanoString float_dtype(int dsize_, bool has_scalar=false, bool has_bf16=false) {
if (!has_scalar) {
if (amp_reg & amp_prefer32) return ns_float32;
if (amp_reg & amp_prefer16) return ns_float16;
if (amp_reg & amp_prefer32)
return ns_float32;
if (amp_reg & amp_prefer16)
return has_bf16 ? ns_bfloat16 : ns_float16;
}
return (dsize_ == 3) ? ns_float64 :
(dsize_ == 2 ) ? ns_float32 : ns_float16;
(dsize_ == 2 ) ? ns_float32 :
has_bf16 ? ns_bfloat16 : ns_float16;
}
inline NanoString int_dtype(int dsize_) {
@ -217,13 +221,15 @@ inline NanoString dtype_infer(NanoString x, NanoString y, bool xscalar=false, b
if (xscalar) dsize_ = y.dsize_();
if (yscalar) dsize_ = x.dsize_();
bool is_float = x.is_float() || y.is_float();
bool has_bf16 = x==ns_bfloat16 || y==ns_bfloat16;
if (is_float)
return float_dtype(dsize_, xscalar||yscalar);
return float_dtype(dsize_, xscalar||yscalar, has_bf16);
else {
return int_dtype(dsize_);
}
}
// @pyjt(binary_dtype_infer)
inline NanoString binary_dtype_infer(NanoString op, NanoString x, NanoString y, bool xscalar=false, bool yscalar=false) {
if (op.is_bool()) return ns_bool;
int dsize_ = std::max(x.dsize_(), y.dsize_());
@ -231,10 +237,11 @@ inline NanoString binary_dtype_infer(NanoString op, NanoString x, NanoString y,
if (yscalar) dsize_ = x.dsize_();
bool is_float = !op.is_int() &&
(x.is_float() || y.is_float() || op.is_float());
bool has_bf16 = x==ns_bfloat16 || y==ns_bfloat16;
if (is_float) {
if (op.is_white() && !(amp_reg & amp_keep_white))
return (dsize_ == 3) ? ns_float64 : ns_float32;
return float_dtype(dsize_, xscalar||yscalar);
return float_dtype(dsize_, xscalar||yscalar, has_bf16);
} else {
if (x.is_bool() && y.is_bool()) return ns_bool;
return int_dtype(dsize_);
@ -247,7 +254,7 @@ inline NanoString unary_dtype_infer(NanoString op, NanoString x) {
if (op.is_float()) {
if (op.is_white() && !(amp_reg & amp_keep_white))
return (dsize_ == 3) ? ns_float64 : ns_float32;
return float_dtype(dsize_);
return float_dtype(dsize_, false, x==ns_bfloat16);
}
if (op.is_int()) return int_dtype(dsize_);
return x;
@ -258,7 +265,7 @@ inline NanoString reduce_dtype_infer(NanoString op, NanoString x) {
int dsize_ = x.dsize_();
if (is_float) {
if (amp_reg & amp_keep_reduce)
return float_dtype(dsize_);
return float_dtype(dsize_, false, x==ns_bfloat16);
return (dsize_ == 3) ? ns_float64 : ns_float32;
} else {
return x;

View File

@ -257,10 +257,10 @@ EXTERN_LIB int amp_reg;
ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
: x(x) {
// improve float16 mean precision
if (!(amp_reg & 32) && x->dtype() == ns_float16 && (op == ns_mean || op == ns_add)) {
if (!(amp_reg & 32) && (x->dtype() == ns_float16 || x->dtype() == ns_bfloat16) && (op == ns_mean || op == ns_add)) {
auto x_float32 = make_unary(x, ns_float32);
auto mean = make_reduce(x_float32, op, dims, keepdims);
mean = make_unary(mean, ns_float16);
mean = make_unary(mean, x->dtype());
forward(mean);
return;
}
@ -293,10 +293,10 @@ ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims)
ReduceOp::ReduceOp(Var* x, NanoString op, uint dims_mask, uint keepdims_mask)
: x(x) {
// improve float16 mean precision
if (!(amp_reg & 32) && x->dtype() == ns_float16 && (op == ns_mean || op == ns_add)) {
if (!(amp_reg & 32) && (x->dtype() == ns_float16 || x->dtype() == ns_bfloat16) && (op == ns_mean || op == ns_add)) {
auto x_float32 = make_unary(x, ns_float32);
auto mean = make_reduce2(x_float32, op, dims_mask, keepdims_mask);
mean = make_unary(mean, ns_float16);
mean = make_unary(mean, x->dtype());
forward(mean);
return;
}

View File

@ -37,9 +37,13 @@ void SafeClipOp::jit_prepare(JK& jk) {
void SafeClipOp::jit_run() {
auto* __restrict__ xp = x->ptr<Tx>();
Tx left_value = (Tx)std::max((float64)
@if(@strcmp(@Tx,float16)==0,-65500,std::numeric_limits<Tx>::lowest()), left);
@if(@strcmp(@Tx,float16)==0,-65500,
@if(@strcmp(@Tx,bfloat16)==0,-1e38,
std::numeric_limits<Tx>::lowest())), left);
Tx right_value = (Tx)std::min((float64)
@if(@strcmp(@Tx,float16)==0,65500,std::numeric_limits<Tx>::max()), right);
@if(@strcmp(@Tx,float16)==0,65500,
@if(@strcmp(@Tx,bfloat16)==0,1e38,
std::numeric_limits<Tx>::max())), right);
auto* __restrict__ yp = y->ptr<Tx>();
index_t num = y->num;
for (index_t i=0; i<num; i++)

View File

@ -235,6 +235,26 @@ static unordered_set<string> unary_ops = {
*/
"float16",
/**
Returns a copy of the input var, casted to bfloat16 (brain half-precision float).
----------------
* [in] x: the input jt.Var
----------------
Example-1::
>>> x = jt.rand(3) * 10
>>> x
jt.Var([4.093273 2.0086648 8.474352 ], dtype=float32)
>>> x.bfloat16()
jt.Var([4.094 2.008 8.48 ], dtype=bfloat16)
>>> jt.bfloat16(x)
jt.Var([4.094 2.008 8.48 ], dtype=bfloat16)
*/
"bfloat16",
/**
Returns a copy of the input var, casted to float32.

View File

@ -79,7 +79,7 @@ void FloatAtomicFixPass::run() {
return;
}
if (!var->dtype().is_float()) return;
if (var->dtype() == ns_float16)
if (var->dtype() == ns_float16 || var->dtype() == ns_bfloat16)
// float16 use atomicCAS, because no float16 atomicMax
return;
LOGvvvv << "find var" << var << "op" << op;

View File

@ -232,7 +232,7 @@ void ConvTuner::forwardTune(FusedOp* fop) {
// only support float32,float16 currently
if (use_cuda) {
if (bop->z->dtype() != ns_float32 && bop->z->dtype() != ns_float16)
if (!bop->z->dtype().is_float())
continue;
} else {
if (bop->z->dtype() != ns_float32)

View File

@ -36,7 +36,8 @@ NPY_TYPES ns2npy[] = {
NPY_BYTE, NPY_SHORT, NPY_INT, NPY_LONGLONG,
NPY_UBYTE, NPY_USHORT, NPY_UINT, NPY_ULONGLONG,
#endif
NPY_HALF, NPY_FLOAT, NPY_DOUBLE
NPY_HALF, NPY_FLOAT, NPY_DOUBLE,
NPY_USHORT // fake half
};
void** PyArray_API;

View File

@ -333,11 +333,21 @@ DEF_IS(ArrayArgs, PyObject*) to_py_object(const T& a) {
PyObjHolder obj(PyArray_SimpleNew(
a.shape.size(),
dims,
get_typenum(a.dtype)
get_typenum(a.dtype == ns_bfloat16 ? ns_float32 : a.dtype)
));
auto arr = (PyArray_Proxy*)(obj.obj);
int64 size = PyArray_Size(arr);
memcpy((void*)arr->data, (void*)a.ptr, size);
if (a.dtype == ns_bfloat16) {
// simple cast bfloat16 to float32
auto ptr = (uint16*)a.ptr;
auto ptr2 = (uint32*)arr->data;
int64 num = size/4;
for (int64 i=0; i<num; i++) {
ptr2[i] = ptr[i]<<16;
}
} else {
memcpy((void*)arr->data, (void*)a.ptr, size);
}
return obj.release();
}

View File

@ -11,10 +11,12 @@
#include <driver_types.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
namespace jittor {
typedef __half float16;
typedef __nv_bfloat16 bfloat16;
#if CUDA_ARCH >= 800
inline __device__ float16 max(float16 a, float16 b) { return __hmax(a, b); }
@ -29,6 +31,21 @@ inline __device__ float16 min(float16 a, float16 b) { return float(a)<float(b)?a
inline __device__ float16 pow(float16 a, float16 b) { return ::pow(float32(a), float32(b)); }
#if CUDA_ARCH >= 800
inline __device__ bfloat16 max(bfloat16 a, bfloat16 b) { return __hmax(a, b); }
inline __device__ bfloat16 min(bfloat16 a, bfloat16 b) { return __hmin(a, b); }
#elif CUDA_ARCH >= 610
inline __device__ bfloat16 max(bfloat16 a, bfloat16 b) { return a<b?b:a; }
inline __device__ bfloat16 min(bfloat16 a, bfloat16 b) { return a<b?a:b; }
#else
inline __device__ bfloat16 max(bfloat16 a, bfloat16 b) { return float(a)<float(b)?b:a; }
inline __device__ bfloat16 min(bfloat16 a, bfloat16 b) { return float(a)<float(b)?a:b; }
#endif
inline __device__ bfloat16 pow(bfloat16 a, bfloat16 b) { return ::pow(float32(a), float32(b)); }
template<int nbyte, class T>
__device__ inline
typename std::enable_if<nbyte<=0,void>::type
@ -209,6 +226,27 @@ bool operator<(float16 x, float16 y) { return float32(x)<float32(y); }
bool operator>(float16 x, float16 y) { return float32(x)>float32(y); }
bool operator==(float16 x, float16 y) { return float32(x)==float32(y); }
struct bfloat16 {
uint16 x;
inline bfloat16(float32 f) {
unsigned x = *((int*)(void*)(&f));
this->x = x>>16;
}
inline operator float() const {
int temp = x<<16;
return reinterpret_cast<float&>(temp);
}
};
bool operator<(bfloat16 x, bfloat16 y) { return float32(x)<float32(y); }
bool operator>(bfloat16 x, bfloat16 y) { return float32(x)>float32(y); }
bool operator==(bfloat16 x, bfloat16 y) { return float32(x)==float32(y); }
}
#endif

View File

@ -21,14 +21,18 @@ struct FP16OpType : OpByType {
FP16OpType() {
types = {
"float16",
"bfloat16",
};
}
string expand_op(const vector<string>& args) {
bool found_fp16 = 0;
bool found_bf16 = 0;
for (int i=1; i<args.size(); i+=2) {
if (types.count(args[i]))
found_fp16 = 1;
if (args[i] == "bfloat16")
found_bf16 = 1;
}
if (!found_fp16) return "";
static unordered_map<string,string> cuda_map = {
@ -65,8 +69,8 @@ struct FP16OpType : OpByType {
{"maximum", "::max($1($2), $1($4))"},
{"minimum", "::min($1($2), $1($4))"},
{"mod", "$1(($2)-::hfloor(($2)/($4))*($4))"},
{"init_maximum", "-32768.0f"},
{"init_minimum", "32768.0f"},
{"init_maximum", "@if(@strcmp($1,float16)==0,-65000.0f,-1e38)"},
{"init_minimum", "@if(@strcmp($1,float16)==0,65000.0f,1e38)"},
{"equal", "(($2)==($4))"},
};
@ -151,7 +155,10 @@ struct FP16OpType : OpByType {
if (args[1] == "float32" && !both_map.count(args.at(0))) {
ret = common_op_type_cuda_map[args.at(0)];
}
if (args[1] == "float16" || args[1] == "float32") {
if (args[1] == "float16" ||
args[1] == "bfloat16" ||
args[1] == "float32")
{
for (int i=3; i<args.size(); i+=2) {
if (args[i] != args[1]) {
ret = replace(ret, "$"+S(i-1),
@ -159,15 +166,17 @@ struct FP16OpType : OpByType {
}
}
} else {
string target = found_bf16 ? "bfloat16" : "float16";
for (int i=3; i<args.size(); i+=2) {
if (args[i] != "float16") {
if (args[i] != target) {
ret = replace(ret, "$"+S(i-1),
"float16($"+S(i-1)+")");
target+"($"+S(i-1)+")");
}
}
}
}
return format(ret, args);
auto result = format(ret, args);
return result;
}
void post_pass(OpCompiler* oc) {

View File

@ -211,9 +211,15 @@ ArrayArgs VarHolder::fetch_sync() {
}
inline static void cast_item_data(ItemData& data) {
auto* fp16 = (float16*)&data;
auto* fp32 = (float32*)&data;
fp32[0] = float32(fp16[0]);
if (data.dtype == ns_float16) {
auto* fp16 = (float16*)&data;
auto* fp32 = (float32*)&data;
fp32[0] = float32(fp16[0]);
} else if (data.dtype == ns_bfloat16) {
auto* bf16 = (bfloat16*)&data;
auto* fp32 = (float32*)&data;
fp32[0] = float32(bf16[0]);
}
data.dtype = ns_float32;
}
@ -235,7 +241,7 @@ ItemData VarHolder::item() {
{
std::memcpy(&data.data, var->mem_ptr, dsize);
}
if (data.dtype == ns_float16)
if (data.dtype == ns_float16 || data.dtype == ns_bfloat16)
cast_item_data(data);
return data;
}

View File

@ -0,0 +1,384 @@
# ***************************************************************
# Copyright (c) 2023 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
import os
def transpose0231(x):
s0, s1, s2, s3 = x.shape
asize = 16
bsize = 16
ILP = 2
return jt.code([s0, s2, s3, s1], x.dtype, [x],
cuda_header="#include <type/fp16_compute.h>\n#include <cassert>",
cuda_src=f"""
__global__ void kernel(in0_type* __restrict__ x, in0_type* __restrict__ y, int s0, int s1, int s2, int s3) {{
__shared__ in0_type t[{asize*ILP}*{bsize*ILP+1}];
int t3 = threadIdx.x % {bsize};
int t1 = threadIdx.x / {bsize};
int b3 = blockIdx.x;
int b2 = blockIdx.y;
int b0 = blockIdx.z;
int x3 = 1;
int x2 = s3;
int x1 = s2*x2;
int x0 = s1*x1;
int y3 = 1;
int y2 = s1;
int y1 = s3*y2;
int y0 = s2*y1;
in0_type tmp[{ILP}];
for (int i=0; i<(s1-1)/{asize*ILP}+1; i++)
{{
int _b3 = b3 * {bsize*ILP} + t3*{ILP};
if (_b3 < s3) {{
#pragma unroll
for (int j=0; j<{ILP}; j++) {{
vload<sizeof(in0_type)*{ILP}>(
tmp,
&x[b0*x0+(t1*{ILP}+j+i*{asize*ILP})*x1+b2*x2+_b3*x3]
);
#pragma unroll
for (int k=0; k<{ILP}; k++)
t[(t1*{ILP}+j)*{bsize*ILP+1}+t3*{ILP}+k] = tmp[k];
}}
}}
__syncthreads();
int t3_ = threadIdx.x % {asize};
int t1_ = threadIdx.x / {asize};
_b3 = b3 * {bsize*ILP} + t1_*{ILP};
if (_b3 < s3) {{
#pragma unroll
for (int j=0; j<{ILP}; j++) {{
#pragma unroll
for (int k=0; k<{ILP}; k++) {{
tmp[k] =
t[(t3*{ILP}+k)*{bsize*ILP+1}+t1_*{ILP}+j];
}}
vload<sizeof(in0_type)*{ILP}>(
&y[b0*y0+b2*y1+(_b3+j)*y2+((t3*{ILP})+i*{asize*ILP})*y3],
tmp
);
}}
}}
__syncthreads();
}}
}}
int s0, s1, s2, s3;
in0->shape.unpack(s0, s1, s2, s3);
kernel<<<{{(s3-1)/{bsize*ILP}+1, s2, s0 }}, {bsize*asize}>>>
(in0_p, out0_p, s0, s1, s2, s3);
""")
def transpose0231_2(x):
s0, s1, s2, s3 = x.shape
asize = 16
bsize = 8
ILP = 2
return jt.code([s0, s2, s3, s1], x.dtype, [x],
cuda_header="#include <type/fp16_compute.h>\n#include <cassert>",
cuda_src=f"""
__global__ __launch_bounds__({asize*bsize}) void kernel(in0_type* __restrict__ x, in0_type* __restrict__ y, int s0, int s1, int s2, int s3) {{
__shared__ in0_type t[{asize*ILP}*{bsize*ILP+1}];
int t3 = threadIdx.x % {bsize};
int t1 = threadIdx.x / {bsize};
int b3 = blockIdx.x;
int b1 = blockIdx.y;
int b2 = 0;
int b0 = blockIdx.z;
int x3 = 1;
int x2 = s3;
int x1 = s2*x2;
int x0 = s1*x1;
int y3 = 1;
int y2 = s1;
int y1 = s3*y2;
int y0 = s2*y1;
in0_type tmp[{ILP}];
{{
int _b3 = b3 * {bsize*ILP} + t3*{ILP};
if (_b3 < s3) {{
#pragma unroll
for (int j=0; j<{ILP}; j++) {{
if (t1*{ILP}+j+b1*{asize*ILP} >= s1)
continue;
vload<sizeof(in0_type)*{ILP}>(
tmp,
&x[b0*x0+(t1*{ILP}+j+b1*{asize*ILP})*x1+b2*x2+_b3*x3]
);
#pragma unroll
for (int k=0; k<{ILP}; k++)
t[(t1*{ILP}+j)*{bsize*ILP+1}+t3*{ILP}+k] = tmp[k];
}}
}}
__syncthreads();
int t3_ = threadIdx.x % {asize};
int t1_ = threadIdx.x / {asize};
_b3 = b3 * {bsize*ILP} + t1_*{ILP};
int yy3 = (t3_*{ILP})+b1*{asize*ILP};
if (_b3 < s3 && yy3 < s1) {{
#pragma unroll
for (int j=0; j<{ILP}; j++) {{
#pragma unroll
for (int k=0; k<{ILP}; k++) {{
tmp[k] =
t[(t3_*{ILP}+k)*{bsize*ILP+1}+t1_*{ILP}+j];
}}
vload<sizeof(in0_type)*{ILP}>(
&y[b0*y0+b2*y1+(_b3+j)*y2+yy3*y3],
tmp
);
// printf("%d %d %d %d %d\\n", b0*y0+b2*y1+(_b3+j)*y2+yy3*y3,
// b0, b2, (_b3+j), yy3);
}}
}}
__syncthreads();
}}
}}
int s0, s1, s2, s3;
in0->shape.unpack(s0, s1, s2, s3);
kernel<<<{{(s3-1)/{bsize*ILP}+1, (s1-1)/{asize*ILP}+1, s0 }}, {bsize*asize}>>>
(in0_p, out0_p, s0, s1, s2, s3);
""")
def check_share():
return
a = jt.rand((30, 32, 4, 2000)).float32()
jt.code(a.shape, a.dtype, [a],
cuda_header="#include <type/fp16_compute.h>\n#include <cassert>",
cuda_src="""
__global__ void kernel(in0_type* __restrict__ a, in0_type* __restrict__ b) {
__shared__ float x[32*33];
for (int i=0; i<3; i++) {
((float2*)&x[i])[0] = ((float2*)&a[i])[0];
((float2*)&b[i])[0] = ((float2*)&x[i+1])[0];
}
}
kernel<<<1024,16*16>>>(in0_p, out0_p);
""").sync()
jt.sync_all(True)
# print(a[0]+1)
print("pass test")
class TestBF16(unittest.TestCase):
def test_array(self):
a = np.array([1,2,3], dtype="float")
b = jt.array(a).bfloat16()
np.testing.assert_allclose(a, b.float().numpy())
def test_add(self):
a = np.array([1,2,3], dtype="float32")
b = jt.bfloat16(a)
c = b+b
assert c.dtype == "bfloat16"
np.testing.assert_allclose(c.numpy(), a+a)
d = c.sum()
np.testing.assert_allclose(d.numpy(), [12])
c = c+1
print(c)
def test_matmul(self):
a = jt.random((100,100)).bfloat16()
b = jt.random((100,100)).bfloat16()
c = jt.matmul(a, b)
c.sync()
print(c)
assert c.dtype == "bfloat16"
def test_bmm(self):
a = jt.random((10,3,4)).bfloat16()
b = jt.random((10,4,5)).bfloat16()
c = jt.matmul(a, b)
c.sync()
def test_matmul_grad(self):
a = jt.random((100,100)).bfloat16()
b = jt.random((100,100)).bfloat16()
c = jt.matmul(a, b)
c.sync()
da, db = jt.grad(c, [a,b])
jt.sync_all()
assert da.dtype == "bfloat16"
assert db.dtype == "bfloat16"
def test_conv(self):
a = jt.random((3,4,5,5)).bfloat16()
b = jt.random((4,4,3,3)).bfloat16()
c = jt.nn.conv(a, b)
c.sync()
def test_max(self):
a = jt.random((100,)).bfloat16()
b = jt.random((100,)).bfloat16()
c = a.maximum(b)
c.sync()
def test_reduce_dtype_infer(self):
return
# this test cannot pass now
with jt.flag_scope(amp_reg=1):
a = jt.random((3,4,5,5)).bfloat16()
b = a.sum()
b.sync()
assert b.dtype == "float32", b.dtype
with jt.flag_scope(amp_reg=2):
a = jt.random((3,4,5,5)).bfloat16()
b = a.sum()
b.sync()
assert b.dtype == "float32"
with jt.flag_scope(amp_reg=0):
a = jt.random((3,4,5,5)).bfloat16()
b = a.sum()
b.sync()
assert b.dtype == "float32"
with jt.flag_scope(amp_reg=2+4):
a = jt.random((3,4,5,5)).bfloat16()
b = a.sum()
b.sync()
assert b.dtype == "bfloat16", b.dtype
def test_white_dtype_infer(self):
with jt.flag_scope(amp_reg=1):
a = jt.random((3,4,5,5)).bfloat16()
b = a**a
b.sync()
assert b.dtype == "float32"
with jt.flag_scope(amp_reg=2):
a = jt.random((3,4,5,5)).bfloat16()
b = a**a
b.sync()
assert b.dtype == "float32"
with jt.flag_scope(amp_reg=0):
a = jt.random((3,4,5,5)).bfloat16()
b = a**a
b.sync()
assert b.dtype == "float32"
with jt.flag_scope(amp_reg=2+8):
a = jt.random((3,4,5,5)).bfloat16()
b = a**a
b.sync()
assert b.dtype == "bfloat16", b.dtype
def test_module_half(self):
a = jt.nn.Linear(10,10)
assert a.weight.dtype == "float32"
a.bfloat16()
assert a.weight.dtype == "bfloat16"
def test_scalar(self):
a = jt.bfloat16([1,2,3])
assert (a*1).dtype == "bfloat16"
assert (a*jt.bfloat16([1,2,3])).dtype == "bfloat16"
assert (a*jt.float32([1,2,3])).dtype == "float32"
assert (a*jt.float32([1,2,3]).sum()).dtype == "bfloat16"
assert jt.int([0,1,0]).ternary(a, jt.float32(1)).dtype == "bfloat16"
def test_amp_level3(self):
with jt.flag_scope(amp_level = 3):
a = jt.bfloat16([1,2,3])
assert (a.sum()).dtype == "bfloat16"
assert (a.mean()).dtype == "bfloat16"
assert (a.log()).dtype == "bfloat16"
assert (a.exp()).dtype == "bfloat16"
def test_safe_clip(self):
import math
assert not jt.bfloat16(math.inf).isfinite()
assert jt.safe_clip(jt.bfloat16(math.inf)).isfinite()
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
class TestBF16CUDA(TestBF16):
def setUp(self):
jt.flags.use_cuda = 1
def tearDown(self):
jt.flags.use_cuda = 0
def test_add_correct(self):
na = np.random.rand(10000)
nb = np.random.rand(10000)
a = jt.array(na).bfloat16()
b = jt.array(nb).bfloat16()
c = a + b
nc = c.numpy()
np.testing.assert_allclose(nc, na+nb, atol=1e-2)
def test_matmul_correct(self):
na = np.random.rand(64,64)
nb = np.random.rand(64,64)
a = jt.array(na).bfloat16()
b = jt.array(nb).bfloat16()
c = jt.matmul(a, b)
nc = c.numpy()
nc2 = np.matmul(na, nb)
np.testing.assert_allclose(nc, nc2, rtol=1e-2)
def test_softmax(self):
a = jt.rand((120, 2000, 2000)).bfloat16()
# a = jt.rand((1, 2000, 2000)).float32()
jt.sync_all()
with jt.profile_scope(10, 100):
a.log_softmax(-1).sync()
def test_transpose(self):
check_share()
# return
a = jt.rand((30, 32, 4, 2000)).float32()
# a = jt.rand((1, 1024, 1, 2000)).float32()
diff = transpose0231(a).data != a.transpose((0,2,3,1)).data
print(np.where(diff))
# return
jt.sync_all()
# with jt.profile_scope(100, 11000):
with jt.profile_scope(100, 11000):
# a.log_softmax(-1).sync()
transpose0231(a).sync()
a.transpose((0,2,3,1)).sync()
# a.transpose((0,2,1,3)).sync()
a.fuse_transpose((0,2,1,3)).sync()
(a+1).sync()
jt.sync_all(True)
diff = transpose0231(a).data != a.transpose((0,2,3,1)).data
print(np.where(diff))
np.testing.assert_allclose(transpose0231(a).data, a.transpose((0,2,3,1)).data)
def test_transpose2(self):
# check_share()
# return
# a = jt.rand((30, 32, 4, 2000)).float32()
# a = jt.rand((1, 10000, 1, 2000)).float32()
a = jt.rand((1, 10000, 1, 2048)).float32()
print("transpose")
transpose0231_2(a).sync()
print("add")
(a+1).sync()
return
# a = jt.arange(32*16).reshape((1, 32, 1, 16))
diff = transpose0231_2(a).data != a.transpose((0,2,3,1)).data
print(np.where(diff))
# return
jt.sync_all()
# with jt.profile_scope(100, 11000):
with jt.profile_scope(100, 1100):
# a.log_softmax(-1).sync()
transpose0231_2(a).sync()
a.transpose((0,2,3,1)).sync()
# a.transpose((0,2,1,3)).sync()
a.fuse_transpose((0,2,1,3)).sync()
(a+1).sync()
jt.sync_all(True)
diff = transpose0231_2(a).data != a.transpose((0,2,3,1)).data
print(np.where(diff))
np.testing.assert_allclose(transpose0231_2(a).data, a.transpose((0,2,3,1)).data)
if __name__ == "__main__":
unittest.main()

View File

@ -344,7 +344,8 @@ def find_cache_path():
pyv = "py"+platform.python_version()
# cpu version
cpuv = get_cpu_version()
dirs = [".cache", "jittor", jtv, ccv, pyv, osv, cpuv]
jittor_path_key = get_str_hash(__file__)[:4]
dirs = [".cache", "jittor", jtv, ccv, pyv, osv, cpuv, jittor_path_key]
dirs = list(map(short, dirs))
cache_name = "default"
try:

View File

@ -21,14 +21,7 @@ def load_tensor(contents, dtype, numel, key, location):
loaded_storages[key] = contents.read_var(name, dtype)
def get_dtype_size(dtype):
dtype = dtype.__str__()
if dtype == "float32" or dtype == "int32":
return 4
if dtype == "float64" or dtype == "int64":
return 8
if dtype == "float16" or dtype == "int16":
return 2
return 1
return jt.NanoString(dtype).dsize()
def persistent_load(saved_id):
global contents
@ -47,6 +40,8 @@ def persistent_load(saved_id):
def _dtype_to_storage_type_map():
return {
np.float16: 'HalfStorage',
# just fake np.uint16 as bfloat16
np.uint16: 'BFloat16Storage',
np.float32: 'FloatStorage',
np.float64: 'DoubleStorage',
np.int64: 'LongStorage',

View File

@ -6,6 +6,7 @@ import sys
import torch
class HalfStorage: pass
class BFloat16Storage: pass
class FloatStorage: pass
class LongStorage: pass
class IntStorage: pass
@ -13,6 +14,7 @@ class ShortStorage: pass
class CharStorage: pass
class BoolStorage: pass
HalfStorage.__module__ = "torch"
BFloat16Storage.__module__ = "torch"
FloatStorage.__module__ = "torch"
LongStorage.__module__ = "torch"
IntStorage.__module__ = "torch"
@ -22,7 +24,7 @@ BoolStorage.__module__ = "torch"
def _rebuild_tensor_v2(*args): pass
_rebuild_tensor_v2.__module__ = "torch._utils"
targets = [HalfStorage, FloatStorage, LongStorage, IntStorage, ShortStorage, CharStorage, BoolStorage, _rebuild_tensor_v2]
targets = [HalfStorage, BFloat16Storage, FloatStorage, LongStorage, IntStorage, ShortStorage, CharStorage, BoolStorage, _rebuild_tensor_v2]
def swap_targets(targets):
original_targets = []
@ -57,6 +59,7 @@ class TensorWrapper:
dtype_map = {
"float16": HalfStorage,
"bfloat16": BFloat16Storage,
"float32": FloatStorage,
"int64": LongStorage,
"int32": IntStorage,