polish flag scope for profile_mark

This commit is contained in:
Dun Liang 2022-11-20 23:39:12 +08:00
parent ef55bd378f
commit f7ba3cab31
2 changed files with 45 additions and 2 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.5.36'
__version__ = '1.3.5.37'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@ -121,7 +121,13 @@ class flag_scope(_call_no_record_scope):
flags_bk = self.flags_bk = {}
try:
for k,v in self.jt_flags.items():
flags_bk[k] = getattr(flags, k)
origin = getattr(flags, k)
flags_bk[k] = origin
# merge dict attrs
if isinstance(origin, dict):
for ok, ov in origin.items():
if ok not in v:
v[ok] = ov
setattr(flags, k, v)
except:
self.__exit__()

View File

@ -71,6 +71,43 @@ vload(T* __restrict__ a, T* __restrict__ b) {
}
}
template<int nbyte, class T>
__device__ inline
typename std::enable_if<nbyte<=0,void>::type
vfill(T* __restrict__ a) {}
template<int nbyte, class T>
__device__ inline
typename std::enable_if<0<nbyte,void>::type
vfill(T* __restrict__ a) {
if (nbyte<=0) return;
if (nbyte>=16) {
auto* __restrict__ aa = (int4* __restrict__)a;
aa[0] = {0};
return vfill<nbyte-16>(aa+1);
}
if (nbyte>=8) {
auto* __restrict__ aa = (int2* __restrict__)a;
aa[0] = {0};
return vfill<nbyte-8>(aa+1);
}
if (nbyte>=4) {
auto* __restrict__ aa = (int* __restrict__)a;
aa[0] = 0;
return vfill<nbyte-4>(aa+1);
}
if (nbyte>=2) {
auto* __restrict__ aa = (int16_t* __restrict__)a;
aa[0] = 0;
return vfill<nbyte-2>(aa+1);
}
if (nbyte>=1) {
auto* __restrict__ aa = (int8_t* __restrict__)a;
aa[0] = 0;
return vfill<nbyte-1>(aa+1);
}
}
}