Merge pull request #94 from Jittor/ygy2

Ygy2
This commit is contained in:
yang guo ye 2020-06-05 04:13:01 -05:00 committed by GitHub
commit f5d1caaa71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 127 additions and 6 deletions

View File

@ -24,6 +24,7 @@
jittor.contrib
jittor.dataset
jittor.transform
jittor.mpi
.. toctree::

View File

@ -6,8 +6,9 @@ jittor.models
```eval_rst
.. automodule:: jittor.models
:members:
:members:
:imported-members:
:undoc-members:
:exclude-members: ResNet,ShuffleNetV2,SqueezeNet,VGG
```

13
doc/source/jittor.mpi.md Normal file
View File

@ -0,0 +1,13 @@
jittor.mpi
=====================
这里是Jittor的MPI模块的API文档您可以通过`from jittor import mpi`来获取该模块。
```eval_rst
.. automodule:: jittor_mpi_core
:members:
:undoc-members:
.. automodule:: jittor_mpi_core.ops
:members:
:undoc-members:
```

View File

@ -29,18 +29,33 @@ extern int mpi_world_rank;
extern int mpi_local_rank;
extern bool inside_mpi;
/**
Return number of MPI nodes.
*/
// @pyjt(world_size)
int _mpi_world_size();
/**
Return global ID of this MPI node.
*/
// @pyjt(world_rank)
int _mpi_world_rank();
/**
Return local ID of this MPI node.
*/
// @pyjt(local_rank)
int _mpi_local_rank();
struct ArrayArgs;
/**
Use jt.Module.mpi_param_broadcast(root=0) to broadcast all moudule parameters of this module in [root] MPI node to all MPI nodes.
This operation has no gradient, and the input parameter type is numpy array.
*/
// @pyjt(broadcast)
void _mpi_broadcast(ArrayArgs&& args, int i);
void _mpi_broadcast(ArrayArgs&& args, int root);
} // jittor

View File

@ -15,6 +15,15 @@ struct MpiAllReduceOp : Op {
Var* x, * y;
NanoString op;
/**
Mpi All Reduce Operator uses the operator [op] to reduce variable [x] in all MPI nodes and broadcast to all MPI nodes.
Args:
* x: variable to be all reduced.
* op: 'sum' or 'add' means sum all [x], 'mean' means average all [x]. Default: 'add'.
*/
MpiAllReduceOp(Var* x, NanoString op=ns_add);
void infer_shape() override;

View File

@ -15,6 +15,15 @@ struct MpiBroadcastOp : Op {
Var* x, * y;
int root;
/**
Mpi Broadcast Operator broadcasts variable [x] in [root] MPI nodes to all MPI nodes.
Args:
* x: variable to be broadcasted.
* root: ID of MPI node to be broadcasted. Default: 0.
*/
MpiBroadcastOp(Var* x, int root=0);
void infer_shape() override;

View File

@ -16,6 +16,16 @@ struct MpiReduceOp : Op {
NanoString op;
int root;
/**
Mpi Reduce Operator uses the operator [op] to reduce variable [x] in all MPI nodes and send to the [root] MPI node.
Args:
* x: variable to be reduced.
* op: 'sum' or 'add' means sum all [x], 'mean' means average all [x]. Default: 'add'.
* root: ID of MPI node to output. Default: 0.
*/
MpiReduceOp(Var* x, NanoString op=ns_add, int root=0);
void infer_shape() override;

View File

@ -44,11 +44,11 @@ int _mpi_local_rank() {
return mpi_local_rank;
}
void _mpi_broadcast(ArrayArgs&& args, int i) {
void _mpi_broadcast(ArrayArgs&& args, int root) {
int64 size = args.dtype.dsize();
for (auto j : args.shape)
size *= j;
MPI_CHECK(MPI_Bcast((void *)args.ptr, size, MPI_BYTE, i, MPI_COMM_WORLD));
MPI_CHECK(MPI_Bcast((void *)args.ptr, size, MPI_BYTE, root, MPI_COMM_WORLD));
}
static uint64_t getHostHash(const char* string) {

View File

@ -380,7 +380,7 @@ def setup_mpi():
# share the 'environ' symbol.
mpi = compile_custom_ops(mpi_src_files,
extra_flags=f" {mpi_flags} ", return_module=True,
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW)
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW, gen_name_="jittor_mpi_core")
mpi_ops = mpi.ops
LOG.vv("Get mpi: "+str(mpi.__dict__.keys()))
LOG.vv("Get mpi_ops: "+str(mpi_ops.__dict__.keys()))

View File

@ -561,7 +561,8 @@ def compile_custom_ops(
filenames,
extra_flags="",
return_module=False,
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND):
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND,
gen_name_ = ""):
"""Compile custom ops
filenames: path of op source files, filenames must be
pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the
@ -599,6 +600,8 @@ def compile_custom_ops(
for name in srcs:
assert name in headers, f"Header of op {name} not found"
gen_name = "gen_ops_" + "_".join(headers.keys())
if gen_name_ != "":
gen_name = gen_name_
if len(gen_name) > 100:
gen_name = gen_name[:80] + "___hash" + str(hash(gen_name))

View File

@ -13,6 +13,19 @@ import jittor.nn as nn
__all__ = ['AlexNet', 'alexnet']
class AlexNet(nn.Module):
""" AlexNet model architecture.
Args:
* num_classes: Number of classes. Default: 1000.
Example::
model = jittor.models.AlexNet(500)
x = jittor.random([10,224,224,3])
y = model(x) # [10, 500]
"""
def __init__(self, num_classes=1000):
super(AlexNet, self).__init__()

View File

@ -16,6 +16,15 @@ def googlenet(**kwargs):
return GoogLeNet(**kwargs)
class GoogLeNet(nn.Module):
""" GoogLeNet model architecture.
Args:
* num_classes: Number of classes. Default: 1000.
* aux_logits: If True, add an auxiliary branch that can improve training. Default: True
* init_weights: Defualt: True.
* blocks: List of three blocks, [conv_block, inception_block, inception_aux_block]. If None, will use [BasicConv2d, Inception, InceptionAux] instead. Default: None.
"""
def __init__(self, num_classes=1000, aux_logits=True, init_weights=True, blocks=None):
super(GoogLeNet, self).__init__()

View File

@ -7,7 +7,16 @@ def inception_v3(pretrained=False, progress=True, **kwargs):
return Inception3(**kwargs)
class Inception3(nn.Module):
""" Inceptionv3 model architecture.
Args:
* num_classes: Number of classes. Default: 1000.
* aux_logits: If True, add an auxiliary branch that can improve training. Default: True
* inception_blocks: List of seven blocks, [conv_block, inception_a, inception_b, inception_c, inception_d, inception_e, inception_aux]. If None, will use [BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux] instead. Default: None.
* init_weights: Defualt: True.
"""
def __init__(self, num_classes=1000, aux_logits=True, inception_blocks=None, init_weights=True):
super(Inception3, self).__init__()
if (inception_blocks is None):

View File

@ -47,6 +47,14 @@ def _get_depths(alpha):
return [_round_to_multiple_of((depth * alpha), 8) for depth in depths]
class MNASNet(nn.Module):
""" MNASNet model architecture. version=2.
Args:
* alpha: Depth multiplier.
* num_classes: Number of classes. Default: 1000.
* dropout: Dropout probability of dropout layer.
"""
_version = 2
def __init__(self, alpha, num_classes=1000, dropout=0.2):

View File

@ -48,6 +48,17 @@ class InvertedResidual(nn.Module):
return self.conv(x)
class MobileNetV2(nn.Module):
""" MobileNetV2 model architecture.
Args:
* num_classes: Number of classes. Default: 1000.
* width_mult: Width multiplier - adjusts number of channels in each layer by this amount. Default: 1.0.
* init_weights: Defualt: True.
* inverted_residual_setting: Network structure
* round_nearest: Round the number of channels in each layer to be a multiple of this number. Set to 1 to turn off rounding. Default: 8.
* block: Module specifying inverted residual building block for mobilenet. If None, use InvertedResidual instead. Default: None.
"""
def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8, block=None):
super(MobileNetV2, self).__init__()

View File

@ -167,6 +167,16 @@ def Resnet50(**kwargs):
resnet50 = Resnet50
def Resnet101(**kwargs):
"""
ResNet-101 model architecture.
Example::
model = jittor.models.Resnet101()
x = jittor.random([10,224,224,3])
y = model(x) # [10, 1000]
"""
return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
resnet101 = Resnet101