mirror of https://github.com/inclusionAI/AReaL
243 lines
9.7 KiB
Python
243 lines
9.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Installation Validation Script for AReaL
|
|
|
|
This script validates that all critical dependencies are properly installed
|
|
and can be imported successfully. It's designed to be run in CI to validate
|
|
the installation procedure described in docs/tutorial/installation.md.
|
|
"""
|
|
|
|
import importlib
|
|
import sys
|
|
import traceback
|
|
import warnings
|
|
from importlib.metadata import version as get_version
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from packaging.version import Version
|
|
|
|
|
|
class InstallationValidator:
|
|
def __init__(self):
|
|
self.results = {}
|
|
self.critical_failures = []
|
|
self.warnings = []
|
|
|
|
def test_import(self, module_name: str, required: bool = True,
|
|
test_func: Optional[callable] = None) -> bool:
|
|
"""Test importing a module and optionally run additional tests."""
|
|
try:
|
|
module = importlib.import_module(module_name)
|
|
|
|
# Run additional test if provided
|
|
if test_func:
|
|
test_func(module)
|
|
|
|
self.results[module_name] = {"status": "SUCCESS", "error": None}
|
|
print(f"✓ {module_name}")
|
|
return True
|
|
|
|
except ImportError as e:
|
|
self.results[module_name] = {"status": "FAILED", "error": str(e)}
|
|
if required:
|
|
self.critical_failures.append(f"{module_name}: {str(e)}")
|
|
print(f"✗ {module_name} (CRITICAL): {str(e)}")
|
|
else:
|
|
self.warnings.append(f"{module_name}: {str(e)}")
|
|
print(f"⚠ {module_name} (OPTIONAL): {str(e)}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
self.results[module_name] = {"status": "ERROR", "error": str(e)}
|
|
if required:
|
|
self.critical_failures.append(f"{module_name}: {str(e)}")
|
|
print(f"✗ {module_name} (CRITICAL ERROR): {str(e)}")
|
|
else:
|
|
self.warnings.append(f"{module_name}: {str(e)}")
|
|
print(f"⚠ {module_name} (OPTIONAL ERROR): {str(e)}")
|
|
return False
|
|
|
|
def test_torch_cuda(self, torch_module):
|
|
"""Test PyTorch CUDA availability."""
|
|
if not torch_module.cuda.is_available():
|
|
raise RuntimeError("CUDA is not available in PyTorch")
|
|
print(f" - CUDA devices: {torch_module.cuda.device_count()}")
|
|
print(f" - CUDA version: {torch_module.version.cuda}")
|
|
|
|
def test_flash_attn_functionality(self, flash_attn_module):
|
|
"""Test flash attention functionality."""
|
|
# Try to import key functions
|
|
import flash_attn_2_cuda
|
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
print(" - Flash attention functions imported successfully")
|
|
|
|
def test_vllm_functionality(self, vllm_module):
|
|
"""Test vLLM basic functionality."""
|
|
from vllm import LLM, SamplingParams
|
|
print(" - vLLM core classes imported successfully")
|
|
|
|
def test_sglang_functionality(self, sglang_module):
|
|
"""Test SGLang basic functionality."""
|
|
# Basic import test is sufficient for CI
|
|
import sgl_kernel
|
|
from sglang import Engine, launch_server
|
|
assert Version(get_version("sglang")) == Version("0.4.9.post2"), "SGLang version should be v0.4.9.post2"
|
|
print(" - SGLang imported successfully")
|
|
|
|
def test_transformers(self, transformers_module):
|
|
assert Version(get_version("transformers")) == Version("4.53.1"), "transformers version should be 4.53.1"
|
|
print(" - transformers imported successfully")
|
|
|
|
def validate_critical_dependencies(self):
|
|
"""Validate critical dependencies that must be present."""
|
|
print("\n=== Testing Critical Dependencies ===")
|
|
|
|
# Core ML frameworks
|
|
self.test_import("torch", required=True, test_func=self.test_torch_cuda)
|
|
self.test_import("transformers", required=True, test_func=self.test_transformers)
|
|
|
|
# Flash attention - critical for performance
|
|
self.test_import("flash_attn", required=True, test_func=self.test_flash_attn_functionality)
|
|
self.test_import("cugae", required=True)
|
|
# Inference engines
|
|
self.test_import("sglang", required=True, test_func=self.test_sglang_functionality)
|
|
|
|
# Distributed computing
|
|
self.test_import("ray", required=True)
|
|
|
|
# Scientific computing
|
|
self.test_import("numpy", required=True)
|
|
self.test_import("scipy", required=True)
|
|
|
|
# Configuration management
|
|
self.test_import("hydra", required=True)
|
|
self.test_import("omegaconf", required=True)
|
|
|
|
# Data processing
|
|
self.test_import("datasets", required=True)
|
|
self.test_import("pandas", required=True)
|
|
self.test_import("einops", required=True)
|
|
|
|
# Monitoring and logging
|
|
self.test_import("wandb", required=True)
|
|
self.test_import("pynvml", required=True)
|
|
|
|
# Networking
|
|
self.test_import("aiohttp", required=True)
|
|
self.test_import("fastapi", required=True)
|
|
self.test_import("uvicorn", required=True)
|
|
|
|
# Math libraries (for evaluation)
|
|
self.test_import("sympy", required=True)
|
|
self.test_import("latex2sympy2", required=True)
|
|
|
|
def validate_optional_dependencies(self):
|
|
"""Validate optional dependencies."""
|
|
print("\n=== Testing Optional Dependencies ===")
|
|
|
|
# CUDA extensions (may not be available in all environments)
|
|
self.test_import("vllm", required=False, test_func=self.test_vllm_functionality)
|
|
self.test_import("grouped_gemm", required=False)
|
|
self.test_import("flashattn_hopper", required=False)
|
|
|
|
# Optional utilities
|
|
self.test_import("tensorboardX", required=False)
|
|
self.test_import("swanlab", required=False)
|
|
self.test_import("matplotlib", required=False)
|
|
self.test_import("seaborn", required=False)
|
|
self.test_import("numba", required=False)
|
|
self.test_import("nltk", required=False)
|
|
|
|
def validate_cuda_extensions(self):
|
|
"""Validate CUDA-specific functionality."""
|
|
print("\n=== Testing CUDA Extensions ===")
|
|
|
|
try:
|
|
import torch
|
|
if torch.cuda.is_available():
|
|
# Test basic CUDA tensor operations
|
|
device = torch.device("cuda:0")
|
|
x = torch.randn(10, device=device)
|
|
y = torch.randn(10, device=device)
|
|
z = x + y
|
|
print("✓ Basic CUDA operations working")
|
|
|
|
# Test flash attention if available
|
|
try:
|
|
from flash_attn import flash_attn_func
|
|
|
|
# Create small test tensors
|
|
batch_size, seq_len, num_heads, head_dim = 1, 32, 4, 64
|
|
q = torch.randn(batch_size, seq_len, num_heads, head_dim,
|
|
device=device, dtype=torch.float16)
|
|
k = torch.randn(batch_size, seq_len, num_heads, head_dim,
|
|
device=device, dtype=torch.float16)
|
|
v = torch.randn(batch_size, seq_len, num_heads, head_dim,
|
|
device=device, dtype=torch.float16)
|
|
|
|
# Test flash attention call
|
|
out = flash_attn_func(q, k, v)
|
|
print("✓ Flash attention CUDA operations working")
|
|
|
|
except Exception as e:
|
|
print(f"⚠ Flash attention CUDA test failed: {e}")
|
|
|
|
else:
|
|
print("⚠ CUDA not available - skipping CUDA extension tests")
|
|
|
|
except Exception as e:
|
|
print(f"✗ CUDA extension validation failed: {e}")
|
|
|
|
def run_validation(self):
|
|
"""Run complete validation suite."""
|
|
print("AReaL Installation Validation")
|
|
print("=" * 50)
|
|
|
|
self.validate_critical_dependencies()
|
|
self.validate_optional_dependencies()
|
|
self.validate_cuda_extensions()
|
|
|
|
# Print summary
|
|
print("\n" + "=" * 50)
|
|
print("VALIDATION SUMMARY")
|
|
print("=" * 50)
|
|
|
|
total_tests = len(self.results)
|
|
successful_tests = sum(1 for r in self.results.values() if r["status"] == "SUCCESS")
|
|
failed_tests = total_tests - successful_tests
|
|
|
|
print(f"Total tests: {total_tests}")
|
|
print(f"Successful: {successful_tests}")
|
|
print(f"Failed: {failed_tests}")
|
|
|
|
if self.critical_failures:
|
|
print(f"\n🚨 CRITICAL FAILURES ({len(self.critical_failures)}):")
|
|
for failure in self.critical_failures:
|
|
print(f" - {failure}")
|
|
|
|
if self.warnings:
|
|
print(f"\n⚠️ WARNINGS ({len(self.warnings)}):")
|
|
for warning in self.warnings:
|
|
print(f" - {warning}")
|
|
|
|
# Determine overall result
|
|
if self.critical_failures:
|
|
print(f"\n❌ INSTALLATION VALIDATION FAILED")
|
|
print("Please check the critical failures above and ensure all required")
|
|
print("dependencies are properly installed according to the installation guide.")
|
|
return False
|
|
else:
|
|
print(f"\n✅ INSTALLATION VALIDATION PASSED")
|
|
if self.warnings:
|
|
print("Note: Some optional dependencies failed but this won't affect")
|
|
print("core functionality.")
|
|
return True
|
|
|
|
def main():
|
|
"""Main entry point."""
|
|
validator = InstallationValidator()
|
|
success = validator.run_validation()
|
|
sys.exit(0 if success else 1)
|
|
|
|
if __name__ == "__main__":
|
|
main() |