diff --git a/tests/experiments/test_math_ppo.py b/tests/experiments/test_math_ppo.py index 65beb0a..95f48b9 100644 --- a/tests/experiments/test_math_ppo.py +++ b/tests/experiments/test_math_ppo.py @@ -53,10 +53,10 @@ def math_code_dataset(request, save_path): @pytest.mark.parametrize( "dp,pp,mp", [ - (2, 2, 1), - # (2, 1, 2), - # (1, 2, 1), - # (1, 1, 2), + (1, 1, 1), + (2, 1, 2), + (1, 2, 1), + (1, 1, 2), ], ) def test_ppo_symm( @@ -120,7 +120,6 @@ def test_ppo_symm( run_test_exp(exp_cfg) -@pytest.mark.skip("") # The global resharding strategy, where all MFCs # occupy the same device mesh but with different # parallelization strategies. @@ -242,7 +241,6 @@ def test_ppo_global_reshard( run_test_exp(exp_cfg) -@pytest.mark.skip("") # Actor/critic train and ref_inf/rew_inf are on disjoint # device meshes and executed concurrently. @pytest.mark.parametrize("actor_gen", [(2, 2, 1)]) @@ -359,7 +357,6 @@ def test_ppo_param_realloc_sub_device_mesh( run_test_exp(exp_cfg) -@pytest.mark.skip("") @pytest.mark.parametrize("freq_step", [3, 4, 7]) @pytest.mark.parametrize("freq_epoch", [1, 2, 3]) @pytest.mark.parametrize("bs", [30, 80, 100])