mirror of https://github.com/Jittor/Jittor
polish zero shape
This commit is contained in:
parent
84967c21c4
commit
71bff3118e
|
@ -8,7 +8,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.24'
|
||||
__version__ = '1.2.2.25'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -1009,6 +1009,8 @@ def randperm(n):
|
|||
return jt.array(np.random.permutation(idx))
|
||||
|
||||
def set_global_seed(seed):
|
||||
import random
|
||||
random.seed(seed)
|
||||
jt.set_seed(seed)
|
||||
np.random.seed(seed)
|
||||
try:
|
||||
|
|
|
@ -16,6 +16,7 @@ CudaDeviceAllocator cuda_device_allocator;
|
|||
const char* CudaDeviceAllocator::name() const {return "cuda_device";}
|
||||
|
||||
void* CudaDeviceAllocator::alloc(size_t size, size_t& allocation) {
|
||||
if (size==0) return (void*)0x10;
|
||||
void* ptr;
|
||||
try {
|
||||
checkCudaErrors(cudaMalloc(&ptr, size));
|
||||
|
@ -32,6 +33,7 @@ void* CudaDeviceAllocator::alloc(size_t size, size_t& allocation) {
|
|||
}
|
||||
|
||||
void CudaDeviceAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) {
|
||||
if (size==0) return;
|
||||
checkCudaErrors(cudaFree(mem_ptr));
|
||||
}
|
||||
|
||||
|
|
|
@ -16,12 +16,14 @@ CudaHostAllocator cuda_host_allocator;
|
|||
const char* CudaHostAllocator::name() const {return "cuda_host";}
|
||||
|
||||
void* CudaHostAllocator::alloc(size_t size, size_t& allocation) {
|
||||
if (size==0) return (void*)0x10;
|
||||
void* ptr;
|
||||
checkCudaErrors(cudaMallocHost(&ptr, size));
|
||||
return ptr;
|
||||
}
|
||||
|
||||
void CudaHostAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) {
|
||||
if (size==0) return;
|
||||
checkCudaErrors(cudaFreeHost(mem_ptr));
|
||||
}
|
||||
|
||||
|
|
|
@ -17,12 +17,14 @@ DEFINE_FLAG(int, use_cuda_managed_allocator, 1, "Enable cuda_managed_allocator")
|
|||
const char* CudaManagedAllocator::name() const {return "cuda_managed";}
|
||||
|
||||
void* CudaManagedAllocator::alloc(size_t size, size_t& allocation) {
|
||||
if (size==0) return (void*)0x10;
|
||||
void* ptr;
|
||||
checkCudaErrors(cudaMallocManaged(&ptr, size));
|
||||
return ptr;
|
||||
}
|
||||
|
||||
void CudaManagedAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) {
|
||||
if (size==0) return;
|
||||
checkCudaErrors(cudaFree(mem_ptr));
|
||||
}
|
||||
|
||||
|
|
|
@ -37,15 +37,15 @@ void ReshapeOp::infer_shape() {
|
|||
} else
|
||||
y_items *= shape[i];
|
||||
}
|
||||
ASSERT(uncertain_dim <= 1) << "max number of -1 is 1, but get" << uncertain_dim << ".";
|
||||
CHECK(uncertain_dim <= 1) << "max number of -1 is 1, but get" << uncertain_dim << ".";
|
||||
int64_t x_items = x->num;
|
||||
auto yshape = shape;
|
||||
if (x_items < 0) {
|
||||
// pass if input is uncertain
|
||||
} else if (uncertain_dim == 0) {
|
||||
ASSERTop(x_items,==,y_items) << "reshape shape is invalid for input of size";
|
||||
CHECKop(x_items,==,y_items) << "reshape shape is invalid for input of size";
|
||||
} else {
|
||||
ASSERT(x_items % y_items == 0) << "reshape shape is invalid for input of size " << x_items;
|
||||
CHECK(y_items != 0 && x_items % y_items == 0) << "reshape shape is invalid for input of size " << x_items;
|
||||
uncertain_dim = x_items / y_items;
|
||||
yshape.clear();
|
||||
for (auto a : shape)
|
||||
|
|
|
@ -85,6 +85,7 @@ void LoopVarAnalyzePass::run() {
|
|||
// ugly fix multi different dim element input
|
||||
// (caused by force fused array op)
|
||||
int max_elm_dim = 0;
|
||||
int64 max_elm_size = 0;
|
||||
for (uint i=0; i<vars.size(); i++) {
|
||||
// output
|
||||
if (vars[i].type == 2) {
|
||||
|
@ -98,6 +99,8 @@ void LoopVarAnalyzePass::run() {
|
|||
if (op->type() == OpType::element) {
|
||||
has_element = true;
|
||||
max_elm_dim = std::max(max_elm_dim, op->outputs().front()->shape.size());
|
||||
if (max_elm_dim == op->outputs().front()->shape.size())
|
||||
max_elm_size = std::max(max_elm_size, std::abs(op->outputs().front()->num));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -116,7 +119,8 @@ void LoopVarAnalyzePass::run() {
|
|||
if (has_element && !has_reduce && op->type() != OpType::element)
|
||||
continue;
|
||||
if (op->type() == OpType::element
|
||||
&& op->outputs().front()->shape.size() != max_elm_dim)
|
||||
&& (op->outputs().front()->shape.size() != max_elm_dim ||
|
||||
std::abs(op->outputs().front()->num) != max_elm_size))
|
||||
continue;
|
||||
Var* loop_var;
|
||||
if (op->type() == OpType::broadcast || op->name_ex() == "index") {
|
||||
|
|
Loading…
Reference in New Issue