From e3d1fafa879f8dfa01a2c38329ff20f1b4893406 Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Wed, 16 Sep 2020 14:27:13 +0800 Subject: [PATCH] fix update queue and conv transpose --- python/jittor/dataset/dataset.py | 9 +++++++-- python/jittor/nn.py | 3 +++ src/update_queue.cc | 6 ++++-- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/python/jittor/dataset/dataset.py b/python/jittor/dataset/dataset.py index 616b18db..1c25d18f 100644 --- a/python/jittor/dataset/dataset.py +++ b/python/jittor/dataset/dataset.py @@ -82,13 +82,16 @@ class Dataset(object): def __getitem__(self, index): raise NotImplementedError - def __len__(self): + def __batch_len__(self): assert self.total_len >= 0 assert self.batch_size > 0 if self.drop_last: return self.total_len // self.batch_size return (self.total_len-1) // self.batch_size + 1 + def __len__(self): + return self.__batch_len__() + def set_attrs(self, **kw): ''' You can set attributes of dataset by using set_attrs function, including total_len, batch_size, shuffle, drop_last, num_workers, buffer_size. @@ -300,6 +303,8 @@ Example:: self.terminate() def __iter__(self): + if self.total_len is None: + self.total_len = len(self) if self.shuffle == False: index_list = get_order_list(self.total_len) else: @@ -354,7 +359,7 @@ Example:: self.real_len = self.total_len self.real_batch_size = self.batch_size - self.batch_len = len(self) + self.batch_len = self.__batch_len__() if not hasattr(self, "workers") and self.num_workers: self._init_workers() diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 7360c9b2..f3769ea3 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -487,6 +487,9 @@ class ConvTranspose(Module): self.real_padding = (self.dilation[0] * (self.kernel_size[0] - 1) - self.padding[0], self.dilation[1] * (self.kernel_size[1] - 1) - self.padding[1]) self.output_padding = output_padding if isinstance (output_padding, tuple) else (output_padding, output_padding) + assert self.output_padding[0] < max(self.stride[0], self.dilation[0]) and \ + self.output_padding[1] < max(self.stride[1], self.dilation[1]), \ + "output padding must be smaller than max(stride, dilation)" self.weight = init.invariant_uniform((in_channels, out_channels) + self.kernel_size, dtype="float") if bias: diff --git a/src/update_queue.cc b/src/update_queue.cc index bb3d729f..f6319560 100644 --- a/src/update_queue.cc +++ b/src/update_queue.cc @@ -128,10 +128,12 @@ void UpdateQueue::push(Var* v, Var* prev) { queue.emplace_front(); owner = queue.begin(); } - if (owner->size() >= update_queue_auto_flush_delay) - auto_flush(); + auto need_auto_flush = owner->size() >= update_queue_auto_flush_delay; owner->emplace_front(UpdateQueue::Item{owner, v}); map[v] = owner->begin(); + if (need_auto_flush) { + auto_flush(); + } // if total size of update queue is too big, // force sync all if (map.size() > 100000)