diff --git a/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc b/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc index c8ff5ce8..06a8b8c9 100644 --- a/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc +++ b/extern/cuda/cudnn/ops/cudnn_conv_backward_x_op.cc @@ -178,7 +178,7 @@ void CudnnConvBackwardXOp::jit_run() { CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED }; - int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; + int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; int perf_count; cudnnConvolutionBwdDataAlgoPerf_t perf_results[num_algos]; cudnnConvolutionBwdDataAlgo_t algo; diff --git a/extern/cuda/cutt/ops/cutt_transpose_op.cc b/extern/cuda/cutt/ops/cutt_transpose_op.cc index af041680..e9b9582e 100644 --- a/extern/cuda/cutt/ops/cutt_transpose_op.cc +++ b/extern/cuda/cutt/ops/cutt_transpose_op.cc @@ -6,16 +6,12 @@ #include "var.h" #include "cutt_transpose_op.h" #include "ops/op_register.h" -#include - -#ifdef JIT #include "cutt.h" -#endif #include "cutt_warper.h" +#include "misc/stack_vector.h" namespace jittor { -#ifndef JIT static auto make_transpose = get_op_info("cutt_transpose") .get_constructor(); @@ -58,52 +54,49 @@ VarPtr CuttTransposeOp::grad(Var* out, Var* dout, Var* v, int v_index) { return make_transpose(dout, reverse); } -void CuttTransposeOp::jit_prepare(JK& jk) { - jk << _CS("[Tx:") << x->dtype(); - jk << _CS("][DIM=") << JK::hex1(axes.size()); - for (uint i=0; i cutt_plan_cache; -#else // JIT -#ifdef JIT_cuda - -extern unordered_map cutt_plan_cache; - -void CuttTransposeOp::jit_run() { - auto* __restrict__ xp = x->ptr(); - auto* __restrict__ yp = y->ptr(); - vector permutation, permutation2; - vector y_shape; - vector x_shape; - @for(i, 0, DIM, permutation.push_back(DIM-1-AXES@i);) - @for(i, 0, DIM, permutation2.push_back(permutation[DIM-1-@i@@]);) - std::vector reverse; - reverse.reserve(permutation2.size()); - for (uint i=0; ishape[DIM-1-@i@@]);) - +void CuttTransposeOp::run() { + auto* __restrict__ xp = x->mem_ptr; + auto* __restrict__ yp = y->mem_ptr; + StackVector x_shape; + StackVector new_shape, new_axes, trans, reverse; + int dim = x->shape.size(); + for (int i=0; ishape[i] != 1) + new_shape.push_back(x->shape[i]); + } + for (int i = 0; i < dim; ++i) { + if (x->shape[axes[i]] != 1) { + new_axes.push_back(trans[axes[i]]); + } + } + dim = new_shape.size(); + for (int i=0; isize, cudaMemcpyDefault, 0)); + return; + } jk.clear(); - jk << @DIM << ","; - for (uint i=0; i<@DIM; i++) jk << x_shape[i] << ","; - for (uint i=0; i<@DIM; i++) jk << reverse[i] << ","; - jk << sizeof(Tx) << "."; + jk << dim << ','; + for (int i=0; idtype().dsize() << '.'; auto iter = cutt_plan_cache.find(jk.to_string()); + LOGvvv << "Run cutt_transpose with key:" << jk.to_string(); if (iter!=cutt_plan_cache.end()){ cuttExecute(iter->second, xp, yp); } else { cuttHandle plan; - cuttPlan(&plan, @DIM, x_shape.data(), reverse.data(), sizeof(Tx), 0); + cuttPlan(&plan, dim, x_shape.data(), reverse.data(), x->dtype().dsize(), 0); cutt_plan_cache[jk.to_string()] = plan; cuttExecute(plan, xp, yp); } } -#endif // JIT_cuda -#endif // JIT } // jittor \ No newline at end of file diff --git a/extern/cuda/cutt/ops/cutt_transpose_op.h b/extern/cuda/cutt/ops/cutt_transpose_op.h index 95fdc656..248e2bc2 100644 --- a/extern/cuda/cutt/ops/cutt_transpose_op.h +++ b/extern/cuda/cutt/ops/cutt_transpose_op.h @@ -19,7 +19,7 @@ struct CuttTransposeOp : Op { const char* name() const override { return "cutt_transpose"; } VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; void infer_shape() override; - DECLARE_jit_run; + void run() override; }; } // jittor \ No newline at end of file diff --git a/python/jittor/misc.py b/python/jittor/misc.py index 8e429d2c..cdfcdda9 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -12,6 +12,35 @@ import numpy as np import math from collections.abc import Sequence,Iterable +def __copy__(x): + return x.copy().detach() +jt.Var.__copy__ = __copy__ + +def __deepcopy__(x,memo): + result = x.copy().detach() + memo[id(x)]=result + return result +jt.Var.__deepcopy__ = __deepcopy__ + +def __len__(x): + return x.shape[0] +jt.Var.__len__ = __len__ + +def __iter__(x): + result = [] + for i in range(x.shape[0]): + result.append(x[i]) + return result.__iter__() +jt.Var.__iter__ = __iter__ + +def all(x,dim): + return x.all_(dim).bool() +jt.Var.all = all + +def any(x,dim): + return x.any_(dim).bool() +jt.Var.any = any + def repeat(x, *shape): r''' @@ -47,10 +76,24 @@ def repeat(x, *shape): x = x.broadcast(x_shape) elif len_x_shape > len_shape: rep_shape = (len_x_shape - len_shape) * [1] + shape + + reshape_shape = [] + broadcast_shape = [] + for x_s,r_s in zip(x_shape,rep_shape): + reshape_shape.append(1) + reshape_shape.append(x_s) + + broadcast_shape.append(r_s) + broadcast_shape.append(1) + + x = x.reshape(reshape_shape) + x = x.broadcast(broadcast_shape) + tar_shape = (np.array(x_shape) * np.array(rep_shape)).tolist() - dims = [] - for i in range(len(tar_shape)): dims.append(f"i{i}%{x_shape[i]}") - return x.reindex(tar_shape, dims) + + x = x.reshape(tar_shape) + return x + jt.Var.repeat = repeat def chunk(x, chunks, dim=0): @@ -326,9 +369,8 @@ def unique(x): ''' x = x.reshape(-1) _,x = jt.argsort(x) - index2 = [i for i in range(1,x.shape[0])] - index1 = [i for i in range(x.shape[0]-1)] - y = x[1:][x[index2] != x[index1]] + index,= jt.index((x.shape[0],)) + y = x[1:][x[index[1:]] != x[index[:-1]]] x = jt.contrib.concat([x[:1],y],dim=0) return x @@ -401,12 +443,6 @@ def log2(x): jt.Var.log2 = log2 -def item(x): - assert x.ndim==1 and x.shape[0]==1 - return x.numpy().item() - -jt.Var.item = item - def meshgrid(*tensors): r''' Take N tensors, each of which can be 1-dimensional vector, and create N n-dimensional grids, diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 3d46f88b..92b05cd1 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -264,17 +264,29 @@ class L1Loss(Module): def execute(self, output, target): return l1_loss(output, target) -class BCEWithLogitsLoss(Module): - def __init__(self, weight=None, size_average=True): - self.sigmoid = Sigmoid() - self.bce = BCELoss(weight, size_average) - def execute(self, output, target): - output = self.sigmoid(output) - output = self.bce(output, target) - return output +def binary_cross_entropy_with_logits(output, target, weight=None, pos_weight=None, size_average=True): + max_val = jt.clamp(-output,min_v=0) + if pos_weight is not None: + log_weight = (pos_weight-1)*target + 1 + loss = (1-target)*output+(log_weight*(((-max_val).exp()+(-output - max_val).exp()).log()+max_val)) + else: + loss = (1-target)*output+max_val+((-max_val).exp()+(-output -max_val).exp()).log() + if weight is not None: + loss *=weight -def binary_cross_entropy_with_logits(input, target, weight=None, size_average=True): - return BCEWithLogitsLoss(weight, size_average)(input, target) + if size_average: + return loss.mean() + else: + return loss.sum() + +class BCEWithLogitsLoss(Module): + def __init__(self, weight=None, pos_weight=None, size_average=True): + self.pos_weight = pos_weight + self.weight = weight + self.size_average = size_average + + def execute(self, output, target): + return binary_cross_entropy_with_logits(output,target,self.weight,self.pos_weight,self.size_average) def softmax(x, dim = None): if dim is None: diff --git a/python/jittor/optim.py b/python/jittor/optim.py index f82b2102..6395ecf9 100644 --- a/python/jittor/optim.py +++ b/python/jittor/optim.py @@ -210,3 +210,64 @@ class Adam(Optimizer): v.update(b1 * v + (1-b1) * g * g) step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n) p.update(p - m * step_size / (jt.sqrt(v) + eps)) + + +class LRScheduler: + def __init__(self,optimizer, last_epoch=-1): + assert isinstance(optimizer,Optimizer) + self.optimizer = optimizer + + if last_epoch==-1: + for gp in optimizer.param_groups: + gp.setdefault('initial_lr',gp.get('lr',optimizer.lr)) + else: + for gp in optimizer.param_groups: + assert 'initial_lr' in gp + + self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) + self.last_epoch = last_epoch + self.optimizer._step_count = 0 + self._step_count = 0 + self.step() + + def get_lr(self): + raise NotImplementedError + + def get_last_lr(self): + return self._last_lr + + def step(self,epoch=None): + self._step_count += 1 + + if epoch is None: + self.last_epoch += 1 + values = self.get_lr() + else: + self.last_epoch = epoch + values = self.get_lr() + + for i, data in enumerate(zip(self.optimizer.param_groups, values)): + param_group, lr = data + param_group['lr'] = lr + + self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + + +class LambdaLR(LRScheduler): + + def __init__(self, optimizer, lr_lambda, last_epoch=-1): + if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple): + self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups) + else: + if len(lr_lambda) != len(optimizer.param_groups): + raise ValueError("Expected {} lr_lambdas, but got {}".format(len(optimizer.param_groups), len(lr_lambda))) + + self.lr_lambdas = list(lr_lambda) + + super(LambdaLR, self).__init__(optimizer, last_epoch) + + + + def get_lr(self): + return [base_lr * lmbda(self.last_epoch) + for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)] \ No newline at end of file diff --git a/python/jittor/test/test_cutt_transpose_op.py b/python/jittor/test/test_cutt_transpose_op.py index bbb9d6c4..c7ef8a8c 100644 --- a/python/jittor/test/test_cutt_transpose_op.py +++ b/python/jittor/test/test_cutt_transpose_op.py @@ -30,7 +30,7 @@ class TestCuttTransposeOp(unittest.TestCase): for perm in perms: with jt.log_capture_scope( log_silent=1, - log_v=0, log_vprefix="op.cc=100" + log_v=0, log_vprefix="cutt=100" ) as raw_log: if perm: x = np.transpose(a, perm) @@ -39,7 +39,7 @@ class TestCuttTransposeOp(unittest.TestCase): x = np.transpose(a) y = jt.transpose(a).data self.assertEqual(x.shape, y.shape) - logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + "cutt_transpose" + ".*)") + logs = find_log_with_re(raw_log, "(Run cutt_transpose with key.*)") if perm is None: continue last = -1 @@ -53,7 +53,7 @@ class TestCuttTransposeOp(unittest.TestCase): last = perm[i] if not in_order: assert len(logs)==1 - assert (x==y).all(), f"\n{x}\n{y}" + assert (x==y).all(), f"\n{x}\n{y}\n{perm}\n{a.shape}" ia = [gen_data([5, 7]), gen_data([2,2,2]), gen_data([2,3,4,5]), gen_data([5,3]), gen_data([3,1,5,3,1])] for a in ia: check(a) diff --git a/src/grad.cc b/src/grad.cc index efb593a4..34e5faf9 100644 --- a/src/grad.cc +++ b/src/grad.cc @@ -177,7 +177,8 @@ vector grad(Var* loss, vector targets) { Var* dout = grads[id]; trace_grad_op = op; VarPtr dvar = make_grad(op, out, dout, var, index); - if (dvar && dvar->num>=0 && var->num) + if (dvar && dvar->num>=0 && var->num>0) + // var->num == 0 represents a any match var ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size()) << "dvar" << dvar << "var" << var; if (!grad) diff --git a/src/misc/stack_vector.h b/src/misc/stack_vector.h index 4fbdea92..c1af6a17 100644 --- a/src/misc/stack_vector.h +++ b/src/misc/stack_vector.h @@ -17,6 +17,7 @@ struct StackVector { inline T& front() { return a[0]; } inline T& back() { return a[n-1]; } inline int size() { return n;} + inline T* data() { return a;} inline StackVector(int n=0) : n(n) {} struct Iter { diff --git a/src/ops/copy_op.cc b/src/ops/copy_op.cc index 254fb267..6289b7c4 100644 --- a/src/ops/copy_op.cc +++ b/src/ops/copy_op.cc @@ -11,6 +11,7 @@ #ifdef HAS_CUDA #include #include +#include "misc/cuda_flags.h" #endif namespace jittor { @@ -36,14 +37,14 @@ void CopyOp::run() { auto size = x->size; auto x_ptr = x->mem_ptr; auto y_ptr = outputs().front()->mem_ptr; - if (flags.get(NodeFlags::_cpu)) { + #ifdef HAS_CUDA + if (flags.get(NodeFlags::_cuda)) { + checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDefault, 0)); + } else + #endif + { std::memcpy(y_ptr, x_ptr, size); } - #ifdef HAS_CUDA - else { - checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDefault, 0)); - } - #endif } diff --git a/src/ops/reduce_op.cc b/src/ops/reduce_op.cc index 9a420bee..23ecd1dd 100644 --- a/src/ops/reduce_op.cc +++ b/src/ops/reduce_op.cc @@ -34,9 +34,9 @@ unordered_set reduce_ops = { "add", // @pybind(prod, product, reduce_multiply) "multiply", - // @pybind(reduce_logical_and, all) + // @pybind(reduce_logical_and, all_) "logical_and", - // @pybind(reduce_logical_or, any) + // @pybind(reduce_logical_or, any_) "logical_or", "logical_xor", "bitwise_and", @@ -65,7 +65,8 @@ ReduceOp::ReduceOp(Var* x, NanoString op, NanoVector dims, bool keepdims) reduce_mask |= 1<dtype() == ns_bool && ns == ns_add) + // if (x->dtype() == ns_bool && ns == ns_add) + if (x->dtype() == ns_bool) y = create_output(nullptr, ns_int32); else y = create_output(nullptr, binary_dtype_infer(ns, x, x)); diff --git a/src/ops/setitem_op.cc b/src/ops/setitem_op.cc index ae545e50..899ec1ff 100644 --- a/src/ops/setitem_op.cc +++ b/src/ops/setitem_op.cc @@ -69,7 +69,7 @@ void SetitemOp::infer_shape() { for (int i=0; i(); } if (cutt_transpose) { - bool need_reshape = false; - int dims = x->shape.size(); - vector in_axes; - vector in_shape; - vector out_shape; - vector trans; - int cnt = 0; - for (int i = 0; i < dims; ++i) { - if (x->shape[i] == 1) { - need_reshape = true; - trans.push_back(-1); - } else { - trans.push_back(cnt); - cnt += 1; - in_shape.push_back(x->shape[i]); - } - out_shape.push_back(x->shape[axes[i]]); - } - for (int i = 0; i < dims; ++i) { - if (x->shape[axes[i]] != 1) { - in_axes.push_back(trans[axes[i]]); - } - } - if (need_reshape) { - auto x1 = make_reshape(x, NanoVector(in_shape)); - auto x2 = cutt_transpose(x1, in_axes); - auto x3 = make_reshape(x2, NanoVector(out_shape)); - forward(x3); - } else { - auto var = cutt_transpose(x, axes); - forward(var); - } + auto var = cutt_transpose(x, axes); + forward(var); return; } } diff --git a/src/pybind/py_var_tracer.cc b/src/pybind/py_var_tracer.cc index 17dd2814..83c42f39 100644 --- a/src/pybind/py_var_tracer.cc +++ b/src/pybind/py_var_tracer.cc @@ -164,6 +164,19 @@ static vector get_stack_info() { (int)PyFrame_GetLineNumber(prev_f)}); } } + if (stacks.size() == 0) { + auto m = std::min(3,n); + for (int i=0; if_code->co_filename); + auto num = (int)PyFrame_GetLineNumber(f); + stacks.emplace_back(Stack{ + s+":"+S(num), + "", + s, + num}); + } + } return stacks; } diff --git a/src/pyjt/py_ring_buffer.cc b/src/pyjt/py_ring_buffer.cc index 2e59a7ab..502b2cbe 100644 --- a/src/pyjt/py_ring_buffer.cc +++ b/src/pyjt/py_ring_buffer.cc @@ -23,7 +23,7 @@ static void push_py_object_pickle(RingBuffer* rb, PyObject* obj, uint64& __restr ASSERT(0 == PyBytes_AsStringAndSize(ret.obj, &s, &size)); rb->push_t(size, offset); rb->push(size, offset); - LOGir << string(rb->get_ptr(size, offset), size); + // LOGir << string(rb->get_ptr(size, offset), size); std::memcpy(rb->get_ptr(size, offset), s, size); return; }