mirror of https://github.com/Jittor/Jittor
Merge branch 'master' of https://github.com/Jittor/jittor into test_init
This commit is contained in:
commit
8cc5b0cce1
|
@ -6,6 +6,7 @@
|
|||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "mem/allocator.h"
|
||||
#include "var.h"
|
||||
#include "cudnn_conv_backward_w_op.h"
|
||||
#include "cudnn_warper.h"
|
||||
|
@ -195,6 +196,8 @@ void CudnnConvBackwardWOp::jit_run() {
|
|||
for (int i = 0; i < num_algos; i++) {
|
||||
size_t sz;
|
||||
cudnnStatus_t ret = cudnnGetConvolutionBackwardFilterWorkspaceSize(handle_, cudnnIdesc, cudnnOdesc, cudnnConvDesc, cudnnFdesc, algos[i], &sz);
|
||||
// continue if use too much workspace
|
||||
if (sz*4 > mem_info.total_cuda_ram) continue;
|
||||
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
|
||||
}
|
||||
size_t allocation;
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "mem/allocator.h"
|
||||
#include "var.h"
|
||||
#include "cudnn_conv_backward_x_op.h"
|
||||
#include "cudnn_warper.h"
|
||||
|
@ -196,6 +197,8 @@ void CudnnConvBackwardXOp::jit_run() {
|
|||
for (int i = 0; i < num_algos; i++) {
|
||||
size_t sz;
|
||||
cudnnStatus_t ret = cudnnGetConvolutionBackwardDataWorkspaceSize(handle_, cudnnFdesc, cudnnOdesc, cudnnConvDesc, cudnnIdesc, algos[i], &sz);
|
||||
// continue if use too much workspace
|
||||
if (sz*4 > mem_info.total_cuda_ram) continue;
|
||||
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
|
||||
}
|
||||
size_t allocation;
|
||||
|
|
|
@ -199,9 +199,11 @@ void CudnnConvOp::jit_run() {
|
|||
for (int i = 0; i < num_algos; i++) {
|
||||
size_t sz;
|
||||
cudnnStatus_t ret = cudnnGetConvolutionForwardWorkspaceSize(
|
||||
handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc,
|
||||
cudnnOdesc, algos[i], &sz);
|
||||
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size && sz<512*1024*1024) max_ws_size = sz;
|
||||
handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc,
|
||||
cudnnOdesc, algos[i], &sz);
|
||||
// continue if use too much workspace
|
||||
if (sz*4 > mem_info.total_cuda_ram) continue;
|
||||
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
|
||||
}
|
||||
size_t allocation;
|
||||
void* ws = exe.allocator->alloc(max_ws_size, allocation);
|
||||
|
|
|
@ -105,7 +105,7 @@ template <typename T>
|
|||
void check(T result, char const *const func, const char *const file,
|
||||
int const line) {
|
||||
if (result) {
|
||||
DEVICE_RESET
|
||||
// DEVICE_RESET
|
||||
LOGf << "CUDA error at" << file >> ":" >> line << " code="
|
||||
>> static_cast<unsigned int>(result) >> "(" << _cudaGetErrorEnum(result) << ")"
|
||||
<< func;
|
||||
|
@ -125,7 +125,7 @@ inline void __getLastCudaError(const char *errorMessage, const char *file,
|
|||
cudaError_t err = cudaGetLastError();
|
||||
|
||||
if (cudaSuccess != err) {
|
||||
DEVICE_RESET
|
||||
// DEVICE_RESET
|
||||
LOGf << "CUDA error at" << file >> ":" >> line << " code="
|
||||
>> static_cast<unsigned int>(err) >> "(" << _cudaGetErrorEnum(err) << ")"
|
||||
<< errorMessage;
|
||||
|
@ -141,7 +141,7 @@ inline void __printLastCudaError(const char *errorMessage, const char *file,
|
|||
cudaError_t err = cudaGetLastError();
|
||||
|
||||
if (cudaSuccess != err) {
|
||||
DEVICE_RESET
|
||||
// DEVICE_RESET
|
||||
LOGf << "CUDA error at" << file >> ":" >> line << " code="
|
||||
>> static_cast<unsigned int>(err) >> "(" << _cudaGetErrorEnum(err) << ")"
|
||||
<< errorMessage;
|
||||
|
|
|
@ -45,7 +45,7 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
|
|||
shape[0], shape[1], shape[2], shape[3]));
|
||||
}
|
||||
|
||||
MklConvBackwardWOp::MklConvBackwardWOp(Var* x, Var* dy, int kernel_size, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
|
||||
MklConvBackwardWOp::MklConvBackwardWOp(Var* x, Var* dy, int kernel_size, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
|
||||
: x(x), dy(dy), kernel_size(kernel_size), stride(stride), padding(padding), dilation(dilation),
|
||||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
dw = create_output(nullptr, dtype_infer(dy->ns, x->ns));
|
||||
|
|
|
@ -16,7 +16,7 @@ struct MklConvBackwardWOp : Op {
|
|||
int kernel_size, stride, padding, dilation;
|
||||
string xformat, wformat, yformat;
|
||||
|
||||
MklConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
MklConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
|
||||
const char* name() const override { return "mkl_conv_backward_w"; }
|
||||
void infer_shape() override;
|
||||
|
|
|
@ -45,7 +45,7 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
|
|||
shape[0], shape[1], shape[2], shape[3]));
|
||||
}
|
||||
|
||||
MklConvBackwardXOp::MklConvBackwardXOp(Var* w, Var* dy, int height, int width, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
|
||||
MklConvBackwardXOp::MklConvBackwardXOp(Var* w, Var* dy, int height, int width, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
|
||||
: w(w), dy(dy), xh(height), xw(width), stride(stride), padding(padding), dilation(dilation),
|
||||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
dx = create_output(nullptr, dtype_infer(dy->ns, w->ns));
|
||||
|
|
|
@ -16,7 +16,7 @@ struct MklConvBackwardXOp : Op {
|
|||
int xh, xw, stride, padding, dilation;
|
||||
string xformat, wformat, yformat;
|
||||
|
||||
MklConvBackwardXOp(Var* w, Var* y, int height, int width, int stride, int padding, int dilation, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
MklConvBackwardXOp(Var* w, Var* y, int height, int width, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="abcd");
|
||||
|
||||
const char* name() const override { return "mkl_conv_backward_x"; }
|
||||
void infer_shape() override;
|
||||
|
|
|
@ -44,7 +44,7 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
|
|||
shape[0], shape[1], shape[2], shape[3]));
|
||||
}
|
||||
|
||||
MklConvOp::MklConvOp(Var* x, Var* w, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
|
||||
MklConvOp::MklConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
|
||||
: x(x), w(w), stride(stride), padding(padding), dilation(dilation),
|
||||
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
|
||||
y = create_output(nullptr, dtype_infer(x->ns, w->ns));
|
||||
|
|
|
@ -16,7 +16,7 @@ struct MklConvOp : Op {
|
|||
int stride, padding, dilation;
|
||||
string xformat, wformat, yformat;
|
||||
/* MklConvOp: xformat abcd represents nchw */
|
||||
MklConvOp(Var* x, Var* w, int stride, int padding, int dilation=1, string xformat="abcd", string wformat="oihw", string yformat="");
|
||||
MklConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="");
|
||||
|
||||
const char* name() const override { return "mkl_conv"; }
|
||||
void infer_shape() override;
|
||||
|
|
|
@ -16,8 +16,9 @@ with lock.lock_scope():
|
|||
from jittor_core import *
|
||||
from jittor_core.ops import *
|
||||
from . import compile_extern
|
||||
from .compile_extern import mkl_ops, mpi, mpi_ops, \
|
||||
cudnn, curand, cublas
|
||||
from .compile_extern import mkl_ops, mpi, mpi_ops
|
||||
if has_cuda:
|
||||
from .compile_extern import cudnn, curand, cublas
|
||||
|
||||
import contextlib
|
||||
import numpy as np
|
||||
|
@ -403,7 +404,8 @@ Var.unsqueeze = unsqueeze
|
|||
|
||||
def squeeze(x, dim):
|
||||
shape = list(x.shape)
|
||||
assert dim < len(shape)
|
||||
if dim < 0: dim += len(shape)
|
||||
assert dim < len(shape) and dim >= 0
|
||||
assert shape[dim] == 1
|
||||
return x.reshape(shape[:dim] + shape[dim+1:])
|
||||
Var.squeeze = squeeze
|
||||
|
@ -447,6 +449,13 @@ def fetch_var(var, func, *args, **kw):
|
|||
Var.fetch = fetch_var
|
||||
del fetch_var
|
||||
|
||||
def display_memory_info():
|
||||
import inspect, os
|
||||
f = inspect.currentframe()
|
||||
fileline = inspect.getframeinfo(f.f_back)
|
||||
fileline = f"{os.path.basename(fileline.filename)}:{fileline.lineno}"
|
||||
core.display_memory_info(fileline)
|
||||
|
||||
def import_vars(data):
|
||||
''' Load variables into current scopes
|
||||
example:
|
||||
|
|
|
@ -491,38 +491,10 @@ class ReflectionPad2d(Module):
|
|||
r = self.pl + w - 1
|
||||
t = self.pt
|
||||
b = self.pt + h - 1
|
||||
x_idx = np.zeros((oh,ow))
|
||||
y_idx = np.zeros((oh,ow))
|
||||
for j in range(oh):
|
||||
for i in range(ow):
|
||||
if i >= l and i <= r and j >= t and j <= b:
|
||||
x_idx[j,i] = i
|
||||
y_idx[j,i] = j
|
||||
elif i < l and j < t:
|
||||
x_idx[j,i] = 2 * l - i
|
||||
y_idx[j,i] = 2 * t - j
|
||||
elif i < l and j > b:
|
||||
x_idx[j,i] = 2 * l - i
|
||||
y_idx[j,i] = 2 * b - j
|
||||
elif i > r and j < t:
|
||||
x_idx[j,i] = 2 * r - i
|
||||
y_idx[j,i] = 2 * t - j
|
||||
elif i > r and j > b:
|
||||
x_idx[j,i] = 2 * r - i
|
||||
y_idx[j,i] = 2 * b - j
|
||||
elif i < l:
|
||||
x_idx[j,i] = 2 * l - i
|
||||
y_idx[j,i] = j
|
||||
elif i > r:
|
||||
x_idx[j,i] = 2 * r - i
|
||||
y_idx[j,i] = j
|
||||
elif j < t:
|
||||
x_idx[j,i] = i
|
||||
y_idx[j,i] = 2 * t - j
|
||||
elif j > b:
|
||||
x_idx[j,i] = i
|
||||
y_idx[j,i] = 2 * b - j
|
||||
return x.reindex([n,c,oh,ow], ["i0","i1","@e1(i2,i3)","@e0(i2,i3)"], extras=[jt.array(x_idx - self.pl), jt.array(y_idx - self.pt)])
|
||||
return x.reindex([n,c,oh,ow], ["i0","i1",
|
||||
f"i2<{t} ? {t}-i2 : i2 > {b} ? {h-1+b}-i2 : i2-{t}",
|
||||
f"i3<{l} ? {l}-i3 : i3 > {r} ? {w-1+r}-i3 : i3-{l}",
|
||||
])
|
||||
|
||||
class ZeroPad2d(Module):
|
||||
def __init__(self, padding):
|
||||
|
@ -580,38 +552,10 @@ class ReplicationPad2d(Module):
|
|||
r = self.pl + w - 1
|
||||
t = self.pt
|
||||
b = self.pt + h - 1
|
||||
x_idx = np.zeros((oh,ow))
|
||||
y_idx = np.zeros((oh,ow))
|
||||
for j in range(oh):
|
||||
for i in range(ow):
|
||||
if i >= l and i <= r and j >= t and j <= b:
|
||||
x_idx[j,i] = i
|
||||
y_idx[j,i] = j
|
||||
elif i < l and j < t:
|
||||
x_idx[j,i] = l
|
||||
y_idx[j,i] = t
|
||||
elif i < l and j > b:
|
||||
x_idx[j,i] = l
|
||||
y_idx[j,i] = b
|
||||
elif i > r and j < t:
|
||||
x_idx[j,i] = r
|
||||
y_idx[j,i] = t
|
||||
elif i > r and j > b:
|
||||
x_idx[j,i] = r
|
||||
y_idx[j,i] = b
|
||||
elif i < l:
|
||||
x_idx[j,i] = l
|
||||
y_idx[j,i] = j
|
||||
elif i > r:
|
||||
x_idx[j,i] = r
|
||||
y_idx[j,i] = j
|
||||
elif j < t:
|
||||
x_idx[j,i] = i
|
||||
y_idx[j,i] = t
|
||||
elif j > b:
|
||||
x_idx[j,i] = i
|
||||
y_idx[j,i] = b
|
||||
return x.reindex([n,c,oh,ow], ["i0","i1","@e1(i2,i3)","@e0(i2,i3)"], extras=[jt.array(x_idx - self.pl), jt.array(y_idx - self.pt)])
|
||||
return x.reindex([n,c,oh,ow], ["i0","i1",
|
||||
f"i2<{t} ? 0 : i2 > {b} ? {h-1} : i2-{t}",
|
||||
f"i3<{l} ? 0 : i3 > {r} ? {w-1} : i3-{l}"
|
||||
])
|
||||
|
||||
class PixelShuffle(Module):
|
||||
def __init__(self, upscale_factor):
|
||||
|
@ -638,7 +582,7 @@ class Sigmoid(Module):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
def execute(self, x) :
|
||||
return 1 / (1 + jt.exp(-x))
|
||||
return x.sigmoid()
|
||||
|
||||
def resize(x, size, mode="nearest"):
|
||||
img = x
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
|
||||
class TestMem(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
jt.clean()
|
||||
jt.gc()
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "no cuda found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_oom(self):
|
||||
backups = []
|
||||
jt.flags.use_cuda = 1
|
||||
|
||||
one_g = np.ones((1024*1024*1024//4,), "float32")
|
||||
|
||||
meminfo = jt.get_mem_info()
|
||||
n = int(meminfo.total_cuda_ram // (1024**3) * 1.5)
|
||||
|
||||
for i in range(n):
|
||||
a = jt.array(one_g)
|
||||
b = a + 1
|
||||
b.sync()
|
||||
backups.append((a,b))
|
||||
backups = []
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -141,6 +141,11 @@ jt.mkl_ops.mkl_conv(x, w, 1, 2).sync()
|
|||
a = m(jt.array([1000]))
|
||||
assert np.isnan(a.data).sum()==0, a
|
||||
|
||||
def test_sigmoid_nan(self):
|
||||
a = jt.float32([1,-1, -1000.1])
|
||||
da = jt.grad(a.sigmoid(), a)
|
||||
assert np.isnan(da.data).sum()==0, da.data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -39,6 +39,8 @@ class TestPad(unittest.TestCase):
|
|||
arr = np.random.randn(16,3,224,224)
|
||||
check_equal(arr, jnn.ReplicationPad2d(10), tnn.ReplicationPad2d(10))
|
||||
check_equal(arr, jnn.ReplicationPad2d((1,23,4,5)), tnn.ReplicationPad2d((1,23,4,5)))
|
||||
check_equal(arr, jnn.ReplicationPad2d((1,0,1,5)), tnn.ReplicationPad2d((1,0,1,5)))
|
||||
check_equal(arr, jnn.ReplicationPad2d((100)), tnn.ReplicationPad2d((100)))
|
||||
|
||||
# ***************************************************************
|
||||
# Test ConstantPad2d Layer
|
||||
|
@ -60,6 +62,8 @@ class TestPad(unittest.TestCase):
|
|||
arr = np.random.randn(16,3,224,224)
|
||||
check_equal(arr, jnn.ReflectionPad2d(20), tnn.ReflectionPad2d(20))
|
||||
check_equal(arr, jnn.ReflectionPad2d((2,3,34,1)), tnn.ReflectionPad2d((2,3,34,1)))
|
||||
check_equal(arr, jnn.ReflectionPad2d((10,123,34,1)), tnn.ReflectionPad2d((10,123,34,1)))
|
||||
check_equal(arr, jnn.ReflectionPad2d((100)), tnn.ReflectionPad2d((100)))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -40,6 +40,7 @@ class TestUnaryOp(unittest.TestCase):
|
|||
"sin", "arcsin", "sinh", "arcsinh",
|
||||
"tan", "arctan", "tanh", "arctanh",
|
||||
"cos", "arccos", "cosh", "arccosh",
|
||||
"sigmoid",
|
||||
]
|
||||
a = [1.1, 2.2, 3.3, 4.4]
|
||||
for op in ops:
|
||||
|
@ -52,6 +53,8 @@ class TestUnaryOp(unittest.TestCase):
|
|||
else:
|
||||
b = np.array(a)
|
||||
func = lambda x: eval(f"np.{op}(x[0]).sum()")
|
||||
if op == "sigmoid":
|
||||
func = lambda x: (1/(1+np.exp(-x[0]))).sum()
|
||||
x, (da,) = ngrad(func, [b], 1e-8)
|
||||
ja = jt.array(b)
|
||||
jb = eval(f"jt.{op}(ja)")
|
||||
|
|
|
@ -1 +1 @@
|
|||
ee002b49b2fd09c70af20f5067a1667dcd07ec05
|
||||
08f4ca8b2c0a2978cd3fbc9a3a6e76bd1463ca12
|
||||
|
|
2
setup.py
2
setup.py
|
@ -21,7 +21,7 @@ with open(os.path.join(path, "README.src.md")) as fh:
|
|||
|
||||
setuptools.setup(
|
||||
name='jittor',
|
||||
version='1.0.1',
|
||||
version='1.1.1',
|
||||
# scripts=[],
|
||||
author="Jittor Group",
|
||||
author_email="ran.donglang@gmail.com",
|
||||
|
|
|
@ -380,6 +380,12 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
// var->finish_pending_liveness();
|
||||
var->finish_pending_liveness();
|
||||
} catch (const std::exception& e) {
|
||||
// log memory info
|
||||
display_memory_info(__FILELINE__);
|
||||
// log jit_key and file location
|
||||
op->do_prepare();
|
||||
string jit_src_path = Op::get_filename_from_jit_key(jk.to_cstring(), ".cc");
|
||||
LOGe << "[Error] source file location:" << jit_src_path;
|
||||
if (is_fused_op) {
|
||||
LOGf << "Execute fused operator(" >> rid >> '/' >> queue.size() >> ")"
|
||||
<< "failed:" << fused_op.ops << "\n\nReason: " >> e.what();
|
||||
|
|
|
@ -122,7 +122,7 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
|||
if (var->tflag == nt)
|
||||
grad = move(grads[var->custom_data]);
|
||||
if (!grad) {
|
||||
LOGvvv << var << "grads[">>i>>"] set to zero";
|
||||
LOGw << "grads[">>i>>"] doesn't have gradient. It will be set to zero:" << var;
|
||||
grad = make_number(0.f, var);
|
||||
assign_attrs(grad.ptr, var);
|
||||
registe_node_trace_grad(grad.ptr, var, 0);
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
// ***************************************************************
|
||||
#include <typeinfo>
|
||||
#include "misc/cuda_flags.h"
|
||||
|
||||
#include "mem/allocator/aligned_allocator.h"
|
||||
#ifdef HAS_CUDA
|
||||
#include "mem/allocator/cuda_managed_allocator.h"
|
||||
|
@ -84,5 +85,5 @@ Allocator* get_allocator() {
|
|||
void gc_all() {
|
||||
for (auto& kv : allocators) kv.second->gc();
|
||||
}
|
||||
|
||||
|
||||
} // jittor
|
|
@ -5,6 +5,7 @@
|
|||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
#include "mem/mem_info.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
|
|
@ -16,7 +16,14 @@ const char* CudaDeviceAllocator::name() const {return "cuda_device";}
|
|||
|
||||
void* CudaDeviceAllocator::alloc(size_t size, size_t& allocation) {
|
||||
void* ptr;
|
||||
checkCudaErrors(cudaMalloc(&ptr, size));
|
||||
try {
|
||||
checkCudaErrors(cudaMalloc(&ptr, size));
|
||||
return ptr;
|
||||
} catch (...) {}
|
||||
LOGw << "Unable to alloc cuda device memory, use unify memory instead. "
|
||||
"This may cause low performance.";
|
||||
display_memory_info(__FILELINE__);
|
||||
checkCudaErrors(cudaMallocManaged(&ptr, size));
|
||||
return ptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -19,5 +19,9 @@ struct StatAllocator : Allocator {
|
|||
};
|
||||
|
||||
DECLARE_FLAG(int, use_stat_allocator);
|
||||
DECLARE_FLAG(size_t, stat_allocator_total_alloc_call);
|
||||
DECLARE_FLAG(size_t, stat_allocator_total_alloc_byte);
|
||||
DECLARE_FLAG(size_t, stat_allocator_total_free_call);
|
||||
DECLARE_FLAG(size_t, stat_allocator_total_free_byte);
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,110 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <iomanip>
|
||||
#include <algorithm>
|
||||
#include <sys/sysinfo.h>
|
||||
|
||||
#include "var.h"
|
||||
#include "op.h"
|
||||
#include "var_holder.h"
|
||||
#include "graph.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "mem/allocator/sfrl_allocator.h"
|
||||
#include "mem/allocator/stat_allocator.h"
|
||||
#include "mem/mem_info.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct FloatOutput {
|
||||
double value;
|
||||
string scale;
|
||||
int base;
|
||||
string suffix;
|
||||
int p=4;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, const FloatOutput& o) {
|
||||
int w = 8;
|
||||
os << std::setw(w-2-o.suffix.size());
|
||||
os << std::setprecision(o.p);
|
||||
uint i=0;
|
||||
double k = o.value;
|
||||
for (; i+1<o.scale.size(); i++) {
|
||||
if (k<o.base) break;
|
||||
k /= o.base;
|
||||
}
|
||||
os << k << o.scale[i];
|
||||
return os << o.suffix;
|
||||
}
|
||||
|
||||
void display_memory_info(const char* fileline) {
|
||||
int p = 3;
|
||||
Log log(fileline, 'i', 0);
|
||||
log << "\n=== display_memory_info ===\n";
|
||||
log << "total_cpu_ram:" <<
|
||||
FloatOutput{(double)mem_info.total_cpu_ram, " KMG", 1024, "B"};
|
||||
log << "total_cuda_ram:" <<
|
||||
FloatOutput{(double)mem_info.total_cuda_ram, " KMG", 1024, "B"} >> "\n";
|
||||
log << "hold_vars:" << VarHolder::hold_vars.size()
|
||||
<< "lived_vars:" << Var::number_of_lived_vars
|
||||
<< "lived_ops:" << Op::number_of_lived_ops >> '\n';
|
||||
|
||||
#ifdef NODE_MEMCHECK
|
||||
// get the oldest var
|
||||
vector<Node*> queue;
|
||||
auto t = ++Node::tflag_count;
|
||||
for (auto& vh : VarHolder::hold_vars)
|
||||
if (vh->var->tflag != t) {
|
||||
vh->var->tflag = t;
|
||||
queue.push_back(vh->var);
|
||||
}
|
||||
bfs_both(queue, [](Node*){return true;});
|
||||
vector<pair<int64, Node*>> nodes;
|
||||
nodes.reserve(queue.size());
|
||||
for (auto* node : queue)
|
||||
nodes.push_back({node->__id(), node});
|
||||
std::sort(nodes.begin(), nodes.end());
|
||||
log << "list of the oldest nodes:\n";
|
||||
for (int i=0; i<10 && i<nodes.size(); i++) {
|
||||
log << "ID#" >> nodes[i].first >> ":" << nodes[i].second << "\n";
|
||||
}
|
||||
#endif
|
||||
|
||||
if (use_stat_allocator) {
|
||||
log << "stat:" << use_stat_allocator;
|
||||
log << "total alloc:" << FloatOutput{(double)(stat_allocator_total_alloc_byte
|
||||
- stat_allocator_total_free_byte), " KMG", 1024, "B"};
|
||||
log << "total alloc call:" << FloatOutput{(double)(stat_allocator_total_alloc_call
|
||||
- stat_allocator_total_free_call), " KMG", 1000, ""} >> '\n';
|
||||
}
|
||||
for (auto& a : SFRLAllocator::sfrl_allocators) {
|
||||
auto total = a->used_memory + a->unused_memory;
|
||||
log << "name:" << a->name() << "is_cuda:" << a->is_cuda()
|
||||
<< "used:" << FloatOutput{(double)a->used_memory, " KMG", 1024, "B"}
|
||||
>> "(" >> std::setprecision(p) >> a->used_memory*100.0 / total >> "%)"
|
||||
<< "unused:" << FloatOutput{(double)a->unused_memory, " KMG", 1024, "B"}
|
||||
>> "(" >> std::setprecision(p) >> a->unused_memory*100.0 / total >> "%)"
|
||||
<< "total:" << FloatOutput{(double)total, " KMG", 1024, "B"} >> "\n";
|
||||
}
|
||||
log >> "===========================\n";
|
||||
log.end();
|
||||
}
|
||||
|
||||
MemInfo::MemInfo() {
|
||||
struct sysinfo info = {0};
|
||||
sysinfo(&info);
|
||||
total_cpu_ram = info.totalram;
|
||||
total_cuda_ram = 0;
|
||||
#ifdef HAS_CUDA
|
||||
cudaDeviceProp prop = {0};
|
||||
cudaGetDeviceProperties(&prop, 0);
|
||||
total_cuda_ram = prop.totalGlobalMem;
|
||||
#endif
|
||||
}
|
||||
|
||||
MemInfo mem_info;
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,31 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
// @pyjt(display_memory_info)
|
||||
void display_memory_info(const char* fileline="");
|
||||
|
||||
// @pyjt(MemInfo)
|
||||
struct MemInfo {
|
||||
// @pyjt(total_cpu_ram)
|
||||
int64 total_cpu_ram;
|
||||
// @pyjt(total_cuda_ram)
|
||||
int64 total_cuda_ram;
|
||||
|
||||
inline MemInfo(const MemInfo&) = default;
|
||||
|
||||
MemInfo();
|
||||
};
|
||||
|
||||
extern MemInfo mem_info;
|
||||
|
||||
// @pyjt(get_mem_info)
|
||||
inline MemInfo get_mem_info() { return mem_info; }
|
||||
|
||||
} // jittor
|
|
@ -73,6 +73,7 @@ static unordered_set<string> unary_ops = {
|
|||
"acos",
|
||||
"cosh",
|
||||
"acosh",
|
||||
"sigmoid",
|
||||
};
|
||||
|
||||
static unordered_set<string> unary_float_ops = {
|
||||
|
|
|
@ -76,6 +76,7 @@ namespace jittor {
|
|||
m(acos) \
|
||||
m(cosh) \
|
||||
m(acosh) \
|
||||
m(sigmoid) \
|
||||
|
||||
struct NanoString;
|
||||
#define DECLEAR_NS(T) extern NanoString ns_##T;
|
||||
|
|
|
@ -108,7 +108,12 @@ struct NanoVector {
|
|||
|
||||
// @pyjt(__map_getitem__)
|
||||
inline NanoVector slice(Slice slice) {
|
||||
if (slice.mask&2) slice.stop = size();
|
||||
if (slice.step>0) {
|
||||
if (slice.mask&2) slice.stop = size();
|
||||
} else {
|
||||
if (slice.mask&1) slice.start = size()-1;
|
||||
if (slice.mask&2) slice.stop = 0;
|
||||
}
|
||||
if (slice.start<0) slice.start += size();
|
||||
if (slice.stop<0) slice.stop += size();
|
||||
ASSERT(slice.start>=0 && slice.stop>=0 && slice.start<size() && slice.stop<=size())
|
||||
|
|
11
src/node.h
11
src/node.h
|
@ -12,6 +12,7 @@
|
|||
namespace jittor {
|
||||
|
||||
extern unordered_map<void*, int64> lived_nodes;
|
||||
extern int64 total_node;
|
||||
|
||||
struct NodeFlags {
|
||||
typedef uint16 nf_t;
|
||||
|
@ -107,7 +108,7 @@ struct Node {
|
|||
|
||||
#ifdef NODE_MEMCHECK
|
||||
inline Node() {
|
||||
lived_nodes[(void*)this] = lived_nodes.size()+1;
|
||||
lived_nodes[(void*)this] = ++total_node;
|
||||
registe_node_trace(this);
|
||||
}
|
||||
|
||||
|
@ -132,7 +133,13 @@ struct Node {
|
|||
#endif
|
||||
}
|
||||
void memcheck_all_exist() const;
|
||||
int64 __id() const;
|
||||
inline int64 __id() const {
|
||||
#ifdef NODE_MEMCHECK
|
||||
return lived_nodes.at((void*)this);
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
// release from counter and memory checker
|
||||
void __release();
|
||||
#define CHECK_NODE_EXIST(node) \
|
||||
|
|
|
@ -758,7 +758,7 @@ string OpCompiler::__get_fused_src(
|
|||
"for", "const", "auto", "get_random_engine",
|
||||
"int", "float", "bool", "CHECK", "STRINGIZE",
|
||||
"void", "__restrict__", "if", "true", "false",
|
||||
"Op", "Var", "Node", "itof"
|
||||
"Op", "Var", "Node", "itof", "assert", "ASSERT"
|
||||
};
|
||||
auto not_change = [&](const string& s) -> bool {
|
||||
if (unchanged.count(s)) return true;
|
||||
|
@ -914,7 +914,9 @@ string OpCompiler::__get_fused_src(
|
|||
fused_kernel = fused_kernel_args + "\n" + fused_kernel;
|
||||
LOGvvvv << "Fused kernel:\n" >> fused_kernel;
|
||||
|
||||
auto fused_src = fused_begin + fused_includes + "\n#include \"fused_op.h\"\n" +
|
||||
auto fused_src = fused_begin + fused_includes +
|
||||
"\n#include <assert.h>\n" +
|
||||
"\n#include \"fused_op.h\"\n" +
|
||||
fused_defines + '\n' +
|
||||
"void jittor::FusedOp::jit_run() {\n" + fused_kernel + "\n}\n";
|
||||
|
||||
|
|
|
@ -76,7 +76,7 @@ struct CodeOp : Op {
|
|||
#include <algorithm>
|
||||
@alias(a, in0)
|
||||
@alias(b, out)
|
||||
"""",
|
||||
""",
|
||||
cpu_src="""
|
||||
for (int i=0; i<a_shape0; i++)
|
||||
@b(i) = @a(i);
|
||||
|
|
|
@ -22,7 +22,7 @@ struct OpInfo {
|
|||
for (uint i=0; i<constructors.size(); i++)
|
||||
if (std::type_index(*(constructors[i].first)) == std::type_index(tid))
|
||||
return func_t(constructors[i].second);
|
||||
LOGf << "constructor" << tid.name() << "not found.";
|
||||
LOGf << "constructor" << name << tid.name() << "not found.";
|
||||
return func_t(nullptr);
|
||||
}
|
||||
};
|
||||
|
|
|
@ -67,6 +67,7 @@ static unordered_set<string> unary_ops = {
|
|||
"cosh",
|
||||
// @pybind(acosh, arccosh)
|
||||
"acosh",
|
||||
"sigmoid",
|
||||
};
|
||||
|
||||
UnaryOp::UnaryOp(Var* x, NanoString op) : x(x) {
|
||||
|
@ -183,6 +184,12 @@ VarPtr UnaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
x2 = make_binary(one, x2, ns_subtract);
|
||||
return make_binary(dout, x2, ns_divide);
|
||||
}
|
||||
// dsigmoid(x) = sigmoid(x) - sigmoid(x)^2
|
||||
if (ns == ns_sigmoid) {
|
||||
auto r = make_binary(out, out, ns_multiply);
|
||||
r = make_binary(out, r, ns_subtract);
|
||||
return make_binary(dout, r, ns_multiply);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -35,6 +35,8 @@ namespace jittor {
|
|||
#define tanh(T,x) ((T) ::tanhf((x)))
|
||||
#define atanh(T,x) ((T) ::atanhf((x)))
|
||||
|
||||
#define sigmoid(T,x) ((T) (1.0f/(1.0f+::expf(-(x)))))
|
||||
|
||||
#else
|
||||
#define abs(T,x) std::abs(x)
|
||||
#define log(T,x) std::log((T)(x))
|
||||
|
@ -59,6 +61,8 @@ namespace jittor {
|
|||
#define tanh(T,x) ((T) std::tanh((x)))
|
||||
#define atanh(T,x) ((T) std::atanh((x)))
|
||||
|
||||
#define sigmoid(T,x) ((T) (1.0f/(1.0f+std::exp(-(x)))))
|
||||
|
||||
#endif
|
||||
|
||||
#define cast(T,x) ((T)(x))
|
||||
|
|
|
@ -170,6 +170,10 @@ void LoopVarAnalyzePass::run() {
|
|||
for (uint j=0; j<ndim; j++)
|
||||
if (!(mask>>j&1) && j<loop_var_names.size()) {
|
||||
for (auto& vname : vnames) {
|
||||
// cannot replace extras shape
|
||||
// TODO: optimize it
|
||||
if (vname.find("extras") != string::npos)
|
||||
continue;
|
||||
// replace op{i}_{vname}shape{j} -> {loop_var_names[j]}
|
||||
std::stringstream name1;
|
||||
name1 << vname<<"shape"<<j;
|
||||
|
@ -193,7 +197,7 @@ void LoopVarAnalyzePass::run() {
|
|||
replace_vars.emplace_back(name1, name2);
|
||||
}
|
||||
|
||||
LOGvvvv << "replace_vars" << replace_vars;
|
||||
LOGvvv << "replace_vars" << replace_vars;
|
||||
ir->replace(replace_vars);
|
||||
LOGvvvv << "KernelIR after replace\n" >> ir->to_string(0, true);
|
||||
// move define
|
||||
|
|
|
@ -160,6 +160,24 @@ struct OpInspector {
|
|||
return 0;
|
||||
}
|
||||
|
||||
string format2(const string& fmt, const vector<int>& order) {
|
||||
string new_fmt = fmt;
|
||||
if (order.size() != fmt.size()) {
|
||||
failed = 1;
|
||||
return "";
|
||||
}
|
||||
if (check_overlap(order))
|
||||
return "";
|
||||
for (uint i=0; i<order.size(); i++) {
|
||||
if (order[i]>=(int)new_fmt.size()) {
|
||||
failed = 1;
|
||||
return "";
|
||||
}
|
||||
new_fmt[order[i]] = fmt[i];
|
||||
}
|
||||
return new_fmt;
|
||||
}
|
||||
|
||||
string format(const string& fmt, const vector<int>& order) {
|
||||
string new_fmt = fmt;
|
||||
if (order.size() != fmt.size()) {
|
||||
|
@ -234,19 +252,19 @@ void ConvTuner::forwardTune(FusedOp* fop) {
|
|||
xh = xoi.mm[zh];
|
||||
xw = xoi.mm[zw];
|
||||
LOGvvvv << "xnchw =" << vector<int>{xn,xc,xh,xw};
|
||||
auto xformat = xoi.format("abcd", {xn, xc, xh, xw});
|
||||
auto xformat = xoi.format2("abcd", {xn, xc, xh, xw});
|
||||
LOGvvvv << "xformat =" << xformat;
|
||||
wci = woi.mm[zci];
|
||||
wco = woi.mm[zco];
|
||||
wh = woi.mm[zwh];
|
||||
ww = woi.mm[zww];
|
||||
auto wformat = xoi.format("iohw", {wci, wco, wh, ww});
|
||||
auto wformat = xoi.format2("iohw", {wci, wco, wh, ww});
|
||||
LOGvvvv << "wformat =" << wformat;
|
||||
yn = yoi.mm[zn];
|
||||
yc = yoi.mm[zco];
|
||||
yh = yoi.mm[zh];
|
||||
yw = yoi.mm[zw];
|
||||
auto yformat = xoi.format("abcd", {yn, yc, yh, yw});
|
||||
auto yformat = xoi.format2("abcd", {yn, yc, yh, yw});
|
||||
LOGvvvv << "yformat =" << yformat;
|
||||
// mkl doesn't support "cdab" format
|
||||
if (yformat == "cdab") continue;
|
||||
|
@ -307,7 +325,7 @@ void ConvTuner::forwardTune(FusedOp* fop) {
|
|||
continue;
|
||||
auto make_conv = get_op_info(relay_conv_name)
|
||||
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, string, string, string>();
|
||||
auto rvar = make_conv(x, w, stride, padding, dilation, 1, xformat, wformat, yformat);
|
||||
auto rvar = make_conv(x, w, stride, padding, dilation, 1, xformat, wformat, yformat);
|
||||
auto rid = fop->context->vrm.add_relay_group({{rvar, rop->y}});
|
||||
if (rid>=0) {
|
||||
auto srid = "relay"+S(rid);
|
||||
|
@ -359,19 +377,19 @@ void ConvTuner::backwardTune(FusedOp* fop) {
|
|||
xh = xoi.mm[zh];
|
||||
xw = xoi.mm[zw];
|
||||
LOGvvvv << "xnchw =" << vector<int>{xn,xc,xh,xw};
|
||||
xformat = xoi.format("abcd", {xn, xc, xh, xw});
|
||||
xformat = xoi.format2("abcd", {xn, xc, xh, xw});
|
||||
LOGvvvv << "xformat =" << xformat;
|
||||
wci = woi.mm[zci];
|
||||
wco = woi.mm[zco];
|
||||
wh = woi.mm[zwh];
|
||||
ww = woi.mm[zww];
|
||||
wformat = xoi.format("iohw", {wci, wco, wh, ww});
|
||||
wformat = xoi.format2("iohw", {wci, wco, wh, ww});
|
||||
LOGvvvv << "wformat =" << wformat;
|
||||
yn = yoi.mm[zn];
|
||||
yc = yoi.mm[zco];
|
||||
yh = yoi.mm[zh];
|
||||
yw = yoi.mm[zw];
|
||||
yformat = xoi.format("abcd", {yn, yc, yh, yw});
|
||||
yformat = xoi.format2("abcd", {yn, yc, yh, yw});
|
||||
LOGvvvv << "yformat =" << yformat;
|
||||
// mkl doesn't support "cdab" format
|
||||
if (yformat == "cdab") continue;
|
||||
|
@ -434,19 +452,19 @@ void ConvTuner::backwardTune(FusedOp* fop) {
|
|||
xh = xoi.mm[zh];
|
||||
xw = xoi.mm[zw];
|
||||
LOGvvvv << "xnchw =" << vector<int>{xn,xc,xh,xw};
|
||||
xformat = xoi.format("abcd", {xn, xc, xh, xw});
|
||||
xformat = xoi.format2("abcd", {xn, xc, xh, xw});
|
||||
LOGvvvv << "xformat =" << xformat;
|
||||
wci = woi.mm[zci];
|
||||
wco = woi.mm[zco];
|
||||
wh = woi.mm[zwh];
|
||||
ww = woi.mm[zww];
|
||||
wformat = xoi.format("iohw", {wci, wco, wh, ww});
|
||||
wformat = xoi.format2("iohw", {wci, wco, wh, ww});
|
||||
LOGvvvv << "wformat =" << wformat;
|
||||
yn = yoi.mm[zn];
|
||||
yc = yoi.mm[zco];
|
||||
yh = yoi.mm[zh];
|
||||
yw = yoi.mm[zw];
|
||||
yformat = xoi.format("abcd", {yn, yc, yh, yw});
|
||||
yformat = xoi.format2("abcd", {yn, yc, yh, yw});
|
||||
LOGvvvv << "yformat =" << yformat;
|
||||
// mkl doesn't support "cdab" format
|
||||
if (yformat == "cdab") continue;
|
||||
|
|
|
@ -15,23 +15,33 @@ using namespace pybind11::literals;
|
|||
|
||||
namespace jittor {
|
||||
|
||||
DEFINE_FLAG(int, trace_py_var, 0, "Trace py stack for debug.");
|
||||
DEFINE_FLAG(int, trace_py_var, 0, "Trace py stack max depth for debug.");
|
||||
|
||||
unordered_map<const Node*, string> trace_data;
|
||||
|
||||
void __registe_node_trace(Node* node) {
|
||||
auto py_stack =
|
||||
auto py_stacks =
|
||||
py::module::import("traceback")
|
||||
.attr("extract_stack")(nullptr, 1).attr("__getitem__")(0);
|
||||
auto filename = py_stack.attr("filename").cast<string>();
|
||||
auto basename = split(filename, "/").back();
|
||||
basename += ':';
|
||||
basename += py_stack.attr("name").cast<string>();
|
||||
basename += ':';
|
||||
basename += S(py_stack.attr("lineno").cast<int>());
|
||||
basename += ':';
|
||||
basename += py_stack.attr("line").cast<string>();
|
||||
trace_data[node] = basename;
|
||||
.attr("extract_stack")(nullptr, trace_py_var);
|
||||
auto len = py_stacks.attr("__len__")().cast<int>();
|
||||
string info;
|
||||
for (int i=0; i<len; i++) {
|
||||
auto py_stack = py_stacks.attr("__getitem__")(i);
|
||||
auto filename = py_stack.attr("filename").cast<string>();
|
||||
if (len==1)
|
||||
info += split(filename, "/").back();
|
||||
else {
|
||||
info += "\n ";
|
||||
info += filename;
|
||||
}
|
||||
info += ':';
|
||||
info += py_stack.attr("name").cast<string>();
|
||||
info += ':';
|
||||
info += S(py_stack.attr("lineno").cast<int>());
|
||||
info += ':';
|
||||
info += py_stack.attr("line").cast<string>();
|
||||
}
|
||||
trace_data[node] = info;
|
||||
}
|
||||
|
||||
void __unregiste_node_trace(Node* node) {
|
||||
|
|
|
@ -158,6 +158,25 @@ DEF_IS(DumpGraphs, const T&) from_py_object(PyObject* obj) {
|
|||
return GET_RAW_PTR(T, obj);
|
||||
}
|
||||
|
||||
// MemInfo
|
||||
struct MemInfo;
|
||||
extern PyTypeObject PyjtMemInfo;
|
||||
DEF_IS(MemInfo, bool) is_type(PyObject* obj) {
|
||||
return Py_TYPE(obj) == &PyjtMemInfo;
|
||||
}
|
||||
|
||||
|
||||
DEF_IS(MemInfo, PyObject*) to_py_object(const T& a) {
|
||||
PyObjHolder obj(_PyObject_New(&PyjtMemInfo));
|
||||
auto ptr = GET_RAW_PTR(T, obj.obj);
|
||||
new (ptr) T(a);
|
||||
return obj.release();
|
||||
}
|
||||
|
||||
DEF_IS(MemInfo, const T&) from_py_object(PyObject* obj) {
|
||||
return GET_RAW_PTR(T, obj);
|
||||
}
|
||||
|
||||
|
||||
// NanoString
|
||||
struct NanoString;
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <sstream>
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
|
@ -122,4 +123,10 @@ vector<ArrayArgs> fetch_sync(const vector<VarHolder*>& vh) {
|
|||
return ret;
|
||||
}
|
||||
|
||||
string VarHolder::debug_msg() {
|
||||
std::stringstream ss;
|
||||
ss << var;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -154,6 +154,9 @@ struct VarHolder {
|
|||
#endif
|
||||
std::memcpy(var->mem_ptr, array.ptr, size);
|
||||
}
|
||||
|
||||
// @pyjt(debug_msg)
|
||||
string debug_msg();
|
||||
};
|
||||
|
||||
// @pyjt(sync)
|
||||
|
|
Loading…
Reference in New Issue