add misc and bcelogits with pos_weight

This commit is contained in:
li-xl 2020-11-10 11:34:39 +08:00
parent 4d989e94b1
commit 1045f04c4b
6 changed files with 113 additions and 18 deletions

View File

@ -12,6 +12,26 @@ import numpy as np
import math
from collections.abc import Sequence,Iterable
def __copy__(x):
return x.copy().detach()
jt.Var.__copy__ = __copy__
def __deepcopy__(x,memo):
result = x.copy().detach()
memo[id(x)]=result
return result
jt.Var.__deepcopy__ = __deepcopy__
def __len__(x):
return x.shape[0]
jt.Var.__len__ = __len__
def __iter__(x):
result = []
for i in range(x.shape[0]):
result.append(x[i])
return result.__iter__()
jt.Var.__iter__ = __iter__
def repeat(x, *shape):
r'''

View File

@ -264,17 +264,29 @@ class L1Loss(Module):
def execute(self, output, target):
return l1_loss(output, target)
class BCEWithLogitsLoss(Module):
def __init__(self, weight=None, size_average=True):
self.sigmoid = Sigmoid()
self.bce = BCELoss(weight, size_average)
def execute(self, output, target):
output = self.sigmoid(output)
output = self.bce(output, target)
return output
def binary_cross_entropy_with_logits(output, target, weight=None, pos_weight=None, size_average=True):
max_val = jt.clamp(-output,min_v=0)
if pos_weight is not None:
log_weight = (pos_weight-1)*target + 1
loss = (1-target)*output+(log_weight*(((-max_val).exp()+(-output - max_val).exp()).log()+max_val))
else:
loss = (1-target)*output+max_val+((-max_val).exp()+(-output -max_val).exp()).log()
if weight is not None:
loss *=weight
def binary_cross_entropy_with_logits(input, target, weight=None, size_average=True):
return BCEWithLogitsLoss(weight, size_average)(input, target)
if size_average:
return loss.mean()
else:
return loss.sum()
class BCEWithLogitsLoss(Module):
def __init__(self, weight=None, pos_weight=None, size_average=True):
self.pos_weight = pos_weight
self.weight = weight
self.size_average = size_average
def execute(self, output, target):
return binary_cross_entropy_with_logits(output,target,self.weight,self.pos_weight,self.size_average)
def softmax(x, dim = None):
if dim is None:

View File

@ -210,3 +210,64 @@ class Adam(Optimizer):
v.update(b1 * v + (1-b1) * g * g)
step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n)
p.update(p - m * step_size / (jt.sqrt(v) + eps))
class LRScheduler:
def __init__(self,optimizer, last_epoch=-1):
assert isinstance(optimizer,Optimizer)
self.optimizer = optimizer
if last_epoch==-1:
for gp in optimizer.param_groups:
gp.setdefault('initial_lr',gp.get('lr',optimizer.lr))
else:
for gp in optimizer.param_groups:
assert 'initial_lr' in gp
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
self.last_epoch = last_epoch
self.optimizer._step_count = 0
self._step_count = 0
self.step()
def get_lr(self):
raise NotImplementedError
def get_last_lr(self):
return self._last_lr
def step(self,epoch=None):
self._step_count += 1
if epoch is None:
self.last_epoch += 1
values = self.get_lr()
else:
self.last_epoch = epoch
values = self.get_lr()
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
param_group, lr = data
param_group['lr'] = lr
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
class LambdaLR(LRScheduler):
def __init__(self, optimizer, lr_lambda, last_epoch=-1):
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
else:
if len(lr_lambda) != len(optimizer.param_groups):
raise ValueError("Expected {} lr_lambdas, but got {}".format(len(optimizer.param_groups), len(lr_lambda)))
self.lr_lambdas = list(lr_lambda)
super(LambdaLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
return [base_lr * lmbda(self.last_epoch)
for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]

View File

@ -177,7 +177,8 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
Var* dout = grads[id];
trace_grad_op = op;
VarPtr dvar = make_grad(op, out, dout, var, index);
if (dvar && dvar->num>=0 && var->num)
if (dvar && dvar->num>=0 && var->num>0)
// var->num == 0 represents a any match var
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
<< "dvar" << dvar << "var" << var;
if (!grad)

View File

@ -11,6 +11,7 @@
#ifdef HAS_CUDA
#include <cuda_runtime.h>
#include <helper_cuda.h>
#include "misc/cuda_flags.h"
#endif
namespace jittor {
@ -36,14 +37,14 @@ void CopyOp::run() {
auto size = x->size;
auto x_ptr = x->mem_ptr;
auto y_ptr = outputs().front()->mem_ptr;
if (flags.get(NodeFlags::_cpu)) {
#ifdef HAS_CUDA
if (flags.get(NodeFlags::_cuda)) {
checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDefault, 0));
} else
#endif
{
std::memcpy(y_ptr, x_ptr, size);
}
#ifdef HAS_CUDA
else {
checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDefault, 0));
}
#endif
}

View File

@ -69,7 +69,7 @@ void SetitemOp::infer_shape() {
for (int i=0; i<data_dim; i++) {
int j = i - data_dim + out_shape.size();
if (!(data_shape[i]==1 && out_shape[j]!=-1)) {
CHECK(data_shape[i]<0 || data_shape[i]==out_shape[j])
CHECK(data_shape[i]<0 || out_shape[j]<0 || data_shape[i]==out_shape[j])
<< "Data shape not match" << data_shape << out_shape;
bmask |= 1<<j;
}