mirror of https://github.com/inclusionAI/AReaL
29 lines
753 B
Python
29 lines
753 B
Python
# Copyright 2025 Ant Group Inc.
|
|
# Copyright 2024 Wei Fu & Zhiyu Mei
|
|
# Licensed under the Apache License, Version 2.0 (the "License").
|
|
|
|
from realhf.api.core.model_api import ReaLModelConfig
|
|
|
|
|
|
def find_factors(n):
|
|
factors = []
|
|
for i in range(1, n + 1):
|
|
if n % i == 0:
|
|
factors.append(i)
|
|
return factors
|
|
|
|
|
|
def make_stats_key(rpc_name, bs, seq_len):
|
|
return f"{rpc_name}|{bs}|{seq_len}"
|
|
|
|
|
|
def parse_stats_key(key):
|
|
rpc_name, bs, seq_len = key.split("|")
|
|
return rpc_name, int(bs), int(seq_len)
|
|
|
|
|
|
def load_model_config(model_class: str, model_path: str) -> ReaLModelConfig:
|
|
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
|
|
|
return getattr(ReaLModel, f"config_from_{model_class}")(model_path=model_path)
|