mirror of https://github.com/Jittor/Jittor
Merge branch 'master' of https://github.com/jittor/jittor into lxl
This commit is contained in:
commit
13941112e1
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.1.1'
|
||||
__version__ = '1.2.1.2'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -383,7 +383,7 @@ class BatchNorm(Module):
|
|||
return norm_x * w + b
|
||||
|
||||
class BatchNorm1d(Module):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True, sync=True):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, is_train=True, sync=True):
|
||||
self.sync = sync
|
||||
self.num_features = num_features
|
||||
self.is_train = is_train
|
||||
|
|
|
@ -92,8 +92,12 @@ inline JK& operator<<(JK& jk, const string& s) {
|
|||
auto a = (__jk_int256*)(jk.buffer+jk.size);
|
||||
auto b = (__jk_int256*)(&s[0]);
|
||||
auto len = s.size();
|
||||
for (uint64 i=0; i*32<len; i++)
|
||||
a[i] = b[i];
|
||||
uint64 i=0;
|
||||
for (; i+32<=len; i+=32)
|
||||
a[i/32] = b[i/32];
|
||||
|
||||
for (; i<len; i++)
|
||||
jk.buffer[jk.size+i] = s[i];
|
||||
jk.size += len;
|
||||
return jk;
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ struct SimpleThread {
|
|||
std::thread thread;
|
||||
void run() {
|
||||
thread_name = "C"+S(id);
|
||||
try{
|
||||
try {
|
||||
std::unique_lock<std::mutex> lck(mtx);
|
||||
if (func)
|
||||
func(id);
|
||||
|
@ -75,6 +75,24 @@ struct SimpleThreads {
|
|||
for (int i=0; i<n; i++)
|
||||
threads.emplace_back(i);
|
||||
}
|
||||
void wait_all() {
|
||||
for (auto& t : threads) {
|
||||
auto start = clock();
|
||||
int ok = 0;
|
||||
while (clock()<start+5000) {
|
||||
if (t.mtx.try_lock()) {
|
||||
t.mtx.unlock();
|
||||
ok = 1;
|
||||
break;
|
||||
}
|
||||
using namespace std::chrono_literals;
|
||||
std::this_thread::sleep_for(1ms);
|
||||
}
|
||||
if (!ok) {
|
||||
LOGw << "Compile thread timeout, ignored.";
|
||||
}
|
||||
}
|
||||
}
|
||||
void launch_all(int active_thread, SimpleThread::Func func) {
|
||||
if (active_thread == 1) {
|
||||
func(0);
|
||||
|
@ -289,10 +307,12 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
|
|||
|
||||
if (segfault_happen) {
|
||||
LOGe << "Segfault happen, main thread exit";
|
||||
threads.wait_all();
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if (has_error) {
|
||||
threads.wait_all();
|
||||
LOGf << "Error happend during compilation, see error above.";
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue