Merge branch 'master' of https://github.com/Jittor/jittor into test_init

This commit is contained in:
guowei yang 2020-05-11 15:56:47 +08:00
commit 8cc5b0cce1
41 changed files with 373 additions and 115 deletions

View File

@ -6,6 +6,7 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "mem/allocator.h"
#include "var.h"
#include "cudnn_conv_backward_w_op.h"
#include "cudnn_warper.h"
@ -195,6 +196,8 @@ void CudnnConvBackwardWOp::jit_run() {
for (int i = 0; i < num_algos; i++) {
size_t sz;
cudnnStatus_t ret = cudnnGetConvolutionBackwardFilterWorkspaceSize(handle_, cudnnIdesc, cudnnOdesc, cudnnConvDesc, cudnnFdesc, algos[i], &sz);
// continue if use too much workspace
if (sz*4 > mem_info.total_cuda_ram) continue;
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
}
size_t allocation;

View File

@ -6,6 +6,7 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "mem/allocator.h"
#include "var.h"
#include "cudnn_conv_backward_x_op.h"
#include "cudnn_warper.h"
@ -196,6 +197,8 @@ void CudnnConvBackwardXOp::jit_run() {
for (int i = 0; i < num_algos; i++) {
size_t sz;
cudnnStatus_t ret = cudnnGetConvolutionBackwardDataWorkspaceSize(handle_, cudnnFdesc, cudnnOdesc, cudnnConvDesc, cudnnIdesc, algos[i], &sz);
// continue if use too much workspace
if (sz*4 > mem_info.total_cuda_ram) continue;
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
}
size_t allocation;

View File

@ -199,9 +199,11 @@ void CudnnConvOp::jit_run() {
for (int i = 0; i < num_algos; i++) {
size_t sz;
cudnnStatus_t ret = cudnnGetConvolutionForwardWorkspaceSize(
handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc,
cudnnOdesc, algos[i], &sz);
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size && sz<512*1024*1024) max_ws_size = sz;
handle_, cudnnIdesc, cudnnFdesc, cudnnConvDesc,
cudnnOdesc, algos[i], &sz);
// continue if use too much workspace
if (sz*4 > mem_info.total_cuda_ram) continue;
if (CUDNN_STATUS_SUCCESS == ret && sz > max_ws_size) max_ws_size = sz;
}
size_t allocation;
void* ws = exe.allocator->alloc(max_ws_size, allocation);

View File

@ -105,7 +105,7 @@ template <typename T>
void check(T result, char const *const func, const char *const file,
int const line) {
if (result) {
DEVICE_RESET
// DEVICE_RESET
LOGf << "CUDA error at" << file >> ":" >> line << " code="
>> static_cast<unsigned int>(result) >> "(" << _cudaGetErrorEnum(result) << ")"
<< func;
@ -125,7 +125,7 @@ inline void __getLastCudaError(const char *errorMessage, const char *file,
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
DEVICE_RESET
// DEVICE_RESET
LOGf << "CUDA error at" << file >> ":" >> line << " code="
>> static_cast<unsigned int>(err) >> "(" << _cudaGetErrorEnum(err) << ")"
<< errorMessage;
@ -141,7 +141,7 @@ inline void __printLastCudaError(const char *errorMessage, const char *file,
cudaError_t err = cudaGetLastError();
if (cudaSuccess != err) {
DEVICE_RESET
// DEVICE_RESET
LOGf << "CUDA error at" << file >> ":" >> line << " code="
>> static_cast<unsigned int>(err) >> "(" << _cudaGetErrorEnum(err) << ")"
<< errorMessage;

View File

@ -45,7 +45,7 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
shape[0], shape[1], shape[2], shape[3]));
}
MklConvBackwardWOp::MklConvBackwardWOp(Var* x, Var* dy, int kernel_size, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
MklConvBackwardWOp::MklConvBackwardWOp(Var* x, Var* dy, int kernel_size, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
: x(x), dy(dy), kernel_size(kernel_size), stride(stride), padding(padding), dilation(dilation),
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
dw = create_output(nullptr, dtype_infer(dy->ns, x->ns));

View File

@ -16,7 +16,7 @@ struct MklConvBackwardWOp : Op {
int kernel_size, stride, padding, dilation;
string xformat, wformat, yformat;
MklConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, string xformat="abcd", string wformat="oihw", string yformat="abcd");
MklConvBackwardWOp(Var* x, Var* y, int kernel_size, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="abcd");
const char* name() const override { return "mkl_conv_backward_w"; }
void infer_shape() override;

View File

@ -45,7 +45,7 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
shape[0], shape[1], shape[2], shape[3]));
}
MklConvBackwardXOp::MklConvBackwardXOp(Var* w, Var* dy, int height, int width, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
MklConvBackwardXOp::MklConvBackwardXOp(Var* w, Var* dy, int height, int width, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
: w(w), dy(dy), xh(height), xw(width), stride(stride), padding(padding), dilation(dilation),
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
dx = create_output(nullptr, dtype_infer(dy->ns, w->ns));

View File

@ -16,7 +16,7 @@ struct MklConvBackwardXOp : Op {
int xh, xw, stride, padding, dilation;
string xformat, wformat, yformat;
MklConvBackwardXOp(Var* w, Var* y, int height, int width, int stride, int padding, int dilation, string xformat="abcd", string wformat="oihw", string yformat="abcd");
MklConvBackwardXOp(Var* w, Var* y, int height, int width, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="abcd");
const char* name() const override { return "mkl_conv_backward_x"; }
void infer_shape() override;

View File

@ -44,7 +44,7 @@ static inline void set_shape(Var* x, const char* f, const string& format, int a,
shape[0], shape[1], shape[2], shape[3]));
}
MklConvOp::MklConvOp(Var* x, Var* w, int stride, int padding, int dilation, string xformat, string wformat, string yformat)
MklConvOp::MklConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups, string xformat, string wformat, string yformat)
: x(x), w(w), stride(stride), padding(padding), dilation(dilation),
xformat(move(xformat)), wformat(move(wformat)), yformat(move(yformat)) {
y = create_output(nullptr, dtype_infer(x->ns, w->ns));

View File

@ -16,7 +16,7 @@ struct MklConvOp : Op {
int stride, padding, dilation;
string xformat, wformat, yformat;
/* MklConvOp: xformat abcd represents nchw */
MklConvOp(Var* x, Var* w, int stride, int padding, int dilation=1, string xformat="abcd", string wformat="oihw", string yformat="");
MklConvOp(Var* x, Var* w, int stride, int padding, int dilation, int groups, string xformat="abcd", string wformat="oihw", string yformat="");
const char* name() const override { return "mkl_conv"; }
void infer_shape() override;

View File

@ -16,8 +16,9 @@ with lock.lock_scope():
from jittor_core import *
from jittor_core.ops import *
from . import compile_extern
from .compile_extern import mkl_ops, mpi, mpi_ops, \
cudnn, curand, cublas
from .compile_extern import mkl_ops, mpi, mpi_ops
if has_cuda:
from .compile_extern import cudnn, curand, cublas
import contextlib
import numpy as np
@ -403,7 +404,8 @@ Var.unsqueeze = unsqueeze
def squeeze(x, dim):
shape = list(x.shape)
assert dim < len(shape)
if dim < 0: dim += len(shape)
assert dim < len(shape) and dim >= 0
assert shape[dim] == 1
return x.reshape(shape[:dim] + shape[dim+1:])
Var.squeeze = squeeze
@ -447,6 +449,13 @@ def fetch_var(var, func, *args, **kw):
Var.fetch = fetch_var
del fetch_var
def display_memory_info():
import inspect, os
f = inspect.currentframe()
fileline = inspect.getframeinfo(f.f_back)
fileline = f"{os.path.basename(fileline.filename)}:{fileline.lineno}"
core.display_memory_info(fileline)
def import_vars(data):
''' Load variables into current scopes
example:

View File

@ -491,38 +491,10 @@ class ReflectionPad2d(Module):
r = self.pl + w - 1
t = self.pt
b = self.pt + h - 1
x_idx = np.zeros((oh,ow))
y_idx = np.zeros((oh,ow))
for j in range(oh):
for i in range(ow):
if i >= l and i <= r and j >= t and j <= b:
x_idx[j,i] = i
y_idx[j,i] = j
elif i < l and j < t:
x_idx[j,i] = 2 * l - i
y_idx[j,i] = 2 * t - j
elif i < l and j > b:
x_idx[j,i] = 2 * l - i
y_idx[j,i] = 2 * b - j
elif i > r and j < t:
x_idx[j,i] = 2 * r - i
y_idx[j,i] = 2 * t - j
elif i > r and j > b:
x_idx[j,i] = 2 * r - i
y_idx[j,i] = 2 * b - j
elif i < l:
x_idx[j,i] = 2 * l - i
y_idx[j,i] = j
elif i > r:
x_idx[j,i] = 2 * r - i
y_idx[j,i] = j
elif j < t:
x_idx[j,i] = i
y_idx[j,i] = 2 * t - j
elif j > b:
x_idx[j,i] = i
y_idx[j,i] = 2 * b - j
return x.reindex([n,c,oh,ow], ["i0","i1","@e1(i2,i3)","@e0(i2,i3)"], extras=[jt.array(x_idx - self.pl), jt.array(y_idx - self.pt)])
return x.reindex([n,c,oh,ow], ["i0","i1",
f"i2<{t} ? {t}-i2 : i2 > {b} ? {h-1+b}-i2 : i2-{t}",
f"i3<{l} ? {l}-i3 : i3 > {r} ? {w-1+r}-i3 : i3-{l}",
])
class ZeroPad2d(Module):
def __init__(self, padding):
@ -580,38 +552,10 @@ class ReplicationPad2d(Module):
r = self.pl + w - 1
t = self.pt
b = self.pt + h - 1
x_idx = np.zeros((oh,ow))
y_idx = np.zeros((oh,ow))
for j in range(oh):
for i in range(ow):
if i >= l and i <= r and j >= t and j <= b:
x_idx[j,i] = i
y_idx[j,i] = j
elif i < l and j < t:
x_idx[j,i] = l
y_idx[j,i] = t
elif i < l and j > b:
x_idx[j,i] = l
y_idx[j,i] = b
elif i > r and j < t:
x_idx[j,i] = r
y_idx[j,i] = t
elif i > r and j > b:
x_idx[j,i] = r
y_idx[j,i] = b
elif i < l:
x_idx[j,i] = l
y_idx[j,i] = j
elif i > r:
x_idx[j,i] = r
y_idx[j,i] = j
elif j < t:
x_idx[j,i] = i
y_idx[j,i] = t
elif j > b:
x_idx[j,i] = i
y_idx[j,i] = b
return x.reindex([n,c,oh,ow], ["i0","i1","@e1(i2,i3)","@e0(i2,i3)"], extras=[jt.array(x_idx - self.pl), jt.array(y_idx - self.pt)])
return x.reindex([n,c,oh,ow], ["i0","i1",
f"i2<{t} ? 0 : i2 > {b} ? {h-1} : i2-{t}",
f"i3<{l} ? 0 : i3 > {r} ? {w-1} : i3-{l}"
])
class PixelShuffle(Module):
def __init__(self, upscale_factor):
@ -638,7 +582,7 @@ class Sigmoid(Module):
def __init__(self):
super().__init__()
def execute(self, x) :
return 1 / (1 + jt.exp(-x))
return x.sigmoid()
def resize(x, size, mode="nearest"):
img = x

View File

@ -0,0 +1,37 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
class TestMem(unittest.TestCase):
def tearDown(self):
jt.clean()
jt.gc()
@unittest.skipIf(not jt.has_cuda, "no cuda found")
@jt.flag_scope(use_cuda=1)
def test_oom(self):
backups = []
jt.flags.use_cuda = 1
one_g = np.ones((1024*1024*1024//4,), "float32")
meminfo = jt.get_mem_info()
n = int(meminfo.total_cuda_ram // (1024**3) * 1.5)
for i in range(n):
a = jt.array(one_g)
b = a + 1
b.sync()
backups.append((a,b))
backups = []
if __name__ == "__main__":
unittest.main()

View File

@ -141,6 +141,11 @@ jt.mkl_ops.mkl_conv(x, w, 1, 2).sync()
a = m(jt.array([1000]))
assert np.isnan(a.data).sum()==0, a
def test_sigmoid_nan(self):
a = jt.float32([1,-1, -1000.1])
da = jt.grad(a.sigmoid(), a)
assert np.isnan(da.data).sum()==0, da.data
if __name__ == "__main__":
unittest.main()

View File

@ -39,6 +39,8 @@ class TestPad(unittest.TestCase):
arr = np.random.randn(16,3,224,224)
check_equal(arr, jnn.ReplicationPad2d(10), tnn.ReplicationPad2d(10))
check_equal(arr, jnn.ReplicationPad2d((1,23,4,5)), tnn.ReplicationPad2d((1,23,4,5)))
check_equal(arr, jnn.ReplicationPad2d((1,0,1,5)), tnn.ReplicationPad2d((1,0,1,5)))
check_equal(arr, jnn.ReplicationPad2d((100)), tnn.ReplicationPad2d((100)))
# ***************************************************************
# Test ConstantPad2d Layer
@ -60,6 +62,8 @@ class TestPad(unittest.TestCase):
arr = np.random.randn(16,3,224,224)
check_equal(arr, jnn.ReflectionPad2d(20), tnn.ReflectionPad2d(20))
check_equal(arr, jnn.ReflectionPad2d((2,3,34,1)), tnn.ReflectionPad2d((2,3,34,1)))
check_equal(arr, jnn.ReflectionPad2d((10,123,34,1)), tnn.ReflectionPad2d((10,123,34,1)))
check_equal(arr, jnn.ReflectionPad2d((100)), tnn.ReflectionPad2d((100)))
if __name__ == "__main__":
unittest.main()

View File

@ -40,6 +40,7 @@ class TestUnaryOp(unittest.TestCase):
"sin", "arcsin", "sinh", "arcsinh",
"tan", "arctan", "tanh", "arctanh",
"cos", "arccos", "cosh", "arccosh",
"sigmoid",
]
a = [1.1, 2.2, 3.3, 4.4]
for op in ops:
@ -52,6 +53,8 @@ class TestUnaryOp(unittest.TestCase):
else:
b = np.array(a)
func = lambda x: eval(f"np.{op}(x[0]).sum()")
if op == "sigmoid":
func = lambda x: (1/(1+np.exp(-x[0]))).sum()
x, (da,) = ngrad(func, [b], 1e-8)
ja = jt.array(b)
jb = eval(f"jt.{op}(ja)")

View File

@ -1 +1 @@
ee002b49b2fd09c70af20f5067a1667dcd07ec05
08f4ca8b2c0a2978cd3fbc9a3a6e76bd1463ca12

View File

@ -21,7 +21,7 @@ with open(os.path.join(path, "README.src.md")) as fh:
setuptools.setup(
name='jittor',
version='1.0.1',
version='1.1.1',
# scripts=[],
author="Jittor Group",
author_email="ran.donglang@gmail.com",

View File

@ -380,6 +380,12 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
// var->finish_pending_liveness();
var->finish_pending_liveness();
} catch (const std::exception& e) {
// log memory info
display_memory_info(__FILELINE__);
// log jit_key and file location
op->do_prepare();
string jit_src_path = Op::get_filename_from_jit_key(jk.to_cstring(), ".cc");
LOGe << "[Error] source file location:" << jit_src_path;
if (is_fused_op) {
LOGf << "Execute fused operator(" >> rid >> '/' >> queue.size() >> ")"
<< "failed:" << fused_op.ops << "\n\nReason: " >> e.what();

View File

@ -122,7 +122,7 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
if (var->tflag == nt)
grad = move(grads[var->custom_data]);
if (!grad) {
LOGvvv << var << "grads[">>i>>"] set to zero";
LOGw << "grads[">>i>>"] doesn't have gradient. It will be set to zero:" << var;
grad = make_number(0.f, var);
assign_attrs(grad.ptr, var);
registe_node_trace_grad(grad.ptr, var, 0);

View File

@ -5,6 +5,7 @@
// ***************************************************************
#include <typeinfo>
#include "misc/cuda_flags.h"
#include "mem/allocator/aligned_allocator.h"
#ifdef HAS_CUDA
#include "mem/allocator/cuda_managed_allocator.h"
@ -84,5 +85,5 @@ Allocator* get_allocator() {
void gc_all() {
for (auto& kv : allocators) kv.second->gc();
}
} // jittor

View File

@ -5,6 +5,7 @@
// ***************************************************************
#pragma once
#include "common.h"
#include "mem/mem_info.h"
namespace jittor {

View File

@ -16,7 +16,14 @@ const char* CudaDeviceAllocator::name() const {return "cuda_device";}
void* CudaDeviceAllocator::alloc(size_t size, size_t& allocation) {
void* ptr;
checkCudaErrors(cudaMalloc(&ptr, size));
try {
checkCudaErrors(cudaMalloc(&ptr, size));
return ptr;
} catch (...) {}
LOGw << "Unable to alloc cuda device memory, use unify memory instead. "
"This may cause low performance.";
display_memory_info(__FILELINE__);
checkCudaErrors(cudaMallocManaged(&ptr, size));
return ptr;
}

View File

@ -19,5 +19,9 @@ struct StatAllocator : Allocator {
};
DECLARE_FLAG(int, use_stat_allocator);
DECLARE_FLAG(size_t, stat_allocator_total_alloc_call);
DECLARE_FLAG(size_t, stat_allocator_total_alloc_byte);
DECLARE_FLAG(size_t, stat_allocator_total_free_call);
DECLARE_FLAG(size_t, stat_allocator_total_free_byte);
} // jittor

110
src/mem/mem_info.cc Normal file
View File

@ -0,0 +1,110 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <iomanip>
#include <algorithm>
#include <sys/sysinfo.h>
#include "var.h"
#include "op.h"
#include "var_holder.h"
#include "graph.h"
#include "misc/cuda_flags.h"
#include "mem/allocator/sfrl_allocator.h"
#include "mem/allocator/stat_allocator.h"
#include "mem/mem_info.h"
namespace jittor {
struct FloatOutput {
double value;
string scale;
int base;
string suffix;
int p=4;
};
std::ostream& operator<<(std::ostream& os, const FloatOutput& o) {
int w = 8;
os << std::setw(w-2-o.suffix.size());
os << std::setprecision(o.p);
uint i=0;
double k = o.value;
for (; i+1<o.scale.size(); i++) {
if (k<o.base) break;
k /= o.base;
}
os << k << o.scale[i];
return os << o.suffix;
}
void display_memory_info(const char* fileline) {
int p = 3;
Log log(fileline, 'i', 0);
log << "\n=== display_memory_info ===\n";
log << "total_cpu_ram:" <<
FloatOutput{(double)mem_info.total_cpu_ram, " KMG", 1024, "B"};
log << "total_cuda_ram:" <<
FloatOutput{(double)mem_info.total_cuda_ram, " KMG", 1024, "B"} >> "\n";
log << "hold_vars:" << VarHolder::hold_vars.size()
<< "lived_vars:" << Var::number_of_lived_vars
<< "lived_ops:" << Op::number_of_lived_ops >> '\n';
#ifdef NODE_MEMCHECK
// get the oldest var
vector<Node*> queue;
auto t = ++Node::tflag_count;
for (auto& vh : VarHolder::hold_vars)
if (vh->var->tflag != t) {
vh->var->tflag = t;
queue.push_back(vh->var);
}
bfs_both(queue, [](Node*){return true;});
vector<pair<int64, Node*>> nodes;
nodes.reserve(queue.size());
for (auto* node : queue)
nodes.push_back({node->__id(), node});
std::sort(nodes.begin(), nodes.end());
log << "list of the oldest nodes:\n";
for (int i=0; i<10 && i<nodes.size(); i++) {
log << "ID#" >> nodes[i].first >> ":" << nodes[i].second << "\n";
}
#endif
if (use_stat_allocator) {
log << "stat:" << use_stat_allocator;
log << "total alloc:" << FloatOutput{(double)(stat_allocator_total_alloc_byte
- stat_allocator_total_free_byte), " KMG", 1024, "B"};
log << "total alloc call:" << FloatOutput{(double)(stat_allocator_total_alloc_call
- stat_allocator_total_free_call), " KMG", 1000, ""} >> '\n';
}
for (auto& a : SFRLAllocator::sfrl_allocators) {
auto total = a->used_memory + a->unused_memory;
log << "name:" << a->name() << "is_cuda:" << a->is_cuda()
<< "used:" << FloatOutput{(double)a->used_memory, " KMG", 1024, "B"}
>> "(" >> std::setprecision(p) >> a->used_memory*100.0 / total >> "%)"
<< "unused:" << FloatOutput{(double)a->unused_memory, " KMG", 1024, "B"}
>> "(" >> std::setprecision(p) >> a->unused_memory*100.0 / total >> "%)"
<< "total:" << FloatOutput{(double)total, " KMG", 1024, "B"} >> "\n";
}
log >> "===========================\n";
log.end();
}
MemInfo::MemInfo() {
struct sysinfo info = {0};
sysinfo(&info);
total_cpu_ram = info.totalram;
total_cuda_ram = 0;
#ifdef HAS_CUDA
cudaDeviceProp prop = {0};
cudaGetDeviceProperties(&prop, 0);
total_cuda_ram = prop.totalGlobalMem;
#endif
}
MemInfo mem_info;
} // jittor

31
src/mem/mem_info.h Normal file
View File

@ -0,0 +1,31 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "common.h"
namespace jittor {
// @pyjt(display_memory_info)
void display_memory_info(const char* fileline="");
// @pyjt(MemInfo)
struct MemInfo {
// @pyjt(total_cpu_ram)
int64 total_cpu_ram;
// @pyjt(total_cuda_ram)
int64 total_cuda_ram;
inline MemInfo(const MemInfo&) = default;
MemInfo();
};
extern MemInfo mem_info;
// @pyjt(get_mem_info)
inline MemInfo get_mem_info() { return mem_info; }
} // jittor

View File

@ -73,6 +73,7 @@ static unordered_set<string> unary_ops = {
"acos",
"cosh",
"acosh",
"sigmoid",
};
static unordered_set<string> unary_float_ops = {

View File

@ -76,6 +76,7 @@ namespace jittor {
m(acos) \
m(cosh) \
m(acosh) \
m(sigmoid) \
struct NanoString;
#define DECLEAR_NS(T) extern NanoString ns_##T;

View File

@ -108,7 +108,12 @@ struct NanoVector {
// @pyjt(__map_getitem__)
inline NanoVector slice(Slice slice) {
if (slice.mask&2) slice.stop = size();
if (slice.step>0) {
if (slice.mask&2) slice.stop = size();
} else {
if (slice.mask&1) slice.start = size()-1;
if (slice.mask&2) slice.stop = 0;
}
if (slice.start<0) slice.start += size();
if (slice.stop<0) slice.stop += size();
ASSERT(slice.start>=0 && slice.stop>=0 && slice.start<size() && slice.stop<=size())

View File

@ -12,6 +12,7 @@
namespace jittor {
extern unordered_map<void*, int64> lived_nodes;
extern int64 total_node;
struct NodeFlags {
typedef uint16 nf_t;
@ -107,7 +108,7 @@ struct Node {
#ifdef NODE_MEMCHECK
inline Node() {
lived_nodes[(void*)this] = lived_nodes.size()+1;
lived_nodes[(void*)this] = ++total_node;
registe_node_trace(this);
}
@ -132,7 +133,13 @@ struct Node {
#endif
}
void memcheck_all_exist() const;
int64 __id() const;
inline int64 __id() const {
#ifdef NODE_MEMCHECK
return lived_nodes.at((void*)this);
#else
return 0;
#endif
}
// release from counter and memory checker
void __release();
#define CHECK_NODE_EXIST(node) \

View File

@ -758,7 +758,7 @@ string OpCompiler::__get_fused_src(
"for", "const", "auto", "get_random_engine",
"int", "float", "bool", "CHECK", "STRINGIZE",
"void", "__restrict__", "if", "true", "false",
"Op", "Var", "Node", "itof"
"Op", "Var", "Node", "itof", "assert", "ASSERT"
};
auto not_change = [&](const string& s) -> bool {
if (unchanged.count(s)) return true;
@ -914,7 +914,9 @@ string OpCompiler::__get_fused_src(
fused_kernel = fused_kernel_args + "\n" + fused_kernel;
LOGvvvv << "Fused kernel:\n" >> fused_kernel;
auto fused_src = fused_begin + fused_includes + "\n#include \"fused_op.h\"\n" +
auto fused_src = fused_begin + fused_includes +
"\n#include <assert.h>\n" +
"\n#include \"fused_op.h\"\n" +
fused_defines + '\n' +
"void jittor::FusedOp::jit_run() {\n" + fused_kernel + "\n}\n";

View File

@ -76,7 +76,7 @@ struct CodeOp : Op {
#include <algorithm>
@alias(a, in0)
@alias(b, out)
"""",
""",
cpu_src="""
for (int i=0; i<a_shape0; i++)
@b(i) = @a(i);

View File

@ -22,7 +22,7 @@ struct OpInfo {
for (uint i=0; i<constructors.size(); i++)
if (std::type_index(*(constructors[i].first)) == std::type_index(tid))
return func_t(constructors[i].second);
LOGf << "constructor" << tid.name() << "not found.";
LOGf << "constructor" << name << tid.name() << "not found.";
return func_t(nullptr);
}
};

View File

@ -67,6 +67,7 @@ static unordered_set<string> unary_ops = {
"cosh",
// @pybind(acosh, arccosh)
"acosh",
"sigmoid",
};
UnaryOp::UnaryOp(Var* x, NanoString op) : x(x) {
@ -183,6 +184,12 @@ VarPtr UnaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {
x2 = make_binary(one, x2, ns_subtract);
return make_binary(dout, x2, ns_divide);
}
// dsigmoid(x) = sigmoid(x) - sigmoid(x)^2
if (ns == ns_sigmoid) {
auto r = make_binary(out, out, ns_multiply);
r = make_binary(out, r, ns_subtract);
return make_binary(dout, r, ns_multiply);
}
return nullptr;
}

View File

@ -35,6 +35,8 @@ namespace jittor {
#define tanh(T,x) ((T) ::tanhf((x)))
#define atanh(T,x) ((T) ::atanhf((x)))
#define sigmoid(T,x) ((T) (1.0f/(1.0f+::expf(-(x)))))
#else
#define abs(T,x) std::abs(x)
#define log(T,x) std::log((T)(x))
@ -59,6 +61,8 @@ namespace jittor {
#define tanh(T,x) ((T) std::tanh((x)))
#define atanh(T,x) ((T) std::atanh((x)))
#define sigmoid(T,x) ((T) (1.0f/(1.0f+std::exp(-(x)))))
#endif
#define cast(T,x) ((T)(x))

View File

@ -170,6 +170,10 @@ void LoopVarAnalyzePass::run() {
for (uint j=0; j<ndim; j++)
if (!(mask>>j&1) && j<loop_var_names.size()) {
for (auto& vname : vnames) {
// cannot replace extras shape
// TODO: optimize it
if (vname.find("extras") != string::npos)
continue;
// replace op{i}_{vname}shape{j} -> {loop_var_names[j]}
std::stringstream name1;
name1 << vname<<"shape"<<j;
@ -193,7 +197,7 @@ void LoopVarAnalyzePass::run() {
replace_vars.emplace_back(name1, name2);
}
LOGvvvv << "replace_vars" << replace_vars;
LOGvvv << "replace_vars" << replace_vars;
ir->replace(replace_vars);
LOGvvvv << "KernelIR after replace\n" >> ir->to_string(0, true);
// move define

View File

@ -160,6 +160,24 @@ struct OpInspector {
return 0;
}
string format2(const string& fmt, const vector<int>& order) {
string new_fmt = fmt;
if (order.size() != fmt.size()) {
failed = 1;
return "";
}
if (check_overlap(order))
return "";
for (uint i=0; i<order.size(); i++) {
if (order[i]>=(int)new_fmt.size()) {
failed = 1;
return "";
}
new_fmt[order[i]] = fmt[i];
}
return new_fmt;
}
string format(const string& fmt, const vector<int>& order) {
string new_fmt = fmt;
if (order.size() != fmt.size()) {
@ -234,19 +252,19 @@ void ConvTuner::forwardTune(FusedOp* fop) {
xh = xoi.mm[zh];
xw = xoi.mm[zw];
LOGvvvv << "xnchw =" << vector<int>{xn,xc,xh,xw};
auto xformat = xoi.format("abcd", {xn, xc, xh, xw});
auto xformat = xoi.format2("abcd", {xn, xc, xh, xw});
LOGvvvv << "xformat =" << xformat;
wci = woi.mm[zci];
wco = woi.mm[zco];
wh = woi.mm[zwh];
ww = woi.mm[zww];
auto wformat = xoi.format("iohw", {wci, wco, wh, ww});
auto wformat = xoi.format2("iohw", {wci, wco, wh, ww});
LOGvvvv << "wformat =" << wformat;
yn = yoi.mm[zn];
yc = yoi.mm[zco];
yh = yoi.mm[zh];
yw = yoi.mm[zw];
auto yformat = xoi.format("abcd", {yn, yc, yh, yw});
auto yformat = xoi.format2("abcd", {yn, yc, yh, yw});
LOGvvvv << "yformat =" << yformat;
// mkl doesn't support "cdab" format
if (yformat == "cdab") continue;
@ -307,7 +325,7 @@ void ConvTuner::forwardTune(FusedOp* fop) {
continue;
auto make_conv = get_op_info(relay_conv_name)
.get_constructor<VarPtr, Var*, Var*, int, int, int, int, string, string, string>();
auto rvar = make_conv(x, w, stride, padding, dilation, 1, xformat, wformat, yformat);
auto rvar = make_conv(x, w, stride, padding, dilation, 1, xformat, wformat, yformat);
auto rid = fop->context->vrm.add_relay_group({{rvar, rop->y}});
if (rid>=0) {
auto srid = "relay"+S(rid);
@ -359,19 +377,19 @@ void ConvTuner::backwardTune(FusedOp* fop) {
xh = xoi.mm[zh];
xw = xoi.mm[zw];
LOGvvvv << "xnchw =" << vector<int>{xn,xc,xh,xw};
xformat = xoi.format("abcd", {xn, xc, xh, xw});
xformat = xoi.format2("abcd", {xn, xc, xh, xw});
LOGvvvv << "xformat =" << xformat;
wci = woi.mm[zci];
wco = woi.mm[zco];
wh = woi.mm[zwh];
ww = woi.mm[zww];
wformat = xoi.format("iohw", {wci, wco, wh, ww});
wformat = xoi.format2("iohw", {wci, wco, wh, ww});
LOGvvvv << "wformat =" << wformat;
yn = yoi.mm[zn];
yc = yoi.mm[zco];
yh = yoi.mm[zh];
yw = yoi.mm[zw];
yformat = xoi.format("abcd", {yn, yc, yh, yw});
yformat = xoi.format2("abcd", {yn, yc, yh, yw});
LOGvvvv << "yformat =" << yformat;
// mkl doesn't support "cdab" format
if (yformat == "cdab") continue;
@ -434,19 +452,19 @@ void ConvTuner::backwardTune(FusedOp* fop) {
xh = xoi.mm[zh];
xw = xoi.mm[zw];
LOGvvvv << "xnchw =" << vector<int>{xn,xc,xh,xw};
xformat = xoi.format("abcd", {xn, xc, xh, xw});
xformat = xoi.format2("abcd", {xn, xc, xh, xw});
LOGvvvv << "xformat =" << xformat;
wci = woi.mm[zci];
wco = woi.mm[zco];
wh = woi.mm[zwh];
ww = woi.mm[zww];
wformat = xoi.format("iohw", {wci, wco, wh, ww});
wformat = xoi.format2("iohw", {wci, wco, wh, ww});
LOGvvvv << "wformat =" << wformat;
yn = yoi.mm[zn];
yc = yoi.mm[zco];
yh = yoi.mm[zh];
yw = yoi.mm[zw];
yformat = xoi.format("abcd", {yn, yc, yh, yw});
yformat = xoi.format2("abcd", {yn, yc, yh, yw});
LOGvvvv << "yformat =" << yformat;
// mkl doesn't support "cdab" format
if (yformat == "cdab") continue;

View File

@ -15,23 +15,33 @@ using namespace pybind11::literals;
namespace jittor {
DEFINE_FLAG(int, trace_py_var, 0, "Trace py stack for debug.");
DEFINE_FLAG(int, trace_py_var, 0, "Trace py stack max depth for debug.");
unordered_map<const Node*, string> trace_data;
void __registe_node_trace(Node* node) {
auto py_stack =
auto py_stacks =
py::module::import("traceback")
.attr("extract_stack")(nullptr, 1).attr("__getitem__")(0);
auto filename = py_stack.attr("filename").cast<string>();
auto basename = split(filename, "/").back();
basename += ':';
basename += py_stack.attr("name").cast<string>();
basename += ':';
basename += S(py_stack.attr("lineno").cast<int>());
basename += ':';
basename += py_stack.attr("line").cast<string>();
trace_data[node] = basename;
.attr("extract_stack")(nullptr, trace_py_var);
auto len = py_stacks.attr("__len__")().cast<int>();
string info;
for (int i=0; i<len; i++) {
auto py_stack = py_stacks.attr("__getitem__")(i);
auto filename = py_stack.attr("filename").cast<string>();
if (len==1)
info += split(filename, "/").back();
else {
info += "\n ";
info += filename;
}
info += ':';
info += py_stack.attr("name").cast<string>();
info += ':';
info += S(py_stack.attr("lineno").cast<int>());
info += ':';
info += py_stack.attr("line").cast<string>();
}
trace_data[node] = info;
}
void __unregiste_node_trace(Node* node) {

View File

@ -158,6 +158,25 @@ DEF_IS(DumpGraphs, const T&) from_py_object(PyObject* obj) {
return GET_RAW_PTR(T, obj);
}
// MemInfo
struct MemInfo;
extern PyTypeObject PyjtMemInfo;
DEF_IS(MemInfo, bool) is_type(PyObject* obj) {
return Py_TYPE(obj) == &PyjtMemInfo;
}
DEF_IS(MemInfo, PyObject*) to_py_object(const T& a) {
PyObjHolder obj(_PyObject_New(&PyjtMemInfo));
auto ptr = GET_RAW_PTR(T, obj.obj);
new (ptr) T(a);
return obj.release();
}
DEF_IS(MemInfo, const T&) from_py_object(PyObject* obj) {
return GET_RAW_PTR(T, obj);
}
// NanoString
struct NanoString;

View File

@ -3,6 +3,7 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <sstream>
#ifdef HAS_CUDA
#include <cuda_runtime.h>
#include <helper_cuda.h>
@ -122,4 +123,10 @@ vector<ArrayArgs> fetch_sync(const vector<VarHolder*>& vh) {
return ret;
}
string VarHolder::debug_msg() {
std::stringstream ss;
ss << var;
return ss.str();
}
} // jittor

View File

@ -154,6 +154,9 @@ struct VarHolder {
#endif
std::memcpy(var->mem_ptr, array.ptr, size);
}
// @pyjt(debug_msg)
string debug_msg();
};
// @pyjt(sync)