mpi dataset

This commit is contained in:
Dun Liang 2020-04-09 22:50:46 +08:00
parent 7d2e1b73fa
commit 69e7d367cd
4 changed files with 61 additions and 13 deletions

View File

@ -37,4 +37,9 @@ int _mpi_world_rank();
// @pyjt(local_rank)
int _mpi_local_rank();
struct ArrayArgs;
// @pyjt(broadcast)
void _mpi_broadcast(ArrayArgs&& args, int i);
} // jittor

View File

@ -12,6 +12,7 @@
#include "mpi_warper.h"
#include "common.h"
#include "ops/array_op.h"
char jt_mpi_err_buffer[MPI_MAX_ERROR_STRING];
@ -44,7 +45,12 @@ int _mpi_local_rank() {
return mpi_local_rank;
}
void _mpi_broadcast(ArrayArgs&& args, int i) {
int64 size = args.dtype.dsize();
for (auto i : args.shape)
size *= i;
MPI_CHECK(MPI_Bcast((void *)args.ptr, size, MPI_BYTE, 0, MPI_COMM_WORLD));
}
static uint64_t getHostHash(const char* string) {
// Based on DJB2, result = result * 33 + char

View File

@ -23,6 +23,7 @@ import jittor as jt
dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset")
mp_log_v = os.environ.get("mp_log_v", 0)
mpi = jt.compile_extern.mpi
class Worker:
def __init__(self, target, args, buffer_size):
@ -187,6 +188,17 @@ class Dataset(object):
else:
index_list = get_random_list(self.total_len)
# scatter index_list for all mpi process
# scatter rule:
# [000000000111111111222222222]
if mpi:
index_list = np.int32(index_list)
mpi.broadcast(index_list, 0)
new_len = (self.total_len - 1) // mpi.world_size() + 1
offset = mpi.world_rank() * new_len
index_list = index_list[offset:offset+new_len]
self.total_len = len(index_list)
self.batch_len = len(self)
if "batch_len" in os.environ:
self.batch_len = int(os.environ["batch_len"])

View File

@ -10,24 +10,49 @@ import unittest
import os, sys
import jittor as jt
import numpy as np
mpi = jt.compile_extern.mpi
def main():
print("test mpi_test")
assert jt.compile_extern.mpi_ops.mpi_test("").data == 123
if jt.compile_extern.nccl_ops:
print("test test_with_mpi")
with jt.flag_scope(use_cuda=1):
assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123
@unittest.skipIf(jt.compile_extern.has_mpi is None, "no mpi found")
@unittest.skipIf(mpi is None, "no inside mpirun")
class TestMpi(unittest.TestCase):
def test(self):
def test_mpi_test_op(self):
assert jt.compile_extern.mpi_ops.mpi_test("").data == 123
@unittest.skipIf(jt.compile_extern.nccl_ops is None, "no inccl")
@jt.flag_scope(use_cuda=1)
def test_nccl_with_mpi(self):
assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123
def test_mpi_broadcast(self):
a = np.zeros(100) + mpi.world_rank()
mpi.broadcast(a, 0)
assert (a == 0).all()
def test_mpi_dataset(self):
from jittor.dataset.dataset import Dataset
class ToyDataset(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=1024)
def __getitem__(self, index):
return index, index*index
toy = ToyDataset()
offset = ((toy.total_len-1) // mpi.world_size() + 1) * mpi.world_rank()
for i,(a,b) in enumerate(toy):
assert (a.data*a.data == b.data).all()
c = np.array(range(offset+i*toy.batch_size, offset+(i+1)*toy.batch_size))
assert (c==a.data).all()
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
class TestMpiEntry(unittest.TestCase):
def test_entry(self):
if not jt.compile_extern.inside_mpi():
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi"
print("run cmd:", cmd)
assert os.system(cmd)==0, "run cmd failed: "+cmd
else:
main()
if __name__ == "__main__":
unittest.main()