mirror of https://github.com/Jittor/Jittor
add single_process_scope for val
This commit is contained in:
parent
8aa5974fef
commit
bc1dca9d8e
|
@ -125,6 +125,38 @@ class profile_scope(_call_no_record_scope):
|
|||
profiler.stop()
|
||||
self.report.extend(profiler.report())
|
||||
|
||||
import jittor as jt
|
||||
from jittor.dataset import dataset
|
||||
class single_process_scope(_call_no_record_scope):
|
||||
""" single_process_scope
|
||||
|
||||
Code in this scope will only be executed by single
|
||||
process.
|
||||
|
||||
example::
|
||||
|
||||
with jt.single_process_scope(root=0):
|
||||
......
|
||||
"""
|
||||
def __init__(self, rank=0):
|
||||
self.rank = rank
|
||||
|
||||
def __enter__(self):
|
||||
self.mpi_backup = jt.mpi
|
||||
jt.mpi = dataset.mpi = None
|
||||
|
||||
def __exit__(self, *exc):
|
||||
jt.mpi = dataset.mpi = self.mpi_backup
|
||||
|
||||
def __call__(self, func):
|
||||
def inner(*args, **kw):
|
||||
if jt.mpi and jt.mpi.world_rank() != self.rank:
|
||||
return
|
||||
with self:
|
||||
ret = func(*args, **kw)
|
||||
return ret
|
||||
return inner
|
||||
|
||||
def clean():
|
||||
import gc
|
||||
# make sure python do a full collection
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Wenyang Zhou <576825820@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import os, sys
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
mpi = jt.compile_extern.mpi
|
||||
|
||||
from jittor.dataset.mnist import MNIST
|
||||
dataloader = MNIST(train=False).set_attrs(batch_size=16)
|
||||
|
||||
def val1():
|
||||
for i, (imgs, labels) in enumerate(dataloader):
|
||||
assert(imgs.shape[0]==8)
|
||||
if i == 5:
|
||||
break
|
||||
|
||||
@jt.single_process_scope(rank=0)
|
||||
def val2():
|
||||
for i, (imgs, labels) in enumerate(dataloader):
|
||||
assert(imgs.shape[0]==16)
|
||||
if i == 5:
|
||||
break
|
||||
|
||||
@unittest.skipIf(mpi is None, "no inside mpirun")
|
||||
class TestSingleProcessScope(unittest.TestCase):
|
||||
def test_single_process_scope(self):
|
||||
val1()
|
||||
val2()
|
||||
|
||||
def run_single_process_scope_test(num_procs, name):
|
||||
if not jt.compile_extern.inside_mpi():
|
||||
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
|
||||
cmd = f"{mpirun_path} -np {num_procs} {sys.executable} -m jittor.test.{name} -v"
|
||||
print("run cmd:", cmd)
|
||||
assert os.system(cmd)==0, "run cmd failed: "+cmd
|
||||
|
||||
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||
class TestSingleProcessScopeEntry(unittest.TestCase):
|
||||
def test_entry(self):
|
||||
run_single_process_scope_test(2, "test_single_process_scope")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue