mirror of https://github.com/inclusionAI/AReaL
PullRequest: 18 Support TensorBoard logging
Merge branch ranghou-math of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/18 Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com> Signed-off-by: 博惟 <bowei.fw@antgroup.com> * support specifying number of gpus and mems for actors * PR fix * support tensorboard * bug fix * add doc
This commit is contained in:
parent
bc6567dd39
commit
ad5baa74ee
|
@ -2,16 +2,9 @@
|
|||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import dataclasses
|
||||
import enum
|
||||
import getpass
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import realhf.api.core.dfg as dfg
|
||||
from realhf.api.core.config import (
|
||||
|
@ -277,10 +270,16 @@ class WandBConfig:
|
|||
config: Optional[Dict] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TensorBoardConfig:
|
||||
path: Optional[str] = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExperimentConfig:
|
||||
exp_ctrl: ExperimentSaveEvalControl
|
||||
wandb: WandBConfig
|
||||
tensorboard: TensorBoardConfig
|
||||
# dataflow
|
||||
model_rpcs: List[dfg.MFCDef]
|
||||
model_worker: List[ModelWorker] = dataclasses.field(default_factory=list)
|
||||
|
|
|
@ -2,20 +2,14 @@
|
|||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
import os
|
||||
import pprint
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import *
|
||||
|
||||
import numpy as np
|
||||
import transformers
|
||||
from omegaconf import MISSING, OmegaConf
|
||||
from transformers.utils import is_accelerate_available
|
||||
|
||||
import realhf.base.logging as logging
|
||||
from realhf.api.core.config import (
|
||||
|
@ -26,7 +20,7 @@ from realhf.api.core.config import (
|
|||
ModelShardID,
|
||||
StandaloneModelShardAbstraction,
|
||||
)
|
||||
from realhf.api.core.dfg import MFCDef, ModelInterfaceType, build_graph
|
||||
from realhf.api.core.dfg import MFCDef, ModelInterfaceType
|
||||
from realhf.api.core.model_api import HF_MODEL_FAMILY_REGISTRY
|
||||
from realhf.api.core.system_api import (
|
||||
AutomaticEvaluator,
|
||||
|
@ -38,6 +32,7 @@ from realhf.api.core.system_api import (
|
|||
Scheduling,
|
||||
TasksGroup,
|
||||
WandBConfig,
|
||||
TensorBoardConfig,
|
||||
)
|
||||
from realhf.api.quickstart.device_mesh import (
|
||||
DeviceMesh,
|
||||
|
@ -138,6 +133,9 @@ class CommonExperimentConfig(Experiment):
|
|||
:param wandb: The WandB initialization config.
|
||||
See https://docs.wandb.ai/ref/python/init/ for details.
|
||||
:type wandb: WandbConfig
|
||||
:param tensorboard: The tensorboard initialization config.
|
||||
Only the field of `path` is needed to specify the directory of saving the tensorboard events.
|
||||
:type tensorboard: TensorBoardConfig
|
||||
:param image_name: The name of the Docker image used by the controller.
|
||||
This parameter is only used in SLURM mode.
|
||||
:type image_name: str or None
|
||||
|
@ -205,6 +203,7 @@ class CommonExperimentConfig(Experiment):
|
|||
partition: str = "dev"
|
||||
schedule_strategy: str = "empty_first"
|
||||
wandb: WandBConfig = dataclasses.field(default_factory=WandBConfig)
|
||||
tensorboard: TensorBoardConfig = dataclasses.field(default_factory=TensorBoardConfig)
|
||||
image_name: Optional[str] = None
|
||||
recover_mode: str = "disabled"
|
||||
recover_retries: int = 1
|
||||
|
@ -727,6 +726,7 @@ class CommonExperimentConfig(Experiment):
|
|||
return ExperimentConfig(
|
||||
exp_ctrl=self.exp_ctrl,
|
||||
wandb=self.wandb,
|
||||
tensorboard=self.tensorboard,
|
||||
model_rpcs=[rpc_alloc.rpc for rpc_alloc in rpc_allocs],
|
||||
model_worker=model_worker,
|
||||
auto_eval=self.auto_eval,
|
||||
|
|
|
@ -3,12 +3,10 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
import copy
|
||||
import dataclasses
|
||||
import math
|
||||
import os
|
||||
import pprint
|
||||
from typing import *
|
||||
|
||||
import numpy as np
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
import realhf.base.logging as logging
|
||||
|
@ -21,9 +19,9 @@ from realhf.api.core.dfg import MFCDef, ParamReallocHook
|
|||
from realhf.api.core.model_api import GenerationHyperparameters
|
||||
from realhf.api.core.system_api import ExperimentConfig
|
||||
from realhf.api.quickstart.dataset import PromptOnlyDatasetConfig
|
||||
from realhf.api.quickstart.device_mesh import DeviceMesh, MFCConfig, RPCAllocation
|
||||
from realhf.api.quickstart.device_mesh import MFCConfig
|
||||
from realhf.api.quickstart.entrypoint import register_quickstart_exp
|
||||
from realhf.api.quickstart.model import ModelTrainEvalConfig, ParallelismConfig
|
||||
from realhf.api.quickstart.model import ModelTrainEvalConfig
|
||||
from realhf.experiments.common.common import CommonExperimentConfig
|
||||
from realhf.experiments.common.utils import resolve_replica_ids, resolve_rpc_hooks
|
||||
|
||||
|
@ -549,6 +547,7 @@ class PPOMATHConfig(CommonExperimentConfig):
|
|||
return ExperimentConfig(
|
||||
exp_ctrl=self.exp_ctrl,
|
||||
wandb=self.wandb,
|
||||
tensorboard=self.tensorboard,
|
||||
model_rpcs=[rpc_alloc.rpc for rpc_alloc in rpc_allocs],
|
||||
model_worker=model_worker,
|
||||
auto_eval=self.auto_eval,
|
||||
|
|
|
@ -7,23 +7,18 @@ import collections
|
|||
import copy
|
||||
import dataclasses
|
||||
import gc
|
||||
import getpass
|
||||
import itertools
|
||||
import os
|
||||
import pprint
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
import colorama
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed
|
||||
import wandb
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
import realhf.api.core.config as config_api
|
||||
import realhf.api.core.data_api as data_api
|
||||
|
@ -37,7 +32,6 @@ from realhf.api.core.config import ModelName
|
|||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.base import (
|
||||
constants,
|
||||
datapack,
|
||||
logging,
|
||||
name_resolve,
|
||||
names,
|
||||
|
@ -456,6 +450,7 @@ async def model_rpc_reply_func(
|
|||
buffer: AsyncIOSequenceBuffer,
|
||||
model_topos: Dict[str, topology.PipeModelDataParallelTopology],
|
||||
ctrl: RPCCorountineControl,
|
||||
summary_writer: SummaryWriter,
|
||||
):
|
||||
topo = model_topos[rpc.model_name]
|
||||
dp_size = topo.get_dim("data")
|
||||
|
@ -494,6 +489,9 @@ async def model_rpc_reply_func(
|
|||
|
||||
if isinstance(res, Dict):
|
||||
wandb.log(res, step=ctrl.step_info.global_step)
|
||||
if summary_writer is not None:
|
||||
for key, val in res.items():
|
||||
summary_writer.add_scalar(f"{key}", val, ctrl.step_info.global_step)
|
||||
|
||||
logger.info(
|
||||
f"Model rpc {rpc.name} finished. Run time {time.perf_counter() - tik:.4f}s."
|
||||
|
@ -919,6 +917,9 @@ class MasterWorker(worker_base.Worker):
|
|||
)
|
||||
|
||||
self.__data_owner = {}
|
||||
self.__summary_writer = None
|
||||
if self.tensorboard_config.path is not None:
|
||||
self.__summary_writer = SummaryWriter(log_dir=self.tensorboard_config.path)
|
||||
|
||||
logger.info(f"Creating asyncio coroutines...")
|
||||
|
||||
|
@ -946,6 +947,7 @@ class MasterWorker(worker_base.Worker):
|
|||
buffer=self.__seqbuffer,
|
||||
model_topos=self.__model_topos,
|
||||
ctrl=self.__rpc_ctrl,
|
||||
summary_writer=self.__summary_writer,
|
||||
)
|
||||
)
|
||||
coroutine_tasks += [request_task, reply_task]
|
||||
|
@ -1312,6 +1314,8 @@ class MasterWorker(worker_base.Worker):
|
|||
)
|
||||
|
||||
wandb.finish()
|
||||
if self.__summary_writer is not None:
|
||||
self.__summary_writer.close()
|
||||
gc.collect()
|
||||
self.__initialized = False
|
||||
self.pause()
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
|
||||
import dataclasses
|
||||
import enum
|
||||
import getpass
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
|
@ -15,13 +14,9 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
|
||||
import realhf.api.core.system_api as system_api
|
||||
from realhf.base import (
|
||||
cluster,
|
||||
logging,
|
||||
monitor,
|
||||
name_resolve,
|
||||
names,
|
||||
network,
|
||||
timeutil,
|
||||
)
|
||||
from realhf.base.gpu_utils import set_cuda_device
|
||||
|
||||
|
@ -580,6 +575,7 @@ class Worker:
|
|||
expr_config.lazy_init()
|
||||
self.wandb_config = expr_config.wandb
|
||||
os.environ["WANDB_MODE"] = self.wandb_config.mode
|
||||
self.tensorboard_config = expr_config.tensorboard
|
||||
config = expr_config.resolve_worker_config(
|
||||
self.__worker_type, self.__worker_index
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue