mirror of https://github.com/Jittor/Jittor
commit
ce0eb785f3
|
@ -177,7 +177,7 @@ def std(x):
|
|||
matsize *= i
|
||||
out=(x-x.mean()).sqr().sum()
|
||||
out=out/(matsize-1)
|
||||
out=out.sqrt()
|
||||
out=out.maximum(1e-6).sqrt()
|
||||
return out
|
||||
Var.std = std
|
||||
|
||||
|
@ -186,7 +186,7 @@ def norm(x, k, dim):
|
|||
if k==1:
|
||||
return x.abs().sum(dim)
|
||||
if k==2:
|
||||
return x.sqr().sum(dim).sqrt()
|
||||
return (x.sqr()).sum(dim).maximum(1e-6).sqrt()
|
||||
Var.norm = norm
|
||||
|
||||
origin_reshape = reshape
|
||||
|
|
|
@ -15,10 +15,14 @@ __all__ = ['ResNet', 'Resnet18', 'Resnet34', 'Resnet50', 'Resnet101', 'Resnet152
|
|||
'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2']
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
return nn.Conv(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
conv=nn.Conv(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
jt.init.relu_invariant_gauss_(conv.weight, mode="fan_out")
|
||||
return conv
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
return nn.Conv(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
conv=nn.Conv(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
jt.init.relu_invariant_gauss_(conv.weight, mode="fan_out")
|
||||
return conv
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
@ -102,6 +106,7 @@ class ResNet(nn.Module):
|
|||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
jt.init.relu_invariant_gauss_(self.conv1.weight, mode="fan_out")
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.Relu()
|
||||
self.maxpool = nn.Pool(kernel_size=3, stride=2, padding=1, op='maximum')
|
||||
|
|
|
@ -84,8 +84,11 @@ def cross_entropy_loss(output, target, ignore_index=None):
|
|||
def mse_loss(output, target):
|
||||
return (output-target).sqr().mean()
|
||||
|
||||
def bce_loss(output, target):
|
||||
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).mean()
|
||||
def bce_loss(output, target, size_average=True):
|
||||
if size_average:
|
||||
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).mean()
|
||||
else:
|
||||
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).sum()
|
||||
|
||||
def l1_loss(output, target):
|
||||
return (output-target).abs().mean()
|
||||
|
@ -105,8 +108,8 @@ class MSELoss(Module):
|
|||
class BCELoss(Module):
|
||||
def __init__(self):
|
||||
pass
|
||||
def execute(self, output, target):
|
||||
return bce_loss(output, target)
|
||||
def execute(self, output, target, size_average=True):
|
||||
return bce_loss(output, target, size_average)
|
||||
|
||||
class L1Loss(Module):
|
||||
def __init__(self):
|
||||
|
@ -118,9 +121,9 @@ class BCEWithLogitsLoss(Module):
|
|||
def __init__(self):
|
||||
self.sigmoid = Sigmoid()
|
||||
self.bce = BCELoss()
|
||||
def execute(self, output, target):
|
||||
def execute(self, output, target, size_average=True):
|
||||
output = self.sigmoid(output)
|
||||
output = self.bce(output, target)
|
||||
output = self.bce(output, target, size_average)
|
||||
return output
|
||||
|
||||
def softmax(x, dim = None):
|
||||
|
@ -279,9 +282,14 @@ class Conv(Module):
|
|||
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
|
||||
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
|
||||
|
||||
self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out")
|
||||
# self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out")
|
||||
self.weight = init.invariant_uniform([out_channels, in_channels//groups, Kh, Kw], dtype="float")
|
||||
if bias:
|
||||
self.bias = init.uniform([out_channels], dtype="float", low=-1, high=1)
|
||||
fan=1
|
||||
for i in self.weight.shape[1:]:
|
||||
fan *= i
|
||||
bound = 1 / math.sqrt(fan)
|
||||
self.bias = init.uniform([out_channels], dtype="float", low=-bound, high=bound)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
|
|
@ -131,6 +131,42 @@ class SGD(Optimizer):
|
|||
p -= v * lr
|
||||
p.detach_inplace()
|
||||
|
||||
class RMSprop(Optimizer):
|
||||
""" RMSprop Optimizer.
|
||||
Args:
|
||||
params(list): parameters of model.
|
||||
lr(float): learning rate.
|
||||
eps(float): term added to the denominator to avoid division by zero, default 1e-8.
|
||||
alpha(float): smoothing constant, default 0.99.
|
||||
|
||||
Example:
|
||||
optimizer = nn.RMSprop(model.parameters(), lr)
|
||||
optimizer.step(loss)
|
||||
"""
|
||||
def __init__(self, params, lr=1e-2, eps=1e-8, alpha=0.99):
|
||||
super().__init__(params, lr)
|
||||
self.eps = eps
|
||||
self.alpha = alpha
|
||||
|
||||
# initialize required arguments for each param_groups
|
||||
for pg in self.param_groups:
|
||||
values = pg["values"] = []
|
||||
for p in pg["params"]:
|
||||
values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
|
||||
|
||||
def step(self, loss):
|
||||
self.pre_step(loss)
|
||||
for pg in self.param_groups:
|
||||
# get arguments from each param_groups
|
||||
lr = pg.get("lr", self.lr)
|
||||
eps = pg.get("eps", self.eps)
|
||||
alpha = pg.get("alpha", self.alpha)
|
||||
for p, g, v in zip(pg["params"], pg["grads"], pg["values"]):
|
||||
if p.is_stop_grad(): continue
|
||||
v.assign(alpha * v + (1-alpha) * g * g)
|
||||
p -= lr * g / (jt.sqrt(v) + eps)
|
||||
p.detach_inplace()
|
||||
|
||||
class Adam(Optimizer):
|
||||
""" Adam Optimizer.
|
||||
|
||||
|
|
|
@ -41,12 +41,15 @@ static void move_rely(KernelIR* inner_loop, KernelIR* outer_loop, KernelIR* def)
|
|||
}
|
||||
}
|
||||
|
||||
static void tune_atomic(Pass* pass, KernelIR* ir, bool is_cuda, int tdim) {
|
||||
// sorder: Array that saves the allocation order of "tn"
|
||||
// sfunc: Array of function names
|
||||
static void tune_atomic(Pass* pass, KernelIR* ir, bool is_cuda, int tdim, vector<vector<int>> &sorder, vector<string> &sfunc) {
|
||||
LOGvvvv << "tune_atomic" << ir->children;
|
||||
vector<string> relys;
|
||||
vector<string> idx_name;
|
||||
vector<KernelIR*> atomics;
|
||||
vector<KernelIR*> loops;
|
||||
vector<int> nrely;
|
||||
vector<int> order;
|
||||
int tmp_cnt=0;
|
||||
for (uint i=0; i<ir->children.size(); i++) {
|
||||
|
@ -57,6 +60,7 @@ static void tune_atomic(Pass* pass, KernelIR* ir, bool is_cuda, int tdim) {
|
|||
atomics.clear();
|
||||
loops.clear();
|
||||
order.clear();
|
||||
nrely.clear();
|
||||
|
||||
c->dfs([&](unique_ptr<KernelIR>& p) {
|
||||
auto& code = p->attrs["code"];
|
||||
|
@ -71,6 +75,7 @@ static void tune_atomic(Pass* pass, KernelIR* ir, bool is_cuda, int tdim) {
|
|||
loops.push_back(loop);
|
||||
idx_name.push_back(loop->attrs["lvalue"]);
|
||||
order.push_back(loops.size()-1);
|
||||
nrely.push_back(-1);
|
||||
bool ok = true;
|
||||
while (1) {
|
||||
loop = loops.back();
|
||||
|
@ -90,6 +95,7 @@ static void tune_atomic(Pass* pass, KernelIR* ir, bool is_cuda, int tdim) {
|
|||
loops.push_back(loop2);
|
||||
idx_name.push_back(loop2->attrs["lvalue"]);
|
||||
order.push_back(loops.size()-1);
|
||||
nrely.push_back(-1);
|
||||
}
|
||||
// TODO: only support single loop children
|
||||
if (!ok) continue;
|
||||
|
@ -107,12 +113,25 @@ static void tune_atomic(Pass* pass, KernelIR* ir, bool is_cuda, int tdim) {
|
|||
for (uint l=0;l<order.size();l++)
|
||||
if (order[l]==sidx) sord=l;
|
||||
ASSERT(sord != -1);
|
||||
for (int l=sord;l;l--) order[l]=order[l-1];
|
||||
for (int l=sord;l;l--){
|
||||
order[l]=order[l-1];
|
||||
nrely[l]=nrely[l-1];
|
||||
}
|
||||
order[0]=sidx;
|
||||
nrely[0]=j;
|
||||
}
|
||||
}
|
||||
LOGvvvv << "atomic tuner order" << order;
|
||||
|
||||
vector<int> tnorder;
|
||||
uint si;
|
||||
for (si=0;si<order.size();si++)
|
||||
if (nrely[si]!=nrely[0]) break;
|
||||
for (int j=si-1;j>=0;j--) tnorder.push_back(order[j]);
|
||||
for (int j=order.size()-1;j>=si;j--) tnorder.push_back(order[j]);
|
||||
sorder.push_back(tnorder);
|
||||
sfunc.push_back(ir->attrs["lvalue"]);
|
||||
|
||||
// sort loop with order
|
||||
int count=0;
|
||||
for (auto j : order) {
|
||||
|
@ -199,12 +218,54 @@ void AtomicTunerPass::run() {
|
|||
if (is_cuda) choice=1;
|
||||
if (!choice) return;
|
||||
|
||||
vector<vector<int>> sorder;
|
||||
vector<string> sfunc;
|
||||
for (uint i=0; i<ir->before.size(); i++) {
|
||||
auto& func_call = ir->before[i];
|
||||
// TODO: remove this if
|
||||
if (func_call->get_attr("dtype") != "__global__ void") continue;
|
||||
tune_atomic(this, func_call.get(), is_cuda, 4);
|
||||
tune_atomic(this, func_call.get(), is_cuda, 4, sorder, sfunc);
|
||||
}
|
||||
|
||||
// Re-adjust the allocation order of "tn" according to the situation of atomic coverage, preferentially allocate the range not covered by atomic, for example:
|
||||
// for (op0_index_t id0 = tid0; id0<range0; id0+=tnum0) {
|
||||
// for (op1_index_t id1 = tid1; id1<range1; id1+=tnum1) {
|
||||
// for (op2_index_t id2 = tid2; id2<range2; id2+=tnum2) {
|
||||
// for (op3_index_t id3 = tid3; id3<range3; id3+=tnum3) {
|
||||
// ...
|
||||
// }
|
||||
// }
|
||||
// atomicAdd(...);
|
||||
// }
|
||||
// }
|
||||
// The allocation order of "tn" will be: tn1, tn0, tn3, tn2
|
||||
for (uint j=0;j<sfunc.size();j++)
|
||||
for (uint i=0; i<ir->children.size(); i++) {
|
||||
auto& func_call = ir->children[i];
|
||||
int bo=0;
|
||||
for (uint k=0; k<func_call->children.size(); k++){
|
||||
auto& save = func_call->children[k];
|
||||
if (save->has_attr("loop_func") && save->attrs["loop_func"]==sfunc[j]){
|
||||
bo=1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!bo) continue;
|
||||
uint k;
|
||||
for (k=0; k<func_call->children.size(); k++){
|
||||
auto& save = func_call->children[k];
|
||||
if (save->has_attr("lvalue") && save->attrs["lvalue"].find("tn")==0) break;
|
||||
}
|
||||
for (uint l=0;l<sorder[j].size();l++){
|
||||
for (uint p=0; p<func_call->children.size(); p++){
|
||||
auto& save = func_call->children[p];
|
||||
if (save->has_attr("lvalue") && save->attrs["lvalue"].find("tn"+S(sorder[j][l]))==0){
|
||||
func_call->children[p]->swap(*func_call->children[k++]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ir->remove_all_unused();
|
||||
}
|
||||
|
||||
|
|
|
@ -264,12 +264,9 @@ void ParallelPass::run() {
|
|||
string nums = rvalues.at(0);
|
||||
for (int i=1; i<rvalues.size(); i++)
|
||||
nums+="*"+rvalues[i];
|
||||
if (fix_thread_num)
|
||||
new_block.push_back("int thread_num=" + S(thread_num) + ");");
|
||||
else
|
||||
new_block.push_back("int thread_num=min(1<<(NanoVector::get_nbits("+nums+")-2)," + S(thread_num) + ");");
|
||||
|
||||
new_block.push_back("int thread_num=" + S(thread_num) + ";");
|
||||
new_block.push_back("int thread_num_left=thread_num;");
|
||||
|
||||
for (int j=ncs.size()-1; j>=0; j--) {
|
||||
auto& rv = rvalues[j];
|
||||
new_block.push_back("int tn"+S(j)+
|
||||
|
@ -344,6 +341,15 @@ void ParallelPass::run() {
|
|||
new_func_def->insert(0, new_tid_def.children);
|
||||
new_func_def->swap(*func_def, true);
|
||||
new_block.swap(*func_call, true);
|
||||
auto code = func_def->to_string();
|
||||
bool has_atomic = code.find("atomic") != string::npos;
|
||||
if (!fix_thread_num) {
|
||||
if (has_atomic) {
|
||||
func_call->find_define("thread_num")->attrs["rvalue"] = "min(1<<max((NanoVector::get_nbits(" + nums + "/16)-2),0)," + S(thread_num) + ")";
|
||||
} else {
|
||||
func_call->find_define("thread_num")->attrs["rvalue"] = "min(1<<max((NanoVector::get_nbits(" + nums + ")-2),0)," + S(thread_num) + ")";
|
||||
}
|
||||
}
|
||||
}
|
||||
ir->remove_all_unused();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue