mirror of https://github.com/Jittor/Jittor
mpi dataset
This commit is contained in:
parent
7d2e1b73fa
commit
69e7d367cd
|
@ -37,4 +37,9 @@ int _mpi_world_rank();
|
||||||
// @pyjt(local_rank)
|
// @pyjt(local_rank)
|
||||||
int _mpi_local_rank();
|
int _mpi_local_rank();
|
||||||
|
|
||||||
|
struct ArrayArgs;
|
||||||
|
|
||||||
|
// @pyjt(broadcast)
|
||||||
|
void _mpi_broadcast(ArrayArgs&& args, int i);
|
||||||
|
|
||||||
} // jittor
|
} // jittor
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
|
|
||||||
#include "mpi_warper.h"
|
#include "mpi_warper.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
#include "ops/array_op.h"
|
||||||
|
|
||||||
char jt_mpi_err_buffer[MPI_MAX_ERROR_STRING];
|
char jt_mpi_err_buffer[MPI_MAX_ERROR_STRING];
|
||||||
|
|
||||||
|
@ -44,7 +45,12 @@ int _mpi_local_rank() {
|
||||||
return mpi_local_rank;
|
return mpi_local_rank;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void _mpi_broadcast(ArrayArgs&& args, int i) {
|
||||||
|
int64 size = args.dtype.dsize();
|
||||||
|
for (auto i : args.shape)
|
||||||
|
size *= i;
|
||||||
|
MPI_CHECK(MPI_Bcast((void *)args.ptr, size, MPI_BYTE, 0, MPI_COMM_WORLD));
|
||||||
|
}
|
||||||
|
|
||||||
static uint64_t getHostHash(const char* string) {
|
static uint64_t getHostHash(const char* string) {
|
||||||
// Based on DJB2, result = result * 33 + char
|
// Based on DJB2, result = result * 33 + char
|
||||||
|
|
|
@ -23,6 +23,7 @@ import jittor as jt
|
||||||
|
|
||||||
dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset")
|
dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset")
|
||||||
mp_log_v = os.environ.get("mp_log_v", 0)
|
mp_log_v = os.environ.get("mp_log_v", 0)
|
||||||
|
mpi = jt.compile_extern.mpi
|
||||||
|
|
||||||
class Worker:
|
class Worker:
|
||||||
def __init__(self, target, args, buffer_size):
|
def __init__(self, target, args, buffer_size):
|
||||||
|
@ -187,6 +188,17 @@ class Dataset(object):
|
||||||
else:
|
else:
|
||||||
index_list = get_random_list(self.total_len)
|
index_list = get_random_list(self.total_len)
|
||||||
|
|
||||||
|
# scatter index_list for all mpi process
|
||||||
|
# scatter rule:
|
||||||
|
# [000000000111111111222222222]
|
||||||
|
if mpi:
|
||||||
|
index_list = np.int32(index_list)
|
||||||
|
mpi.broadcast(index_list, 0)
|
||||||
|
new_len = (self.total_len - 1) // mpi.world_size() + 1
|
||||||
|
offset = mpi.world_rank() * new_len
|
||||||
|
index_list = index_list[offset:offset+new_len]
|
||||||
|
self.total_len = len(index_list)
|
||||||
|
|
||||||
self.batch_len = len(self)
|
self.batch_len = len(self)
|
||||||
if "batch_len" in os.environ:
|
if "batch_len" in os.environ:
|
||||||
self.batch_len = int(os.environ["batch_len"])
|
self.batch_len = int(os.environ["batch_len"])
|
||||||
|
|
|
@ -10,24 +10,49 @@ import unittest
|
||||||
import os, sys
|
import os, sys
|
||||||
import jittor as jt
|
import jittor as jt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
mpi = jt.compile_extern.mpi
|
||||||
|
|
||||||
def main():
|
@unittest.skipIf(mpi is None, "no inside mpirun")
|
||||||
print("test mpi_test")
|
|
||||||
assert jt.compile_extern.mpi_ops.mpi_test("").data == 123
|
|
||||||
if jt.compile_extern.nccl_ops:
|
|
||||||
print("test test_with_mpi")
|
|
||||||
with jt.flag_scope(use_cuda=1):
|
|
||||||
assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123
|
|
||||||
|
|
||||||
@unittest.skipIf(jt.compile_extern.has_mpi is None, "no mpi found")
|
|
||||||
class TestMpi(unittest.TestCase):
|
class TestMpi(unittest.TestCase):
|
||||||
def test(self):
|
def test_mpi_test_op(self):
|
||||||
|
assert jt.compile_extern.mpi_ops.mpi_test("").data == 123
|
||||||
|
|
||||||
|
@unittest.skipIf(jt.compile_extern.nccl_ops is None, "no inccl")
|
||||||
|
@jt.flag_scope(use_cuda=1)
|
||||||
|
def test_nccl_with_mpi(self):
|
||||||
|
assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123
|
||||||
|
|
||||||
|
def test_mpi_broadcast(self):
|
||||||
|
a = np.zeros(100) + mpi.world_rank()
|
||||||
|
mpi.broadcast(a, 0)
|
||||||
|
assert (a == 0).all()
|
||||||
|
|
||||||
|
def test_mpi_dataset(self):
|
||||||
|
from jittor.dataset.dataset import Dataset
|
||||||
|
class ToyDataset(Dataset):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.set_attrs(total_len=1024)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return index, index*index
|
||||||
|
|
||||||
|
toy = ToyDataset()
|
||||||
|
offset = ((toy.total_len-1) // mpi.world_size() + 1) * mpi.world_rank()
|
||||||
|
|
||||||
|
for i,(a,b) in enumerate(toy):
|
||||||
|
assert (a.data*a.data == b.data).all()
|
||||||
|
c = np.array(range(offset+i*toy.batch_size, offset+(i+1)*toy.batch_size))
|
||||||
|
assert (c==a.data).all()
|
||||||
|
|
||||||
|
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||||
|
class TestMpiEntry(unittest.TestCase):
|
||||||
|
def test_entry(self):
|
||||||
if not jt.compile_extern.inside_mpi():
|
if not jt.compile_extern.inside_mpi():
|
||||||
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
|
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
|
||||||
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi"
|
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi"
|
||||||
|
print("run cmd:", cmd)
|
||||||
assert os.system(cmd)==0, "run cmd failed: "+cmd
|
assert os.system(cmd)==0, "run cmd failed: "+cmd
|
||||||
else:
|
|
||||||
main()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
Loading…
Reference in New Issue