MPI ops documents

This commit is contained in:
guoye 2020-06-02 16:14:29 +08:00
parent 7db591ef0a
commit 626ca2c272
9 changed files with 64 additions and 5 deletions

View File

@ -24,6 +24,7 @@
jittor.contrib
jittor.dataset
jittor.transform
jittor.compile_extern.mpi
.. toctree::

View File

@ -0,0 +1,13 @@
jittor.compile_extern.mpi
=====================
这里是Jittor的MPI模块的API文档您可以通过`from jittor.compile_extern import mpi`来获取该模块。
```eval_rst
.. automodule:: jittor_mpi_core
:members:
:undoc-members:
.. automodule:: jittor_mpi_core.ops
:members:
:undoc-members:
```

View File

@ -29,18 +29,32 @@ extern int mpi_world_rank;
extern int mpi_local_rank;
extern bool inside_mpi;
/**
Return number of MPI nodes.
*/
// @pyjt(world_size)
int _mpi_world_size();
/**
Return global ID of this MPI node.
*/
// @pyjt(world_rank)
int _mpi_world_rank();
/**
Return local ID of this MPI node.
*/
// @pyjt(local_rank)
int _mpi_local_rank();
struct ArrayArgs;
/**
Use jt.Module.mpi_param_broadcast(root=0) to broadcast all moudule parameters of this module in [root] MPI node to all MPI nodes.
*/
// @pyjt(broadcast)
void _mpi_broadcast(ArrayArgs&& args, int i);
void _mpi_broadcast(ArrayArgs&& args, int root);
} // jittor

View File

@ -15,6 +15,15 @@ struct MpiAllReduceOp : Op {
Var* x, * y;
NanoString op;
/**
Mpi All Reduce Operator uses the operator [op] to reduce variable [x] in all MPI nodes and broadcast to all MPI nodes.
Args:
* x: variable to be all reduced.
* op: 'sum' or 'add' means sum all [x], 'mean' means average all [x].
*/
MpiAllReduceOp(Var* x, NanoString op=ns_add);
void infer_shape() override;

View File

@ -15,6 +15,15 @@ struct MpiBroadcastOp : Op {
Var* x, * y;
int root;
/**
Mpi Broadcast Operator broadcasts variable [x] in [root] MPI nodes to all MPI nodes.
Args:
* x: variable to be broadcasted.
* root: ID of MPI node to be broadcasted.
*/
MpiBroadcastOp(Var* x, int root=0);
void infer_shape() override;

View File

@ -16,6 +16,16 @@ struct MpiReduceOp : Op {
NanoString op;
int root;
/**
Mpi Reduce Operator uses the operator [op] to reduce variable [x] in all MPI nodes and send to the [root] MPI node.
Args:
* x: variable to be reduced.
* op: 'sum' or 'add' means sum all [x], 'mean' means average all [x].
* root: ID of MPI node to output.
*/
MpiReduceOp(Var* x, NanoString op=ns_add, int root=0);
void infer_shape() override;

View File

@ -44,11 +44,11 @@ int _mpi_local_rank() {
return mpi_local_rank;
}
void _mpi_broadcast(ArrayArgs&& args, int i) {
void _mpi_broadcast(ArrayArgs&& args, int root) {
int64 size = args.dtype.dsize();
for (auto j : args.shape)
size *= j;
MPI_CHECK(MPI_Bcast((void *)args.ptr, size, MPI_BYTE, i, MPI_COMM_WORLD));
MPI_CHECK(MPI_Bcast((void *)args.ptr, size, MPI_BYTE, root, MPI_COMM_WORLD));
}
static uint64_t getHostHash(const char* string) {

View File

@ -380,7 +380,7 @@ def setup_mpi():
# share the 'environ' symbol.
mpi = compile_custom_ops(mpi_src_files,
extra_flags=f" {mpi_flags} ", return_module=True,
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW)
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW, gen_name_="jittor_mpi_core")
mpi_ops = mpi.ops
LOG.vv("Get mpi: "+str(mpi.__dict__.keys()))
LOG.vv("Get mpi_ops: "+str(mpi_ops.__dict__.keys()))

View File

@ -561,7 +561,8 @@ def compile_custom_ops(
filenames,
extra_flags="",
return_module=False,
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND):
dlopen_flags=os.RTLD_GLOBAL | os.RTLD_NOW | os.RTLD_DEEPBIND,
gen_name_ = ""):
"""Compile custom ops
filenames: path of op source files, filenames must be
pairs of xxx_xxx_op.cc and xxx_xxx_op.h, and the
@ -599,6 +600,8 @@ def compile_custom_ops(
for name in srcs:
assert name in headers, f"Header of op {name} not found"
gen_name = "gen_ops_" + "_".join(headers.keys())
if gen_name_ != "":
gen_name = gen_name_
if len(gen_name) > 100:
gen_name = gen_name[:80] + "___hash" + str(hash(gen_name))