mirror of https://github.com/Jittor/Jittor
Merge branch 'develop' into lxl
This commit is contained in:
commit
5dfef0af13
|
@ -25,3 +25,4 @@ python/jittor.egg-info
|
|||
dist/
|
||||
!doc/source/*
|
||||
core
|
||||
__data__
|
||||
|
|
|
@ -223,7 +223,7 @@ sudo apt install python3.7 python3.7-dev
|
|||
|
||||
The whole framework is compiled Just-in-time. Let's install jittor via pip
|
||||
|
||||
整个框架是及时编译的。 让我们通过pip安装jittor
|
||||
整个框架是即时编译的。 让我们通过pip安装jittor
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Jittor/jittor.git
|
||||
|
|
|
@ -56,17 +56,6 @@ struct NonZeroOp
|
|||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
struct ConvertOp
|
||||
{
|
||||
const T div;
|
||||
const T dim_size;
|
||||
ConvertOp(T _div,T dim_size): div(_div),dim_size(dim_size){}
|
||||
__host__ __device__ __forceinline__ T operator()(const T& val) const {
|
||||
return (val/div) % dim_size;
|
||||
}
|
||||
};
|
||||
|
||||
__global__ static void where_kernel(
|
||||
int n,
|
||||
To* input
|
||||
|
@ -90,30 +79,25 @@ void CubWhereOp::jit_run(){
|
|||
int N = cond->num;
|
||||
size_t temp_storage_bytes=0;
|
||||
size_t num_nonzeros_allocation;
|
||||
auto num_nonzeros = exe.allocator->alloc(sizeof(int), num_nonzeros_allocation);
|
||||
cub::TransformInputIterator<bool, NonZeroOp<Ti>, Ti*> itr(cond->ptr<Ti>(), NonZeroOp<Ti>());
|
||||
cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, itr, (int *)num_nonzeros, N);
|
||||
|
||||
auto num_nonzeros = exe.allocator->alloc(sizeof(To), num_nonzeros_allocation);
|
||||
|
||||
size_t temp_storage_allocation;
|
||||
auto temp_storage = exe.allocator->alloc(temp_storage_bytes, temp_storage_allocation);
|
||||
|
||||
cub::DeviceReduce::Sum(temp_storage, temp_storage_bytes, itr, (int *)num_nonzeros, N);
|
||||
exe.allocator->free(temp_storage, temp_storage_bytes, temp_storage_allocation);
|
||||
|
||||
int num_nonzeros_h;
|
||||
checkCudaErrors(cudaMemcpy(&num_nonzeros_h, num_nonzeros, sizeof(int), cudaMemcpyDeviceToHost));
|
||||
void* temp_storage;
|
||||
|
||||
To* out_temp = outs[0]->ptr<To>();
|
||||
|
||||
@for(i, 0, NDIM, outs[@i]->set_shape({num_nonzeros_h});)
|
||||
|
||||
cub::CountingInputIterator<To> counting_itr(0);
|
||||
cub::TransformInputIterator<bool, NonZeroOp<Ti>, Ti*> itr(cond->ptr<Ti>(), NonZeroOp<Ti>());
|
||||
temp_storage_bytes = 0;
|
||||
cub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, counting_itr, itr,out_temp, (int*)num_nonzeros, N);
|
||||
checkCudaErrors(cub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, counting_itr, itr, out_temp, (To*)num_nonzeros, N));
|
||||
temp_storage = exe.allocator->alloc(temp_storage_bytes, temp_storage_allocation);
|
||||
cub::DeviceSelect::Flagged(temp_storage, temp_storage_bytes, counting_itr, itr,out_temp, (int*)num_nonzeros, N);
|
||||
checkCudaErrors(cub::DeviceSelect::Flagged(temp_storage, temp_storage_bytes, counting_itr, itr,out_temp, (To*)num_nonzeros, N));
|
||||
exe.allocator->free(temp_storage, temp_storage_bytes, temp_storage_allocation);
|
||||
|
||||
To num_nonzeros_h;
|
||||
cudaMemcpy(&num_nonzeros_h, num_nonzeros, sizeof(To), cudaMemcpyDeviceToHost);
|
||||
@for(i, 0, NDIM, outs[@i]->set_shape({num_nonzeros_h});)
|
||||
|
||||
if (num_nonzeros_h > 0 && NDIM > 1) {
|
||||
int thread_num = std::min(1024, num_nonzeros_h);
|
||||
int block_num = std::max(1, num_nonzeros_h/1024);
|
||||
|
|
|
@ -15,8 +15,12 @@ namespace jittor {
|
|||
|
||||
extern cudnnHandle_t cudnn_handle;
|
||||
extern int max_cache_size;
|
||||
extern float max_workspace_ratio;
|
||||
|
||||
// @pyjt(set_algorithm_cache_size)
|
||||
void set_algorithm_cache_size(int size);
|
||||
|
||||
// @pyjt(set_max_workspace_ratio)
|
||||
void set_max_workspace_ratio(float64 ratio);
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -198,7 +198,7 @@ void CudnnConvBackwardWOp::jit_run() {
|
|||
size_t sz;
|
||||
cudnnStatus_t ret = cudnnGetConvolutionBackwardFilterWorkspaceSize(handle_, cudnnIdesc, cudnnOdesc, cudnnConvDesc, cudnnFdesc, algos[i], &sz);
|
||||
// continue if use too much workspace
|
||||
if (sz*4 > mem_info.total_cuda_ram) continue;
|
||||
if (sz > mem_info.total_cuda_ram * max_workspace_ratio) continue;
|
||||
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
|
||||
}
|
||||
size_t allocation;
|
||||
|
|
|
@ -199,7 +199,7 @@ void CudnnConvBackwardXOp::jit_run() {
|
|||
size_t sz;
|
||||
cudnnStatus_t ret = cudnnGetConvolutionBackwardDataWorkspaceSize(handle_, cudnnFdesc, cudnnOdesc, cudnnConvDesc, cudnnIdesc, algos[i], &sz);
|
||||
// continue if use too much workspace
|
||||
if (sz*4 > mem_info.total_cuda_ram) continue;
|
||||
if (sz > mem_info.total_cuda_ram * max_workspace_ratio) continue;
|
||||
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
|
||||
}
|
||||
size_t allocation;
|
||||
|
|
|
@ -203,7 +203,7 @@ void CudnnConvOp::jit_run() {
|
|||
handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc,
|
||||
cudnnOdesc, algos[i], &sz);
|
||||
// continue if use too much workspace
|
||||
if (sz*4 > mem_info.total_cuda_ram) continue;
|
||||
if (sz > mem_info.total_cuda_ram * max_workspace_ratio) continue;
|
||||
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
|
||||
}
|
||||
size_t allocation;
|
||||
|
|
|
@ -10,11 +10,16 @@ namespace jittor {
|
|||
|
||||
cudnnHandle_t cudnn_handle;
|
||||
int max_cache_size = 100;
|
||||
float max_workspace_ratio = 0.25;
|
||||
|
||||
void set_algorithm_cache_size(int size) {
|
||||
max_cache_size = size;
|
||||
}
|
||||
|
||||
void set_max_workspace_ratio(float64 ratio) {
|
||||
max_workspace_ratio = ratio;
|
||||
}
|
||||
|
||||
struct cudnn_initer {
|
||||
|
||||
inline cudnn_initer() {
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.1.2'
|
||||
__version__ = '1.2.1.3'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -864,8 +864,6 @@ def size(v, dim=None):
|
|||
return v.shape[dim]
|
||||
Var.size = size
|
||||
|
||||
def item(v):
|
||||
return v.data.item()
|
||||
|
||||
def to_int(v):
|
||||
dtype = str(v.dtype)
|
||||
|
@ -882,7 +880,6 @@ def to_bool(v):
|
|||
assert dtype.startswith("int") or dtype=="bool"
|
||||
return ori_bool(v.item())
|
||||
|
||||
Var.item = item
|
||||
Var.__int__ = to_int
|
||||
Var.__float__ = to_float
|
||||
Var.__bool__ = to_bool
|
||||
|
|
|
@ -78,20 +78,21 @@ def setup_mkl():
|
|||
|
||||
|
||||
def install_cub(root_folder):
|
||||
url = "https://github.com/NVlabs/cub/archive/v1.8.0.tar.gz"
|
||||
filename = "cub-1.8.0.tgz"
|
||||
url = "https://github.com/NVIDIA/cub/archive/1.11.0-rc1.tar.gz"
|
||||
filename = "cub-1.11.0-rc1.tgz"
|
||||
md5 = "f395687060bed7eaeb5fa8a689276ede"
|
||||
fullname = os.path.join(root_folder, filename)
|
||||
dirname = os.path.join(root_folder, filename.replace(".tgz",""))
|
||||
|
||||
if not os.path.isfile(os.path.join(dirname, "examples", "test")):
|
||||
LOG.i("Downloading cub...")
|
||||
download_url_to_local(url, filename, root_folder, "9203ea2499b56782601fddf8a12e9b08")
|
||||
download_url_to_local(url, filename, root_folder, md5)
|
||||
import tarfile
|
||||
|
||||
with tarfile.open(fullname, "r") as tar:
|
||||
tar.extractall(root_folder)
|
||||
assert 0 == os.system(f"cd {dirname}/examples && "
|
||||
f"{nvcc_path} device/example_device_radix_sort.cu -O2 -I.. -o test")
|
||||
f"{nvcc_path} device/example_device_radix_sort.cu -O2 -I.. -std=c++14 -o test")
|
||||
if core.get_device_count():
|
||||
assert 0 == os.system(f"cd {dirname}/examples && ./test")
|
||||
return dirname
|
||||
|
|
|
@ -896,8 +896,8 @@ make_cache_dir(os.path.join(cache_path, "obj_files"))
|
|||
make_cache_dir(os.path.join(cache_path, "gen"))
|
||||
|
||||
# build cache_compile
|
||||
cc_flags += pybind_include
|
||||
cc_flags += f" -I{jittor_path}/src "
|
||||
cc_flags += pybind_include
|
||||
check_cache_compile()
|
||||
LOG.v(f"Get cache_compile: {jit_utils.cc}")
|
||||
|
||||
|
@ -981,10 +981,11 @@ assert libname is not None, "openmp library not found"
|
|||
ctypes.CDLL(libname, os.RTLD_NOW | os.RTLD_GLOBAL)
|
||||
|
||||
version_file = os.path.join(jittor_path, "version")
|
||||
if os.path.isfile(version_file):
|
||||
if os.path.isfile(version_file) and not os.path.isdir(os.path.join(jittor_path, "src", "__data__")):
|
||||
with open(version_file, 'r') as f:
|
||||
version = f.read().strip()
|
||||
key = f"{version}-{cc_type}-{'cuda' if has_cuda else 'cpu'}.o"
|
||||
# key = f"{version}-{cc_type}-{'cuda' if has_cuda else 'cpu'}.o"
|
||||
key = f"{version}-g++-cpu.o"
|
||||
# TODO: open the website
|
||||
extra_obj = os.path.join(cache_path, key)
|
||||
url = os.path.join("https://cg.cs.tsinghua.edu.cn/jittor/assets/build/"+key)
|
||||
|
|
|
@ -267,6 +267,9 @@ Example::
|
|||
LOG.i('\n'.join(msg))
|
||||
|
||||
def _stop_all_workers(self):
|
||||
# stop workers
|
||||
for w in self.workers:
|
||||
w.buffer.stop()
|
||||
# wait until all workers idle
|
||||
if self.num_idle.value < self.num_workers:
|
||||
with self.gid.get_lock():
|
||||
|
|
|
@ -143,7 +143,7 @@ class ResNet(nn.Module):
|
|||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.avgpool(x)
|
||||
x = jt.reshape(x, (x.shape[0], (- 1)))
|
||||
x = jt.reshape(x, (x.shape[0], -1))
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
|
|
@ -151,7 +151,7 @@ jt.Var.__imatmul__ = lambda a,b: a.assign(matmul(a,b))
|
|||
def get_init_var_rand(shape, dtype):
|
||||
return jt.array(np.random.normal(0.0, 1.0, shape).astype(np.float32))
|
||||
|
||||
def relu(x): return jt.maximum(x, 0)
|
||||
def relu(x): return jt.ternary((x>0.0), x, jt.broadcast_var(0.0, x))
|
||||
def leaky_relu(x, scale=0.01): return jt.ternary(x>0, x, x*scale)
|
||||
def relu6(x): return jt.minimum(jt.maximum(x, 0), 6)
|
||||
def sign(x):
|
||||
|
@ -352,80 +352,39 @@ class BatchNorm(Module):
|
|||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.affine = affine
|
||||
if affine:
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
self.weight = init.constant((num_features,), "float32", 1.0) if affine else 1.0
|
||||
self.bias = init.constant((num_features,), "float32", 0.0) if affine else 0.0
|
||||
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
|
||||
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
||||
|
||||
def execute(self, x):
|
||||
dims = [0]+list(range(2,x.ndim))
|
||||
if self.is_train:
|
||||
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
|
||||
xmean = jt.mean(x, dims=dims)
|
||||
x2mean = jt.mean(x*x, dims=dims)
|
||||
if self.sync and jt.in_mpi:
|
||||
xmean = xmean.mpi_all_reduce("mean")
|
||||
x2mean = x2mean.mpi_all_reduce("mean")
|
||||
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
xvar = (x2mean-xmean*xmean).maximum(0.0)
|
||||
w = self.weight / jt.sqrt(xvar+self.eps)
|
||||
b = self.bias - xmean * w
|
||||
norm_x = x * w.broadcast(x, dims) + b.broadcast(x, dims)
|
||||
|
||||
self.running_mean.update(self.running_mean +
|
||||
(xmean.reshape((-1,)) - self.running_mean) * self.momentum)
|
||||
self.running_var.update(self.running_var +
|
||||
(xvar.reshape((-1,))-self.running_var)*self.momentum)
|
||||
else:
|
||||
running_mean = self.running_mean.broadcast(x, [0,2,3])
|
||||
running_var = self.running_var.broadcast(x, [0,2,3])
|
||||
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
|
||||
if not self.affine:
|
||||
return norm_x
|
||||
w = self.weight.broadcast(x, [0,2,3])
|
||||
b = self.bias.broadcast(x, [0,2,3])
|
||||
return norm_x * w + b
|
||||
|
||||
class BatchNorm1d(Module):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, is_train=True, sync=True):
|
||||
self.sync = sync
|
||||
self.num_features = num_features
|
||||
self.is_train = is_train
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.affine = affine
|
||||
if affine:
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
|
||||
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
||||
|
||||
def execute(self, x):
|
||||
if len(x.shape) == 3:
|
||||
dims = [0, 2]
|
||||
else:
|
||||
dims = [0]
|
||||
if self.is_train:
|
||||
xmean = jt.mean(x, dims=dims, keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=dims, keepdims=1)
|
||||
|
||||
if self.sync and jt.in_mpi:
|
||||
xmean = xmean.mpi_all_reduce("mean")
|
||||
x2mean = x2mean.mpi_all_reduce("mean")
|
||||
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
self.running_mean.update(self.running_mean +
|
||||
(xmean.sum(dims)-self.running_mean)*self.momentum)
|
||||
self.running_var.update(self.running_var +
|
||||
(xvar.sum(dims)-self.running_var)*self.momentum)
|
||||
else:
|
||||
running_mean = self.running_mean.broadcast(x, dims)
|
||||
running_var = self.running_var.broadcast(x, dims)
|
||||
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
|
||||
if not self.affine:
|
||||
w = self.weight / jt.sqrt(self.running_var+self.eps)
|
||||
b = self.bias - self.running_mean * w
|
||||
norm_x = x * w.broadcast(x, dims) + b.broadcast(x, dims)
|
||||
return norm_x
|
||||
w = self.weight.broadcast(x, dims)
|
||||
b = self.bias.broadcast(x, dims)
|
||||
return norm_x * w + b
|
||||
|
||||
class InstanceNorm2d(Module):
|
||||
BatchNorm2d = BatchNorm1d = BatchNorm
|
||||
|
||||
class InstanceNorm(Module):
|
||||
def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, is_train=True, sync=True):
|
||||
self.sync = sync
|
||||
self.num_features = num_features
|
||||
|
@ -434,47 +393,42 @@ class InstanceNorm2d(Module):
|
|||
self.momentum = momentum
|
||||
|
||||
self.affine = affine
|
||||
if self.affine:
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
self.weight = init.constant((num_features,), "float32", 1.0) if affine else 1.0
|
||||
self.bias = init.constant((num_features,), "float32", 0.0) if affine else 0.0
|
||||
|
||||
def execute(self, x):
|
||||
xmean = jt.mean(x, dims=[2,3], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[2,3], keepdims=1)
|
||||
if self.sync and jt.in_mpi:
|
||||
xmean = xmean.mpi_all_reduce("mean")
|
||||
x2mean = x2mean.mpi_all_reduce("mean")
|
||||
dims = list(range(2,x.ndim))
|
||||
xmean = jt.mean(x, dims=dims)
|
||||
x2mean = jt.mean(x*x, dims=dims)
|
||||
|
||||
xvar = jt.maximum(x2mean-xmean*xmean, 0)
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
if not self.affine:
|
||||
return norm_x
|
||||
w = self.weight.broadcast(x, [0,2,3])
|
||||
b = self.bias.broadcast(x, [0,2,3])
|
||||
return norm_x * w + b
|
||||
xvar = (x2mean-xmean*xmean).maximum(0.0)
|
||||
w = self.weight / jt.sqrt(xvar+self.eps)
|
||||
b = self.bias - xmean * w
|
||||
return x * w.broadcast(x, dims) + b.broadcast(x, dims)
|
||||
|
||||
InstanceNorm2d = InstanceNorm1d = InstanceNorm
|
||||
|
||||
class LayerNorm(Module):
|
||||
def __init__(self, normalized_shape, eps: float = 1e-5, elementwise_affine: bool = True) -> None:
|
||||
super(LayerNorm, self).__init__()
|
||||
if isinstance(normalized_shape, int):
|
||||
normalized_shape = (normalized_shape,)
|
||||
self.normalized_shape = tuple(normalized_shape)
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = init.constant(normalized_shape, "float32", 1.0)
|
||||
self.bias = init.constant(normalized_shape, "float32", 0.0)
|
||||
self.weight = init.constant(normalized_shape, "float32", 1.0) if elementwise_affine else 1.0
|
||||
self.bias = init.constant(normalized_shape, "float32", 0.0) if elementwise_affine else 0.0
|
||||
|
||||
def execute(self,x):
|
||||
def execute(self, x):
|
||||
dims = [-i for i in range(len(self.normalized_shape), 0, -1)]
|
||||
mean = jt.mean(x,dims=dims,keepdims=1)
|
||||
numerator = x-mean
|
||||
variance = jt.mean(numerator.sqr(),dims=dims,keepdims=1)
|
||||
denominator = jt.sqrt(variance+self.eps)
|
||||
norm_x = numerator/denominator
|
||||
if self.elementwise_affine:
|
||||
norm_x = norm_x * self.weight+self.bias
|
||||
return norm_x
|
||||
xmean = jt.mean(x, dims=dims)
|
||||
x2mean = jt.mean(x*x, dims=dims)
|
||||
|
||||
xvar = (x2mean-xmean*xmean).maximum(0.0)
|
||||
w = self.weight / jt.sqrt(xvar+self.eps)
|
||||
b = self.bias - xmean * w
|
||||
return x * w.broadcast(x, dims) + b.broadcast(x, dims)
|
||||
|
||||
LayerNorm2d = LayerNorm1d = LayerNorm
|
||||
|
||||
class GroupNorm(Module):
|
||||
def __init__(self, num_groups, num_channels, eps=1e-05, affine=True, is_train=True):
|
||||
|
@ -483,28 +437,32 @@ class GroupNorm(Module):
|
|||
self.eps = eps
|
||||
|
||||
self.affine = affine
|
||||
if self.affine:
|
||||
self.weight = init.constant((num_channels,), "float32", 1.0)
|
||||
self.bias = init.constant((num_channels,), "float32", 0.0)
|
||||
self.weight = init.constant((num_channels,), "float32", 1.0) if affine else 1.0
|
||||
self.bias = init.constant((num_channels,), "float32", 0.0) if affine else 0.0
|
||||
|
||||
def execute(self, x):
|
||||
N = x.shape[0]
|
||||
C = self.num_channels
|
||||
output_shape = (N,-1)
|
||||
# TODO: 3d group norm
|
||||
# TODO: 3d group norm
|
||||
if x.ndim==4:
|
||||
output_shape = x.shape
|
||||
assert C % self.num_groups == 0
|
||||
x = x.reshape((N, self.num_groups, int(C/self.num_groups), -1))
|
||||
xmean = jt.mean(x, dims=[2,3], keepdims=1)
|
||||
x2mean = jt.mean(x*x, dims=[2,3], keepdims=1)
|
||||
xvar = jt.maximum(x2mean-xmean*xmean, 0)
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
if not self.affine:
|
||||
return norm_x.reshape(output_shape)
|
||||
w = self.weight.reshape((1,self.num_groups,C//self.num_groups,1))
|
||||
b = self.bias.reshape((1,self.num_groups,C//self.num_groups,1))
|
||||
return (norm_x * w + b).reshape(output_shape)
|
||||
x = x.reshape((N, self.num_groups, C//self.num_groups, -1))
|
||||
xmean = jt.mean(x, dims=[2,3]).reshape((N, self.num_groups, 1))
|
||||
x2mean = jt.mean(x*x, dims=[2,3]).reshape((N, self.num_groups, 1))
|
||||
xvar = (x2mean-xmean*xmean).maximum(0.0)
|
||||
|
||||
if self.affine:
|
||||
w = self.weight.reshape((1, self.num_groups, -1))
|
||||
b = self.bias.reshape((1, self.num_groups, -1))
|
||||
else:
|
||||
w = 1
|
||||
b = 0
|
||||
w = w / jt.sqrt(xvar+self.eps)
|
||||
b = b - xmean * w
|
||||
x = x * w.broadcast(x, [3]) + b.broadcast(x, [3])
|
||||
return x.reshape(output_shape)
|
||||
|
||||
Relu = jt.make_module(relu)
|
||||
ReLU = Relu
|
||||
|
|
|
@ -40,6 +40,7 @@ class Pool(Module):
|
|||
count = f"int count = {self.kernel_size*self.kernel_size};"
|
||||
else:
|
||||
count = "int count = (k2_ - k2) * (k3_ - k3);"
|
||||
count += "float32 rcount = 1.0f / count;"
|
||||
else:
|
||||
count = ""
|
||||
forward_body = f'''{{
|
||||
|
@ -168,7 +169,9 @@ class AdaptiveAvgPool2d(Module):
|
|||
oh = x.shape[2] if self.output_size[0] is None else self.output_size[0]
|
||||
ow = x.shape[3] if self.output_size[1] is None else self.output_size[1]
|
||||
else:
|
||||
raise TypeError(f"AdaptiveAvgPool2d only support int, typle or list input. Not support {type(self.output_size)} yet.")
|
||||
raise TypeError(f"AdaptiveAvgPool2d only support int, tuple or list input. Not support {type(self.output_size)} yet.")
|
||||
if oh == 1 and ow == 1:
|
||||
return x.reduce("mean", [2,3], keepdims=True)
|
||||
N,C,H,W = x.shape
|
||||
self.sh = math.floor(H / oh)
|
||||
self.sw = math.floor(W / ow)
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as numpy
|
||||
|
||||
class TestMergeLoopVarPass(unittest.TestCase):
|
||||
def test(self):
|
||||
a = jt.ones([10,10,10,10])
|
||||
a.sync()
|
||||
with jt.profile_scope() as rep:
|
||||
b = a.sum([2,3])
|
||||
b.sync()
|
||||
with open(rep[1][1]) as f:
|
||||
src = f.read()
|
||||
assert "range01" in src
|
||||
assert "range23" in src
|
||||
|
||||
def test2(self):
|
||||
a = jt.ones([10,10,10,10])
|
||||
a.sync()
|
||||
with jt.profile_scope() as rep:
|
||||
b = a + 1
|
||||
b.sync()
|
||||
with open(rep[1][1]) as f:
|
||||
src = f.read()
|
||||
assert "range0123" in src
|
||||
|
||||
def test3(self):
|
||||
a = jt.ones([10,10,10,10])
|
||||
x = jt.ones([1,10,1,1])
|
||||
a.sync(), x.sync()
|
||||
with jt.profile_scope() as rep:
|
||||
b = a + x
|
||||
b.sync()
|
||||
with open(rep[1][1]) as f:
|
||||
src = f.read()
|
||||
assert "range23" in src
|
||||
|
||||
def test4(self):
|
||||
# don't optimize reindex like op yet
|
||||
a = jt.ones([10,10,10,10])
|
||||
a.sync()
|
||||
with jt.profile_scope() as rep:
|
||||
b = a.reindex_reduce("add", [10,10], ["i0","i1"])
|
||||
b.sync()
|
||||
with open(rep[1][1]) as f:
|
||||
src = f.read()
|
||||
assert "range23" not in src
|
||||
|
||||
def test5(self):
|
||||
a = jt.ones([10,10,10,10])
|
||||
a.sync()
|
||||
with jt.profile_scope() as rep:
|
||||
b = a.sum([1])
|
||||
b.sync()
|
||||
with open(rep[1][1]) as f:
|
||||
src = f.read()
|
||||
assert "range01" not in src
|
||||
assert "range23" in src
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
class TestMergeLoopVarPassCuda(TestMergeLoopVarPass):
|
||||
def setUp(self):
|
||||
jt.flags.use_cuda = 1
|
||||
def tearDown(self):
|
||||
jt.flags.use_cuda = 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -120,7 +120,7 @@ class TestOpCompiler(unittest.TestCase):
|
|||
OP1
|
||||
1+2
|
||||
std::max(T(1), T(2))
|
||||
((1)+T(2)*(T(1)/T(count)))''')
|
||||
((1)+T(2)*(T(rcount)))''')
|
||||
expect_error(lambda: jit_precompile(vars, "@{a"))
|
||||
expect_error(lambda: jit_precompile(vars, "@for(a"))
|
||||
expect_error(lambda: jit_precompile(vars, "@for(i,l,r)"))
|
||||
|
|
|
@ -97,7 +97,7 @@ class TestParallelPass3(unittest.TestCase):
|
|||
def check(ndim, depth, tdim):
|
||||
a = jt.random([16]*ndim)
|
||||
a.sync()
|
||||
compile_options = {"parallel":1}
|
||||
compile_options = {"parallel":1, "merge_loop_var": self.merge_loop_var}
|
||||
if depth is not None:
|
||||
compile_options["max_parallel_depth"] = depth
|
||||
with jt.profile_scope(compile_options=compile_options) as rep:
|
||||
|
@ -110,6 +110,7 @@ class TestParallelPass3(unittest.TestCase):
|
|||
for i in range(tdim):
|
||||
assert f"tnum{i}" in src
|
||||
assert f"tnum{tdim}" not in src
|
||||
self.merge_loop_var = 0
|
||||
check(1, None, 0)
|
||||
check(2, None, 1)
|
||||
check(3, None, 2)
|
||||
|
@ -134,7 +135,7 @@ class TestParallelPass3(unittest.TestCase):
|
|||
a = jt.random(shape)
|
||||
a.sync()
|
||||
config = {
|
||||
"parallel":1, "max_parallel_depth":depth
|
||||
"parallel":1, "max_parallel_depth":depth, "merge_loop_var": self.merge_loop_var
|
||||
}
|
||||
for k in args:
|
||||
config[k] = args[k]
|
||||
|
@ -164,6 +165,7 @@ class TestParallelPass3(unittest.TestCase):
|
|||
assert np.allclose(a.data.sum(rdim), b), (b.sum(), a.data.sum())
|
||||
|
||||
def test_reduce(self):
|
||||
self.merge_loop_var = 0
|
||||
check = lambda *a, **kw: self.reduce_check(*a, **kw)
|
||||
check(1, 2, 1, 0, 1)
|
||||
check(2, 1, 1, 1, 0)
|
||||
|
@ -185,6 +187,29 @@ class TestParallelPass3(unittest.TestCase):
|
|||
check(4, 2, 2, [2,3], 0)
|
||||
check(4, 2, 2, [0,3], 1)
|
||||
|
||||
def test_reduce_with_merge_loop_var(self):
|
||||
self.merge_loop_var = 1
|
||||
check = lambda *a, **kw: self.reduce_check(*a, **kw)
|
||||
check(1, 2, 1, 0, 1)
|
||||
check(2, 1, 1, 1, 0)
|
||||
check(2, 1, 1, 0, 1)
|
||||
check(2, 1, 1, 0, 1, [0,0])
|
||||
check(2, 1, 1, 0, 0, [0,1])
|
||||
check(2, 1, 1, 0, 0, [0,1], [0,64])
|
||||
check(2, 1, 1, [0,1], 1, [0,1])
|
||||
check(3, 1, 1, [1,2], 0)
|
||||
check(3, 1, 1, [0,1], 1)
|
||||
check(3, 1, 1, [0,1], 0, [0,0,2])
|
||||
check(3, 2, 1, [2], 0)
|
||||
if jt.flags.use_cuda:
|
||||
# loop is not merged so parallel depth 2
|
||||
check(3, 2, 2, [1], 1)
|
||||
else:
|
||||
check(3, 2, 1, [1], 0)
|
||||
check(3, 2, 2, [1], 1, merge=0)
|
||||
check(4, 2, 1, [2,3], 0)
|
||||
check(4, 2, 2, [0,3], 1)
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
def test_reduce_cuda(self):
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
|
|
|
@ -96,7 +96,7 @@ class TestResnet(unittest.TestCase):
|
|||
-jt.flags.stat_allocator_total_free_byte
|
||||
# assert mem_used < 4e9, mem_used
|
||||
# TODO: why bigger?
|
||||
assert mem_used < 5.5e9, mem_used
|
||||
assert mem_used < 5.6e9, mem_used
|
||||
# example log:
|
||||
# Train Epoch: 0 [0/100 (0%)] Loss: 2.352903 Acc: 0.110000
|
||||
# Train Epoch: 0 [1/100 (1%)] Loss: 2.840830 Acc: 0.080000
|
||||
|
@ -115,9 +115,9 @@ class TestResnet(unittest.TestCase):
|
|||
# Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000
|
||||
|
||||
if jt.in_mpi:
|
||||
assert jt.core.number_of_lived_vars() < 7500, jt.core.number_of_lived_vars()
|
||||
assert jt.core.number_of_lived_vars() < 7800, jt.core.number_of_lived_vars()
|
||||
else:
|
||||
assert jt.core.number_of_lived_vars() < 6500, jt.core.number_of_lived_vars()
|
||||
assert jt.core.number_of_lived_vars() < 6700, jt.core.number_of_lived_vars()
|
||||
|
||||
jt.sync_all(True)
|
||||
assert np.mean(loss_list[-50:])<0.5
|
||||
|
|
|
@ -43,7 +43,7 @@ class TestWhereOp(unittest.TestCase):
|
|||
x = a.reindex_var(self.where(a>0.1))
|
||||
x = x.reindex_var(self.where(x<0.9))
|
||||
na = a.data
|
||||
assert (na[np.logical_and(na>0.1, na<0.9)]==x.data).all()
|
||||
assert np.allclose(na[np.logical_and(na>0.1, na<0.9)], x.data)
|
||||
|
||||
def test_reduce_dep(self):
|
||||
a = jt.random([100,100])
|
||||
|
|
|
@ -22,49 +22,36 @@ from jittor.compiler import run_cmd
|
|||
from jittor_utils import translator
|
||||
import sys
|
||||
|
||||
jittor_path = os.path.realpath(os.path.join(jt.flags.jittor_path, "..", ".."))
|
||||
|
||||
polish_path = os.path.join(jittor_path, "..", "jittor-polish")
|
||||
polish_path = os.path.realpath(polish_path)
|
||||
build_path = polish_path + "/build"
|
||||
LOG.i("Polish path:", polish_path)
|
||||
if not os.path.isdir(polish_path):
|
||||
# create jittor-polish repo
|
||||
os.mkdir(polish_path)
|
||||
jittor_path = jt.flags.jittor_path
|
||||
root_path = os.path.realpath(os.path.join(jt.flags.jittor_path, "..", ".."))
|
||||
data_path = os.path.join(jittor_path, "src", "__data__")
|
||||
build_path = os.path.join(data_path, "build")
|
||||
if not os.path.isdir(build_path):
|
||||
os.mkdir(build_path)
|
||||
run_cmd("git init . && git remote add origin git@github.com:Jittor/Jittor.git", polish_path)
|
||||
status = run_cmd("git status", data_path)
|
||||
print(status)
|
||||
if "working tree clean" not in status:
|
||||
LOG.f("__data__ has untracked files")
|
||||
|
||||
# copy jittor src into it
|
||||
names = "extern notebook python script src README.md README.src.md README.cn.md LICENSE.txt setup.py .gitignore".split()
|
||||
for name in names:
|
||||
run_cmd(f"rsync -a {jittor_path}/{name} {polish_path}/")
|
||||
|
||||
git_version = run_cmd("git rev-parse HEAD", jittor_path)
|
||||
git_version = run_cmd("git rev-parse HEAD", data_path)
|
||||
LOG.i("git_version", git_version)
|
||||
run_cmd(f"git rev-parse HEAD > {polish_path}/python/jittor/version", jittor_path)
|
||||
|
||||
run_cmd(f"git rev-parse HEAD > {jittor_path}/version", data_path)
|
||||
|
||||
# remove files
|
||||
files = jt.compiler.files
|
||||
file_to_delete = [ name for name in files
|
||||
if name.startswith("src") and \
|
||||
len(name.split("/"))==2 and name.endswith("node.cc")
|
||||
data_files = [ name for name in files
|
||||
if "__data__" in name
|
||||
]
|
||||
LOG.i("file_to_delete", file_to_delete)
|
||||
run_cmd(f"rm {' '.join(file_to_delete)}", polish_path)
|
||||
LOG.i("data_files", data_files)
|
||||
|
||||
# commit jittor-polish
|
||||
run_cmd(f"git add .", polish_path)
|
||||
status = run_cmd(f"git status", polish_path)
|
||||
if "new file" not in status:
|
||||
LOG.i("Nothing change, exit...")
|
||||
else:
|
||||
run_cmd(f"git commit -a -m 'version {git_version}'", polish_path)
|
||||
|
||||
# compile delete files
|
||||
# compile data files
|
||||
from pathlib import Path
|
||||
home = str(Path.home())
|
||||
for cc_type in ["g++", "clang"]:
|
||||
for device in ["cpu", "cuda"]:
|
||||
# for cc_type in ["g++", "clang"]:
|
||||
# for device in ["cpu", "cuda"]:
|
||||
for cc_type in ["g++"]:
|
||||
for device in ["cpu"]:
|
||||
key = f"{git_version}-{cc_type}-{device}"
|
||||
env = f"cache_name=build/{cc_type}/{device} cc_path="
|
||||
cname = "g++" if cc_type=="g++" else "clang-8"
|
||||
|
@ -84,7 +71,7 @@ for cc_type in ["g++", "clang"]:
|
|||
|
||||
obj_path = home + f"/.cache/jittor/build/{cc_type}/{device}/{cname}/obj_files"
|
||||
obj_files = []
|
||||
for name in file_to_delete:
|
||||
for name in data_files:
|
||||
name = name.split("/")[-1]
|
||||
fname = f"{obj_path}/{name}.o"
|
||||
assert os.path.isfile(fname), fname
|
||||
|
@ -94,14 +81,17 @@ for cc_type in ["g++", "clang"]:
|
|||
# compress source
|
||||
# tar -cvzf build/jittor.tgz . --exclude build --exclude .git --exclude .ipynb_checkpoints --exclude __pycache__
|
||||
# mkdir -p jittor && tar -xvf ./jittor.tgz -C jittor
|
||||
assert os.system(f"cd {polish_path} && tar --exclude=build --exclude=.git --exclude=.ipynb_checkpoints --exclude=__pycache__ -cvzf build/jittor.tgz . ")==0
|
||||
assert os.system(f"cd {root_path} && tar --exclude=build --exclude=.git --exclude=.ipynb_checkpoints --exclude=__pycache__ --exclude=__data__ --exclude=my --exclude=dist --exclude=.vscode --exclude=.github -cvzf {build_path}/jittor.tgz * ")==0
|
||||
|
||||
# rsync to build-server
|
||||
jittor_web_base_dir = "Documents/jittor-blog/assets/"
|
||||
jittor_web_build_dir = jittor_web_base_dir + "build/"
|
||||
assert os.system(f"rsync -avPu {polish_path}/build/ jittor-web:{jittor_web_build_dir}")==0
|
||||
jittor_web_build_dir = jittor_web_base_dir
|
||||
assert os.system(f"rsync -avPu {build_path} jittor-web:{jittor_web_build_dir}")==0
|
||||
assert os.system(f"ssh jittor-web Documents/jittor-blog.git/hooks/post-update")==0
|
||||
|
||||
|
||||
# sys.exit(0)
|
||||
|
||||
# push to github
|
||||
# assert os.system(f"cd {polish_path} && git push -f origin master")==0
|
||||
|
||||
|
|
|
@ -1 +1 @@
|
|||
f9e290160bead0d5892754da56b9ad63bc316320
|
||||
84596508776983dce645fc4ef77c7f35700549d5
|
||||
|
|
|
@ -41,29 +41,46 @@ void load_fused_op(FusedOp& fused_op, vector<int>& fuse_ops, vector<Op*>& ops, i
|
|||
op->tflag = ntt;
|
||||
fused_op.ops.push_back(op);
|
||||
}
|
||||
for (Op* op : fused_op.ops) {
|
||||
uint fid1 = op->custom_data;
|
||||
uint oid = 0;
|
||||
for (Var* v : op->outputs()) {
|
||||
oid++;
|
||||
if (v->tflag != tt) {
|
||||
// this var node not belong to current execution
|
||||
// this will happend in multiple outputs fuseable op
|
||||
// v->custom_data = 0 represents this var cannot be fused
|
||||
v->custom_data = 0;
|
||||
continue;
|
||||
}
|
||||
for (auto o : v->outputs_with_index()) {
|
||||
Op* op2 = o.op;
|
||||
uint iid = o.index;
|
||||
if (op2->tflag != ntt) continue;
|
||||
uint fid2 = op2->custom_data;
|
||||
fused_op.edges.emplace_back(fid1, oid-1, fid2, iid);
|
||||
}
|
||||
}
|
||||
}
|
||||
LOGvvv << "Prepare fused_op" << fused_op.ops;
|
||||
fused_op.update_ops();
|
||||
for (Op* op : fused_op.ops) {
|
||||
uint fid1 = op->custom_data;
|
||||
int iid = 0;
|
||||
for (Var* v : op->inputs()) {
|
||||
iid++;
|
||||
int iop_id;
|
||||
int iv_id;
|
||||
if (v->_inputs.size() && v->input()->tflag == ntt) {
|
||||
auto e = v->_inputs.front();
|
||||
iop_id = e.node->custom_data;
|
||||
iv_id = e.back->index;
|
||||
} else {
|
||||
iv_id = v->custom_data >> 2;
|
||||
// add iv_id, prevent iv_id jit key overflow
|
||||
iop_id = fused_op.ops.size() + iv_id;
|
||||
}
|
||||
fused_op.edges.emplace_back(iop_id, iv_id, fid1, iid-1);
|
||||
}
|
||||
// TODO: can we remove this?
|
||||
// uint oid = 0;
|
||||
// for (Var* v : op->outputs()) {
|
||||
// oid++;
|
||||
// if (v->tflag != tt) {
|
||||
// // this var node not belong to current execution
|
||||
// // this will happend in multiple outputs fuseable op
|
||||
// // v->custom_data = 0 represents this var cannot be fused
|
||||
// v->custom_data = 0;
|
||||
// continue;
|
||||
// }
|
||||
// // for (auto o : v->outputs_with_index()) {
|
||||
// // Op* op2 = o.op;
|
||||
// // uint iid = o.index;
|
||||
// // if (op2->tflag != ntt) continue;
|
||||
// // uint fid2 = op2->custom_data;
|
||||
// // fused_op.edges.emplace_back(fid1, oid-1, fid2, iid);
|
||||
// // }
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
||||
|
|
196
src/fuser.cc
196
src/fuser.cc
|
@ -1,196 +0,0 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include "fuser.h"
|
||||
#include "var.h"
|
||||
#include "op.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "graph.h"
|
||||
#include "fused_op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#define PREVENT_LARGE_FUSED_OP 16
|
||||
|
||||
void count_fuse(int64_t tt, int start_var_num, const vector<Op*>& ops, const vector<Var*>& vars, vector<int> &father, vector<int> &var_fused) {
|
||||
vector<int> dis(ops.size(), -1);
|
||||
|
||||
auto find_fa = [&](int i) -> int {
|
||||
int j=i;
|
||||
while (father[j] != j) j = father[j];
|
||||
while (i != j) {
|
||||
int tmp = father[i];
|
||||
father[i] = j;
|
||||
i = tmp;
|
||||
}
|
||||
return j;
|
||||
};
|
||||
|
||||
auto can_fuse = [&](Var* v, Op* op1, Op* op2, int fuse_type) -> bool {
|
||||
if (v->flags.get(NodeFlags::_stop_fuse))
|
||||
return false;
|
||||
if (fuse_type == 1) {
|
||||
// if v is output, do not fuse
|
||||
if (v->custom_data < start_var_num)
|
||||
return false;
|
||||
// op2 ---> v ---> op1
|
||||
if (op1->type() == OpType::other || op2->type() == OpType::other)
|
||||
return false;
|
||||
if (v->flags.get(NodeFlags::_force_fuse))
|
||||
return true;
|
||||
// Do not fuse op after reduce(has reduce)
|
||||
// TODO: better fuse strategy
|
||||
if (op2->type() == OpType::reduce)
|
||||
return false;
|
||||
// Do not fuse op before broadcast
|
||||
// TODO: better fuse strategy
|
||||
if (op1->type() == OpType::broadcast)
|
||||
return false;
|
||||
return op2->type() == OpType::element ||
|
||||
op2->type() == OpType::broadcast;
|
||||
} else if (fuse_type == 0) {
|
||||
#ifdef PREVENT_LARGE_FUSED_OP
|
||||
// This statement prevent fuse large ops
|
||||
if (v->outputs().size()>=PREVENT_LARGE_FUSED_OP) return false;
|
||||
#endif
|
||||
|
||||
// v ---> op1
|
||||
// |
|
||||
// +----> op2 ( prev of op1 )
|
||||
if (op1->type() == OpType::other || op2->type() == OpType::other)
|
||||
return false;
|
||||
// Do not fuse op after reduce(has reduce)
|
||||
// TODO: better fuse strategy
|
||||
if (op2->type() == OpType::broadcast || op1->type() == OpType::broadcast)
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
auto for_each_edge = [&](Op* op, int forward, auto&& func){
|
||||
auto e=op->_inputs.begin();
|
||||
for (Var* v : op->inputs()) {
|
||||
if ((forward && (*e).back!=std::prev(v->_outputs.end())) ||
|
||||
(!forward && (*e).back!=v->_outputs.begin())){
|
||||
Op* next_op = forward ? std::next((*e).back)->node->op() : std::prev((*e).back)->node->op();
|
||||
if (next_op && next_op->tflag==tt
|
||||
&& next_op->custom_data != op->custom_data
|
||||
&& can_fuse(v, next_op, op, 0))
|
||||
func(v, next_op, 0);
|
||||
}
|
||||
e = std::next(e);
|
||||
}
|
||||
|
||||
if (forward) {
|
||||
for (Var* sv : op->outputs())
|
||||
if (sv && sv->tflag == tt)
|
||||
for (Op* next_op: sv->outputs())
|
||||
if (next_op && next_op->tflag==tt) func(sv, next_op, 1);
|
||||
} else {
|
||||
for (Var* sv : op->inputs())
|
||||
if (sv && sv->tflag == tt) func(sv, sv->input(), 1);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
vector<int> queue;
|
||||
vector<int> deps;
|
||||
deps.reserve(ops.size());
|
||||
queue.reserve(ops.size());
|
||||
for (uint i=0; i<ops.size(); i++) {
|
||||
deps.push_back(0);
|
||||
Op* op = ops[i];
|
||||
|
||||
for_each_edge(op, 1, [&](Var* v, Op* next_op, int real_edge) {
|
||||
deps[i]++;
|
||||
});
|
||||
|
||||
if (!deps[i]) {
|
||||
queue.push_back(i);
|
||||
dis[i]=0;
|
||||
}
|
||||
}
|
||||
|
||||
uint head=0;
|
||||
while (head<queue.size()) {
|
||||
int op_id=queue[head++];
|
||||
Op* op = ops[op_id];
|
||||
|
||||
for_each_edge(op, 1, [&](Var* v, Op* next_op, int real_edge) {
|
||||
int next_id = next_op->custom_data;
|
||||
if (dis[next_id] == dis[op_id]){
|
||||
int next_fa = find_fa(next_id);
|
||||
father[next_fa] = op_id;
|
||||
}
|
||||
});
|
||||
|
||||
for_each_edge(op, 0, [&](Var* v, Op* next_op, int real_edge) {
|
||||
int next_id = next_op->custom_data;
|
||||
int lon=0;
|
||||
if (real_edge && !can_fuse(v, op, next_op, 1)) lon=1;
|
||||
if (dis[op_id]+lon>dis[next_id])
|
||||
dis[next_id]=dis[op_id]+lon;
|
||||
if (!--deps[next_id]) queue.push_back(next_id);
|
||||
});
|
||||
}
|
||||
|
||||
if (V_ON(1000)) {
|
||||
for (uint i=0; i<ops.size(); i++)
|
||||
LOGvvvv << ops[i] << dis[i] << deps[i];
|
||||
}
|
||||
for (uint i=0; i<vars.size(); i++) {
|
||||
Var* v = vars[i];
|
||||
if (!v || v->tflag!=tt) {
|
||||
var_fused[i]=1;
|
||||
continue;
|
||||
}
|
||||
// sf: input op's father id
|
||||
int sf = -1;
|
||||
// vf: is input op can be fused with all output op
|
||||
int vf = 1;
|
||||
// all outputs are reduce
|
||||
int all_reduce = 1;
|
||||
Op* iop = v->input();
|
||||
// if (iop && iop->tflag==tt)
|
||||
sf = find_fa(iop->custom_data);
|
||||
|
||||
for (Op* sop : v->outputs())
|
||||
if (sop->tflag==tt) {
|
||||
if (vf && !can_fuse(v,sop,iop,1))
|
||||
vf = 0;
|
||||
if (sop->type()!=OpType::reduce)
|
||||
all_reduce = 0;
|
||||
// in two different fused op
|
||||
if (find_fa(sop->custom_data)!=sf) {
|
||||
var_fused[i]=1;
|
||||
}
|
||||
}
|
||||
if (vf==0)
|
||||
// cannot fused
|
||||
var_fused[i]=1;
|
||||
else if (var_fused[i]) {
|
||||
if (iop->type()==OpType::broadcast ||
|
||||
all_reduce ||
|
||||
v->flags.get(NodeFlags::_force_fuse))
|
||||
// strong fused
|
||||
var_fused[i] = 3;
|
||||
else
|
||||
// weak fused
|
||||
var_fused[i] = 2;
|
||||
// var_fused[i] = 3;
|
||||
}
|
||||
}
|
||||
// output vars can not be fused
|
||||
for (int i=0; i<start_var_num; i++)
|
||||
var_fused[i] = 1;
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -280,4 +280,20 @@ inline bool operator!=(const NanoVector& a, const NanoVector& b) {
|
|||
return ne(a, b);
|
||||
}
|
||||
|
||||
inline bool operator<(const NanoVector& a, const NanoVector& b) {
|
||||
return a.data < b.data || (a.data == b.data && a.offset < b.offset);
|
||||
}
|
||||
|
||||
} // jittor
|
||||
|
||||
|
||||
namespace std {
|
||||
template<> struct hash<jittor::NanoVector> {
|
||||
inline std::size_t operator()(jittor::NanoVector const& s) const noexcept {
|
||||
std::size_t h1 = std::hash<jittor::int64>{}(s.data);
|
||||
std::size_t h2 = std::hash<jittor::int64>{}(s.offset);
|
||||
return h1 ^ (h2 << 1);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -92,7 +92,7 @@ JIT_TEST(ring_buffer_benchmark) {
|
|||
LOGi << tt << tt*1.0/n;
|
||||
LOGi << s << (n*(n-1)/2);
|
||||
ASSERTop(s,==,(n*(n-1)/2));
|
||||
ASSERTop(tt*1.0/n,<=,50);
|
||||
ASSERTop(tt*1.0/n,<=,100);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -88,6 +88,8 @@ struct RingBuffer {
|
|||
static RingBuffer* make_ring_buffer(uint64 size, bool multiprocess);
|
||||
static void free_ring_buffer(RingBuffer* rb);
|
||||
|
||||
inline void clear() { l = r = is_stop = 0; }
|
||||
|
||||
inline void wait() {
|
||||
if (is_stop) {
|
||||
throw std::runtime_error("stop");
|
||||
|
|
|
@ -287,8 +287,12 @@ std::ostream& operator<<(std::ostream& os, const Op* op) {
|
|||
os << ')';
|
||||
#ifdef NODE_MEMCHECK
|
||||
os << '<' << op->__id() << '>';
|
||||
print_node_trace(op, os);
|
||||
#endif
|
||||
if (trace_py_var) {
|
||||
os << '{';
|
||||
print_node_trace(op, os);
|
||||
os << '}';
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
|
|
|
@ -75,8 +75,19 @@ ArrayOp::ArrayOp(ArrayArgs&& args) {
|
|||
}
|
||||
|
||||
void ArrayOp::jit_prepare(JK& jk) {
|
||||
if (output->flags.get(NodeFlags::_force_fuse))
|
||||
if (output->flags.get(NodeFlags::_force_fuse)) {
|
||||
jk << _CS("[T:") << output->dtype() << ']';
|
||||
|
||||
// fill or find cbuffer for const var pass
|
||||
if (output->dtype().dsize() == 4) {
|
||||
auto x = abs(ptr<int32>()[0]);
|
||||
auto y = abs(ptr<float32>()[0]);
|
||||
auto z = ptr<uint32>()[0];
|
||||
if ((x<=2) || (y==1.0f || y==2.0f))
|
||||
jk << _CS("[o:") << z << ']';
|
||||
}
|
||||
// end of fill cbuffer
|
||||
}
|
||||
}
|
||||
|
||||
void ArrayOp::run() {
|
||||
|
|
|
@ -126,8 +126,8 @@ VarPtr BinaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
}
|
||||
if (ns == ns_maximum || ns == ns_minimum) {
|
||||
auto zeros = make_number(0, dout);
|
||||
auto cond = make_binary(x, y, ns_greater_equal);
|
||||
if ((ns == ns_maximum) == (v_index==0))
|
||||
auto cond = make_binary(y, z, ns_equal);
|
||||
if (v_index==1)
|
||||
return make_ternary(cond, dout, zeros);
|
||||
else
|
||||
return make_ternary(cond, zeros, dout);
|
||||
|
|
|
@ -38,7 +38,7 @@ namespace jittor {
|
|||
#define bitwise_and(T,a,b) ((a)&(b))
|
||||
#define bitwise_or(T,a,b) ((a)|(b))
|
||||
#define bitwise_xor(T,a,b) ((a)^(b))
|
||||
#define mean(T,a,b) ((a)+T(b)*(T(1)/T(count)))
|
||||
#define mean(T,a,b) ((a)+T(b)*(T(rcount)))
|
||||
|
||||
#ifdef JIT_cuda
|
||||
#define init_maximum(T) ::numeric_min<T>()
|
||||
|
|
|
@ -72,9 +72,9 @@ void IndexOp::jit_run() {
|
|||
@for(d, 0, XDIM, for (index_t i@d=0; i@d < x0shape@d; i@d++)) {
|
||||
auto xid = @for(d, 0, XDIM, + i@d * x0stride@d);
|
||||
@if(DIM==XDIM,
|
||||
@for(i,0,XDIM, x@i@@p[xid] = i@i;)
|
||||
@for(i,0,XDIM, T x@i@@id = i@i; x@i@@p[xid] = x@i@@id;)
|
||||
,
|
||||
x0p[xid] = i@DIM;
|
||||
T x@DIM@@id = i@DIM; x0p[xid] = x@DIM@@id;
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -158,6 +158,7 @@ void ReduceOp::jit_run() {
|
|||
index_t xstride@{DIM-1} = 1;
|
||||
@for(i, DIM-2, -1, -1, auto xstride@i = xstride@{i+1} * xshape@{i+1};)
|
||||
Ty count = Ty(x->num) / Ty(y->num);
|
||||
Ty rcount = Ty(y->num) / Ty(x->num);
|
||||
@for(d, 0, DIM,@if(REDUCE>>d&1,, for (index_t xi@d=0; xi@d < xshape@d; xi@d++))) {
|
||||
auto yid = 0 @for(d, 0, DIM,@if(REDUCE>>d&1,, + xi@d * ystride@d));
|
||||
yp[yid] = @expand_macro(init_@OP, Ty);
|
||||
|
@ -170,7 +171,7 @@ void ReduceOp::jit_run() {
|
|||
yp[yid] = @expand_macro(@OP, Ty, yp[yid], xp[xid]);
|
||||
}
|
||||
}
|
||||
(void)count, (void)yshape0, (void)ystride0;
|
||||
(void)count, (void)rcount, (void)yshape0, (void)ystride0;
|
||||
}
|
||||
#endif // JIT
|
||||
|
||||
|
|
|
@ -64,8 +64,11 @@ void TernaryOp::jit_run() {
|
|||
auto* __restrict__ yp = y->ptr<Ty>();
|
||||
auto* __restrict__ zp = z->ptr<Tz>();
|
||||
index_t num = z->num;
|
||||
for (index_t i=0; i<num; i++)
|
||||
zp[i] = condp[i] ? xp[i] : yp[i];
|
||||
for (index_t i=0; i<num; i++) {
|
||||
Tz xd_ = xp[i];
|
||||
Tz yd_ = yp[i];
|
||||
zp[i] = condp[i] ? xd_ : yd_;
|
||||
}
|
||||
}
|
||||
#endif // JIT
|
||||
|
||||
|
|
|
@ -981,7 +981,7 @@ bool match(
|
|||
return false;
|
||||
if (!ze.first && s->children.size() != t->children.size())
|
||||
return false;
|
||||
int n = s->children.size();
|
||||
int n = s->is(_op) ? s->children.size() : 1;
|
||||
int m = t->children.size();
|
||||
unique_ptr<Expr> zep;
|
||||
if (ze.first) {
|
||||
|
@ -1013,8 +1013,16 @@ bool match(
|
|||
}
|
||||
return true;
|
||||
};
|
||||
if (s->is_not(_op) && t->str == "*" && s->str == "0") {
|
||||
// 0 match 0*a*b
|
||||
for (int i=0; i<m; i++) {
|
||||
if (check_match(0,i))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
int asso_wildcard_id = -1, asso_wildcard_tid=0;
|
||||
if (s->is(_asso_op)) {
|
||||
if (t->is(_asso_op)) {
|
||||
// wildcard assosiative id
|
||||
// a*b+c <----
|
||||
for (int i=m-1; i>=0; i--) {
|
||||
|
@ -1041,11 +1049,10 @@ bool match(
|
|||
}
|
||||
}
|
||||
}
|
||||
LOGvvvv << "asso_wildcard_id" << asso_wildcard_id <<
|
||||
t->children.at(asso_wildcard_id);
|
||||
LOGvvvv << "asso_wildcard_id" << asso_wildcard_id;
|
||||
// asso_wildcard_id = -1;
|
||||
}
|
||||
if (s->is(_comm_op)) {
|
||||
if (t->is(_comm_op)) {
|
||||
// is commutative op, children can be matched in any order
|
||||
vector<bool> is_matched(m);
|
||||
for (int i=0; i<n; i++) {
|
||||
|
@ -1056,7 +1063,11 @@ bool match(
|
|||
if (check_match(i, j)) {
|
||||
is_matched[j] = true;
|
||||
matched = true;
|
||||
break;
|
||||
// if i is zero elem
|
||||
if (ze.first && i == n-1)
|
||||
continue;
|
||||
else
|
||||
break;
|
||||
}
|
||||
}
|
||||
for (int _=0; _<results.size(); _++)
|
||||
|
@ -1073,6 +1084,7 @@ bool match(
|
|||
if (!check_match(i, j)) {
|
||||
return false;
|
||||
}
|
||||
is_matched[j] = true;
|
||||
if (bk)
|
||||
res = make_op(s->str, move(bk), move(res));
|
||||
continue;
|
||||
|
@ -1083,8 +1095,15 @@ bool match(
|
|||
}
|
||||
}
|
||||
for (int j=0; j<is_matched.size(); j++)
|
||||
if (j!=asso_wildcard_id && !is_matched[j])
|
||||
return false;
|
||||
//
|
||||
// if (j!=asso_wildcard_id && !is_matched[j])
|
||||
// return false;
|
||||
if (!is_matched[j]) {
|
||||
if (!ze.first)
|
||||
return false;
|
||||
if (!check_match(n-1, j))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
// not a commutative op, match in the same order
|
||||
|
|
|
@ -919,6 +919,7 @@ void KernelIR::check_unused() {
|
|||
}
|
||||
|
||||
void KernelIR::find_used(KernelIR* def, vector<KernelIR*>& used) {
|
||||
if (has_attr("raw")) return;
|
||||
const char* ss[] = {"code", "rvalue", "rvalue2"};
|
||||
for (const char* s : ss) {
|
||||
auto& code = get_attr(s);
|
||||
|
|
|
@ -1,318 +0,0 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors:
|
||||
// Guowei Yang <471184555@qq.com>
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <sstream>
|
||||
#include <omp.h>
|
||||
#include "var.h"
|
||||
#include "opt/expr.h"
|
||||
#include "opt/pass_manager.h"
|
||||
#include "opt/pass/atomic_tuner_pass.h"
|
||||
#include "opt/pass/loop_var_analyze_pass.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
/*
|
||||
move a statements and its relied statements from inner loop to outer loop:
|
||||
|
||||
for ... // outer loop
|
||||
for ...
|
||||
for ... // inner loop
|
||||
statement_not_rely
|
||||
statement_x
|
||||
statement_y // def
|
||||
|
||||
-->
|
||||
|
||||
statement_x
|
||||
for ... // outer loop
|
||||
for ...
|
||||
for ... // inner loop
|
||||
statement_not_rely
|
||||
|
||||
statement_y // def
|
||||
|
||||
*/
|
||||
static void move_rely(KernelIR* inner_loop, KernelIR* outer_loop, KernelIR* def){
|
||||
// move all dependence of def from inner_loop to outer_loop
|
||||
vector<KernelIR*> q{def};
|
||||
map<KernelIR*, int> visited;
|
||||
visited[def]=1;
|
||||
outer_loop->push_front(def->move_out(), &outer_loop->before);
|
||||
for (int i=0; i<q.size(); i++) {
|
||||
auto e = expr::make(q[i]->attrs["rvalue"]);
|
||||
LOGvvvv << "move_rely" << e->to_string();
|
||||
e->dfs([&](expr::Expr* a) {
|
||||
if (!a->is_sym()) return;
|
||||
auto ir = inner_loop->find_define(a->str);
|
||||
if (ir==nullptr) return;
|
||||
if (!ir->father) return;
|
||||
// TODO: definition between inner loop and outer loop
|
||||
if (ir->father != inner_loop) return;
|
||||
if (!visited.count(ir)) {
|
||||
outer_loop->push_front(ir->move_out(), &outer_loop->before);
|
||||
q.push_back(ir);
|
||||
visited[ir]=1;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// find init value of correspondence op and var
|
||||
string find_init_value(Op* op, Var* var, bool is_cuda) {
|
||||
// example: reindex_reduce.minimun
|
||||
auto names = split(op->name_ex(), ".");
|
||||
ASSERT(names.size()==2) << names;
|
||||
// find init value, such as
|
||||
// * add: tmp = 0
|
||||
// * min: tmp = numeric_max<float32>
|
||||
// the init value is load from binary_op_defs.h header file
|
||||
auto init_code = OpCompiler::precompile(
|
||||
{
|
||||
{"OP",names.back()},
|
||||
{"T", var->dtype().to_cstring()},
|
||||
{is_cuda?"JIT_cuda":"JIT_cpu", "1"}
|
||||
},
|
||||
"#include \"ops/binary_op_defs.h\"\n@expand_macro(init_@OP, @T)");
|
||||
return init_code;
|
||||
}
|
||||
|
||||
// sorder: Array that saves the allocation order of "tn"
|
||||
// sfunc: Array of function names
|
||||
static void tune_atomic(Pass* pass, KernelIR* ir, bool is_cuda, int tdim, vector<vector<int>> &sorder, vector<string> &sfunc) {
|
||||
LOGvvvv << "tune_atomic" << ir->children;
|
||||
vector<string> relys;
|
||||
vector<string> idx_name;
|
||||
vector<KernelIR*> atomics;
|
||||
vector<KernelIR*> loops;
|
||||
vector<int> nrely;
|
||||
vector<int> order;
|
||||
int tmp_cnt=0;
|
||||
for (uint i=0; i<ir->children.size(); i++) {
|
||||
auto& c = ir->children[i];
|
||||
if (c->type != "loop") continue;
|
||||
relys.clear();
|
||||
idx_name.clear();
|
||||
atomics.clear();
|
||||
loops.clear();
|
||||
order.clear();
|
||||
nrely.clear();
|
||||
|
||||
c->dfs([&](unique_ptr<KernelIR>& p) {
|
||||
auto& code = p->attrs["code"];
|
||||
if (code.find("atomic")!=-1 && p->has_attr("rely")){
|
||||
atomics.push_back(p.get());
|
||||
}
|
||||
});
|
||||
if (atomics.size()==0) continue;
|
||||
|
||||
// get loops & idx_name
|
||||
KernelIR* loop = c.get();
|
||||
loops.push_back(loop);
|
||||
idx_name.push_back(loop->attrs["lvalue"]);
|
||||
order.push_back(loops.size()-1);
|
||||
nrely.push_back(-1);
|
||||
bool ok = true;
|
||||
while (1) {
|
||||
loop = loops.back();
|
||||
KernelIR* loop2 = nullptr;
|
||||
for (auto& p : loop->children) {
|
||||
if (p->type != "loop")
|
||||
continue;
|
||||
// TODO: only support single loop children
|
||||
if (loop2 != nullptr) ok = false;
|
||||
loop2 = p.get();
|
||||
}
|
||||
if (loop2 == nullptr) break;
|
||||
// TODO: only support single loop children
|
||||
if (loop->children.size() != 1) ok = false;
|
||||
if (!ok) break;
|
||||
ASSERT(loop->children.size()==1);
|
||||
loops.push_back(loop2);
|
||||
idx_name.push_back(loop2->attrs["lvalue"]);
|
||||
order.push_back(loops.size()-1);
|
||||
nrely.push_back(-1);
|
||||
}
|
||||
// TODO: only support single loop children
|
||||
if (!ok) continue;
|
||||
|
||||
// reorder
|
||||
for (uint j=0;j<atomics.size();j++) {
|
||||
KernelIR* p=atomics[j];
|
||||
auto si=split(p->get_attr("rely"),",");
|
||||
for (int k=(int)si.size()-2;k>=0;k--) {
|
||||
// ignore empty string
|
||||
if (!si[k].size())
|
||||
continue;
|
||||
int sidx=-1;
|
||||
int sord=-1;
|
||||
for (uint l=0;l<idx_name.size();l++)
|
||||
if (idx_name[l]==si[k]) sidx=l;
|
||||
ASSERT(sidx != -1);
|
||||
for (uint l=0;l<order.size();l++)
|
||||
if (order[l]==sidx) sord=l;
|
||||
ASSERT(sord != -1);
|
||||
for (int l=sord;l;l--){
|
||||
order[l]=order[l-1];
|
||||
nrely[l]=nrely[l-1];
|
||||
}
|
||||
order[0]=sidx;
|
||||
nrely[0]=j;
|
||||
}
|
||||
}
|
||||
LOGvvvv << "atomic tuner order" << order;
|
||||
|
||||
vector<int> tnorder;
|
||||
uint si;
|
||||
for (si=0;si<order.size();si++)
|
||||
if (nrely[si]!=nrely[0]) break;
|
||||
for (int j=si-1;j>=0;j--) tnorder.push_back(order[j]);
|
||||
for (int j=order.size()-1;j>=si;j--) tnorder.push_back(order[j]);
|
||||
sorder.push_back(tnorder);
|
||||
sfunc.push_back(ir->attrs["lvalue"]);
|
||||
|
||||
// sort loop with order
|
||||
int count=0;
|
||||
for (auto j : order) {
|
||||
uint k;
|
||||
for (k=count; k<loops.size(); k++)
|
||||
if (loops[k]->check_attr("loop_id", S(j)))
|
||||
break;
|
||||
if (k<loops.size())
|
||||
loops[k]->swap(*loops[count++]);
|
||||
}
|
||||
|
||||
// move atomic
|
||||
for (uint j=0;j<atomics.size();j++) {
|
||||
KernelIR* p=atomics[j];
|
||||
auto si=split(p->get_attr("rely"),",");
|
||||
int sidx=-1;
|
||||
for (int k=si.size()-2;k>=0;k--)
|
||||
for (int l=0;l<order.size();l++)
|
||||
if (idx_name[order[l]]==si[k] && l>sidx) sidx=l;
|
||||
|
||||
vector<unique_ptr<expr::Expr>> results;
|
||||
string stmp = "tmp"+std::to_string(tmp_cnt++);
|
||||
auto& code = p->attrs["code"];
|
||||
LOGvvvv << "atomic code" << code;
|
||||
auto e = expr::make(code.substr(0, code.size()-1));
|
||||
// add atomic code
|
||||
auto check = [&](const string& t, const vector<string>& args, const string& cpu, const string& cuda, const string& acpu, const string& acuda) -> bool {
|
||||
auto target = is_cuda ? expr::make(cuda) : expr::make(cpu);
|
||||
if (!expr::match(e.get(), target.get(), args, {}, results))
|
||||
return false;
|
||||
unordered_map<string,string> defs;
|
||||
for (int i=0; i<args.size(); i++)
|
||||
defs[args[i]] = results[i]->to_string();
|
||||
|
||||
string a=defs["a"];
|
||||
if (!expr::match(expr::make(a).get(), expr::make("(c[d])").get(), {"c","d"}, {}, results))
|
||||
return false;
|
||||
// dvar[didx]
|
||||
string dvar=results[0]->to_string();
|
||||
string didx=results[1]->to_string();
|
||||
|
||||
auto def=p->father->find_define(didx);
|
||||
ASSERT(def != nullptr);
|
||||
if (sidx>=0 && def->father == loops[sidx])
|
||||
return true;
|
||||
auto& loop_i = loops.at(sidx+1);
|
||||
code = OpCompiler::precompile(defs, t) + ";";
|
||||
loop_i->push_back(
|
||||
OpCompiler::precompile(defs, is_cuda ? acuda : acpu) + ";",
|
||||
&loop_i->after);
|
||||
uint op_id, opvar_id;
|
||||
Op* op;
|
||||
Var* var;
|
||||
pass->pm->oc->get_op_var_by_name(dvar.substr(0,dvar.length()-1), op_id, opvar_id, op, var);
|
||||
auto init_code = find_init_value(op, var, is_cuda);
|
||||
loop_i->push_back(string(var->dtype().to_cstring())+" "+stmp+"="+init_code+";", &loop_i->before);
|
||||
string sa=is_cuda ? cuda : cpu;
|
||||
LOGvvv << "atomictuner: move "+sa.substr(0,sa.find("("))+" to loop "+std::to_string(sidx);
|
||||
move_rely(def->father, loop_i, def);
|
||||
return true;
|
||||
};
|
||||
string sstd=is_cuda ? "" : "std";
|
||||
if (
|
||||
check(stmp+"="+stmp+"+@b", {"a","b"}, "cpu_atomic_add(&a,b)", "atomicAdd(&a,b)", "cpu_atomic_add(&@a,"+stmp+")", "atomicAdd(&@a,"+stmp+")") ||
|
||||
check(stmp+"="+stmp+"-@b", {"a","b"}, "cpu_atomic_sub(&a,b)", "atomicSub(&a,b)", "cpu_atomic_sub(&@a,"+stmp+")", "atomicSub(&@a,"+stmp+")") ||
|
||||
check(stmp+"="+stmp+"*@b", {"a","b"}, "cpu_atomic_mul(&a,b)", "cuda_atomic_mul(&a,b)", "cpu_atomic_mul(&@a,"+stmp+")", "cuda_atomic_mul(&@a,"+stmp+")") ||
|
||||
check(stmp+"="+sstd+"::max(@T@@("+stmp+"),@T@@(@b))", {"a","b","T"}, "cpu_atomic_max(&a,T(b))", "cuda_atomic_max(&a,T(b))", "cpu_atomic_max(&@a,@T@@("+stmp+"))", "cuda_atomic_max(&@a,@T@@("+stmp+"))") ||
|
||||
check(stmp+"="+sstd+"::max(@T@@("+stmp+"),@T@@(@b))", {"a","b","T"}, "cpu_atomic_max(&a,T(b))", "cuda_atomic_max(&a,T(b))", "cpu_atomic_max(&@a,@T@@("+stmp+"))", "cuda_atomic_max(&@a,@T@@("+stmp+"))") ||
|
||||
check(stmp+"="+sstd+"::min(@T@@("+stmp+"),@T@@(@b))", {"a","b","T"}, "cpu_atomic_min(&a,T(b))", "cuda_atomic_min(&a,T(b))", "cpu_atomic_min(&@a,@T@@("+stmp+"))", "cuda_atomic_min(&@a,@T@@("+stmp+"))") ||
|
||||
check(stmp+"="+sstd+"::min(@T@@("+stmp+"),@T@@(@b))", {"a","b","T"}, "cpu_atomic_min(&a,T(b))", "cuda_atomic_min(&a,T(b))", "cpu_atomic_min(&@a,@T@@("+stmp+"))", "cuda_atomic_min(&@a,@T@@("+stmp+"))") ||
|
||||
check(stmp+"="+stmp+"&@b", {"a","b"}, "cpu_atomic_and(&a,b)", "atomicAnd(&a,b)", "cpu_atomic_and(&@a,"+stmp+")", "atomicAnd(&@a,"+stmp+")") ||
|
||||
check(stmp+"="+stmp+"|@b", {"a","b"}, "cpu_atomic_or(&a,b)", "atomicOr(&a,b)", "cpu_atomic_or(&@a,"+stmp+")", "atomicOr(&@a,"+stmp+")") ||
|
||||
check(stmp+"="+stmp+"^@b", {"a","b"}, "cpu_atomic_xor(&a,b)", "atomicXor(&a,b)", "cpu_atomic_xor(&@a,"+stmp+")", "atomicXor(&@a,"+stmp+")") ||
|
||||
check(stmp+"="+stmp+"&&@b", {"a","b"}, "cpu_atomic_and(&a,bool(b))", "atomicAnd(&a,bool(b))", "cpu_atomic_and(&@a,bool("+stmp+"))", "atomicAnd(&@a,bool("+stmp+"))") ||
|
||||
check(stmp+"="+stmp+"||@b", {"a","b"}, "cpu_atomic_or(&a,bool(b))", "atomicOr(&a,bool(b))", "cpu_atomic_or(&@a,bool("+stmp+"))", "atomicOr(&@a,bool("+stmp+"))") ||
|
||||
check(stmp+"=((bool("+stmp+"))!=(bool(@b)))", {"a","b"}, "cpu_atomic_xor(&@a,bool(@b))", "atomicXor(&@a,bool(@b))", "cpu_atomic_xor(&@a,bool(@b))", "atomicXor(&@a,bool(@b))")
|
||||
) continue;
|
||||
LOGf << "Atomic not match" << e;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void AtomicTunerPass::run() {
|
||||
auto choice = op->get_loop_option("parallel");
|
||||
bool is_cuda = op->flags.get(NodeFlags::_cuda);
|
||||
if (is_cuda) choice=1;
|
||||
if (!choice) return;
|
||||
|
||||
vector<vector<int>> sorder;
|
||||
vector<string> sfunc;
|
||||
for (uint i=0; i<ir->before.size(); i++) {
|
||||
auto& func_call = ir->before[i];
|
||||
// TODO: remove this if
|
||||
if (func_call->get_attr("dtype").find("__global__ void") == string::npos) continue;
|
||||
tune_atomic(this, func_call.get(), is_cuda, 4, sorder, sfunc);
|
||||
}
|
||||
|
||||
// Re-adjust the allocation order of "tn" according to the situation of atomic coverage, preferentially allocate the range not covered by atomic, for example:
|
||||
// for (op0_index_t id0 = tid0; id0<range0; id0+=tnum0) {
|
||||
// for (op1_index_t id1 = tid1; id1<range1; id1+=tnum1) {
|
||||
// for (op2_index_t id2 = tid2; id2<range2; id2+=tnum2) {
|
||||
// for (op3_index_t id3 = tid3; id3<range3; id3+=tnum3) {
|
||||
// ...
|
||||
// }
|
||||
// }
|
||||
// atomicAdd(...);
|
||||
// }
|
||||
// }
|
||||
// The allocation order of "tn" will be: tn1, tn0, tn3, tn2
|
||||
for (uint j=0;j<sfunc.size();j++)
|
||||
for (uint i=0; i<ir->children.size(); i++) {
|
||||
auto& func_call = ir->children[i];
|
||||
int bo=0;
|
||||
for (uint k=0; k<func_call->children.size(); k++){
|
||||
auto& save = func_call->children[k];
|
||||
if (save->has_attr("loop_func") && save->attrs["loop_func"]==sfunc[j]){
|
||||
bo=1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!bo) continue;
|
||||
uint k;
|
||||
for (k=0; k<func_call->children.size(); k++){
|
||||
auto& save = func_call->children[k];
|
||||
if (save->has_attr("lvalue") && save->attrs["lvalue"].find("tn")==0) break;
|
||||
}
|
||||
for (uint l=0;l<sorder[j].size();l++){
|
||||
for (uint p=0; p<func_call->children.size(); p++){
|
||||
auto& save = func_call->children[p];
|
||||
if (save->has_attr("lvalue") && save->attrs["lvalue"].find("tn"+S(sorder[j][l]))==0){
|
||||
func_call->children[p]->swap(*func_call->children[k++]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ir->remove_all_unused();
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,52 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <sstream>
|
||||
#include "opt/expr.h"
|
||||
#include "var.h"
|
||||
#include "opt/pass_manager.h"
|
||||
#include "opt/pass/const_var_pass.h"
|
||||
#include "ops/array_op.h"
|
||||
#include "jit_key.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
using namespace expr;
|
||||
|
||||
void ConstVarPass::run() {
|
||||
int changed = 0;
|
||||
for (int i=0; i<op->ops.size(); i++) {
|
||||
auto opi = op->ops[i];
|
||||
if (opi->name() != string("array"))
|
||||
continue;
|
||||
string s;
|
||||
auto* v = opi->output(0);
|
||||
if (v->num != 1)
|
||||
continue;
|
||||
auto array_op = (ArrayOp*)opi;
|
||||
jk.clear();
|
||||
array_op->jit_prepare(jk);
|
||||
if (jk.to_string().find("[o:") == string::npos)
|
||||
continue;
|
||||
if (v->dtype() == ns_int32) {
|
||||
s = S(array_op->ptr<int32>()[0]);
|
||||
} else
|
||||
if (v->dtype() == ns_float32) {
|
||||
s = S(array_op->ptr<float32>()[0]);
|
||||
} else
|
||||
continue;
|
||||
auto def = ir->find_define("op"+S(i)+"_outputd");
|
||||
ASSERT(def);
|
||||
def->attrs["dtype"] = v->dtype().to_cstring();
|
||||
def->attrs["rvalue"] = s;
|
||||
changed ++;
|
||||
LOGvvvv << def->to_string();
|
||||
}
|
||||
if (changed) {
|
||||
ir->remove_all_unused();
|
||||
}
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,16 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "opt/pass/pass.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct ConstVarPass : Pass {
|
||||
ConstVarPass() : Pass("const_var_pass") {};
|
||||
void run() override;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -19,6 +19,11 @@ void LoopToFuncPass::run() {
|
|||
if (cc_type=="clang") choice=1;
|
||||
if (!choice) return;
|
||||
int func_num=0;
|
||||
string hash_name;
|
||||
std::stringstream ss;
|
||||
op->do_prepare(jk);
|
||||
ss << std::hex << std::hash<string>()(jk.to_string());
|
||||
hash_name = ss.str();
|
||||
|
||||
ir->push_back("using namespace jittor;", &ir->before);
|
||||
if ((cc_type=="icc" || cc_type=="g++") && choice)
|
||||
|
@ -41,7 +46,7 @@ void LoopToFuncPass::run() {
|
|||
continue;
|
||||
|
||||
// func definition
|
||||
ir->push_back("INLINE_FUNC func"+S(func_num++)+"() {}", &ir->before);
|
||||
ir->push_back("INLINE_FUNC func_"+hash_name+"_"+S(func_num++)+"() {}", &ir->before);
|
||||
auto& func = ir->before.back();
|
||||
|
||||
// generate function arguments
|
||||
|
@ -97,7 +102,7 @@ void LoopToFuncPass::run() {
|
|||
auto& fc = ir->children[i];
|
||||
fc->attrs["loop_func"] = func->attrs["lvalue"];
|
||||
}
|
||||
ir->remove_all_unused();
|
||||
// ir->remove_all_unused();
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -12,6 +12,8 @@
|
|||
|
||||
namespace jittor {
|
||||
|
||||
DECLARE_FLAG(int, para_opt_level);
|
||||
|
||||
void LoopVarAnalyzePass::run() {
|
||||
// loop_vars: opi_xx->shape[j]
|
||||
vector<string> loop_vars;
|
||||
|
@ -183,13 +185,45 @@ void LoopVarAnalyzePass::run() {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (para_opt_level) {
|
||||
map<Var*, Op*> same_inputs;
|
||||
for (auto o : op->ops) {
|
||||
if (!pm->oc->op_exist(o))
|
||||
continue;
|
||||
int i_id = 0;
|
||||
for (auto i : o->inputs()) {
|
||||
i_id ++;
|
||||
auto fi_id = op->get_node_id(i);
|
||||
if (op->vars.at(fi_id).type != 0)
|
||||
continue;
|
||||
if (same_inputs.count(i)) {
|
||||
auto j = same_inputs[i];
|
||||
auto name1 = pm->oc->get_name_by_op_input(o, i_id-1);
|
||||
auto name2 = pm->oc->get_name_by_op_var(j, i);
|
||||
if (name1[0] == '_' || name2[0] == '_')
|
||||
continue;
|
||||
// replace name1 -> name2
|
||||
replace_vars.emplace_back(name1+'p', name2+'p');
|
||||
} else {
|
||||
auto name2 = pm->oc->get_name_by_op_var(o, i);
|
||||
if (name2[0] == '_')
|
||||
continue;
|
||||
same_inputs[i] = o;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& t : op->edges) {
|
||||
uint i,j,k,l;
|
||||
std::tie(i,j,k,l) = t;
|
||||
// virtual op holds all inputs
|
||||
if (i>=op->ops.size())
|
||||
continue;
|
||||
// loop var may not exist(relayed)
|
||||
auto opa = op->ops[i];
|
||||
auto opb = op->ops[k];
|
||||
auto opa = op->ops.at(i);
|
||||
auto opb = op->ops.at(k);
|
||||
if (!pm->oc->op_exist(opa) || !pm->oc->op_exist(opb))
|
||||
continue;
|
||||
// replace op{j}_{kname}* -> op{i}_{oname}*
|
||||
|
|
|
@ -0,0 +1,154 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <sstream>
|
||||
#include "opt/expr.h"
|
||||
#include "var.h"
|
||||
#include "opt/pass_manager.h"
|
||||
#include "opt/pass/merge_loop_var_pass.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
using namespace expr;
|
||||
|
||||
static unique_ptr<expr::Expr> trace_and_expand(KernelIR* ir, expr::Expr* e) {
|
||||
auto a = e->clone();
|
||||
std::function<void(expr::Expr*)> func =
|
||||
[&](expr::Expr* c) {
|
||||
if (!c->is_sym()) return;
|
||||
if (startswith(c->str, "range") && c->str.size() == 6)
|
||||
// dont expand range
|
||||
return;
|
||||
if (endswith(c->str, "outputd"))
|
||||
return;
|
||||
auto def = ir->find_define(c->str);
|
||||
if (!def) return;
|
||||
if (def->type!="define")
|
||||
return;
|
||||
if (!def->has_attr("rvalue")) return;
|
||||
auto& rvalue = def->attrs["rvalue"];
|
||||
LOGvvvv << *c << "->" << rvalue;
|
||||
if (def->father && def->flist==&def->father->inner) {
|
||||
// dont expand loop or func
|
||||
return;
|
||||
}
|
||||
c->swap(expr::make(rvalue).get());
|
||||
if (!c->children.size()) func(c);
|
||||
};
|
||||
a->dfs(func);
|
||||
return a;
|
||||
}
|
||||
|
||||
void MergeLoopVarPass::run() {
|
||||
// LOGir << ir->to_string();
|
||||
auto choice = op->get_loop_option("merge_loop_var", 1);
|
||||
if (!choice) return;
|
||||
for (int ci=0; ci<ir->children.size(); ci++) {
|
||||
auto& c = ir->children[ci];
|
||||
if (c->type != "loop")
|
||||
continue;
|
||||
vector<KernelIR*> to_opt;
|
||||
c->dfs([&](unique_ptr<KernelIR>& i) {
|
||||
if (i->type == "loop" && i->father && i->father->type == "loop"
|
||||
&& i->father->children.size() == 1 &&
|
||||
i->before.size() == 0 && i->after.size() == 0) {
|
||||
to_opt.push_back(i.get());
|
||||
}
|
||||
});
|
||||
for (int ii=0; ii<to_opt.size(); ii++) {
|
||||
auto i = to_opt[to_opt.size()-1-ii];
|
||||
auto fa = i->father;
|
||||
LOGvvvv << "check opt" << i->attrs["rvalue"] << fa->attrs["rvalue"];
|
||||
auto range_b = i->attrs["rvalue"];
|
||||
auto id_b = i->attrs["lvalue"];
|
||||
auto range_a = fa->attrs["rvalue"];
|
||||
auto id_a = fa->attrs["lvalue"];
|
||||
if (!(i->type == "loop" && i->father && i->father->type == "loop"
|
||||
&& i->father->children.size() == 1 && i->father->inner.size() == 3 &&
|
||||
i->before.size() == 0 && i->after.size() == 0)) {
|
||||
continue;
|
||||
}
|
||||
if (range_b.size() > 6) {
|
||||
// range23 -> range2*range3
|
||||
string tmp = range_b.substr(0, 6);
|
||||
for (int i=6; i<range_b.size(); i++) {
|
||||
tmp += "*range";
|
||||
tmp += range_b[i];
|
||||
}
|
||||
range_b = tmp;
|
||||
}
|
||||
/*
|
||||
for (id_a : range_a)
|
||||
for (id_b : range_b)
|
||||
match(id_a * range_b * d + id_b * d + c)
|
||||
*/
|
||||
auto te = expr::make(id_a+"*"+range_b+"*d+"+id_b+"*d+c");
|
||||
vector<unique_ptr<Expr>> results;
|
||||
vector<string> solve_symbols = {"d", "c"};
|
||||
vector<string> exclude_symbols = {id_a, id_b};
|
||||
|
||||
bool can_opt = true;
|
||||
i->dfs([&](unique_ptr<KernelIR>& c) {
|
||||
if (!can_opt) return;
|
||||
if (c->type == "if") {
|
||||
// don't optimize reindex like op yet
|
||||
can_opt = false;
|
||||
return;
|
||||
}
|
||||
if (c->type == "define" && c->has_attr("rvalue")) {
|
||||
auto& s = c->attrs["rvalue"];
|
||||
auto& lv = c->attrs["lvalue"];
|
||||
if (!(endswith(lv, "id") || endswith(lv, "_i")))
|
||||
return;
|
||||
auto se = expr::make(s);
|
||||
se = trace_and_expand(c.get(), se.get())->simplify();
|
||||
LOGvvvv << "expand" << s << "->" << se;
|
||||
// LOGir << "expand" << s << "->" << se;
|
||||
results.clear();
|
||||
auto ret = expr::match(se.get(), te.get(), solve_symbols, exclude_symbols, results);
|
||||
if (ret) {
|
||||
LOGvvvv << "check rvalue" << se << '\n' <<
|
||||
te << '\n' <<
|
||||
ret << results;
|
||||
} else {
|
||||
can_opt = false;
|
||||
LOGvvvv << "cannot match" << se << '\n' <<
|
||||
te;
|
||||
}
|
||||
}
|
||||
});
|
||||
if (!can_opt)
|
||||
continue;
|
||||
auto ni = i->clone();
|
||||
auto aid = fa->attrs["loop_id"];
|
||||
auto bid = i->attrs["loop_id"];
|
||||
auto newid = aid+bid;
|
||||
auto new_range = "range" + newid;
|
||||
auto x = i->find_define(new_range);
|
||||
if (!x) {
|
||||
ir->push_back(i->attrs["dtype"]+" "+new_range+" = "+range_b+" * "+range_a+";");
|
||||
}
|
||||
ni->replace({{"range"+bid, new_range}, {"id"+aid, "0"}}, true, true);
|
||||
ni->attrs["loop_id"] = newid;
|
||||
ni->attrs["rvalue"] = new_range;
|
||||
// simplify 0 * x -> 0
|
||||
// ni->dfs([&](unique_ptr<KernelIR>& c) {
|
||||
// if (!can_opt) return;
|
||||
// if (c->type == "define" && c->has_attr("rvalue")) {
|
||||
// auto& s = c->attrs["rvalue"];
|
||||
// auto se = expr::make(s)->simplify();
|
||||
// s = se->to_string();
|
||||
// }
|
||||
// });
|
||||
LOGvvvv << "new merged loop" << ni;
|
||||
ni->swap(*fa, true);
|
||||
}
|
||||
}
|
||||
ir->move_loop_back();
|
||||
ir->remove_all_unused();
|
||||
// LOGir << ir->to_string();
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,16 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "opt/pass/pass.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct MergeLoopVarPass : Pass {
|
||||
MergeLoopVarPass() : Pass("merge_loop_var") {};
|
||||
void run() override;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -1,356 +0,0 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <sstream>
|
||||
#include <functional>
|
||||
#include <omp.h>
|
||||
#include "var.h"
|
||||
#include "opt/expr.h"
|
||||
#include "opt/pass_manager.h"
|
||||
#include "opt/pass/parallel_pass.h"
|
||||
#include "opt/pass/loop_var_analyze_pass.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#define __get_thread_range_log \
|
||||
inline static int get_thread_range_log(int& thread_num, int64 range) { \
|
||||
int nbits = NanoVector::get_nbits(std::min((int64)thread_num, range)) - 2; \
|
||||
thread_num >>= nbits; \
|
||||
return nbits; \
|
||||
}
|
||||
|
||||
__get_thread_range_log
|
||||
|
||||
#define STR(a) #a
|
||||
#define STR_MACRO(a) STR(a)
|
||||
|
||||
unique_ptr<expr::Expr> trace_and_expand(KernelIR* ir, expr::Expr* e) {
|
||||
auto a = e->clone();
|
||||
string rely=",";
|
||||
std::function<void(expr::Expr*)> func =
|
||||
[&](expr::Expr* c) {
|
||||
if (!c->is_sym()) return;
|
||||
auto def = ir->find_define(c->str);
|
||||
if (!def) return;
|
||||
ASSERT(def->type=="define");
|
||||
if (!def->has_attr("rvalue")) return;
|
||||
auto& rvalue = def->attrs["rvalue"];
|
||||
if (def->father && def->flist==&def->father->inner) {
|
||||
if (def->father->type=="func") return;
|
||||
if (def->father->type!="loop") return;
|
||||
LOGvvvv << "expand loop expr" << def->father->inner;
|
||||
// find x < range
|
||||
vector<unique_ptr<expr::Expr>> r1;
|
||||
if (!expr::match(
|
||||
expr::make(def->father->inner.at(1)->attrs.at("code")).get(),
|
||||
expr::make(c->str+"<range").get(),
|
||||
{"range"}, {}, r1))
|
||||
return;
|
||||
rely+=c->str+",";
|
||||
// find x++ or x+=stride
|
||||
vector<unique_ptr<expr::Expr>> r2;
|
||||
if (expr::match(
|
||||
expr::make(def->father->inner.at(2)->attrs.at("code")).get(),
|
||||
expr::make(c->str+"++").get())) {
|
||||
r2.push_back(expr::make("1"));
|
||||
} else
|
||||
if (!expr::match(
|
||||
expr::make(def->father->inner.at(2)->attrs.at("code")).get(),
|
||||
expr::make(c->str+"+=stride").get(),
|
||||
{"stride"}, {}, r2))
|
||||
return;
|
||||
// tid + loop_cnt * tnum
|
||||
auto new_expr = expr::make_op("+",
|
||||
expr::make(rvalue),
|
||||
expr::make_op("*",
|
||||
expr::make("loop_cnt"),
|
||||
r2.at(0)->clone()
|
||||
)
|
||||
);
|
||||
c->swap(new_expr.get());
|
||||
return;
|
||||
}
|
||||
c->swap(expr::make(rvalue).get());
|
||||
if (!c->children.size()) func(c);
|
||||
};
|
||||
a->dfs(func);
|
||||
// indexes of relyied loop, split with ","
|
||||
ir->attrs["rely"] = rely;
|
||||
return a;
|
||||
}
|
||||
|
||||
static void check_atomic(KernelIR* ir, bool is_cuda, int tdim) {
|
||||
ir->dfs([&](unique_ptr<KernelIR>& c) {
|
||||
if (c->type != "") return;
|
||||
if (!c->has_attr("code")) return;
|
||||
auto& code = c->attrs["code"];
|
||||
auto e = expr::make(code.substr(0, code.size()-1)); // remove ';'
|
||||
vector<unique_ptr<expr::Expr>> results;
|
||||
auto target = expr::make("a=b");
|
||||
if (!expr::match(e.get(), target.get(), {"a", "b"}, {}, results))
|
||||
return;
|
||||
bool has_a = 0;
|
||||
results[1]->dfs([&](expr::Expr* p) {
|
||||
if (p->to_string()==results[0]->to_string())
|
||||
has_a = 1;
|
||||
});
|
||||
if (!has_a) return;
|
||||
vector<unique_ptr<expr::Expr>> ptr_and_offset;
|
||||
if (!expr::match(results[0].get(), expr::make("a[b]").get(), {"a", "b"}, {}, ptr_and_offset))
|
||||
return;
|
||||
LOGvvvv << "ptr_and_offset" << ptr_and_offset;
|
||||
auto offset = trace_and_expand(c.get(), ptr_and_offset.at(1).get())
|
||||
->simplify();
|
||||
LOGvvvv << "rely" << c->get_attr("rely");
|
||||
LOGvvvv << "full offset expr" << offset->to_string(1);
|
||||
// try to optimize unneccesary atomic operation
|
||||
bool need_atomic = false;
|
||||
for (int i=0; i<tdim; i++) {
|
||||
vector<unique_ptr<expr::Expr>> xres;
|
||||
if (!expr::match(
|
||||
offset.get(),
|
||||
expr::make("(tid"+S(i)+"+tnum"+S(i)+"*a)*b+c").get(),
|
||||
{"a","b","c"}, {"tid"+S(i)}, xres
|
||||
)) {
|
||||
LOGvvvv << "offset" << offset << "not match, need atomic";
|
||||
need_atomic = true;
|
||||
break;
|
||||
}
|
||||
LOGvvvv << "atomic optimize match:" << i << xres;
|
||||
// set tid=0 and simplify
|
||||
offset = offset->assign_symbol({{"tid"+S(i),"0"}})->simplify();
|
||||
LOGvvvv << "new offset" << offset;
|
||||
}
|
||||
if (!need_atomic) return;
|
||||
|
||||
// add atomic code
|
||||
auto check = [&](const string& t, const vector<string>& args, const string& cpu, const string& cuda) -> bool {
|
||||
auto target = expr::make(t);
|
||||
if (!expr::match(e.get(), target.get(), args, {}, results))
|
||||
return false;
|
||||
unordered_map<string,string> defs;
|
||||
for (int i=0; i<args.size(); i++)
|
||||
defs[args[i]] = results[i]->to_string();
|
||||
code = OpCompiler::precompile(defs, is_cuda ? cuda : cpu) + ";";
|
||||
LOGvvvv << "matched" << results << code;
|
||||
return true;
|
||||
};
|
||||
if (
|
||||
check("a=a+b", {"a","b"}, "cpu_atomic_add(&@a,@b)", "atomicAdd(&@a,@b)") ||
|
||||
check("a=a-b", {"a","b"}, "cpu_atomic_sub(&@a,@b)", "atomicSub(&@a,@b)") ||
|
||||
check("a=a*b", {"a","b"}, "cpu_atomic_mul(&@a,@b)", "cuda_atomic_mul(&@a,@b)") ||
|
||||
check("a=std::max(T(a),T(b))", {"a","b","T"}, "cpu_atomic_max(&@a,@T@@(@b))", "cuda_atomic_max(&@a,@T@@(@b))") ||
|
||||
check("a=::max(T(a),T(b))", {"a","b","T"}, "cpu_atomic_max(&@a,@T@@(@b))", "cuda_atomic_max(&@a,@T@@(@b))") ||
|
||||
check("a=std::min(T(a),T(b))", {"a","b","T"}, "cpu_atomic_min(&@a,@T@@(@b))", "cuda_atomic_min(&@a,@T@@(@b))") ||
|
||||
check("a=::min(T(a),T(b))", {"a","b","T"}, "cpu_atomic_min(&@a,@T@@(@b))", "cuda_atomic_min(&@a,@T@@(@b))") ||
|
||||
check("a=a&b", {"a","b"}, "cpu_atomic_and(&@a,@b)", "atomicAnd(&@a,@b)") ||
|
||||
check("a=a|b", {"a","b"}, "cpu_atomic_or(&@a,@b)", "atomicOr(&@a,@b)") ||
|
||||
check("a=a^b", {"a","b"}, "cpu_atomic_xor(&@a,@b)", "atomicXor(&@a,@b)") ||
|
||||
check("a=a&&b", {"a","b"}, "cpu_atomic_and(&@a,bool(@b))", "atomicAnd(&@a,bool(@b))") ||
|
||||
check("a=a||b", {"a","b"}, "cpu_atomic_or(&@a,bool(@b))", "atomicOr(&@a,bool(@b))") ||
|
||||
check("a=((bool(a))!=(bool(b)))", {"a","b"}, "cpu_atomic_xor(&@a,bool(@b))", "atomicXor(&@a,bool(@b))")
|
||||
)
|
||||
return;
|
||||
LOGf << "Expr not match" << e;
|
||||
});
|
||||
}
|
||||
|
||||
int to_pow(int x) {
|
||||
return 1 << (NanoVector::get_nbits(x) - 2);
|
||||
}
|
||||
|
||||
void ParallelPass::run() {
|
||||
auto choice = op->get_loop_option("parallel");
|
||||
auto fix_thread_num = op->get_loop_option("fix_thread_num", 0);
|
||||
bool is_cuda = op->flags.get(NodeFlags::_cuda);
|
||||
if (is_cuda) choice=1;
|
||||
if (!choice) return;
|
||||
|
||||
int cuda_block_num = to_pow(op->get_loop_option("cuda_block_num", 256));
|
||||
int cuda_thread_num = to_pow(op->get_loop_option("cuda_thread_num", 1024));
|
||||
int cpu_thread_num = to_pow(op->get_loop_option("cpu_thread_num", omp_get_max_threads()));
|
||||
int max_parallel_depth;
|
||||
if (!is_cuda) {
|
||||
// omp include
|
||||
ir->push_front("#include \"misc/cpu_atomic.h\"", &ir->before);
|
||||
ir->push_front("#include <omp.h>", &ir->before);
|
||||
max_parallel_depth = op->get_loop_option("max_parallel_depth", 2);
|
||||
auto* lva_pass = pm->get_pass<LoopVarAnalyzePass>("loop_var_analyze");
|
||||
auto number_of_ranges = lva_pass->number_of_ranges;
|
||||
if (!op->loop_options->count("max_parallel_depth")) {
|
||||
if (number_of_ranges<=max_parallel_depth)
|
||||
max_parallel_depth = number_of_ranges-1;
|
||||
}
|
||||
if (max_parallel_depth<=0) return;
|
||||
} else {
|
||||
ir->push_front("#include \"helper_cuda.h\"", &ir->before);
|
||||
ir->push_front("#include \"misc/cuda_limits.h\"", &ir->before);
|
||||
ir->push_front("#include \"misc/cuda_atomic.h\"", &ir->before);
|
||||
max_parallel_depth = op->get_loop_option("max_parallel_depth", 4);
|
||||
}
|
||||
ir->push_back("#pragma GCC diagnostic ignored \"-Wunused-function\"", &ir->before, true);
|
||||
ir->push_back(STR_MACRO(__get_thread_range_log), &ir->before, true);
|
||||
|
||||
for (uint i=0; i<ir->children.size(); i++) {
|
||||
auto& func_call = ir->children[i];
|
||||
if (!func_call->has_attr("loop_func")) continue;
|
||||
auto& func_name = func_call->attrs["loop_func"];
|
||||
uint j=0;
|
||||
while (j<ir->before.size() && !ir->before[j]->check_attr("lvalue", func_name))
|
||||
j++;
|
||||
ASSERT(j<ir->before.size()) << "loop func" << func_name << "not found.";
|
||||
|
||||
auto& func_def = ir->before[j];
|
||||
auto c = func_def->children.back().get();
|
||||
ASSERTop(c->type,==,"loop");
|
||||
// only one loop
|
||||
ASSERT(func_def->children.size()==1 ||
|
||||
func_def->children[func_def->children.size()-2]->type!="loop");
|
||||
vector<KernelIR*> cs;
|
||||
vector<string> rvalues, strides;
|
||||
for (int j=0; j<max_parallel_depth; j++) {
|
||||
if (!c->has_attr("rvalue")) break;
|
||||
if (!c->has_attr("lvalue")) break;
|
||||
auto& lvalue = c->attrs["lvalue"];
|
||||
auto& stride = c->inner[2]->attrs["code"];
|
||||
if (stride == lvalue+"++;") {
|
||||
strides.push_back("1");
|
||||
} else {
|
||||
if (!c->has_attr("rvalue2")) break;
|
||||
auto& rvalue2 = c->attrs["rvalue2"];
|
||||
if (stride != lvalue+"+="+rvalue2+";") break;
|
||||
strides.push_back(rvalue2);
|
||||
}
|
||||
rvalues.push_back(c->attrs["rvalue"]);
|
||||
cs.push_back(c);
|
||||
LOGvvvv << "Parallel loop dep=">>j<<"range=" >> rvalues.back() <<
|
||||
"stride=" >> strides.back()
|
||||
<< "code:" << c->inner;
|
||||
if (c->children.size()==1 && c->children[0]->type=="loop") {
|
||||
c = c->children[0].get();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
(void)get_thread_range_log;
|
||||
KernelIR new_block("{}");
|
||||
auto new_func_call = func_call->clone();
|
||||
auto new_func_def = func_def->clone();
|
||||
vector<KernelIR*> ncs;
|
||||
c = new_func_def->children.back().get();
|
||||
for (int j=0; j<cs.size(); j++) {
|
||||
ncs.push_back(c);
|
||||
if (c->children.size()==0) break;
|
||||
c = c->children[0].get();
|
||||
}
|
||||
auto& func_call_code = new_func_call->attrs["code"];
|
||||
int thread_num = is_cuda ?
|
||||
cuda_block_num * cuda_thread_num
|
||||
: cpu_thread_num;
|
||||
// resolve undefined rvalues
|
||||
for (auto& rv : rvalues) {
|
||||
auto e = expr::make(rv);
|
||||
if (!e->is(expr::_number)) {
|
||||
auto rdef = func_def->find_define(rv);
|
||||
ASSERT(rdef);
|
||||
if (rdef->has_attr("rvalue"))
|
||||
rv = rdef->attrs["rvalue"];
|
||||
}
|
||||
}
|
||||
|
||||
// calc max thread num
|
||||
string nums = rvalues.at(0);
|
||||
for (int i=1; i<rvalues.size(); i++)
|
||||
nums+="*"+rvalues[i];
|
||||
new_block.push_back("int thread_num=" + S(thread_num) + ";");
|
||||
new_block.push_back("int thread_num_left=thread_num;");
|
||||
|
||||
for (int j=ncs.size()-1; j>=0; j--) {
|
||||
auto& rv = rvalues[j];
|
||||
new_block.push_back("int tn"+S(j)+
|
||||
"=get_thread_range_log(thread_num_left, "+rv+");");
|
||||
func_call_code = func_call_code.substr(0, func_call_code.size()-2)
|
||||
+ ",tn" + S(j) + ");";
|
||||
new_func_def->push_back("int tn"+S(j)+";", &new_func_def->inner);
|
||||
}
|
||||
for (int j=ncs.size()-2; j>0; j--) {
|
||||
new_block.push_back("tn"+S(j)+"=tn"+S(j)+"+tn"+S(j+1)+";");
|
||||
}
|
||||
new_block.push_back("tn0=NanoVector::get_nbits(thread_num)-2;");
|
||||
new_block.push_back("int p1 = std::max(thread_num/1024, 1);");
|
||||
new_block.push_back("int p2 = std::min(thread_num, 1024);");
|
||||
KernelIR new_tid_def("{}");
|
||||
if (!is_cuda) {
|
||||
// omp thread id
|
||||
new_tid_def.push_front("int thread_id = omp_get_thread_num();");
|
||||
// omp func call
|
||||
// we set num_threads in code
|
||||
new_func_call->push_back(
|
||||
"#pragma omp parallel num_threads(thread_num)",
|
||||
&new_func_call->before
|
||||
);
|
||||
} else {
|
||||
new_func_def->get_attr("dtype") = "__launch_bounds__("+S(cuda_thread_num)+") __global__ void";
|
||||
new_tid_def.push_front("int thread_id = blockIdx.x * blockDim.x + threadIdx.x;");
|
||||
// cuda kernel launch
|
||||
auto& code = func_call_code;
|
||||
auto pos = code.find("(");
|
||||
ASSERT(pos != string::npos);
|
||||
code = code.substr(0, pos) +
|
||||
"<<<p1,p2>>>" +
|
||||
code.substr(pos);
|
||||
}
|
||||
|
||||
new_block.push_back(move(new_func_call));
|
||||
LOGvvvv << "new block:" << new_block.to_string();
|
||||
new_tid_def.push_back("int tn"+S(ncs.size())+"=0;");
|
||||
for (int j=0; j<ncs.size(); j++) {
|
||||
new_tid_def.push_back("int tnum"+S(j)+
|
||||
" = 1<<(tn"+S(j)+"-tn"+S(j+1)+");");
|
||||
new_tid_def.push_back("int tid"+S(j)+
|
||||
" = (thread_id>>tn"+S(j+1)+") & (tnum"+S(j)+"-1);");
|
||||
auto c = ncs[j];
|
||||
auto& lvalue = c->attrs["lvalue"];
|
||||
auto& stride = c->inner[2]->attrs["code"];
|
||||
string new_stride, new_init;
|
||||
// change
|
||||
// for (T i=0; i<range; i+=stride)
|
||||
// to
|
||||
// for (T i=stride*thread_id; i<range; i+=stride*thread_num)
|
||||
// TODO: check loop deps
|
||||
if (stride == lvalue+"++;") {
|
||||
new_stride = lvalue+"+=tnum"+S(j)+";";
|
||||
new_init = lvalue+"=tid"+S(j)+";";
|
||||
} else {
|
||||
if (!c->has_attr("rvalue2")) continue;
|
||||
auto& rvalue2 = c->attrs["rvalue2"];
|
||||
if (stride != lvalue+"+="+rvalue2+";") continue;
|
||||
new_stride = lvalue+"+="+rvalue2+"*tnum"+S(j)+";";
|
||||
new_init = lvalue+"="+rvalue2+"*tid"+S(j)+";";
|
||||
}
|
||||
LOGvvvv << "Parallel loop" << c->attrs["loop_id"] << "with new stride" << new_stride;
|
||||
if (c->inner[0]->type == "define")
|
||||
new_init = c->inner[0]->attrs["dtype"] + " " + new_init;
|
||||
stride = new_stride;
|
||||
c->inner[0]->try_parse_define(new_init);
|
||||
}
|
||||
LOGvvvv << "new_tid_def:" << new_tid_def.to_string();
|
||||
check_atomic(new_func_def.get(), is_cuda, ncs.size());
|
||||
new_func_def->insert(0, new_tid_def.children);
|
||||
new_func_def->swap(*func_def, true);
|
||||
new_block.swap(*func_call, true);
|
||||
auto code = func_def->to_string();
|
||||
bool has_atomic = code.find("atomic") != string::npos;
|
||||
if (!fix_thread_num) {
|
||||
if (has_atomic) {
|
||||
nums += "/16";
|
||||
}
|
||||
func_call->find_define("thread_num")->attrs["rvalue"] = "min(max(1<<(NanoVector::get_nbits(" + nums + ")-2),32)," + S(thread_num) + ")";
|
||||
}
|
||||
}
|
||||
ir->remove_all_unused();
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -13,6 +13,8 @@
|
|||
#include "opt/pass/split_loop_pass.h"
|
||||
#include "opt/pass/reorder_loop_pass.h"
|
||||
#include "opt/pass/merge_loop_pass.h"
|
||||
#include "opt/pass/merge_loop_var_pass.h"
|
||||
#include "opt/pass/const_var_pass.h"
|
||||
#include "opt/pass/expand_empty_block_pass.h"
|
||||
#include "opt/pass/solve_conflict_define_pass.h"
|
||||
#include "opt/pass/remove_intermediate_pass.h"
|
||||
|
@ -88,6 +90,8 @@ void PassManager::run_passes() {
|
|||
run_pass<RemoveIntermediatePass>();
|
||||
|
||||
run_pass<SolveConflictDefinePass>();
|
||||
run_pass<MergeLoopVarPass>();
|
||||
run_pass<ConstVarPass>();
|
||||
|
||||
run_pass<RestridePass>();
|
||||
|
||||
|
|
|
@ -79,7 +79,7 @@ struct SimpleThreads {
|
|||
for (auto& t : threads) {
|
||||
auto start = clock();
|
||||
int ok = 0;
|
||||
while (clock()<start+5000) {
|
||||
while (clock()<start+5*CLOCKS_PER_SEC) {
|
||||
if (t.mtx.try_lock()) {
|
||||
t.mtx.unlock();
|
||||
ok = 1;
|
||||
|
|
|
@ -23,7 +23,11 @@ namespace jittor {
|
|||
|
||||
Profiler profiler;
|
||||
|
||||
DEFINE_FLAG(int, profiler_warmup, 0, "Profiler warmup.");
|
||||
DEFINE_FLAG(int, profiler_rerun, 0, "Profiler rerun.");
|
||||
DEFINE_FLAG(int, profiler_hide_relay, 0, "Profiler hide relayed op.");
|
||||
DEFINE_FLAG_WITH_SETTER(int, profiler_enable, 0, "Enable profiler.");
|
||||
|
||||
void setter_profiler_enable(int value) {
|
||||
if (value)
|
||||
Profiler::start();
|
||||
|
@ -39,6 +43,8 @@ Profiler::~Profiler() {
|
|||
}
|
||||
|
||||
void Profiler::start(int64 warmup, int64 rerun) {
|
||||
if (warmup==0) warmup = profiler_warmup;
|
||||
if (rerun==0) rerun = profiler_rerun;
|
||||
profiler_enable = 1;
|
||||
profiler.records.clear();
|
||||
profiler.warmup = warmup;
|
||||
|
@ -62,7 +68,7 @@ unique_ptr<MemoryChecker>* load_memory_checker(string name) {
|
|||
return mm;
|
||||
}
|
||||
|
||||
extern string _get_stack_info(Op* op);
|
||||
extern string _get_stack_info(Node* node);
|
||||
|
||||
static string get_stack_info(Op* op) {
|
||||
string stack_info = "stack info:\n";
|
||||
|
@ -76,10 +82,39 @@ static string get_stack_info(Op* op) {
|
|||
stack_info += kv.first;
|
||||
stack_info += '\n';
|
||||
}
|
||||
if (trace_py_var == 2) {
|
||||
std::stringstream ss;
|
||||
ss << "input from:\n";
|
||||
for (auto& vi : fop->vars) {
|
||||
if (vi.type == 0) {
|
||||
auto v = vi.var;
|
||||
ss << v->shape << ',' << v->dtype() << ',' << v->name << ',';
|
||||
if (v->input())
|
||||
ss << v->input()->name_ex() << ',' << _get_stack_info(v->input());
|
||||
else
|
||||
ss << _get_stack_info(v);
|
||||
ss << '\n';
|
||||
}
|
||||
}
|
||||
stack_info += ss.str();
|
||||
}
|
||||
return stack_info;
|
||||
} else {
|
||||
stack_info += _get_stack_info(op);
|
||||
stack_info += '\n';
|
||||
if (trace_py_var == 2) {
|
||||
std::stringstream ss;
|
||||
ss << "input from:\n";
|
||||
for (auto v : op->inputs()) {
|
||||
ss << v->shape << ',' << v->dtype() << ',' << v->name << ',';
|
||||
if (v->input())
|
||||
ss << v->input()->name_ex() << ',' << _get_stack_info(v->input());
|
||||
else
|
||||
ss << _get_stack_info(v);
|
||||
ss << '\n';
|
||||
}
|
||||
stack_info += ss.str();
|
||||
}
|
||||
return stack_info;
|
||||
}
|
||||
}
|
||||
|
@ -109,6 +144,36 @@ void Profiler::record_and_run(
|
|||
}
|
||||
}
|
||||
bool is_fused = op->name() == string("fused");
|
||||
|
||||
uint64* shape_time = nullptr;
|
||||
if (trace_py_var) {
|
||||
// record shape
|
||||
NanoVector shape;
|
||||
int64 num = 0;
|
||||
Op** ops = &op;
|
||||
int op_num = 1;
|
||||
if (is_fused) {
|
||||
ops = &(((FusedOp*)op)->ops[0]);
|
||||
op_num = ((FusedOp*)op)->ops.size();
|
||||
}
|
||||
for (int i=0; i<op_num; i++) {
|
||||
auto o = ops[i];
|
||||
for (auto v : o->inputs()) {
|
||||
if (v->num > num) {
|
||||
num = v->num;
|
||||
shape = v->shape;
|
||||
}
|
||||
}
|
||||
for (auto v : o->outputs()) {
|
||||
if (v->num > num) {
|
||||
num = v->num;
|
||||
shape = v->shape;
|
||||
}
|
||||
}
|
||||
}
|
||||
iter->second.shapes[shape].second += 1;
|
||||
shape_time = &iter->second.shapes[shape].first;
|
||||
}
|
||||
int loop = (is_fused &&
|
||||
((FusedOp*)op)->get_loop_option("insert_profile_loop")) ? 10 : 0;
|
||||
int64_t warmup = profiler.warmup ? std::max(profiler.warmup>>loop, (int64_t)1) : 0;
|
||||
|
@ -141,6 +206,7 @@ void Profiler::record_and_run(
|
|||
// 24ns function call overhead
|
||||
total_ns = std::max((int64_t)1, total_ns-24);
|
||||
iter->second.update(loop, total_ns, in, out, compute);
|
||||
if (shape_time) shape_time[0] += total_ns;
|
||||
LOGvvvv << "Duration" << total_ns >> "ns running" << op;
|
||||
}
|
||||
if (is_fused &&
|
||||
|
@ -153,7 +219,7 @@ void Profiler::record_and_run(
|
|||
}
|
||||
|
||||
vector<vector<string>> Profiler::report(const string& sort_key) {
|
||||
vector<vector<string>> rep = {{"Name", "FileName", "Count", "TotalTime", "AvgTime", "MinTime", "MaxTime", "Input", "Output", "Compute"}};
|
||||
vector<vector<string>> rep = {{"Name", "FileName", "Count", "TotalTime", "AvgTime", "MinTime", "MaxTime", "Input", "Output", "InOut", "Compute"}};
|
||||
vector<string> names, fnames;
|
||||
vector<vector<double>> info;
|
||||
vector<int> order;
|
||||
|
@ -163,18 +229,41 @@ vector<vector<string>> Profiler::report(const string& sort_key) {
|
|||
break;
|
||||
ASSERT(sort_key_id<(int)rep[0].size()) << "Key not supported:" << sort_key;
|
||||
double total_time = 0;
|
||||
double total_mem_access = 0;
|
||||
for (auto& kv : profiler.records) {
|
||||
auto& kinfo = kv.second;
|
||||
names.push_back(kv.first);
|
||||
fnames.push_back(Op::get_filename_from_jit_key(kv.first, ".cc"));
|
||||
if (kv.second.stack_info.size()) {
|
||||
fnames.back() += '\n';
|
||||
fnames.back() += kv.second.stack_info.c_str();
|
||||
}
|
||||
auto& kinfo = kv.second;
|
||||
if (kv.second.shapes.size()) {
|
||||
// show shapes
|
||||
vector<pair<pair<uint64,uint64>,NanoVector>> shapes;
|
||||
shapes.reserve(kv.second.shapes.size());
|
||||
for (auto& kv2 : kv.second.shapes) {
|
||||
shapes.push_back(std::make_pair(kv2.second, kv2.first));
|
||||
}
|
||||
std::sort(shapes.begin(), shapes.end());
|
||||
std::stringstream ss;
|
||||
ss << "shapes:\n";
|
||||
for (int i=0; i<10; i++) {
|
||||
if (i>=shapes.size()) break;
|
||||
auto& sp = shapes[shapes.size() - i - 1];
|
||||
auto rate = sp.first.first * 100.0 / kinfo.time_total;
|
||||
ss << sp.second << ':' << sp.first.second <<
|
||||
"("<< std::setprecision(3) << rate << "%), ";
|
||||
}
|
||||
if (shapes.size()>10)
|
||||
ss << "... total " << shapes.size() << '\n';
|
||||
fnames.back() += ss.str();
|
||||
}
|
||||
order.push_back(order.size());
|
||||
// do not count relay op time
|
||||
if (kv.first.find("relay") == string::npos) {
|
||||
total_time += kinfo.time_total;
|
||||
total_mem_access += kinfo.in_total + kinfo.out_total;
|
||||
}
|
||||
info.push_back({
|
||||
(double)kinfo.count, // Count
|
||||
|
@ -184,6 +273,7 @@ vector<vector<string>> Profiler::report(const string& sort_key) {
|
|||
(double)kinfo.time_max, // MaxTime
|
||||
(double)kinfo.in_total*1e9 / kinfo.time_total, // Input
|
||||
(double)kinfo.out_total*1e9 / kinfo.time_total, // Output
|
||||
(double)(kinfo.in_total+kinfo.out_total)*1e9 / kinfo.time_total, // InOut
|
||||
(double)kinfo.compute_total*1e9 / kinfo.time_total, // Compute
|
||||
});
|
||||
}
|
||||
|
@ -219,9 +309,15 @@ vector<vector<string>> Profiler::report(const string& sort_key) {
|
|||
ss << "Total time:";
|
||||
output_float("num ", 1000, "s", total_time);
|
||||
ss << '\n';
|
||||
ss << "Total Memory Access:";
|
||||
output_float(" KMG", 1024, "B", total_mem_access);
|
||||
ss << '\n';
|
||||
double cum_time = 0;
|
||||
for (auto i : order) {
|
||||
auto& name = names[i];
|
||||
auto is_relay = name.find("relay") != string::npos;
|
||||
if (is_relay && profiler_hide_relay)
|
||||
continue;
|
||||
auto& fname = fnames[i];
|
||||
rep.push_back({name, fname});
|
||||
ss << std::setw(w) << name;
|
||||
|
@ -240,7 +336,7 @@ vector<vector<string>> Profiler::report(const string& sort_key) {
|
|||
// output total ratio
|
||||
if (j == 1) {
|
||||
// do not count relay op time
|
||||
if (name.find("relay") != string::npos)
|
||||
if (is_relay)
|
||||
k = 0;
|
||||
cum_time += k;
|
||||
ss << '(' << std::setw(3)
|
||||
|
@ -248,7 +344,7 @@ vector<vector<string>> Profiler::report(const string& sort_key) {
|
|||
<< std::setw(3)
|
||||
<< std::setprecision(p) << cum_time / total_time * 100 << "%)";
|
||||
}
|
||||
} else if (j<=6) {
|
||||
} else if (j<=7) {
|
||||
// output thoughtput
|
||||
output_float(" KMG", 1024, "B/s", k);
|
||||
} else {
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
#include "profiler/cache_info.h"
|
||||
#include "op_compiler.h"
|
||||
#include "misc/cstr.h"
|
||||
#include "misc/nano_vector.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -25,6 +26,7 @@ struct Profiler {
|
|||
// cache test info
|
||||
unique_ptr<CacheInfo> cache_info;
|
||||
cstr stack_info;
|
||||
unordered_map<NanoVector, pair<uint64, uint64>> shapes;
|
||||
|
||||
void update(int c, uint64_t t, uint64_t in, uint64_t out, uint64_t comp) {
|
||||
count += 1<<c;
|
||||
|
|
|
@ -95,6 +95,21 @@ static vector<Stack> get_stack_info() {
|
|||
int i=n;
|
||||
while (i) frames[--i] = frame, frame = frame->f_back;
|
||||
PyObject* prev_obj = nullptr;
|
||||
if (trace_py_var == 2) {
|
||||
// trace raw stack
|
||||
auto start = std::max(0, n-5);
|
||||
for (int i=start; i<n; i++) {
|
||||
auto f = frames[i];
|
||||
auto filename = to_string(f->f_code->co_filename);
|
||||
auto lineno = (int)PyFrame_GetLineNumber(f);
|
||||
stacks.emplace_back(Stack{
|
||||
filename+":"+S(lineno),
|
||||
to_string(f->f_code->co_name),
|
||||
filename,
|
||||
lineno});
|
||||
}
|
||||
return stacks;
|
||||
}
|
||||
for (int i=0; i<n; i++) {
|
||||
auto f = frames[i];
|
||||
if (Py_SIZE(f->f_code->co_varnames)) {
|
||||
|
@ -170,7 +185,7 @@ void TraceData::record_node(Node* node, bool record_stack) {
|
|||
NodeData data;
|
||||
data.id = node_data_cnt++;
|
||||
id_map[node] = data.id;
|
||||
if (!node->is_var()) {
|
||||
if (!node->is_var() || trace_py_var==2) {
|
||||
if (record_stack) {
|
||||
if (trace_grad_op) {
|
||||
auto iter = trace_data.id_map.find(trace_grad_op);
|
||||
|
@ -324,23 +339,27 @@ void clear_trace_data() {
|
|||
trace_data.node_data.clear();
|
||||
}
|
||||
|
||||
string _get_stack_info(Op* op) {
|
||||
string _get_stack_info(Node* node) {
|
||||
string stack_info = "";
|
||||
auto iter = trace_data.id_map.find(op);
|
||||
auto iter = trace_data.id_map.find(node);
|
||||
if (iter == trace_data.id_map.end())
|
||||
return stack_info;
|
||||
auto node_id = iter->second;
|
||||
auto iter2 = trace_data.node_data.find(node_id);
|
||||
if (iter2 == trace_data.node_data.end())
|
||||
return stack_info;
|
||||
for (auto& stack : iter2->second.stacks)
|
||||
stack_info += stack.module_name + " -> ";
|
||||
for (auto& stack : iter2->second.stacks) {
|
||||
stack_info += stack.module_name;
|
||||
stack_info += '(';
|
||||
stack_info += stack.module_type;
|
||||
stack_info += ')';
|
||||
stack_info += " -> ";
|
||||
}
|
||||
return stack_info;
|
||||
}
|
||||
|
||||
void print_node_trace(const Node* node, std::ostream& os) {
|
||||
if (!node->is_var())
|
||||
os << _get_stack_info((((Node*)node))->op());
|
||||
os << _get_stack_info((Node*)node);
|
||||
}
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -401,6 +401,29 @@ DEF_IS(DataView, PyObject*) to_py_object(T a) {
|
|||
return oh.release();
|
||||
}
|
||||
|
||||
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
struct ItemData;
|
||||
DEF_IS(ItemData, PyObject*) to_py_object(T a) {
|
||||
if (a.dtype == ns_bool) {
|
||||
if (*((bool*)(&a.data))) Py_RETURN_TRUE;
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
if (a.dtype == ns_int32)
|
||||
return PyLong_FromLongLong((int64)*(int*)&a.data);
|
||||
if (a.dtype == ns_float32)
|
||||
return PyFloat_FromDouble((float64)*(float32*)&a.data);
|
||||
if (a.dtype == ns_int64)
|
||||
return PyLong_FromLongLong(a.data);
|
||||
if (a.dtype == ns_float64)
|
||||
return PyFloat_FromDouble(*(float64*)&a.data);
|
||||
if (a.dtype == ns_int16)
|
||||
return PyLong_FromLongLong((int64)*(int16*)&a.data);
|
||||
if (a.dtype == ns_int8)
|
||||
return PyLong_FromLongLong((int64)*(int8*)&a.data);
|
||||
return PyLong_FromLongLong(a.data);
|
||||
}
|
||||
|
||||
struct NumpyFunc;
|
||||
|
||||
DEF_IS(NumpyFunc, bool) is_type(PyObject* obj) {
|
||||
|
|
|
@ -22,7 +22,7 @@ struct PyMultiprocessRingBuffer {
|
|||
// @pyjt(pop,recv)
|
||||
PyObject* pop();
|
||||
// @pyjt(clear)
|
||||
inline void clear() { rb->l = rb->r = rb->is_stop = 0; }
|
||||
inline void clear() { rb->clear(); }
|
||||
// @pyjt(stop)
|
||||
inline void stop() { rb->stop(); }
|
||||
// @pyjt(is_stop)
|
||||
|
|
|
@ -177,6 +177,7 @@ std::vector<std::map<string,string>> log_capture_read() {
|
|||
|
||||
void log_exiting();
|
||||
|
||||
bool exited = false;
|
||||
size_t thread_local protected_page = 0;
|
||||
int segfault_happen = 0;
|
||||
string thread_local thread_name;
|
||||
|
@ -184,6 +185,7 @@ string thread_local thread_name;
|
|||
void segfault_sigaction(int signal, siginfo_t *si, void *arg) {
|
||||
if (signal == SIGINT) {
|
||||
LOGe << "Caught SIGINT, exit";
|
||||
exited = true;
|
||||
exit(1);
|
||||
}
|
||||
std::cerr << "Caught segfault at address " << si->si_addr << ", "
|
||||
|
@ -194,13 +196,16 @@ void segfault_sigaction(int signal, siginfo_t *si, void *arg) {
|
|||
si->si_addr<(void*)(protected_page+4*1024)) {
|
||||
LOGf << "Accessing protect pages, maybe jit_key too long";
|
||||
}
|
||||
if (signal == SIGSEGV) {
|
||||
// only print trace in main thread
|
||||
if (thread_name.size() == 0)
|
||||
print_trace();
|
||||
std::cerr << "Segfault, exit" << std::endl;
|
||||
} else {
|
||||
std::cerr << "Get signal " << signal << ", exit" << std::endl;
|
||||
if (!exited) {
|
||||
exited = true;
|
||||
if (signal == SIGSEGV) {
|
||||
// only print trace in main thread
|
||||
if (thread_name.size() == 0)
|
||||
print_trace();
|
||||
std::cerr << "Segfault, exit" << std::endl;
|
||||
} else {
|
||||
std::cerr << "Get signal " << signal << ", exit" << std::endl;
|
||||
}
|
||||
}
|
||||
segfault_happen = 1;
|
||||
exit(1);
|
||||
|
@ -290,8 +295,8 @@ bool check_vlog(const char* fileline, int verbose) {
|
|||
}
|
||||
|
||||
int system_popen(const char* cmd) {
|
||||
static thread_local char buf[BUFSIZ];
|
||||
static thread_local string cmd2;
|
||||
char buf[BUFSIZ];
|
||||
string cmd2;
|
||||
cmd2 = cmd;
|
||||
cmd2 += " 2>&1 ";
|
||||
FILE *ptr = popen(cmd2.c_str(), "r");
|
||||
|
@ -314,11 +319,10 @@ void system_with_check(const char* cmd) {
|
|||
std::thread log_thread(log_main);
|
||||
|
||||
void log_exiting() {
|
||||
static bool exited = false;
|
||||
if (exited) return;
|
||||
exited = true;
|
||||
mwsr_list_log::stop();
|
||||
log_thread.join();
|
||||
exited = true;
|
||||
}
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -89,8 +89,12 @@ std::ostream& operator<<(std::ostream& os, const Var& var) {
|
|||
<< ')' << var.shape;
|
||||
#ifdef NODE_MEMCHECK
|
||||
os << '<' << var.__id() << '>';
|
||||
print_node_trace(&var, os);
|
||||
#endif
|
||||
if (trace_py_var) {
|
||||
os << '{';
|
||||
print_node_trace(&var, os);
|
||||
os << '}';
|
||||
}
|
||||
return os;
|
||||
}
|
||||
std::ostream& operator<<(std::ostream& os, const Var* var) {
|
||||
|
|
|
@ -114,6 +114,23 @@ ArrayArgs VarHolder::fetch_sync() {
|
|||
return {var->mem_ptr, var->shape, var->dtype()};
|
||||
}
|
||||
|
||||
ItemData VarHolder::item() {
|
||||
sync();
|
||||
ItemData data;
|
||||
data.dtype = var->dtype();
|
||||
auto dsize = data.dtype.dsize();
|
||||
#ifdef HAS_CUDA
|
||||
migrate_to_cpu(var, exe.allocator);
|
||||
if (var->allocator->is_cuda()) {
|
||||
checkCudaErrors(cudaMemcpy(&data.data, var->mem_ptr, dsize, cudaMemcpyDeviceToHost));
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
std::memcpy(&data.data, var->mem_ptr, dsize);
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
// from fetch_op.cc
|
||||
extern list<VarPtr> fetcher;
|
||||
|
||||
|
|
|
@ -22,6 +22,11 @@ struct DataView {
|
|||
NanoString dtype;
|
||||
};
|
||||
|
||||
struct ItemData {
|
||||
int64 data;
|
||||
NanoString dtype;
|
||||
};
|
||||
|
||||
// @pyjt(Var)
|
||||
// @attrs(heaptype)
|
||||
struct VarHolder {
|
||||
|
@ -145,6 +150,10 @@ struct VarHolder {
|
|||
return {this, var->mem_ptr, var->shape, var->dtype()};
|
||||
}
|
||||
|
||||
/** Get one item data */
|
||||
// @pyjt(item)
|
||||
ItemData item();
|
||||
|
||||
// @pyjt(__get__ndim)
|
||||
inline int ndim() {
|
||||
return var->shape.size();
|
||||
|
|
Loading…
Reference in New Issue