fix update queue and conv transpose

This commit is contained in:
Dun Liang 2020-09-16 14:27:13 +08:00
parent f3e99b96bc
commit e3d1fafa87
3 changed files with 14 additions and 4 deletions

View File

@ -82,13 +82,16 @@ class Dataset(object):
def __getitem__(self, index): def __getitem__(self, index):
raise NotImplementedError raise NotImplementedError
def __len__(self): def __batch_len__(self):
assert self.total_len >= 0 assert self.total_len >= 0
assert self.batch_size > 0 assert self.batch_size > 0
if self.drop_last: if self.drop_last:
return self.total_len // self.batch_size return self.total_len // self.batch_size
return (self.total_len-1) // self.batch_size + 1 return (self.total_len-1) // self.batch_size + 1
def __len__(self):
return self.__batch_len__()
def set_attrs(self, **kw): def set_attrs(self, **kw):
''' '''
You can set attributes of dataset by using set_attrs function, including total_len, batch_size, shuffle, drop_last, num_workers, buffer_size. You can set attributes of dataset by using set_attrs function, including total_len, batch_size, shuffle, drop_last, num_workers, buffer_size.
@ -300,6 +303,8 @@ Example::
self.terminate() self.terminate()
def __iter__(self): def __iter__(self):
if self.total_len is None:
self.total_len = len(self)
if self.shuffle == False: if self.shuffle == False:
index_list = get_order_list(self.total_len) index_list = get_order_list(self.total_len)
else: else:
@ -354,7 +359,7 @@ Example::
self.real_len = self.total_len self.real_len = self.total_len
self.real_batch_size = self.batch_size self.real_batch_size = self.batch_size
self.batch_len = len(self) self.batch_len = self.__batch_len__()
if not hasattr(self, "workers") and self.num_workers: if not hasattr(self, "workers") and self.num_workers:
self._init_workers() self._init_workers()

View File

@ -487,6 +487,9 @@ class ConvTranspose(Module):
self.real_padding = (self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0], self.real_padding = (self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0],
self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1]) self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1])
self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding) self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding)
assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \
self.output_padding[1] < max(self.stride[1], self.dilation[1]), \
"output padding must be smaller than max(stride, dilation)"
self.weight = init.invariant_uniform((in_channels, out_channels) + self.kernel_size, dtype="float") self.weight = init.invariant_uniform((in_channels, out_channels) + self.kernel_size, dtype="float")
if bias: if bias:

View File

@ -128,10 +128,12 @@ void UpdateQueue::push(Var* v, Var* prev) {
queue.emplace_front(); queue.emplace_front();
owner = queue.begin(); owner = queue.begin();
} }
if (owner->size() >= update_queue_auto_flush_delay) auto need_auto_flush = owner->size() >= update_queue_auto_flush_delay;
auto_flush();
owner->emplace_front(UpdateQueue::Item{owner, v}); owner->emplace_front(UpdateQueue::Item{owner, v});
map[v] = owner->begin(); map[v] = owner->begin();
if (need_auto_flush) {
auto_flush();
}
// if total size of update queue is too big, // if total size of update queue is too big,
// force sync all // force sync all
if (map.size() > 100000) if (map.size() > 100000)