polish zero shape

This commit is contained in:
Dun Liang 2021-01-31 21:32:25 +08:00
parent 84967c21c4
commit 71bff3118e
7 changed files with 17 additions and 5 deletions

View File

@ -8,7 +8,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.24'
__version__ = '1.2.2.25'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -1009,6 +1009,8 @@ def randperm(n):
return jt.array(np.random.permutation(idx))
def set_global_seed(seed):
import random
random.seed(seed)
jt.set_seed(seed)
np.random.seed(seed)
try:

View File

@ -16,6 +16,7 @@ CudaDeviceAllocator cuda_device_allocator;
const char* CudaDeviceAllocator::name() const {return "cuda_device";}
void* CudaDeviceAllocator::alloc(size_t size, size_t& allocation) {
if (size==0) return (void*)0x10;
void* ptr;
try {
checkCudaErrors(cudaMalloc(&ptr, size));
@ -32,6 +33,7 @@ void* CudaDeviceAllocator::alloc(size_t size, size_t& allocation) {
}
void CudaDeviceAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) {
if (size==0) return;
checkCudaErrors(cudaFree(mem_ptr));
}

View File

@ -16,12 +16,14 @@ CudaHostAllocator cuda_host_allocator;
const char* CudaHostAllocator::name() const {return "cuda_host";}
void* CudaHostAllocator::alloc(size_t size, size_t& allocation) {
if (size==0) return (void*)0x10;
void* ptr;
checkCudaErrors(cudaMallocHost(&ptr, size));
return ptr;
}
void CudaHostAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) {
if (size==0) return;
checkCudaErrors(cudaFreeHost(mem_ptr));
}

View File

@ -17,12 +17,14 @@ DEFINE_FLAG(int, use_cuda_managed_allocator, 1, "Enable cuda_managed_allocator")
const char* CudaManagedAllocator::name() const {return "cuda_managed";}
void* CudaManagedAllocator::alloc(size_t size, size_t& allocation) {
if (size==0) return (void*)0x10;
void* ptr;
checkCudaErrors(cudaMallocManaged(&ptr, size));
return ptr;
}
void CudaManagedAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) {
if (size==0) return;
checkCudaErrors(cudaFree(mem_ptr));
}

View File

@ -37,15 +37,15 @@ void ReshapeOp::infer_shape() {
} else
y_items *= shape[i];
}
ASSERT(uncertain_dim <= 1) << "max number of -1 is 1, but get" << uncertain_dim << ".";
CHECK(uncertain_dim <= 1) << "max number of -1 is 1, but get" << uncertain_dim << ".";
int64_t x_items = x->num;
auto yshape = shape;
if (x_items < 0) {
// pass if input is uncertain
} else if (uncertain_dim == 0) {
ASSERTop(x_items,==,y_items) << "reshape shape is invalid for input of size";
CHECKop(x_items,==,y_items) << "reshape shape is invalid for input of size";
} else {
ASSERT(x_items % y_items == 0) << "reshape shape is invalid for input of size " << x_items;
CHECK(y_items != 0 && x_items % y_items == 0) << "reshape shape is invalid for input of size " << x_items;
uncertain_dim = x_items / y_items;
yshape.clear();
for (auto a : shape)

View File

@ -85,6 +85,7 @@ void LoopVarAnalyzePass::run() {
// ugly fix multi different dim element input
// (caused by force fused array op)
int max_elm_dim = 0;
int64 max_elm_size = 0;
for (uint i=0; i<vars.size(); i++) {
// output
if (vars[i].type == 2) {
@ -98,6 +99,8 @@ void LoopVarAnalyzePass::run() {
if (op->type() == OpType::element) {
has_element = true;
max_elm_dim = std::max(max_elm_dim, op->outputs().front()->shape.size());
if (max_elm_dim == op->outputs().front()->shape.size())
max_elm_size = std::max(max_elm_size, std::abs(op->outputs().front()->num));
}
}
}
@ -116,7 +119,8 @@ void LoopVarAnalyzePass::run() {
if (has_element && !has_reduce && op->type() != OpType::element)
continue;
if (op->type() == OpType::element
&& op->outputs().front()->shape.size() != max_elm_dim)
&& (op->outputs().front()->shape.size() != max_elm_dim ||
std::abs(op->outputs().front()->num) != max_elm_size))
continue;
Var* loop_var;
if (op->type() == OpType::broadcast || op->name_ex() == "index") {