add single_process_scope for val

This commit is contained in:
zhowuy19 2020-06-03 11:04:31 +08:00
parent 8aa5974fef
commit bc1dca9d8e
2 changed files with 82 additions and 0 deletions

View File

@ -125,6 +125,38 @@ class profile_scope(_call_no_record_scope):
profiler.stop()
self.report.extend(profiler.report())
import jittor as jt
from jittor.dataset import dataset
class single_process_scope(_call_no_record_scope):
""" single_process_scope
Code in this scope will only be executed by single
process.
example::
with jt.single_process_scope(root=0):
......
"""
def __init__(self, rank=0):
self.rank = rank
def __enter__(self):
self.mpi_backup = jt.mpi
jt.mpi = dataset.mpi = None
def __exit__(self, *exc):
jt.mpi = dataset.mpi = self.mpi_backup
def __call__(self, func):
def inner(*args, **kw):
if jt.mpi and jt.mpi.world_rank() != self.rank:
return
with self:
ret = func(*args, **kw)
return ret
return inner
def clean():
import gc
# make sure python do a full collection

View File

@ -0,0 +1,50 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Wenyang Zhou <576825820@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import os, sys
import jittor as jt
import numpy as np
mpi = jt.compile_extern.mpi
from jittor.dataset.mnist import MNIST
dataloader = MNIST(train=False).set_attrs(batch_size=16)
def val1():
for i, (imgs, labels) in enumerate(dataloader):
assert(imgs.shape[0]==8)
if i == 5:
break
@jt.single_process_scope(rank=0)
def val2():
for i, (imgs, labels) in enumerate(dataloader):
assert(imgs.shape[0]==16)
if i == 5:
break
@unittest.skipIf(mpi is None, "no inside mpirun")
class TestSingleProcessScope(unittest.TestCase):
def test_single_process_scope(self):
val1()
val2()
def run_single_process_scope_test(num_procs, name):
if not jt.compile_extern.inside_mpi():
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
cmd = f"{mpirun_path} -np {num_procs} {sys.executable} -m jittor.test.{name} -v"
print("run cmd:", cmd)
assert os.system(cmd)==0, "run cmd failed: "+cmd
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
class TestSingleProcessScopeEntry(unittest.TestCase):
def test_entry(self):
run_single_process_scope_test(2, "test_single_process_scope")
if __name__ == "__main__":
unittest.main()