mirror of https://github.com/Jittor/Jittor
mpi dataset
This commit is contained in:
parent
7d2e1b73fa
commit
69e7d367cd
|
@ -37,4 +37,9 @@ int _mpi_world_rank();
|
|||
// @pyjt(local_rank)
|
||||
int _mpi_local_rank();
|
||||
|
||||
struct ArrayArgs;
|
||||
|
||||
// @pyjt(broadcast)
|
||||
void _mpi_broadcast(ArrayArgs&& args, int i);
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
|
||||
#include "mpi_warper.h"
|
||||
#include "common.h"
|
||||
#include "ops/array_op.h"
|
||||
|
||||
char jt_mpi_err_buffer[MPI_MAX_ERROR_STRING];
|
||||
|
||||
|
@ -44,7 +45,12 @@ int _mpi_local_rank() {
|
|||
return mpi_local_rank;
|
||||
}
|
||||
|
||||
|
||||
void _mpi_broadcast(ArrayArgs&& args, int i) {
|
||||
int64 size = args.dtype.dsize();
|
||||
for (auto i : args.shape)
|
||||
size *= i;
|
||||
MPI_CHECK(MPI_Bcast((void *)args.ptr, size, MPI_BYTE, 0, MPI_COMM_WORLD));
|
||||
}
|
||||
|
||||
static uint64_t getHostHash(const char* string) {
|
||||
// Based on DJB2, result = result * 33 + char
|
||||
|
|
|
@ -23,6 +23,7 @@ import jittor as jt
|
|||
|
||||
dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset")
|
||||
mp_log_v = os.environ.get("mp_log_v", 0)
|
||||
mpi = jt.compile_extern.mpi
|
||||
|
||||
class Worker:
|
||||
def __init__(self, target, args, buffer_size):
|
||||
|
@ -186,6 +187,17 @@ class Dataset(object):
|
|||
index_list = get_order_list(self.total_len)
|
||||
else:
|
||||
index_list = get_random_list(self.total_len)
|
||||
|
||||
# scatter index_list for all mpi process
|
||||
# scatter rule:
|
||||
# [000000000111111111222222222]
|
||||
if mpi:
|
||||
index_list = np.int32(index_list)
|
||||
mpi.broadcast(index_list, 0)
|
||||
new_len = (self.total_len - 1) // mpi.world_size() + 1
|
||||
offset = mpi.world_rank() * new_len
|
||||
index_list = index_list[offset:offset+new_len]
|
||||
self.total_len = len(index_list)
|
||||
|
||||
self.batch_len = len(self)
|
||||
if "batch_len" in os.environ:
|
||||
|
|
|
@ -10,24 +10,49 @@ import unittest
|
|||
import os, sys
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
mpi = jt.compile_extern.mpi
|
||||
|
||||
def main():
|
||||
print("test mpi_test")
|
||||
assert jt.compile_extern.mpi_ops.mpi_test("").data == 123
|
||||
if jt.compile_extern.nccl_ops:
|
||||
print("test test_with_mpi")
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123
|
||||
|
||||
@unittest.skipIf(jt.compile_extern.has_mpi is None, "no mpi found")
|
||||
@unittest.skipIf(mpi is None, "no inside mpirun")
|
||||
class TestMpi(unittest.TestCase):
|
||||
def test(self):
|
||||
def test_mpi_test_op(self):
|
||||
assert jt.compile_extern.mpi_ops.mpi_test("").data == 123
|
||||
|
||||
@unittest.skipIf(jt.compile_extern.nccl_ops is None, "no inccl")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_nccl_with_mpi(self):
|
||||
assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123
|
||||
|
||||
def test_mpi_broadcast(self):
|
||||
a = np.zeros(100) + mpi.world_rank()
|
||||
mpi.broadcast(a, 0)
|
||||
assert (a == 0).all()
|
||||
|
||||
def test_mpi_dataset(self):
|
||||
from jittor.dataset.dataset import Dataset
|
||||
class ToyDataset(Dataset):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.set_attrs(total_len=1024)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return index, index*index
|
||||
|
||||
toy = ToyDataset()
|
||||
offset = ((toy.total_len-1) // mpi.world_size() + 1) * mpi.world_rank()
|
||||
|
||||
for i,(a,b) in enumerate(toy):
|
||||
assert (a.data*a.data == b.data).all()
|
||||
c = np.array(range(offset+i*toy.batch_size, offset+(i+1)*toy.batch_size))
|
||||
assert (c==a.data).all()
|
||||
|
||||
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||
class TestMpiEntry(unittest.TestCase):
|
||||
def test_entry(self):
|
||||
if not jt.compile_extern.inside_mpi():
|
||||
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
|
||||
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi"
|
||||
print("run cmd:", cmd)
|
||||
assert os.system(cmd)==0, "run cmd failed: "+cmd
|
||||
else:
|
||||
main()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue