mirror of https://github.com/Jittor/Jittor
add misc and bcelogits with pos_weight
This commit is contained in:
parent
4d989e94b1
commit
1045f04c4b
|
@ -12,6 +12,26 @@ import numpy as np
|
|||
import math
|
||||
from collections.abc import Sequence,Iterable
|
||||
|
||||
def __copy__(x):
|
||||
return x.copy().detach()
|
||||
jt.Var.__copy__ = __copy__
|
||||
|
||||
def __deepcopy__(x,memo):
|
||||
result = x.copy().detach()
|
||||
memo[id(x)]=result
|
||||
return result
|
||||
jt.Var.__deepcopy__ = __deepcopy__
|
||||
|
||||
def __len__(x):
|
||||
return x.shape[0]
|
||||
jt.Var.__len__ = __len__
|
||||
|
||||
def __iter__(x):
|
||||
result = []
|
||||
for i in range(x.shape[0]):
|
||||
result.append(x[i])
|
||||
return result.__iter__()
|
||||
jt.Var.__iter__ = __iter__
|
||||
|
||||
def repeat(x, *shape):
|
||||
r'''
|
||||
|
|
|
@ -264,17 +264,29 @@ class L1Loss(Module):
|
|||
def execute(self, output, target):
|
||||
return l1_loss(output, target)
|
||||
|
||||
class BCEWithLogitsLoss(Module):
|
||||
def __init__(self, weight=None, size_average=True):
|
||||
self.sigmoid = Sigmoid()
|
||||
self.bce = BCELoss(weight, size_average)
|
||||
def execute(self, output, target):
|
||||
output = self.sigmoid(output)
|
||||
output = self.bce(output, target)
|
||||
return output
|
||||
def binary_cross_entropy_with_logits(output, target, weight=None, pos_weight=None, size_average=True):
|
||||
max_val = jt.clamp(-output,min_v=0)
|
||||
if pos_weight is not None:
|
||||
log_weight = (pos_weight-1)*target + 1
|
||||
loss = (1-target)*output+(log_weight*(((-max_val).exp()+(-output - max_val).exp()).log()+max_val))
|
||||
else:
|
||||
loss = (1-target)*output+max_val+((-max_val).exp()+(-output -max_val).exp()).log()
|
||||
if weight is not None:
|
||||
loss *=weight
|
||||
|
||||
def binary_cross_entropy_with_logits(input, target, weight=None, size_average=True):
|
||||
return BCEWithLogitsLoss(weight, size_average)(input, target)
|
||||
if size_average:
|
||||
return loss.mean()
|
||||
else:
|
||||
return loss.sum()
|
||||
|
||||
class BCEWithLogitsLoss(Module):
|
||||
def __init__(self, weight=None, pos_weight=None, size_average=True):
|
||||
self.pos_weight = pos_weight
|
||||
self.weight = weight
|
||||
self.size_average = size_average
|
||||
|
||||
def execute(self, output, target):
|
||||
return binary_cross_entropy_with_logits(output,target,self.weight,self.pos_weight,self.size_average)
|
||||
|
||||
def softmax(x, dim = None):
|
||||
if dim is None:
|
||||
|
|
|
@ -210,3 +210,64 @@ class Adam(Optimizer):
|
|||
v.update(b1 * v + (1-b1) * g * g)
|
||||
step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n)
|
||||
p.update(p - m * step_size / (jt.sqrt(v) + eps))
|
||||
|
||||
|
||||
class LRScheduler:
|
||||
def __init__(self,optimizer, last_epoch=-1):
|
||||
assert isinstance(optimizer,Optimizer)
|
||||
self.optimizer = optimizer
|
||||
|
||||
if last_epoch==-1:
|
||||
for gp in optimizer.param_groups:
|
||||
gp.setdefault('initial_lr',gp.get('lr',optimizer.lr))
|
||||
else:
|
||||
for gp in optimizer.param_groups:
|
||||
assert 'initial_lr' in gp
|
||||
|
||||
self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
|
||||
self.last_epoch = last_epoch
|
||||
self.optimizer._step_count = 0
|
||||
self._step_count = 0
|
||||
self.step()
|
||||
|
||||
def get_lr(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_last_lr(self):
|
||||
return self._last_lr
|
||||
|
||||
def step(self,epoch=None):
|
||||
self._step_count += 1
|
||||
|
||||
if epoch is None:
|
||||
self.last_epoch += 1
|
||||
values = self.get_lr()
|
||||
else:
|
||||
self.last_epoch = epoch
|
||||
values = self.get_lr()
|
||||
|
||||
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
|
||||
param_group, lr = data
|
||||
param_group['lr'] = lr
|
||||
|
||||
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
|
||||
|
||||
|
||||
class LambdaLR(LRScheduler):
|
||||
|
||||
def __init__(self, optimizer, lr_lambda, last_epoch=-1):
|
||||
if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
|
||||
self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
|
||||
else:
|
||||
if len(lr_lambda) != len(optimizer.param_groups):
|
||||
raise ValueError("Expected {} lr_lambdas, but got {}".format(len(optimizer.param_groups), len(lr_lambda)))
|
||||
|
||||
self.lr_lambdas = list(lr_lambda)
|
||||
|
||||
super(LambdaLR, self).__init__(optimizer, last_epoch)
|
||||
|
||||
|
||||
|
||||
def get_lr(self):
|
||||
return [base_lr * lmbda(self.last_epoch)
|
||||
for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
|
|
@ -177,7 +177,8 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
|||
Var* dout = grads[id];
|
||||
trace_grad_op = op;
|
||||
VarPtr dvar = make_grad(op, out, dout, var, index);
|
||||
if (dvar && dvar->num>=0 && var->num)
|
||||
if (dvar && dvar->num>=0 && var->num>0)
|
||||
// var->num == 0 represents a any match var
|
||||
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
|
||||
<< "dvar" << dvar << "var" << var;
|
||||
if (!grad)
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "misc/cuda_flags.h"
|
||||
#endif
|
||||
|
||||
namespace jittor {
|
||||
|
@ -36,14 +37,14 @@ void CopyOp::run() {
|
|||
auto size = x->size;
|
||||
auto x_ptr = x->mem_ptr;
|
||||
auto y_ptr = outputs().front()->mem_ptr;
|
||||
if (flags.get(NodeFlags::_cpu)) {
|
||||
#ifdef HAS_CUDA
|
||||
if (flags.get(NodeFlags::_cuda)) {
|
||||
checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDefault, 0));
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
std::memcpy(y_ptr, x_ptr, size);
|
||||
}
|
||||
#ifdef HAS_CUDA
|
||||
else {
|
||||
checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDefault, 0));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -69,7 +69,7 @@ void SetitemOp::infer_shape() {
|
|||
for (int i=0; i<data_dim; i++) {
|
||||
int j = i - data_dim + out_shape.size();
|
||||
if (!(data_shape[i]==1 && out_shape[j]!=-1)) {
|
||||
CHECK(data_shape[i]<0 || data_shape[i]==out_shape[j])
|
||||
CHECK(data_shape[i]<0 || out_shape[j]<0 || data_shape[i]==out_shape[j])
|
||||
<< "Data shape not match" << data_shape << out_shape;
|
||||
bmask |= 1<<j;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue