This commit is contained in:
Dun Liang 2020-06-04 13:48:53 +08:00
commit b093b12c04
7 changed files with 224 additions and 20 deletions

View File

@ -125,6 +125,44 @@ class profile_scope(_call_no_record_scope):
profiler.stop()
self.report.extend(profiler.report())
class single_process_scope(_call_no_record_scope):
""" single_process_scope
Code in this scope will only be executed by single process.
example::
with jt.single_process_scope(root=0):
......
@jt.single_process_scope(root=0)
def xxx():
...
"""
def __init__(self, rank=0):
self.rank = rank
def __enter__(self):
global mpi
from jittor.dataset import dataset
self.mpi_backup = mpi
mpi = dataset.mpi = None
def __exit__(self, *exc):
global mpi
from jittor.dataset import dataset
mpi = dataset.mpi = self.mpi_backup
def __call__(self, func):
global mpi
def inner(*args, **kw):
if mpi and mpi.world_rank() != self.rank:
return
with self:
ret = func(*args, **kw)
return ret
return inner
def clean():
import gc
# make sure python do a full collection
@ -177,7 +215,7 @@ def std(x):
matsize *= i
out=(x-x.mean()).sqr().sum()
out=out/(matsize-1)
out=out.sqrt()
out=out.maximum(1e-6).sqrt()
return out
Var.std = std
@ -186,7 +224,7 @@ def norm(x, k, dim):
if k==1:
return x.abs().sum(dim)
if k==2:
return x.sqr().sum(dim).sqrt()
return (x.sqr()).sum(dim).maximum(1e-6).sqrt()
Var.norm = norm
origin_reshape = reshape

View File

@ -15,10 +15,14 @@ __all__ = ['ResNet', 'Resnet18', 'Resnet34', 'Resnet50', 'Resnet101', 'Resnet152
'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2']
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
return nn.Conv(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)
conv=nn.Conv(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, groups=groups, bias=False, dilation=dilation)
jt.init.relu_invariant_gauss_(conv.weight, mode="fan_out")
return conv
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
conv=nn.Conv(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
jt.init.relu_invariant_gauss_(conv.weight, mode="fan_out")
return conv
class BasicBlock(nn.Module):
expansion = 1
@ -102,6 +106,7 @@ class ResNet(nn.Module):
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
jt.init.relu_invariant_gauss_(self.conv1.weight, mode="fan_out")
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.Relu()
self.maxpool = nn.Pool(kernel_size=3, stride=2, padding=1, op='maximum')

View File

@ -84,8 +84,11 @@ def cross_entropy_loss(output, target, ignore_index=None):
def mse_loss(output, target):
return (output-target).sqr().mean()
def bce_loss(output, target):
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).mean()
def bce_loss(output, target, size_average=True):
if size_average:
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).mean()
else:
return - (target * jt.log(jt.maximum(output, 1e-20)) + (1 - target) * jt.log(jt.maximum(1 - output, 1e-20))).sum()
def l1_loss(output, target):
return (output-target).abs().mean()
@ -105,8 +108,8 @@ class MSELoss(Module):
class BCELoss(Module):
def __init__(self):
pass
def execute(self, output, target):
return bce_loss(output, target)
def execute(self, output, target, size_average=True):
return bce_loss(output, target, size_average)
class L1Loss(Module):
def __init__(self):
@ -118,9 +121,9 @@ class BCEWithLogitsLoss(Module):
def __init__(self):
self.sigmoid = Sigmoid()
self.bce = BCELoss()
def execute(self, output, target):
def execute(self, output, target, size_average=True):
output = self.sigmoid(output)
output = self.bce(output, target)
output = self.bce(output, target, size_average)
return output
def softmax(x, dim = None):
@ -279,9 +282,14 @@ class Conv(Module):
assert in_channels % groups == 0, 'in_channels must be divisible by groups'
assert out_channels % groups == 0, 'out_channels must be divisible by groups'
self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out")
# self.weight = init.relu_invariant_gauss([out_channels, in_channels//groups, Kh, Kw], dtype="float", mode="fan_out")
self.weight = init.invariant_uniform([out_channels, in_channels//groups, Kh, Kw], dtype="float")
if bias:
self.bias = init.uniform([out_channels], dtype="float", low=-1, high=1)
fan=1
for i in self.weight.shape[1:]:
fan *= i
bound = 1 / math.sqrt(fan)
self.bias = init.uniform([out_channels], dtype="float", low=-bound, high=bound)
else:
self.bias = None

View File

@ -131,6 +131,42 @@ class SGD(Optimizer):
p -= v * lr
p.detach_inplace()
class RMSprop(Optimizer):
""" RMSprop Optimizer.
Args:
params(list): parameters of model.
lr(float): learning rate.
eps(float): term added to the denominator to avoid division by zero, default 1e-8.
alpha(float): smoothing constant, default 0.99.
Example:
optimizer = nn.RMSprop(model.parameters(), lr)
optimizer.step(loss)
"""
def __init__(self, params, lr=1e-2, eps=1e-8, alpha=0.99):
super().__init__(params, lr)
self.eps = eps
self.alpha = alpha
# initialize required arguments for each param_groups
for pg in self.param_groups:
values = pg["values"] = []
for p in pg["params"]:
values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
def step(self, loss):
self.pre_step(loss)
for pg in self.param_groups:
# get arguments from each param_groups
lr = pg.get("lr", self.lr)
eps = pg.get("eps", self.eps)
alpha = pg.get("alpha", self.alpha)
for p, g, v in zip(pg["params"], pg["grads"], pg["values"]):
if p.is_stop_grad(): continue
v.assign(alpha * v + (1-alpha) * g * g)
p -= lr * g / (jt.sqrt(v) + eps)
p.detach_inplace()
class Adam(Optimizer):
""" Adam Optimizer.

View File

@ -0,0 +1,50 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Wenyang Zhou <576825820@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import os, sys
import jittor as jt
import numpy as np
mpi = jt.compile_extern.mpi
from jittor.dataset.mnist import MNIST
dataloader = MNIST(train=False).set_attrs(batch_size=16)
def val1():
for i, (imgs, labels) in enumerate(dataloader):
assert(imgs.shape[0]==8)
if i == 5:
break
@jt.single_process_scope(rank=0)
def val2():
for i, (imgs, labels) in enumerate(dataloader):
assert(imgs.shape[0]==16)
if i == 5:
break
@unittest.skipIf(mpi is None, "no inside mpirun")
class TestSingleProcessScope(unittest.TestCase):
def test_single_process_scope(self):
val1()
val2()
def run_single_process_scope_test(num_procs, name):
if not jt.compile_extern.inside_mpi():
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
cmd = f"{mpirun_path} -np {num_procs} {sys.executable} -m jittor.test.{name} -v"
print("run cmd:", cmd)
assert os.system(cmd)==0, "run cmd failed: "+cmd
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
class TestSingleProcessScopeEntry(unittest.TestCase):
def test_entry(self):
run_single_process_scope_test(2, "test_single_process_scope")
if __name__ == "__main__":
unittest.main()

View File

@ -41,12 +41,15 @@ static void move_rely(KernelIR* inner_loop, KernelIR* outer_loop, KernelIR* def)
}
}
static void tune_atomic(Pass* pass, KernelIR* ir, bool is_cuda, int tdim) {
// sorder: Array that saves the allocation order of "tn"
// sfunc: Array of function names
static void tune_atomic(Pass* pass, KernelIR* ir, bool is_cuda, int tdim, vector<vector<int>> &sorder, vector<string> &sfunc) {
LOGvvvv << "tune_atomic" << ir->children;
vector<string> relys;
vector<string> idx_name;
vector<KernelIR*> atomics;
vector<KernelIR*> loops;
vector<int> nrely;
vector<int> order;
int tmp_cnt=0;
for (uint i=0; i<ir->children.size(); i++) {
@ -57,6 +60,7 @@ static void tune_atomic(Pass* pass, KernelIR* ir, bool is_cuda, int tdim) {
atomics.clear();
loops.clear();
order.clear();
nrely.clear();
c->dfs([&](unique_ptr<KernelIR>& p) {
auto& code = p->attrs["code"];
@ -71,6 +75,7 @@ static void tune_atomic(Pass* pass, KernelIR* ir, bool is_cuda, int tdim) {
loops.push_back(loop);
idx_name.push_back(loop->attrs["lvalue"]);
order.push_back(loops.size()-1);
nrely.push_back(-1);
bool ok = true;
while (1) {
loop = loops.back();
@ -90,6 +95,7 @@ static void tune_atomic(Pass* pass, KernelIR* ir, bool is_cuda, int tdim) {
loops.push_back(loop2);
idx_name.push_back(loop2->attrs["lvalue"]);
order.push_back(loops.size()-1);
nrely.push_back(-1);
}
// TODO: only support single loop children
if (!ok) continue;
@ -107,12 +113,25 @@ static void tune_atomic(Pass* pass, KernelIR* ir, bool is_cuda, int tdim) {
for (uint l=0;l<order.size();l++)
if (order[l]==sidx) sord=l;
ASSERT(sord != -1);
for (int l=sord;l;l--) order[l]=order[l-1];
for (int l=sord;l;l--){
order[l]=order[l-1];
nrely[l]=nrely[l-1];
}
order[0]=sidx;
nrely[0]=j;
}
}
LOGvvvv << "atomic tuner order" << order;
vector<int> tnorder;
uint si;
for (si=0;si<order.size();si++)
if (nrely[si]!=nrely[0]) break;
for (int j=si-1;j>=0;j--) tnorder.push_back(order[j]);
for (int j=order.size()-1;j>=si;j--) tnorder.push_back(order[j]);
sorder.push_back(tnorder);
sfunc.push_back(ir->attrs["lvalue"]);
// sort loop with order
int count=0;
for (auto j : order) {
@ -199,12 +218,54 @@ void AtomicTunerPass::run() {
if (is_cuda) choice=1;
if (!choice) return;
vector<vector<int>> sorder;
vector<string> sfunc;
for (uint i=0; i<ir->before.size(); i++) {
auto& func_call = ir->before[i];
// TODO: remove this if
if (func_call->get_attr("dtype") != "__global__ void") continue;
tune_atomic(this, func_call.get(), is_cuda, 4);
tune_atomic(this, func_call.get(), is_cuda, 4, sorder, sfunc);
}
// Re-adjust the allocation order of "tn" according to the situation of atomic coverage, preferentially allocate the range not covered by atomic, for example:
// for (op0_index_t id0 = tid0; id0<range0; id0+=tnum0) {
// for (op1_index_t id1 = tid1; id1<range1; id1+=tnum1) {
// for (op2_index_t id2 = tid2; id2<range2; id2+=tnum2) {
// for (op3_index_t id3 = tid3; id3<range3; id3+=tnum3) {
// ...
// }
// }
// atomicAdd(...);
// }
// }
// The allocation order of "tn" will be: tn1, tn0, tn3, tn2
for (uint j=0;j<sfunc.size();j++)
for (uint i=0; i<ir->children.size(); i++) {
auto& func_call = ir->children[i];
int bo=0;
for (uint k=0; k<func_call->children.size(); k++){
auto& save = func_call->children[k];
if (save->has_attr("loop_func") && save->attrs["loop_func"]==sfunc[j]){
bo=1;
break;
}
}
if (!bo) continue;
uint k;
for (k=0; k<func_call->children.size(); k++){
auto& save = func_call->children[k];
if (save->has_attr("lvalue") && save->attrs["lvalue"].find("tn")==0) break;
}
for (uint l=0;l<sorder[j].size();l++){
for (uint p=0; p<func_call->children.size(); p++){
auto& save = func_call->children[p];
if (save->has_attr("lvalue") && save->attrs["lvalue"].find("tn"+S(sorder[j][l]))==0){
func_call->children[p]->swap(*func_call->children[k++]);
break;
}
}
}
}
ir->remove_all_unused();
}

View File

@ -264,12 +264,9 @@ void ParallelPass::run() {
string nums = rvalues.at(0);
for (int i=1; i<rvalues.size(); i++)
nums+="*"+rvalues[i];
if (fix_thread_num)
new_block.push_back("int thread_num=" + S(thread_num) + ");");
else
new_block.push_back("int thread_num=min(1<<(NanoVector::get_nbits("+nums+")-2)," + S(thread_num) + ");");
new_block.push_back("int thread_num=" + S(thread_num) + ";");
new_block.push_back("int thread_num_left=thread_num;");
for (int j=ncs.size()-1; j>=0; j--) {
auto& rv = rvalues[j];
new_block.push_back("int tn"+S(j)+
@ -344,6 +341,15 @@ void ParallelPass::run() {
new_func_def->insert(0, new_tid_def.children);
new_func_def->swap(*func_def, true);
new_block.swap(*func_call, true);
auto code = func_def->to_string();
bool has_atomic = code.find("atomic") != string::npos;
if (!fix_thread_num) {
if (has_atomic) {
func_call->find_define("thread_num")->attrs["rvalue"] = "min(1<<max((NanoVector::get_nbits(" + nums + "/16)-2),0)," + S(thread_num) + ")";
} else {
func_call->find_define("thread_num")->attrs["rvalue"] = "min(1<<max((NanoVector::get_nbits(" + nums + ")-2),0)," + S(thread_num) + ")";
}
}
}
ir->remove_all_unused();
}