Merge branch 'main' of https://github.com/inclusionAI/AReaL into fw/refactor

This commit is contained in:
garrett4wade 2025-06-23 20:14:52 +08:00
commit 7f1397e37b
23 changed files with 533 additions and 24 deletions

View File

@ -0,0 +1,128 @@
name: Installation Validation
on:
push:
branches: [ none ]
paths:
- 'examples/env/scripts/setup-pip-deps.sh'
- 'docs/tutorial/installation.md'
- 'examples/env/validate_installation.py'
- 'setup.py'
- 'requirements*.txt'
- '.github/workflows/installation-validation.yml'
pull_request:
branches: [ none ]
paths:
- 'examples/env/scripts/setup-pip-deps.sh'
- 'docs/tutorial/installation.md'
- 'examples/env/validate_installation.py'
- 'setup.py'
- 'requirements*.txt'
- '.github/workflows/installation-validation.yml'
workflow_dispatch:
jobs:
validate-installation:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
lfs: true
- name: Set up SSH key
run: |
mkdir -p ~/.ssh
echo "${{ secrets.REMOTE_SSH_KEY }}" > ~/.ssh/id_rsa
chmod 600 ~/.ssh/id_rsa
ssh-keyscan -p 8107 101.6.96.205 >> ~/.ssh/known_hosts
- name: Synchronize repository to remote machine
run: |
# Use rsync to synchronize repository to remote machine
rsync -avz --delete \
--exclude='.git' \
--exclude='__pycache__' \
--exclude='*.pyc' \
--exclude='*.pyo' \
--exclude='*.egg-info' \
--exclude='build/' \
--exclude='dist/' \
--exclude='.pytest_cache' \
--exclude='.coverage' \
--exclude='*.so' \
--exclude='*.dylib' \
--exclude='node_modules/' \
--exclude='.env' \
--exclude='.venv' \
-e 'ssh -p 8107' . fuwei@101.6.96.205:/tmp/areal-validation/
- name: Run installation validation on remote machine
run: |
ssh -p 8107 fuwei@101.6.96.205 << 'EOF'
set -e
# Navigate to the synchronized repository
cd /tmp/areal-validation
# Create persistent pip cache directory
mkdir -p /tmp/pip-cache
# Generate a unique container name
CONTAINER_NAME="areal-validation-$(date +%s)"
# Stop and remove any existing container with the same name
docker stop $CONTAINER_NAME 2>/dev/null || true
docker rm $CONTAINER_NAME 2>/dev/null || true
echo "=== Starting Docker container ==="
# Launch Docker container with NVIDIA PyTorch image
docker run -d \
--name $CONTAINER_NAME \
--gpus all \
--shm-size=8g \
-v $(pwd):/workspace \
-v /tmp/pip-cache:/root/.cache/pip \
-w /workspace \
nvcr.io/nvidia/pytorch:25.01-py3 \
sleep infinity
echo "=== Verifying CUDA environment in container ==="
docker exec $CONTAINER_NAME nvidia-smi
docker exec $CONTAINER_NAME nvcc --version
echo "=== Verifying workspace contents ==="
docker exec $CONTAINER_NAME pwd
docker exec $CONTAINER_NAME ls -la /workspace
docker exec $CONTAINER_NAME ls -la /workspace/examples/env/ || echo "examples/env directory not found"
echo "=== Checking pip cache before installation ==="
du -sh /tmp/pip-cache 2>/dev/null || echo "Cache directory empty"
echo "=== Installing dependencies ==="
docker exec $CONTAINER_NAME bash -c "
python -m pip install --upgrade pip
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
pip config unset global.extra-index-url
# Run the installation script
bash examples/env/scripts/setup-pip-deps.sh
python examples/env/validate_installation.py
"
echo "=== Checking pip cache after installation ==="
du -sh /tmp/pip-cache 2>/dev/null || echo "Cache directory still empty"
echo "=== Installation validation completed successfully ==="
# Cleanup
docker stop $CONTAINER_NAME
docker rm $CONTAINER_NAME
cd ~
rm -rf /tmp/areal-validation
EOF
- name: Cleanup SSH key
if: always()
run: |
rm -f ~/.ssh/id_rsa

Binary file not shown.

Before

Width:  |  Height:  |  Size: 253 KiB

After

Width:  |  Height:  |  Size: 162 KiB

View File

