diff --git a/.github/workflows/installation-validation.yml b/.github/workflows/installation-validation.yml new file mode 100644 index 0000000..9a82dd8 --- /dev/null +++ b/.github/workflows/installation-validation.yml @@ -0,0 +1,128 @@ +name: Installation Validation + +on: + push: + branches: [ none ] + paths: + - 'examples/env/scripts/setup-pip-deps.sh' + - 'docs/tutorial/installation.md' + - 'examples/env/validate_installation.py' + - 'setup.py' + - 'requirements*.txt' + - '.github/workflows/installation-validation.yml' + pull_request: + branches: [ none ] + paths: + - 'examples/env/scripts/setup-pip-deps.sh' + - 'docs/tutorial/installation.md' + - 'examples/env/validate_installation.py' + - 'setup.py' + - 'requirements*.txt' + - '.github/workflows/installation-validation.yml' + workflow_dispatch: + +jobs: + validate-installation: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + lfs: true + + - name: Set up SSH key + run: | + mkdir -p ~/.ssh + echo "${{ secrets.REMOTE_SSH_KEY }}" > ~/.ssh/id_rsa + chmod 600 ~/.ssh/id_rsa + ssh-keyscan -p 8107 101.6.96.205 >> ~/.ssh/known_hosts + + - name: Synchronize repository to remote machine + run: | + # Use rsync to synchronize repository to remote machine + rsync -avz --delete \ + --exclude='.git' \ + --exclude='__pycache__' \ + --exclude='*.pyc' \ + --exclude='*.pyo' \ + --exclude='*.egg-info' \ + --exclude='build/' \ + --exclude='dist/' \ + --exclude='.pytest_cache' \ + --exclude='.coverage' \ + --exclude='*.so' \ + --exclude='*.dylib' \ + --exclude='node_modules/' \ + --exclude='.env' \ + --exclude='.venv' \ + -e 'ssh -p 8107' . fuwei@101.6.96.205:/tmp/areal-validation/ + + - name: Run installation validation on remote machine + run: | + ssh -p 8107 fuwei@101.6.96.205 << 'EOF' + set -e + + # Navigate to the synchronized repository + cd /tmp/areal-validation + + # Create persistent pip cache directory + mkdir -p /tmp/pip-cache + + # Generate a unique container name + CONTAINER_NAME="areal-validation-$(date +%s)" + + # Stop and remove any existing container with the same name + docker stop $CONTAINER_NAME 2>/dev/null || true + docker rm $CONTAINER_NAME 2>/dev/null || true + + echo "=== Starting Docker container ===" + # Launch Docker container with NVIDIA PyTorch image + docker run -d \ + --name $CONTAINER_NAME \ + --gpus all \ + --shm-size=8g \ + -v $(pwd):/workspace \ + -v /tmp/pip-cache:/root/.cache/pip \ + -w /workspace \ + nvcr.io/nvidia/pytorch:25.01-py3 \ + sleep infinity + + echo "=== Verifying CUDA environment in container ===" + docker exec $CONTAINER_NAME nvidia-smi + docker exec $CONTAINER_NAME nvcc --version + + echo "=== Verifying workspace contents ===" + docker exec $CONTAINER_NAME pwd + docker exec $CONTAINER_NAME ls -la /workspace + docker exec $CONTAINER_NAME ls -la /workspace/examples/env/ || echo "examples/env directory not found" + + echo "=== Checking pip cache before installation ===" + du -sh /tmp/pip-cache 2>/dev/null || echo "Cache directory empty" + + echo "=== Installing dependencies ===" + docker exec $CONTAINER_NAME bash -c " + python -m pip install --upgrade pip + pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple + pip config unset global.extra-index-url + # Run the installation script + bash examples/env/scripts/setup-pip-deps.sh + python examples/env/validate_installation.py + " + + echo "=== Checking pip cache after installation ===" + du -sh /tmp/pip-cache 2>/dev/null || echo "Cache directory still empty" + + echo "=== Installation validation completed successfully ===" + + # Cleanup + docker stop $CONTAINER_NAME + docker rm $CONTAINER_NAME + cd ~ + rm -rf /tmp/areal-validation + EOF + + - name: Cleanup SSH key + if: always() + run: | + rm -f ~/.ssh/id_rsa \ No newline at end of file diff --git a/assets/wechat_qrcode.png b/assets/wechat_qrcode.png index 2a8c63b..13f69c7 100644 Binary files a/assets/wechat_qrcode.png and b/assets/wechat_qrcode.png differ diff --git a/docs/tutorial/quickstart.md b/docs/tutorial/quickstart.md index b871cd7..899b377 100644 --- a/docs/tutorial/quickstart.md +++ b/docs/tutorial/quickstart.md @@ -97,12 +97,15 @@ python3 training/main_sync_ppo.py --help ## Monitoring the Training Process -We recommend using Weights & Biases (wandb) for monitoring. Run `wandb login` or set the `WANDB_API_KEY` environment variable. Set `wandb.mode=online` in your configuration to upload training statistics. ++ We recommend using [Weights & Biases (wandb)](https://github.com/wandb/wandb) or [SwanLab](https://github.com/SwanHubX/SwanLab) for monitoring—run `wandb login` or `swanlab login`, or set the corresponding environment variable API key (`WANDB_API_KEY` or `SWANLAB_API_KEY`). Set `wandb.mode="online"` or `swanlab.mode="cloud"` in your configuration to upload training statistics. If you cannot connect to the server, you can also use `wandb.mode="offline"` or `swanlab.mode="local"` to save data locally without uploading. + You can also use TensorBoard by setting the `tensorboard.path` parameter. The main log will be saved to `${fileroot}/logs/${USER}/${experiment_name}/${trial_name}/main.log` and contains the statistics uploaded to wandb. +If SwanLab is enabled, logs will be saved to the directory specified by `swanlab.logdir`. + ### Key Training Statistics - **`Epoch 1/5`**: Indicates the total epochs required and the current epoch being trained. diff --git a/evaluation/requirements.txt b/evaluation/requirements.txt index 3b211d4..7f8c6cb 100644 --- a/evaluation/requirements.txt +++ b/evaluation/requirements.txt @@ -15,3 +15,4 @@ prettytable timeout-decorator timeout_decorator wandb +swanlab[dashboard] \ No newline at end of file diff --git a/examples/env/scripts/setup-pip-deps.sh b/examples/env/scripts/setup-pip-deps.sh index 65ee55b..8fa203b 100644 --- a/examples/env/scripts/setup-pip-deps.sh +++ b/examples/env/scripts/setup-pip-deps.sh @@ -1,11 +1,12 @@ #!/bin/bash # basic dependencies pip install -U pip -pip uninstall deepspeed flash-attn pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y +pip uninstall torch deepspeed flash-attn pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y +pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 pip install "sglang[all]==0.4.6.post4" pip install megatron-core==0.11.0 nvidia-ml-py pip install git+https://github.com/garrett4wade/cugae --no-build-isolation --verbose -pip install flash-attn --no-build-isolation +pip install "flash-attn<=2.7.3" --no-build-isolation # Package used for calculating math reward pip install -e evaluation/latex2sympy diff --git a/examples/env/validate_installation.py b/examples/env/validate_installation.py new file mode 100644 index 0000000..61ef6f7 --- /dev/null +++ b/examples/env/validate_installation.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +""" +Installation Validation Script for AReaL + +This script validates that all critical dependencies are properly installed +and can be imported successfully. It's designed to be run in CI to validate +the installation procedure described in docs/tutorial/installation.md. +""" + +import importlib +import sys +import traceback +import warnings +from importlib.metadata import version as get_version +from typing import Any, Dict, List, Optional + +from packaging.version import Version + + +class InstallationValidator: + def __init__(self): + self.results = {} + self.critical_failures = [] + self.warnings = [] + + def test_import(self, module_name: str, required: bool = True, + test_func: Optional[callable] = None) -> bool: + """Test importing a module and optionally run additional tests.""" + try: + module = importlib.import_module(module_name) + + # Run additional test if provided + if test_func: + test_func(module) + + self.results[module_name] = {"status": "SUCCESS", "error": None} + print(f"✓ {module_name}") + return True + + except ImportError as e: + self.results[module_name] = {"status": "FAILED", "error": str(e)} + if required: + self.critical_failures.append(f"{module_name}: {str(e)}") + print(f"✗ {module_name} (CRITICAL): {str(e)}") + else: + self.warnings.append(f"{module_name}: {str(e)}") + print(f"⚠ {module_name} (OPTIONAL): {str(e)}") + return False + + except Exception as e: + self.results[module_name] = {"status": "ERROR", "error": str(e)} + if required: + self.critical_failures.append(f"{module_name}: {str(e)}") + print(f"✗ {module_name} (CRITICAL ERROR): {str(e)}") + else: + self.warnings.append(f"{module_name}: {str(e)}") + print(f"⚠ {module_name} (OPTIONAL ERROR): {str(e)}") + return False + + def test_torch_cuda(self, torch_module): + """Test PyTorch CUDA availability.""" + if not torch_module.cuda.is_available(): + raise RuntimeError("CUDA is not available in PyTorch") + print(f" - CUDA devices: {torch_module.cuda.device_count()}") + print(f" - CUDA version: {torch_module.version.cuda}") + + def test_flash_attn_functionality(self, flash_attn_module): + """Test flash attention functionality.""" + # Try to import key functions + from flash_attn import flash_attn_func, flash_attn_varlen_func + print(" - Flash attention functions imported successfully") + + def test_vllm_functionality(self, vllm_module): + """Test vLLM basic functionality.""" + from vllm import LLM, SamplingParams + print(" - vLLM core classes imported successfully") + + def test_sglang_functionality(self, sglang_module): + """Test SGLang basic functionality.""" + # Basic import test is sufficient for CI + import sgl_kernel + from sglang import launch_server + assert Version(get_version("sglang")) == Version("0.4.6.post4") + print(" - SGLang imported successfully") + + def test_transformers(self, transformers_module): + assert Version(get_version("transformers")) == Version("4.51.1") + print(" - transformers imported successfully") + + def validate_critical_dependencies(self): + """Validate critical dependencies that must be present.""" + print("\n=== Testing Critical Dependencies ===") + + # Core ML frameworks + self.test_import("torch", required=True, test_func=self.test_torch_cuda) + self.test_import("transformers", required=True, test_func=self.test_transformers) + + # Flash attention - critical for performance + self.test_import("flash_attn", required=True, test_func=self.test_flash_attn_functionality) + self.test_import("cugae", required=True) + # Inference engines + self.test_import("sglang", required=True, test_func=self.test_sglang_functionality) + + # Distributed computing + self.test_import("ray", required=True) + + # Scientific computing + self.test_import("numpy", required=True) + self.test_import("scipy", required=True) + + # Configuration management + self.test_import("hydra", required=True) + self.test_import("omegaconf", required=True) + + # Data processing + self.test_import("datasets", required=True) + self.test_import("pandas", required=True) + self.test_import("einops", required=True) + + # Monitoring and logging + self.test_import("wandb", required=True) + self.test_import("pynvml", required=True) + + # Networking + self.test_import("aiohttp", required=True) + self.test_import("fastapi", required=True) + self.test_import("uvicorn", required=True) + + # Math libraries (for evaluation) + self.test_import("sympy", required=True) + self.test_import("latex2sympy2", required=True) + + def validate_optional_dependencies(self): + """Validate optional dependencies.""" + print("\n=== Testing Optional Dependencies ===") + + # CUDA extensions (may not be available in all environments) + self.test_import("vllm", required=False, test_func=self.test_vllm_functionality) + self.test_import("grouped_gemm", required=False) + self.test_import("flashattn_hopper", required=False) + + # Optional utilities + self.test_import("tensorboardx", required=False) + self.test_import("swanlab", required=False) + self.test_import("matplotlib", required=False) + self.test_import("seaborn", required=False) + self.test_import("numba", required=False) + self.test_import("nltk", required=False) + + def validate_cuda_extensions(self): + """Validate CUDA-specific functionality.""" + print("\n=== Testing CUDA Extensions ===") + + try: + import torch + if torch.cuda.is_available(): + # Test basic CUDA tensor operations + device = torch.device("cuda:0") + x = torch.randn(10, device=device) + y = torch.randn(10, device=device) + z = x + y + print("✓ Basic CUDA operations working") + + # Test flash attention if available + try: + from flash_attn import flash_attn_func + + # Create small test tensors + batch_size, seq_len, num_heads, head_dim = 1, 32, 4, 64 + q = torch.randn(batch_size, seq_len, num_heads, head_dim, + device=device, dtype=torch.float16) + k = torch.randn(batch_size, seq_len, num_heads, head_dim, + device=device, dtype=torch.float16) + v = torch.randn(batch_size, seq_len, num_heads, head_dim, + device=device, dtype=torch.float16) + + # Test flash attention call + out = flash_attn_func(q, k, v) + print("✓ Flash attention CUDA operations working") + + except Exception as e: + print(f"⚠ Flash attention CUDA test failed: {e}") + + else: + print("⚠ CUDA not available - skipping CUDA extension tests") + + except Exception as e: + print(f"✗ CUDA extension validation failed: {e}") + + def run_validation(self): + """Run complete validation suite.""" + print("AReaL Installation Validation") + print("=" * 50) + + self.validate_critical_dependencies() + self.validate_optional_dependencies() + self.validate_cuda_extensions() + + # Print summary + print("\n" + "=" * 50) + print("VALIDATION SUMMARY") + print("=" * 50) + + total_tests = len(self.results) + successful_tests = sum(1 for r in self.results.values() if r["status"] == "SUCCESS") + failed_tests = total_tests - successful_tests + + print(f"Total tests: {total_tests}") + print(f"Successful: {successful_tests}") + print(f"Failed: {failed_tests}") + + if self.critical_failures: + print(f"\n🚨 CRITICAL FAILURES ({len(self.critical_failures)}):") + for failure in self.critical_failures: + print(f" - {failure}") + + if self.warnings: + print(f"\n⚠️ WARNINGS ({len(self.warnings)}):") + for warning in self.warnings: + print(f" - {warning}") + + # Determine overall result + if self.critical_failures: + print(f"\n❌ INSTALLATION VALIDATION FAILED") + print("Please check the critical failures above and ensure all required") + print("dependencies are properly installed according to the installation guide.") + return False + else: + print(f"\n✅ INSTALLATION VALIDATION PASSED") + if self.warnings: + print("Note: Some optional dependencies failed but this won't affect") + print("core functionality.") + return True + +def main(): + """Main entry point.""" + validator = InstallationValidator() + success = validator.run_validation() + sys.exit(0 if success else 1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 9e4f28c..de1cb0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ dependencies = [ "colorlog", "psutil", "pynvml", + "swanlab[dashboard]", # Performance and compression "ninja", diff --git a/realhf/api/cli_args.py b/realhf/api/cli_args.py index 92f0a1b..d599d1f 100644 --- a/realhf/api/cli_args.py +++ b/realhf/api/cli_args.py @@ -855,6 +855,16 @@ class WandBConfig: config: Optional[Dict] = None +@dataclass +class SwanlabConfig: + project: Optional[str] = None + name: Optional[str] = None + config: Optional[Dict] = None + logdir: Optional[str] = None + mode: Optional[str] = "local" + api_key: Optional[str] = os.getenv("SWANLAB_API_KEY", None) + + @dataclass class TensorBoardConfig: path: Optional[str] = None @@ -986,6 +996,10 @@ class BaseExperimentConfig: default_factory=WandBConfig, metadata={"help": "Weights & Biases configuration."}, ) + swanlab: SwanlabConfig = field( + default_factory=SwanlabConfig, + metadata={"help": "SwanLab configuration."}, + ) tensorboard: TensorBoardConfig = field( default_factory=TensorBoardConfig, metadata={"help": "TensorBoard configuration. Only 'path' field required."}, @@ -1061,7 +1075,7 @@ class BaseExperimentConfig: default=False, metadata={ "help": "Enable automatic evaluation during training. " - "Results logged to disk and WandB (if active)." + "Results logged to disk and WandB or Swanlab(if active)." }, ) auto_eval_config: AutomaticEvaluator = field( diff --git a/realhf/api/core/system_api.py b/realhf/api/core/system_api.py index d409740..ea30213 100644 --- a/realhf/api/core/system_api.py +++ b/realhf/api/core/system_api.py @@ -11,6 +11,7 @@ import realhf.api.core.dfg as dfg from realhf.api.cli_args import ( AutomaticEvaluator, ExperimentSaveEvalControl, + SwanlabConfig, TensorBoardConfig, WandBConfig, ) @@ -189,6 +190,7 @@ class ExperimentScheduling: class ExperimentConfig: exp_ctrl: ExperimentSaveEvalControl wandb: WandBConfig + swanlab: SwanlabConfig tensorboard: TensorBoardConfig # dataflow model_rpcs: List[dfg.MFCDef] diff --git a/realhf/apps/main.py b/realhf/apps/main.py index e87f8d6..e1ca087 100644 --- a/realhf/apps/main.py +++ b/realhf/apps/main.py @@ -94,7 +94,7 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0): raise RuntimeError("Experiment initial setup failed.") from e evaluator = ( - AutomaticEvaluator(exp_cfg, exp_cfg.evaluator, exp_cfg.wandb) + AutomaticEvaluator(exp_cfg, exp_cfg.evaluator, exp_cfg.wandb, exp_cfg.swanlab) if exp_cfg.auto_eval else None ) diff --git a/realhf/base/logging.py b/realhf/base/logging.py index 7b46ea9..4f21e47 100644 --- a/realhf/base/logging.py +++ b/realhf/base/logging.py @@ -141,19 +141,29 @@ def getLogger( return logging.getLogger(name) -_LATEST_WANDB_STEP = 0 +_LATEST_LOG_STEP = 0 -def log_wandb_tensorboard(data, step=None, summary_writer=None): +def log_swanlab_wandb_tensorboard(data, step=None, summary_writer=None): + # Logs data to SwanLab、 wandb、 TensorBoard. + + global _LATEST_LOG_STEP + if step is None: + step = _LATEST_LOG_STEP + else: + _LATEST_LOG_STEP = max(_LATEST_LOG_STEP, step) + + # swanlab + import swanlab + + swanlab.log(data, step=step) + + # wandb import wandb - global _LATEST_WANDB_STEP - if step is None: - step = _LATEST_WANDB_STEP - else: - _LATEST_WANDB_STEP = max(_LATEST_WANDB_STEP, step) - wandb.log(data, step=step) + + # tensorboard if summary_writer is not None: for key, val in data.items(): summary_writer.add_scalar(f"{key}", val, step) diff --git a/realhf/experiments/async_exp/async_rl_exp.py b/realhf/experiments/async_exp/async_rl_exp.py index b7cc61a..3c23f13 100755 --- a/realhf/experiments/async_exp/async_rl_exp.py +++ b/realhf/experiments/async_exp/async_rl_exp.py @@ -341,6 +341,7 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions): return ExperimentConfig( exp_ctrl=self.exp_ctrl, wandb=self.wandb, + swanlab=self.swanlab, tensorboard=self.tensorboard, # NOTE: master and model worker only see RPCs without generation model_rpcs=[ diff --git a/realhf/experiments/common/common.py b/realhf/experiments/common/common.py index dfbc860..370f16f 100644 --- a/realhf/experiments/common/common.py +++ b/realhf/experiments/common/common.py @@ -567,6 +567,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment): return ExperimentConfig( exp_ctrl=self.exp_ctrl, wandb=self.wandb, + swanlab=self.swanlab, tensorboard=self.tensorboard, model_rpcs=[rpc_alloc.rpc for rpc_alloc in rpc_allocs], model_worker=model_worker, diff --git a/realhf/experiments/common/ppo_math_exp.py b/realhf/experiments/common/ppo_math_exp.py index 01d3070..48fff91 100644 --- a/realhf/experiments/common/ppo_math_exp.py +++ b/realhf/experiments/common/ppo_math_exp.py @@ -373,6 +373,7 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions): return ExperimentConfig( exp_ctrl=self.exp_ctrl, wandb=self.wandb, + swanlab=self.swanlab, tensorboard=self.tensorboard, model_rpcs=[rpc_alloc.rpc for rpc_alloc in rpc_allocs], model_worker=model_worker, diff --git a/realhf/scheduler/evaluator.py b/realhf/scheduler/evaluator.py index 439a41c..597a75b 100644 --- a/realhf/scheduler/evaluator.py +++ b/realhf/scheduler/evaluator.py @@ -8,6 +8,7 @@ import subprocess import time from typing import Any, Dict, Optional +import swanlab import wandb import realhf.api.core.system_api as config_pkg @@ -126,13 +127,15 @@ class EvaluationStep: self.status = EvaluationStepStatus.FAILED return False - wandb_data = {} + log_data = {} for data_name, d in data.items(): for k, v in d.items(): - wandb_data[f"{data_name}_{k}"] = v - wandb.log(wandb_data, step=self.global_step) + log_data[f"{data_name}_{k}"] = v + wandb.log(log_data, step=self.global_step) + swanlab.log(log_data, step=self.global_step) self.status = EvaluationStepStatus.LOGGED - logger.info(f"Logging eval result {wandb_data} to step {self.global_step}") + logger.info(f"Logging eval result {log_data} to step {self.global_step}") + return True def check(self): @@ -156,14 +159,16 @@ class AutomaticEvaluator: args: BaseExperimentConfig, config: config_pkg.AutomaticEvaluator, wandb_config: config_pkg.WandBConfig, + swanlab_config: config_pkg.SwanlabConfig, ): self.args = args self.__eval_steps: Dict[int, EvaluationStep] = {} self.__max_concurrent_jobs = config.max_concurrent_jobs self.__wandb_config = wandb_config + self.__swanlab_config = swanlab_config self.__config = config self.__wandb_initialized = False - + self.__swanlab_initialized = False # Check evaluated checkpoints by logs in recover # NOTE: All previous evaluation steps with output will be marked # as logged, even if it is not really logged in wandb. @@ -228,6 +233,40 @@ class AutomaticEvaluator: settings=wandb.Settings(start_method="fork"), ) + def __lazy_swanlab_init(self): + if self.__swanlab_config.api_key: + swanlab.login(self.__swanlab_config.api_key) + if self.swanlab_config.config is None: + import yaml + + with open( + os.path.join( + constants.LOG_ROOT, + constants.experiment_name(), + constants.trial_name(), + "config.yaml", + ), + "r", + ) as f: + __config = yaml.safe_load(f) + else: + __config = self.swanlab_config.config + __config["FRAMEWORK"] = "AReaL" + swanlab.init( + project=self.__swanlab_config.project or constants.experiment_name(), + experiment_name=self.__swanlab_config.name + or f"{constants.trial_name()}_eval", + config=__config, + logdir=self.__swanlab_config.logdir + or os.path.join( + constants.LOG_ROOT, + constants.experiment_name(), + constants.trial_name(), + "swanlab", + ), + mode=self.__swanlab_config.mode, + ) + def step(self): # Check whether a new evaluation step should be created ckpt_parent = os.path.join( @@ -290,6 +329,9 @@ class AutomaticEvaluator: if not self.__wandb_initialized: self.__lazy_wandb_init() self.__wandb_initialized = True + if not self.__swanlab_initialized: + self.__lazy_swanlab_init() + self.__swanlab_initialized = True self.__eval_steps[log_step].log(self.__config) @property diff --git a/realhf/system/master_worker.py b/realhf/system/master_worker.py index 80ab519..cabed4c 100644 --- a/realhf/system/master_worker.py +++ b/realhf/system/master_worker.py @@ -12,6 +12,7 @@ from typing import Dict import colorama import networkx as nx import numpy as np +import swanlab import wandb from tensorboardX import SummaryWriter @@ -303,6 +304,40 @@ class MasterWorker(worker_base.AsyncWorker): resume="allow", settings=wandb.Settings(start_method="fork"), ) + + # swanlab init, connect to remote or local swanlab host + if self.swanlab_config.mode != "disabled" and self.swanlab_config.api_key: + swanlab.login(self.swanlab_config.api_key) + if self.swanlab_config.config is None: + import yaml + + with open( + os.path.join( + constants.LOG_ROOT, + constants.experiment_name(), + constants.trial_name(), + "config.yaml", + ), + "r", + ) as f: + __config = yaml.safe_load(f) + else: + __config = self.swanlab_config.config + __config["FRAMEWORK"] = "AReaL" + swanlab.init( + project=self.swanlab_config.project or constants.experiment_name(), + experiment_name=self.swanlab_config.name + or f"{constants.trial_name()}_train", + config=__config, + logdir=self.swanlab_config.logdir + or os.path.join( + constants.LOG_ROOT, + constants.experiment_name(), + constants.trial_name(), + "swanlab", + ), + mode=self.swanlab_config.mode, + ) # tensorboard logging self.__summary_writer = None if self.tensorboard_config.path is not None: @@ -479,7 +514,7 @@ class MasterWorker(worker_base.AsyncWorker): s += f"(global step {global_step}) finishes. " s += f"#End to end# execution time: *{e2e_time:.3f}*s. " s += f"Total time consumption: {time_since_configure:.3f}s. " - logging.log_wandb_tensorboard({"timeperf/e2e": e2e_time}) + logging.log_swanlab_wandb_tensorboard({"timeperf/e2e": e2e_time}) if len(self.e2e_time_history) > 2: remaining_steps = self._steps_per_epoch - epoch_step remaining_epochs = self.__total_train_epochs - epoch @@ -532,6 +567,7 @@ class MasterWorker(worker_base.AsyncWorker): ) wandb.finish() + swanlab.finish() if self.__summary_writer is not None: self.__summary_writer.close() gc.collect() diff --git a/realhf/system/model_function_call.py b/realhf/system/model_function_call.py index 86e21f7..87fb6e4 100644 --- a/realhf/system/model_function_call.py +++ b/realhf/system/model_function_call.py @@ -10,6 +10,7 @@ import uuid from collections import defaultdict from typing import Dict, Hashable, List, Set, Tuple +import swanlab import wandb from tensorboardX import SummaryWriter @@ -439,7 +440,7 @@ class ModelFunctionCall: logger.info( f"RPC name {rpc.name} returns\n{data_api.tabulate_stats(res)}" ) - logging.log_wandb_tensorboard( + logging.log_swanlab_wandb_tensorboard( res, step=ctrl.step_info.global_step, summary_writer=self.summary_writer, @@ -450,7 +451,7 @@ class ModelFunctionCall: f"RPC name {rpc.name} returns ({j + 1}/{len(res)})\n{data_api.tabulate_stats(r)}" ) offset = len(res) * ctrl.step_info.global_step - logging.log_wandb_tensorboard( + logging.log_swanlab_wandb_tensorboard( r, step=offset + j, summary_writer=self.summary_writer, @@ -462,11 +463,10 @@ class ModelFunctionCall: for time_record in time_records: stats_tracker.scalar(**time_record) time_stats = stats_tracker.export() - logging.log_wandb_tensorboard( + logging.log_swanlab_wandb_tensorboard( time_stats, summary_writer=self.summary_writer, ) - logger.info( f"Model rpc {rpc.name} finished. " f"Request-reply time {time.perf_counter() - tik:.4f}s. " diff --git a/realhf/system/worker_base.py b/realhf/system/worker_base.py index a7a96f1..de41aa3 100644 --- a/realhf/system/worker_base.py +++ b/realhf/system/worker_base.py @@ -581,7 +581,9 @@ class Worker: ) expr_config.lazy_init() self.wandb_config = expr_config.wandb + self.swanlab_config = expr_config.swanlab os.environ["WANDB_MODE"] = self.wandb_config.mode + os.environ["SWANLAB_MODE"] = self.swanlab_config.mode self.tensorboard_config = expr_config.tensorboard config = expr_config.resolve_worker_config( self.__worker_type, self.__worker_index diff --git a/requirements.txt b/requirements.txt index 70f6a73..bf7552b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -70,4 +70,5 @@ Pebble timeout-decorator prettytable gymnasium>=1.1.1 -torchdata \ No newline at end of file +torchdata +swanlab[dashboard] diff --git a/training/configs/async-ppo.yaml b/training/configs/async-ppo.yaml index bb4cd6d..7c6e609 100644 --- a/training/configs/async-ppo.yaml +++ b/training/configs/async-ppo.yaml @@ -14,6 +14,13 @@ wandb: notes: null tags: null config: null +swanlab: + mode: disabled + api_key: null + project: null + name: null + config: null + logdir: null tensorboard: path: null recover_mode: auto diff --git a/training/configs/sft.yaml b/training/configs/sft.yaml index 822369b..109ce97 100644 --- a/training/configs/sft.yaml +++ b/training/configs/sft.yaml @@ -14,6 +14,13 @@ wandb: notes: null tags: null config: null +swanlab: + mode: disabled + api_key: null + project: null + name: null + config: null + logdir: null tensorboard: path: null recover_mode: auto diff --git a/training/configs/sync-ppo.yaml b/training/configs/sync-ppo.yaml index 88ae35f..cef7523 100644 --- a/training/configs/sync-ppo.yaml +++ b/training/configs/sync-ppo.yaml @@ -14,6 +14,13 @@ wandb: notes: null tags: null config: null +swanlab: + mode: disabled + api_key: null + project: null + name: null + config: null + logdir: null tensorboard: path: null recover_mode: auto diff --git a/training/utils.py b/training/utils.py index 4d5a8e1..e5d409b 100644 --- a/training/utils.py +++ b/training/utils.py @@ -93,6 +93,7 @@ class RayWorker: worker_info.experiment_name, worker_info.trial_name ) self.worker.wandb_config = expr_config.wandb + self.worker.swanlab_config = expr_config.swanlab self.worker.tensorboard_config = expr_config.tensorboard self.worker.args = self.args self.logger = logging.getLogger(f"{self.worker_type} {idx}", "benchmark") @@ -130,6 +131,7 @@ def _run_experiment(exp_cfg, expr_name, trial_name): env_vars = constants.get_env_vars( exp_cfg, WADNB_MODE=exp_cfg.wandb.mode, + SWANLAB_MODE=exp_cfg.swanlab.mode, REAL_MODE="ray", REAL_RECOVER_RUN="0", REAL_SAVE_RECOVER_STATES="1",