PullRequest: 18 Support TensorBoard logging

Merge branch ranghou-math of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/18

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>
Signed-off-by: 博惟 <bowei.fw@antgroup.com>


* support specifying number of gpus and mems for actors
* PR fix
* support tensorboard
* bug fix
* add doc
This commit is contained in:
穰侯 2025-03-06 15:18:12 +08:00
parent bc6567dd39
commit ad5baa74ee
5 changed files with 30 additions and 32 deletions

View File

@ -2,16 +2,9 @@
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import collections
import copy
import dataclasses
import enum
import getpass
import itertools
import math
import os
import sys
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
import realhf.api.core.dfg as dfg
from realhf.api.core.config import (
@ -277,10 +270,16 @@ class WandBConfig:
config: Optional[Dict] = None
@dataclasses.dataclass
class TensorBoardConfig:
path: Optional[str] = None
@dataclasses.dataclass
class ExperimentConfig:
exp_ctrl: ExperimentSaveEvalControl
wandb: WandBConfig
tensorboard: TensorBoardConfig
# dataflow
model_rpcs: List[dfg.MFCDef]
model_worker: List[ModelWorker] = dataclasses.field(default_factory=list)

View File

@ -2,20 +2,14 @@
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import contextlib
import dataclasses
import functools
import itertools
import os
import pprint
import re
from collections import defaultdict
from typing import *
import numpy as np
import transformers
from omegaconf import MISSING, OmegaConf
from transformers.utils import is_accelerate_available
import realhf.base.logging as logging
from realhf.api.core.config import (
@ -26,7 +20,7 @@ from realhf.api.core.config import (
ModelShardID,
StandaloneModelShardAbstraction,
)
from realhf.api.core.dfg import MFCDef, ModelInterfaceType, build_graph
from realhf.api.core.dfg import MFCDef, ModelInterfaceType
from realhf.api.core.model_api import HF_MODEL_FAMILY_REGISTRY
from realhf.api.core.system_api import (
AutomaticEvaluator,
@ -38,6 +32,7 @@ from realhf.api.core.system_api import (
Scheduling,
TasksGroup,
WandBConfig,
TensorBoardConfig,
)
from realhf.api.quickstart.device_mesh import (
DeviceMesh,
@ -138,6 +133,9 @@ class CommonExperimentConfig(Experiment):
:param wandb: The WandB initialization config.
See https://docs.wandb.ai/ref/python/init/ for details.
:type wandb: WandbConfig
:param tensorboard: The tensorboard initialization config.
Only the field of `path` is needed to specify the directory of saving the tensorboard events.
:type tensorboard: TensorBoardConfig
:param image_name: The name of the Docker image used by the controller.
This parameter is only used in SLURM mode.
:type image_name: str or None
@ -205,6 +203,7 @@ class CommonExperimentConfig(Experiment):
partition: str = "dev"
schedule_strategy: str = "empty_first"
wandb: WandBConfig = dataclasses.field(default_factory=WandBConfig)
tensorboard: TensorBoardConfig = dataclasses.field(default_factory=TensorBoardConfig)
image_name: Optional[str] = None
recover_mode: str = "disabled"
recover_retries: int = 1
@ -727,6 +726,7 @@ class CommonExperimentConfig(Experiment):
return ExperimentConfig(
exp_ctrl=self.exp_ctrl,
wandb=self.wandb,
tensorboard=self.tensorboard,
model_rpcs=[rpc_alloc.rpc for rpc_alloc in rpc_allocs],
model_worker=model_worker,
auto_eval=self.auto_eval,

View File

@ -3,12 +3,10 @@
# Licensed under the Apache License, Version 2.0 (the "License").
import copy
import dataclasses
import math
import os
import pprint
from typing import *
import numpy as np
from omegaconf import DictConfig, OmegaConf
import realhf.base.logging as logging
@ -21,9 +19,9 @@ from realhf.api.core.dfg import MFCDef, ParamReallocHook
from realhf.api.core.model_api import GenerationHyperparameters
from realhf.api.core.system_api import ExperimentConfig
from realhf.api.quickstart.dataset import PromptOnlyDatasetConfig
from realhf.api.quickstart.device_mesh import DeviceMesh, MFCConfig, RPCAllocation
from realhf.api.quickstart.device_mesh import MFCConfig
from realhf.api.quickstart.entrypoint import register_quickstart_exp
from realhf.api.quickstart.model import ModelTrainEvalConfig, ParallelismConfig
from realhf.api.quickstart.model import ModelTrainEvalConfig
from realhf.experiments.common.common import CommonExperimentConfig
from realhf.experiments.common.utils import resolve_replica_ids, resolve_rpc_hooks
@ -549,6 +547,7 @@ class PPOMATHConfig(CommonExperimentConfig):
return ExperimentConfig(
exp_ctrl=self.exp_ctrl,
wandb=self.wandb,
tensorboard=self.tensorboard,
model_rpcs=[rpc_alloc.rpc for rpc_alloc in rpc_allocs],
model_worker=model_worker,
auto_eval=self.auto_eval,

View File

@ -7,23 +7,18 @@ import collections
import copy
import dataclasses
import gc
import getpass
import itertools
import os
import pprint
import random
import re
import time
import uuid
from collections import defaultdict
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Dict, List, Set, Tuple
import colorama
import networkx as nx
import numpy as np
import torch
import torch.distributed
import wandb
from tensorboardX import SummaryWriter
import realhf.api.core.config as config_api
import realhf.api.core.data_api as data_api
@ -37,7 +32,6 @@ from realhf.api.core.config import ModelName
from realhf.api.core.model_api import ReaLModelConfig
from realhf.base import (
constants,
datapack,
logging,
name_resolve,
names,
@ -456,6 +450,7 @@ async def model_rpc_reply_func(
buffer: AsyncIOSequenceBuffer,
model_topos: Dict[str, topology.PipeModelDataParallelTopology],
ctrl: RPCCorountineControl,
summary_writer: SummaryWriter,
):
topo = model_topos[rpc.model_name]
dp_size = topo.get_dim("data")
@ -494,6 +489,9 @@ async def model_rpc_reply_func(
if isinstance(res, Dict):
wandb.log(res, step=ctrl.step_info.global_step)
if summary_writer is not None:
for key, val in res.items():
summary_writer.add_scalar(f"{key}", val, ctrl.step_info.global_step)
logger.info(
f"Model rpc {rpc.name} finished. Run time {time.perf_counter() - tik:.4f}s."
@ -919,6 +917,9 @@ class MasterWorker(worker_base.Worker):
)
self.__data_owner = {}
self.__summary_writer = None
if self.tensorboard_config.path is not None:
self.__summary_writer = SummaryWriter(log_dir=self.tensorboard_config.path)
logger.info(f"Creating asyncio coroutines...")
@ -946,6 +947,7 @@ class MasterWorker(worker_base.Worker):
buffer=self.__seqbuffer,
model_topos=self.__model_topos,
ctrl=self.__rpc_ctrl,
summary_writer=self.__summary_writer,
)
)
coroutine_tasks += [request_task, reply_task]
@ -1312,6 +1314,8 @@ class MasterWorker(worker_base.Worker):
)
wandb.finish()
if self.__summary_writer is not None:
self.__summary_writer.close()
gc.collect()
self.__initialized = False
self.pause()

View File

@ -4,7 +4,6 @@
import dataclasses
import enum
import getpass
import os
import queue
import re
@ -15,13 +14,9 @@ from typing import Any, Dict, List, Optional, Tuple
import realhf.api.core.system_api as system_api
from realhf.base import (
cluster,
logging,
monitor,
name_resolve,
names,
network,
timeutil,
)
from realhf.base.gpu_utils import set_cuda_device
@ -580,6 +575,7 @@ class Worker:
expr_config.lazy_init()
self.wandb_config = expr_config.wandb
os.environ["WANDB_MODE"] = self.wandb_config.mode
self.tensorboard_config = expr_config.tensorboard
config = expr_config.resolve_worker_config(
self.__worker_type, self.__worker_index
)