@ -97,12 +97,15 @@ python3 training/main_sync_ppo.py --help
## Monitoring the Training Process
We recommend using Weights & Biases (wandb) for monitoring. Run `wandb login` or set the `WANDB_API_KEY` environment variable. Set `wandb.mode=online` in your configuration to upload training statistics.
+ We recommend using [Weights & Biases (wandb)](https://github.com/wandb/wandb) or [SwanLab](https://github.com/SwanHubX/SwanLab) for monitoring—run `wandb login` or `swanlab login`, or set the corresponding environment variable API key (`WANDB_API_KEY` or `SWANLAB_API_KEY`). Set `wandb.mode="online"` or `swanlab.mode="cloud"` in your configuration to upload training statistics. If you cannot connect to the server, you can also use `wandb.mode="offline"` or `swanlab.mode="local"` to save data locally without uploading.
You can also use TensorBoard by setting the `tensorboard.path` parameter.
The main log will be saved to `${fileroot}/logs/${USER}/${experiment_name}/${trial_name}/main.log` and contains the statistics uploaded to wandb.
If SwanLab is enabled, logs will be saved to the directory specified by `swanlab.logdir`.
### Key Training Statistics
- **`Epoch 1/5`**: Indicates the total epochs required and the current epoch being trained.

View File

@ -15,3 +15,4 @@ prettytable
timeout-decorator
timeout_decorator
wandb
swanlab[dashboard]

View File

@ -1,11 +1,12 @@
#!/bin/bash
# basic dependencies
pip install -U pip
pip uninstall deepspeed flash-attn pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y
pip uninstall torch deepspeed flash-attn pynvml cugraph-dgl dask-cuda cugraph-service-server raft-dask cugraph cuml cugraph-pyg -y
pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0
pip install "sglang[all]==0.4.6.post4"
pip install megatron-core==0.11.0 nvidia-ml-py
pip install git+https://github.com/garrett4wade/cugae --no-build-isolation --verbose
pip install flash-attn --no-build-isolation
pip install "flash-attn<=2.7.3" --no-build-isolation
# Package used for calculating math reward
pip install -e evaluation/latex2sympy

242
examples/env/validate_installation.py vendored Normal file
View File

@ -0,0 +1,242 @@
#!/usr/bin/env python3
"""
Installation Validation Script for AReaL
This script validates that all critical dependencies are properly installed
and can be imported successfully. It's designed to be run in CI to validate
the installation procedure described in docs/tutorial/installation.md.
"""
import importlib
import sys
import traceback
import warnings
from importlib.metadata import version as get_version
from typing import Any, Dict, List, Optional
from packaging.version import Version
class InstallationValidator:
def __init__(self):
self.results = {}
self.critical_failures = []
self.warnings = []
def test_import(self, module_name: str, required: bool = True,
test_func: Optional[callable] = None) -> bool:
"""Test importing a module and optionally run additional tests."""
try:
module = importlib.import_module(module_name)
# Run additional test if provided
if test_func:
test_func(module)
self.results[module_name] = {"status": "SUCCESS", "error": None}
print(f"{module_name}")
return True
except ImportError as e:
self.results[module_name] = {"status": "FAILED", "error": str(e)}
if required:
self.critical_failures.append(f"{module_name}: {str(e)}")
print(f"{module_name} (CRITICAL): {str(e)}")
else:
self.warnings.append(f"{module_name}: {str(e)}")
print(f"{module_name} (OPTIONAL): {str(e)}")
return False
except Exception as e:
self.results[module_name] = {"status": "ERROR", "error": str(e)}
if required:
self.critical_failures.append(f"{module_name}: {str(e)}")
print(f"{module_name} (CRITICAL ERROR): {str(e)}")
else:
self.warnings.append(f"{module_name}: {str(e)}")
print(f"{module_name} (OPTIONAL ERROR): {str(e)}")
return False
def test_torch_cuda(self, torch_module):
"""Test PyTorch CUDA availability."""
if not torch_module.cuda.is_available():
raise RuntimeError("CUDA is not available in PyTorch")
print(f" - CUDA devices: {torch_module.cuda.device_count()}")
print(f" - CUDA version: {torch_module.version.cuda}")
def test_flash_attn_functionality(self, flash_attn_module):
"""Test flash attention functionality."""
# Try to import key functions
from flash_attn import flash_attn_func, flash_attn_varlen_func
print(" - Flash attention functions imported successfully")
def test_vllm_functionality(self, vllm_module):
"""Test vLLM basic functionality."""
from vllm import LLM, SamplingParams
print(" - vLLM core classes imported successfully")
def test_sglang_functionality(self, sglang_module):
"""Test SGLang basic functionality."""
# Basic import test is sufficient for CI
import sgl_kernel
from sglang import launch_server
assert Version(get_version("sglang")) == Version("0.4.6.post4")
print(" - SGLang imported successfully")
def test_transformers(self, transformers_module):
assert Version(get_version("transformers")) == Version("4.51.1")
print(" - transformers imported successfully")
def validate_critical_dependencies(self):
"""Validate critical dependencies that must be present."""
print("\n=== Testing Critical Dependencies ===")
# Core ML frameworks
self.test_import("torch", required=True, test_func=self.test_torch_cuda)
self.test_import("transformers", required=True, test_func=self.test_transformers)
# Flash attention - critical for performance
self.test_import("flash_attn", required=True, test_func=self.test_flash_attn_functionality)
self.test_import("cugae", required=True)
# Inference engines
self.test_import("sglang", required=True, test_func=self.test_sglang_functionality)
# Distributed computing
self.test_import("ray", required=True)
# Scientific computing
self.test_import("numpy", required=True)
self.test_import("scipy", required=True)
# Configuration management
self.test_import("hydra", required=True)
self.test_import("omegaconf", required=True)
# Data processing
self.test_import("datasets", required=True)
self.test_import("pandas", required=True)
self.test_import("einops", required=True)
# Monitoring and logging
self.test_import("wandb", required=True)
self.test_import("pynvml", required=True)
# Networking
self.test_import("aiohttp", required=True)
self.test_import("fastapi", required=True)
self.test_import("uvicorn", required=True)
# Math libraries (for evaluation)
self.test_import("sympy", required=True)
self.test_import("latex2sympy2", required=True)
def validate_optional_dependencies(self):
"""Validate optional dependencies."""
print("\n=== Testing Optional Dependencies ===")
# CUDA extensions (may not be available in all environments)
self.test_import("vllm", required=False, test_func=self.test_vllm_functionality)
self.test_import("grouped_gemm", required=False)
self.test_import("flashattn_hopper", required=False)
# Optional utilities
self.test_import("tensorboardx", required=False)
self.test_import("swanlab", required=False)
self.test_import("matplotlib", required=False)
self.test_import("seaborn", required=False)
self.test_import("numba", required=False)
self.test_import("nltk", required=False)
def validate_cuda_extensions(self):
"""Validate CUDA-specific functionality."""
print("\n=== Testing CUDA Extensions ===")
try:
import torch
if torch.cuda.is_available():
# Test basic CUDA tensor operations
device = torch.device("cuda:0")
x = torch.randn(10, device=device)
y = torch.randn(10, device=device)
z = x + y
print("✓ Basic CUDA operations working")
# Test flash attention if available
try:
from flash_attn import flash_attn_func
# Create small test tensors
batch_size, seq_len, num_heads, head_dim = 1, 32, 4, 64
q = torch.randn(batch_size, seq_len, num_heads, head_dim,
device=device, dtype=torch.float16)
k = torch.randn(batch_size, seq_len, num_heads, head_dim,
device=device, dtype=torch.float16)
v = torch.randn(batch_size, seq_len, num_heads, head_dim,
device=device, dtype=torch.float16)
# Test flash attention call
out = flash_attn_func(q, k, v)
print("✓ Flash attention CUDA operations working")
except Exception as e:
print(f"⚠ Flash attention CUDA test failed: {e}")
else:
print("⚠ CUDA not available - skipping CUDA extension tests")
except Exception as e:
print(f"✗ CUDA extension validation failed: {e}")
def run_validation(self):
"""Run complete validation suite."""
print("AReaL Installation Validation")
print("=" * 50)
self.validate_critical_dependencies()
self.validate_optional_dependencies()
self.validate_cuda_extensions()
# Print summary
print("\n" + "=" * 50)
print("VALIDATION SUMMARY")
print("=" * 50)
total_tests = len(self.results)
successful_tests = sum(1 for r in self.results.values() if r["status"] == "SUCCESS")
failed_tests = total_tests - successful_tests
print(f"Total tests: {total_tests}")
print(f"Successful: {successful_tests}")
print(f"Failed: {failed_tests}")
if self.critical_failures:
print(f"\n🚨 CRITICAL FAILURES ({len(self.critical_failures)}):")
for failure in self.critical_failures:
print(f" - {failure}")
if self.warnings:
print(f"\n⚠️ WARNINGS ({len(self.warnings)}):")
for warning in self.warnings:
print(f" - {warning}")
# Determine overall result
if self.critical_failures:
print(f"\n❌ INSTALLATION VALIDATION FAILED")
print("Please check the critical failures above and ensure all required")
print("dependencies are properly installed according to the installation guide.")
return False
else:
print(f"\n✅ INSTALLATION VALIDATION PASSED")
if self.warnings:
print("Note: Some optional dependencies failed but this won't affect")
print("core functionality.")
return True
def main():
"""Main entry point."""
validator = InstallationValidator()
success = validator.run_validation()
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()

View File

@ -63,6 +63,7 @@ dependencies = [
"colorlog",
"psutil",
"pynvml",
"swanlab[dashboard]",
# Performance and compression
"ninja",

View File

@ -855,6 +855,16 @@ class WandBConfig:
config: Optional[Dict] = None
@dataclass
class SwanlabConfig:
project: Optional[str] = None
name: Optional[str] = None
config: Optional[Dict] = None
logdir: Optional[str] = None
mode: Optional[str] = "local"
api_key: Optional[str] = os.getenv("SWANLAB_API_KEY", None)
@dataclass
class TensorBoardConfig:
path: Optional[str] = None
@ -986,6 +996,10 @@ class BaseExperimentConfig:
default_factory=WandBConfig,
metadata={"help": "Weights & Biases configuration."},
)
swanlab: SwanlabConfig = field(
default_factory=SwanlabConfig,
metadata={"help": "SwanLab configuration."},
)
tensorboard: TensorBoardConfig = field(
default_factory=TensorBoardConfig,
metadata={"help": "TensorBoard configuration. Only 'path' field required."},
@ -1061,7 +1075,7 @@ class BaseExperimentConfig:
default=False,
metadata={
"help": "Enable automatic evaluation during training. "
"Results logged to disk and WandB (if active)."
"Results logged to disk and WandB or Swanlab(if active)."
},
)
auto_eval_config: AutomaticEvaluator = field(

View File

@ -11,6 +11,7 @@ import realhf.api.core.dfg as dfg
from realhf.api.cli_args import (
AutomaticEvaluator,
ExperimentSaveEvalControl,
SwanlabConfig,
TensorBoardConfig,
WandBConfig,
)
@ -189,6 +190,7 @@ class ExperimentScheduling:
class ExperimentConfig:
exp_ctrl: ExperimentSaveEvalControl
wandb: WandBConfig
swanlab: SwanlabConfig
tensorboard: TensorBoardConfig
# dataflow
model_rpcs: List[dfg.MFCDef]

View File

@ -94,7 +94,7 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0):
raise RuntimeError("Experiment initial setup failed.") from e
evaluator = (
AutomaticEvaluator(exp_cfg, exp_cfg.evaluator, exp_cfg.wandb)
AutomaticEvaluator(exp_cfg, exp_cfg.evaluator, exp_cfg.wandb, exp_cfg.swanlab)
if exp_cfg.auto_eval
else None
)

View File

@ -141,19 +141,29 @@ def getLogger(
return logging.getLogger(name)
_LATEST_WANDB_STEP = 0
_LATEST_LOG_STEP = 0
def log_wandb_tensorboard(data, step=None, summary_writer=None):
def log_swanlab_wandb_tensorboard(data, step=None, summary_writer=None):
# Logs data to SwanLab、 wandb、 TensorBoard.
global _LATEST_LOG_STEP
if step is None:
step = _LATEST_LOG_STEP
else:
_LATEST_LOG_STEP = max(_LATEST_LOG_STEP, step)
# swanlab
import swanlab
swanlab.log(data, step=step)
# wandb
import wandb
global _LATEST_WANDB_STEP
if step is None:
step = _LATEST_WANDB_STEP
else:
_LATEST_WANDB_STEP = max(_LATEST_WANDB_STEP, step)
wandb.log(data, step=step)
# tensorboard
if summary_writer is not None:
for key, val in data.items():
summary_writer.add_scalar(f"{key}", val, step)

View File

@ -341,6 +341,7 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
return ExperimentConfig(
exp_ctrl=self.exp_ctrl,
wandb=self.wandb,
swanlab=self.swanlab,
tensorboard=self.tensorboard,
# NOTE: master and model worker only see RPCs without generation
model_rpcs=[

View File

@ -567,6 +567,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
return ExperimentConfig(
exp_ctrl=self.exp_ctrl,
wandb=self.wandb,
swanlab=self.swanlab,
tensorboard=self.tensorboard,
model_rpcs=[rpc_alloc.rpc for rpc_alloc in rpc_allocs],
model_worker=model_worker,

View File

@ -373,6 +373,7 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
return ExperimentConfig(
exp_ctrl=self.exp_ctrl,
wandb=self.wandb,
swanlab=self.swanlab,
tensorboard=self.tensorboard,
model_rpcs=[rpc_alloc.rpc for rpc_alloc in rpc_allocs],
model_worker=model_worker,

View File

@ -8,6 +8,7 @@ import subprocess
import time
from typing import Any, Dict, Optional
import swanlab
import wandb
import realhf.api.core.system_api as config_pkg
@ -126,13 +127,15 @@ class EvaluationStep:
self.status = EvaluationStepStatus.FAILED
return False
wandb_data = {}
log_data = {}
for data_name, d in data.items():
for k, v in d.items():
wandb_data[f"{data_name}_{k}"] = v
wandb.log(wandb_data, step=self.global_step)
log_data[f"{data_name}_{k}"] = v
wandb.log(log_data, step=self.global_step)
swanlab.log(log_data, step=self.global_step)
self.status = EvaluationStepStatus.LOGGED
logger.info(f"Logging eval result {wandb_data} to step {self.global_step}")
logger.info(f"Logging eval result {log_data} to step {self.global_step}")
return True
def check(self):
@ -156,14 +159,16 @@ class AutomaticEvaluator:
args: BaseExperimentConfig,
config: config_pkg.AutomaticEvaluator,
wandb_config: config_pkg.WandBConfig,
swanlab_config: config_pkg.SwanlabConfig,
):
self.args = args
self.__eval_steps: Dict[int, EvaluationStep] = {}
self.__max_concurrent_jobs = config.max_concurrent_jobs
self.__wandb_config = wandb_config
self.__swanlab_config = swanlab_config
self.__config = config
self.__wandb_initialized = False
self.__swanlab_initialized = False
# Check evaluated checkpoints by logs in recover
# NOTE: All previous evaluation steps with output will be marked
# as logged, even if it is not really logged in wandb.
@ -228,6 +233,40 @@ class AutomaticEvaluator:
settings=wandb.Settings(start_method="fork"),
)
def __lazy_swanlab_init(self):
if self.__swanlab_config.api_key:
swanlab.login(self.__swanlab_config.api_key)
if self.swanlab_config.config is None:
import yaml
with open(
os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"config.yaml",
),
"r",
) as f:
__config = yaml.safe_load(f)
else:
__config = self.swanlab_config.config
__config["FRAMEWORK"] = "AReaL"
swanlab.init(
project=self.__swanlab_config.project or constants.experiment_name(),
experiment_name=self.__swanlab_config.name
or f"{constants.trial_name()}_eval",
config=__config,
logdir=self.__swanlab_config.logdir
or os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"swanlab",
),
mode=self.__swanlab_config.mode,
)
def step(self):
# Check whether a new evaluation step should be created
ckpt_parent = os.path.join(
@ -290,6 +329,9 @@ class AutomaticEvaluator:
if not self.__wandb_initialized:
self.__lazy_wandb_init()
self.__wandb_initialized = True
if not self.__swanlab_initialized:
self.__lazy_swanlab_init()
self.__swanlab_initialized = True
self.__eval_steps[log_step].log(self.__config)
@property

View File

@ -12,6 +12,7 @@ from typing import Dict
import colorama
import networkx as nx
import numpy as np
import swanlab
import wandb
from tensorboardX import SummaryWriter
@ -303,6 +304,40 @@ class MasterWorker(worker_base.AsyncWorker):
resume="allow",
settings=wandb.Settings(start_method="fork"),
)
# swanlab init, connect to remote or local swanlab host
if self.swanlab_config.mode != "disabled" and self.swanlab_config.api_key:
swanlab.login(self.swanlab_config.api_key)
if self.swanlab_config.config is None:
import yaml
with open(
os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"config.yaml",
),
"r",
) as f:
__config = yaml.safe_load(f)
else:
__config = self.swanlab_config.config
__config["FRAMEWORK"] = "AReaL"
swanlab.init(
project=self.swanlab_config.project or constants.experiment_name(),
experiment_name=self.swanlab_config.name
or f"{constants.trial_name()}_train",
config=__config,
logdir=self.swanlab_config.logdir
or os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"swanlab",
),
mode=self.swanlab_config.mode,
)
# tensorboard logging
self.__summary_writer = None
if self.tensorboard_config.path is not None:
@ -479,7 +514,7 @@ class MasterWorker(worker_base.AsyncWorker):
s += f"(global step {global_step}) finishes. "
s += f"#End to end# execution time: *{e2e_time:.3f}*s. "
s += f"Total time consumption: {time_since_configure:.3f}s. "
logging.log_wandb_tensorboard({"timeperf/e2e": e2e_time})
logging.log_swanlab_wandb_tensorboard({"timeperf/e2e": e2e_time})
if len(self.e2e_time_history) > 2:
remaining_steps = self._steps_per_epoch - epoch_step
remaining_epochs = self.__total_train_epochs - epoch
@ -532,6 +567,7 @@ class MasterWorker(worker_base.AsyncWorker):
)
wandb.finish()
swanlab.finish()
if self.__summary_writer is not None:
self.__summary_writer.close()
gc.collect()

View File

@ -10,6 +10,7 @@ import uuid
from collections import defaultdict
from typing import Dict, Hashable, List, Set, Tuple
import swanlab
import wandb
from tensorboardX import SummaryWriter
@ -439,7 +440,7 @@ class ModelFunctionCall:
logger.info(
f"RPC name {rpc.name} returns\n{data_api.tabulate_stats(res)}"
)
logging.log_wandb_tensorboard(
logging.log_swanlab_wandb_tensorboard(
res,
step=ctrl.step_info.global_step,
summary_writer=self.summary_writer,
@ -450,7 +451,7 @@ class ModelFunctionCall:
f"RPC name {rpc.name} returns ({j + 1}/{len(res)})\n{data_api.tabulate_stats(r)}"
)
offset = len(res) * ctrl.step_info.global_step
logging.log_wandb_tensorboard(
logging.log_swanlab_wandb_tensorboard(
r,
step=offset + j,
summary_writer=self.summary_writer,
@ -462,11 +463,10 @@ class ModelFunctionCall:
for time_record in time_records:
stats_tracker.scalar(**time_record)
time_stats = stats_tracker.export()
logging.log_wandb_tensorboard(
logging.log_swanlab_wandb_tensorboard(
time_stats,
summary_writer=self.summary_writer,
)
logger.info(
f"Model rpc {rpc.name} finished. "
f"Request-reply time {time.perf_counter() - tik:.4f}s. "

View File

@ -581,7 +581,9 @@ class Worker:
)
expr_config.lazy_init()
self.wandb_config = expr_config.wandb
self.swanlab_config = expr_config.swanlab
os.environ["WANDB_MODE"] = self.wandb_config.mode
os.environ["SWANLAB_MODE"] = self.swanlab_config.mode
self.tensorboard_config = expr_config.tensorboard
config = expr_config.resolve_worker_config(
self.__worker_type, self.__worker_index

View File

@ -70,4 +70,5 @@ Pebble
timeout-decorator
prettytable
gymnasium>=1.1.1
torchdata
torchdata
swanlab[dashboard]

View File

@ -14,6 +14,13 @@ wandb:
notes: null
tags: null
config: null
swanlab:
mode: disabled
api_key: null
project: null
name: null
config: null
logdir: null
tensorboard:
path: null
recover_mode: auto

View File

@ -14,6 +14,13 @@ wandb:
notes: null
tags: null
config: null
swanlab:
mode: disabled
api_key: null
project: null
name: null
config: null
logdir: null
tensorboard:
path: null
recover_mode: auto

View File

@ -14,6 +14,13 @@ wandb:
notes: null
tags: null
config: null
swanlab:
mode: disabled
api_key: null
project: null
name: null
config: null
logdir: null
tensorboard:
path: null
recover_mode: auto

View File

@ -93,6 +93,7 @@ class RayWorker:
worker_info.experiment_name, worker_info.trial_name
)
self.worker.wandb_config = expr_config.wandb
self.worker.swanlab_config = expr_config.swanlab
self.worker.tensorboard_config = expr_config.tensorboard
self.worker.args = self.args
self.logger = logging.getLogger(f"{self.worker_type} {idx}", "benchmark")
@ -130,6 +131,7 @@ def _run_experiment(exp_cfg, expr_name, trial_name):
env_vars = constants.get_env_vars(
exp_cfg,
WADNB_MODE=exp_cfg.wandb.mode,
SWANLAB_MODE=exp_cfg.swanlab.mode,
REAL_MODE="ray",
REAL_RECOVER_RUN="0",
REAL_SAVE_RECOVER_STATES="1",