mirror of https://github.com/Jittor/Jittor
commit
f5d1caaa71
|
@ -24,6 +24,7 @@
|
|||
jittor.contrib
|
||||
jittor.dataset
|
||||
jittor.transform
|
||||
jittor.mpi
|
||||
|
||||
|
||||
.. toctree::
|
||||
|
|
|
@ -6,8 +6,9 @@ jittor.models
|
|||
```eval_rst
|
||||
|
||||
.. automodule:: jittor.models
|
||||
:members:
|
||||
:members:
|
||||
:imported-members:
|
||||
:undoc-members:
|
||||
:exclude-members: ResNet,ShuffleNetV2,SqueezeNet,VGG
|
||||
```
|
||||
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
jittor.mpi
|
||||
=====================
|
||||
|
||||
这里是Jittor的MPI模块的API文档,您可以通过`from jittor import mpi`来获取该模块。
|
||||
|
||||
```eval_rst
|
||||
.. automodule:: jittor_mpi_core
|
||||
:members:
|
||||
:undoc-members:
|
||||
.. automodule:: jittor_mpi_core.ops
|
||||
:members:
|
||||
:undoc-members:
|
||||
```
|
|
@ -29,18 +29,33 @@ extern int mpi_world_rank;
|
|||
extern int mpi_local_rank;
|
||||
extern bool inside_mpi;
|
||||
|
||||
/**
|
||||
Return number of MPI nodes.
|
||||
*/
|
||||
// @pyjt(world_size)
|
||||
int _mpi_world_size();
|
||||
|
||||
/**
|
||||
Return global ID of this MPI node.
|
||||
*/
|
||||
// @pyjt(world_rank)
|
||||
int _mpi_world_rank();
|
||||
|
||||
/**
|
||||
Return local ID of this MPI node.
|
||||
*/
|
||||
// @pyjt(local_rank)
|
||||
int _mpi_local_rank();
|
||||
|
||||
struct ArrayArgs;
|
||||
|
||||
/**
|
||||
|
||||
Use jt.Module.mpi_param_broadcast(root=0) to broadcast all moudule parameters of this module in [root] MPI node to all MPI nodes.
|
||||
|
||||
This operation has no gradient, and the input parameter type is numpy array.
|
||||
*/
|
||||
// @pyjt(broadcast)
|
||||
void _mpi_broadcast(ArrayArgs&& args, int i);
|
||||
void _mpi_broadcast(ArrayArgs&& args, int root);
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -15,6 +15,15 @@ struct MpiAllReduceOp : Op {
|
|||
Var* x, * y;
|
||||
NanoString op;
|
||||
|
||||
/**
|
||||
|
||||
Mpi All Reduce Operator uses the operator [op] to reduce variable [x] in all MPI nodes and broadcast to all MPI nodes.
|
||||
|
||||
Args:
|
||||
|
||||
* x: variable to be all reduced.
|
||||
* op: 'sum' or 'add' means sum all [x], 'mean' means average all [x]. Default: 'add'.
|
||||
*/
|
||||
MpiAllReduceOp(Var* x, NanoString op=ns_add);
|
||||
void infer_shape() override;
|
||||
|
||||
|
|
|
@ -15,6 +15,15 @@ struct MpiBroadcastOp : Op {
|
|||
Var* x, * y;
|
||||
int root;
|
||||
|
||||
/**
|
||||
|
||||
Mpi Broadcast Operator broadcasts variable [x] in [root] MPI nodes to all MPI nodes.
|
||||
|
||||
Args:
|
||||
|
||||
* x: variable to be broadcasted.
|
||||
* root: ID of MPI node to be broadcasted. Default: 0.
|
||||
*/
|
||||
MpiBroadcastOp(Var* x, int root=0);
|
||||
void infer_shape() override;
|
||||
|
||||
|
|
|
@ -16,6 +16,16 @@ struct MpiReduceOp : Op {
|
|||
NanoString op;
|
||||
int root;
|
||||
|
||||
/**
|
||||
|
||||
Mpi Reduce Operator uses the operator [op] to reduce variable [x] in all MPI nodes and send to the [root] MPI node.
|
||||
|
||||
Args:
|
||||
|
||||
* x: variable to be reduced.
|
||||
* op: 'sum' or 'add' means sum all [x], 'mean' means average all [x]. Default: 'add'.
|
||||
* root: ID of MPI node to output. Default: 0.
|
||||
*/
|
||||
MpiReduceOp(Var* x, NanoString op=ns_add, int root=0);
|
||||
void infer_shape() override;
|
||||
|
||||
|
|
|
@ -44,11 +44,11 @@ int _mpi_local_rank() {
|
|||
return mpi_local_rank;
|
||||
}
|
||||
|
||||
void _mpi_broadcast(ArrayArgs&& args, int i) {
|
||||
void _mpi_broadcast(ArrayArgs&& args, int root) {
|
||||
int64 size = args.dtype.dsize();
|
||||
for (auto j : args.shape)
|
||||
size *= j;
|
||||
MPI_CHECK(MPI_Bcast((void *)args.ptr, size, MPI_BYTE, i, MPI_COMM_WORLD));
|
||||
MPI_CHECK(MPI_Bcast((void *)args.ptr, size, MPI_BYTE, root, MPI_COMM_WORLD));
|
||||
}
|
||||
|
||||
static uint64_t getHostHash(const char* string) {
|
||||
|
|
|
@ -380,7 +380,7 @@ def setup_mpi():
|
|||
# share the 'environ' symbol.
|
||||
mpi = compile_custom_ops(mpi_src_files,
|
||||
extra_flags=f" {mpi_flags} ", return_module=True,
|
||||
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW)
|
||||
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW, gen_name_="jittor_mpi_core")
|
||||
mpi_ops = mpi.ops
|
||||
LOG.vv("Get mpi: "+str(mpi.__dict__.keys()))
|
||||
LOG.vv("Get mpi_ops: "+str(mpi_ops.__dict__.keys()))
|
||||
|
|
|
@ -561,7 +561,8 @@ def compile_custom_ops(
|
|||
filenames,
|
||||
extra_flags="",
|
||||
return_module=False,
|
||||
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND):
|
||||
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND,
|
||||
gen_name_ = ""):
|
||||
"""Compile custom ops
|
||||
filenames: path of op source files, filenames must be
|
||||
pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the
|
||||
|
@ -599,6 +600,8 @@ def compile_custom_ops(
|
|||
for name in srcs:
|
||||
assert name in headers, f"Header of op {name} not found"
|
||||
gen_name = "gen_ops_" + "_".join(headers.keys())
|
||||
if gen_name_ != "":
|
||||
gen_name = gen_name_
|
||||
if len(gen_name) > 100:
|
||||
gen_name = gen_name[:80] + "___hash" + str(hash(gen_name))
|
||||
|
||||
|
|
|
@ -13,6 +13,19 @@ import jittor.nn as nn
|
|||
__all__ = ['AlexNet', 'alexnet']
|
||||
|
||||
class AlexNet(nn.Module):
|
||||
""" AlexNet model architecture.
|
||||
|
||||
Args:
|
||||
|
||||
* num_classes: Number of classes. Default: 1000.
|
||||
|
||||
Example::
|
||||
|
||||
model = jittor.models.AlexNet(500)
|
||||
x = jittor.random([10,224,224,3])
|
||||
y = model(x) # [10, 500]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=1000):
|
||||
super(AlexNet, self).__init__()
|
||||
|
|
|
@ -16,6 +16,15 @@ def googlenet(**kwargs):
|
|||
return GoogLeNet(**kwargs)
|
||||
|
||||
class GoogLeNet(nn.Module):
|
||||
""" GoogLeNet model architecture.
|
||||
|
||||
Args:
|
||||
|
||||
* num_classes: Number of classes. Default: 1000.
|
||||
* aux_logits: If True, add an auxiliary branch that can improve training. Default: True
|
||||
* init_weights: Defualt: True.
|
||||
* blocks: List of three blocks, [conv_block, inception_block, inception_aux_block]. If None, will use [BasicConv2d, Inception, InceptionAux] instead. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=1000, aux_logits=True, init_weights=True, blocks=None):
|
||||
super(GoogLeNet, self).__init__()
|
||||
|
|
|
@ -7,7 +7,16 @@ def inception_v3(pretrained=False, progress=True, **kwargs):
|
|||
return Inception3(**kwargs)
|
||||
|
||||
class Inception3(nn.Module):
|
||||
""" Inceptionv3 model architecture.
|
||||
|
||||
Args:
|
||||
|
||||
* num_classes: Number of classes. Default: 1000.
|
||||
* aux_logits: If True, add an auxiliary branch that can improve training. Default: True
|
||||
* inception_blocks: List of seven blocks, [conv_block, inception_a, inception_b, inception_c, inception_d, inception_e, inception_aux]. If None, will use [BasicConv2d, InceptionA, InceptionB, InceptionC, InceptionD, InceptionE, InceptionAux] instead. Default: None.
|
||||
* init_weights: Defualt: True.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=1000, aux_logits=True, inception_blocks=None, init_weights=True):
|
||||
super(Inception3, self).__init__()
|
||||
if (inception_blocks is None):
|
||||
|
|
|
@ -47,6 +47,14 @@ def _get_depths(alpha):
|
|||
return [_round_to_multiple_of((depth * alpha), 8) for depth in depths]
|
||||
|
||||
class MNASNet(nn.Module):
|
||||
""" MNASNet model architecture. version=2.
|
||||
|
||||
Args:
|
||||
|
||||
* alpha: Depth multiplier.
|
||||
* num_classes: Number of classes. Default: 1000.
|
||||
* dropout: Dropout probability of dropout layer.
|
||||
"""
|
||||
_version = 2
|
||||
|
||||
def __init__(self, alpha, num_classes=1000, dropout=0.2):
|
||||
|
|
|
@ -48,6 +48,17 @@ class InvertedResidual(nn.Module):
|
|||
return self.conv(x)
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
""" MobileNetV2 model architecture.
|
||||
|
||||
Args:
|
||||
|
||||
* num_classes: Number of classes. Default: 1000.
|
||||
* width_mult: Width multiplier - adjusts number of channels in each layer by this amount. Default: 1.0.
|
||||
* init_weights: Defualt: True.
|
||||
* inverted_residual_setting: Network structure
|
||||
* round_nearest: Round the number of channels in each layer to be a multiple of this number. Set to 1 to turn off rounding. Default: 8.
|
||||
* block: Module specifying inverted residual building block for mobilenet. If None, use InvertedResidual instead. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8, block=None):
|
||||
super(MobileNetV2, self).__init__()
|
||||
|
|
|
@ -167,6 +167,16 @@ def Resnet50(**kwargs):
|
|||
resnet50 = Resnet50
|
||||
|
||||
def Resnet101(**kwargs):
|
||||
"""
|
||||
ResNet-101 model architecture.
|
||||
|
||||
Example::
|
||||
|
||||
model = jittor.models.Resnet101()
|
||||
x = jittor.random([10,224,224,3])
|
||||
y = model(x) # [10, 1000]
|
||||
|
||||
"""
|
||||
return _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
resnet101 = Resnet101
|
||||
|
||||
|
|
Loading…
Reference in New Issue