Support asynchronous RL training, Qwen3, and the latest SGLang (#47)

* feat: one buffer for each task

* feat: support "one buffer for each task" for async

* make kv_cache_dtype configurable

Signed-off-by: Tiwei Bie <tiwei.btw@antgroup.com>

* style: use plural form

fix: use _seed_from_key to set different seeds for data loaders
fix: call load_data for one buffer each time

* PullRequest: 125 Support running async experiments in the 2407 image.

Merge branch fw/async2407 of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/125

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .

* fix: handle multiple datasets in recover indices
fix: `isinstance(self.__datasets, PullerStreamDataset)`
feat: use the "spec" request to obtain the number of datasets
fix: revert rollout worker

* fix: revert async_rl_exp.py

* fix flag for list (cuda_graph_bs)

* format

* [FIX] fix async task reward [sglang bf16-> fp16]

* fix: define `self.__datasets` in advance

* PullRequest: 130 [Refactor] Remove deprecated search related code

Merge branch mzy/remove-search of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/130

Signed-off-by: 博惟 <bowei.fw@antgroup.com>


* remove search related

* PullRequest: 131 [Refactor] Change terminology "model parallel" into "tensor parallel" to align with megatron.

Merge branch mzy/mp-to-tp of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/131?tab=comment

Signed-off-by: 博惟 <bowei.fw@antgroup.com>


* change mp to tp
* .
* .

* PullRequest: 142 Fix an error for megatron backend destroy

Merge branch fw/fix-meagatron-destroy of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/142

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .

* PullRequest: 143 Fix the port conflict issue of generation servers

Merge branch fw/fix-gen-port of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/143?tab=comment

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* somehow fix the port issue
* add clearance period
* .
* .

* PullRequest: 145 Add code environment

Merge branch fw/code-env of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/145?tab=comment

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* add code env
* somehow fix the port issue
* fix

* PullRequest: 144 Add decoupled PPO loss

Merge branch fw/decoupled-ppo-loss of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/144?tab=comment

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* fix ppo step logging, nan in stats tracker, and add decoupled loss
* .
* somehow fix the port issue
* fix typo

* PullRequest: 146 Merge SLURM logs and save experiment configs in yaml format.

Merge branch fw/better-logging of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/146

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* merge all slurm logs into one
* write config to yaml

* PullRequest: 141 Merge changes during NeurIPS submission

Merge branch fw/async-dev of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/141

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* .
* .
* .
* .
* update script
* .
* .
* .
* .
* [ADD] add least req scheduling
* fix test genreq
* .
* .
* fix stats tracker nan
* .
* .
* .
* .
* .
* .
* .
* uppper clip decoupled objective
* add throughput exp script
* .
* remove behav upper clip param
* .
* .
* .
* plot curve
* update thpt script
* .
* master worker raise error when exiting
* update script
* add gen throughput logging
* .
* .
* add decoupled wandb data
* .
* fix port issue and add no training option
* .
* enlarge ttl
* remove gserver manager await staled
* update weights in groups
* .
* .
* .
* add port clearance period
* .
* .
* .
* add plot script
* add sft throughput eval
* .
* log tokens in null interface
* 消融实验和interruptible generation
* 画图脚本/运行脚本/数据结果
* .
* remove scripts
* add port test
* remove force_sync_reward
* revert some changes
* .
* revert
* revert fix
* fix
* revert
* fix typo

* support qwen3 training

* PullRequest: 147 Support interruption in SGLang and fix a KeyError in gather-scatter communication

Merge branch fw/sglang046-with-abort-request of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/147?tab=diff

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* fix ppo step logging, nan in stats tracker, and add decoupled loss
* .
* somehow fix the port issue
* initial commit
* add interupt request
* fix data transfer issue
* max concurrent rollouts defaults to train batch size
* merge main
* add patch
* fix patch typp
* revert sglang
* fix typo
* fix minor typo
* .
* pip show editable sglang path

* PullRequest: 149 fix: code faas max_retries

Merge branch xss/fix_code_verifier of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/149

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* fix: code faas max_retries

* PullRequest: 150 [Bug Fix] Fix key errors in `_run_scatter` in data transfer

Merge branch mzy/fix-scatter-groups of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/150

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* fix scatter groups key error

* fix test

* .

* PullRequest: 151 Fix Qwen3 import error when using transformers with a lower version

Merge branch fw/fix-qwen3 of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/151

Reviewed-by: 温差 <xushusheng.xss@antgroup.com>


* merge all slurm logs into one
* write config to yaml
* .

* PullRequest: 152 Support sglang0.4.6 and fix master_worker import error

Merge branch adopt_sglang046 of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/152

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* Support sglang0.4.6 and fix master_worker import error
* remove disable_mla option

---------

Signed-off-by: Tiwei Bie <tiwei.btw@antgroup.com>
Co-authored-by: wanghuaijie.whj <wanghuaijie.whj@antgroup.com>
Co-authored-by: Tiwei Bie <tiwei.btw@antgroup.com>
Co-authored-by: kira.gw <kira.gw@antgroup.com>
Co-authored-by: shenxujie.sxj <shenxujie.sxj@antgroup.com>
Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com>
Co-authored-by: sam.gjx <sam.gjx@antgroup.com>
Co-authored-by: 温差 <xushusheng.xss@antgroup.com>
Co-authored-by: 履渊 <yuhong.gyh@antgroup.com>
This commit is contained in:
Wei Fu 2025-05-26 09:45:13 +08:00 committed by GitHub
parent f7ab31d050
commit c60d128b14
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
112 changed files with 2361 additions and 4559 deletions

View File

@ -1,87 +0,0 @@
#include <device_mesh.hpp>
#include <cassert>
#include <iostream>
// DeviceMesh::DeviceMesh()
// : device_mesh_name(""), n_nodes(0), n_gpus(0), node_names({}), gpu_ids({}) {
// };
DeviceMesh::DeviceMesh(int n_nodes, int n_gpus_per_node, std::vector<std::vector<int>> mapping,
std::string global_mesh_name, std::string name)
: n_nodes(n_nodes),
n_gpus_per_node(n_gpus_per_node),
mapping(mapping),
global_mesh_name(global_mesh_name),
name(name) {
assert(n_nodes == static_cast<int>(mapping.size()));
for (int i = 0; i < n_nodes; i++) {
assert(n_gpus_per_node == static_cast<int>(mapping[i].size()));
}
};
bool is_all_overlap(std::vector<DeviceMesh *> device_meshes, DeviceMesh device_mesh) {
for (DeviceMesh *other : device_meshes) {
if (!device_mesh.overlap(*other)) return false;
}
return true;
};
bool is_all_overlap(std::unordered_set<DeviceMesh *> device_meshes, DeviceMesh device_mesh) {
for (DeviceMesh *other : device_meshes) {
if (!device_mesh.overlap(*other)) return false;
}
return true;
};
bool DeviceMesh::contain(const DeviceMesh &other) {
// check whether one device mapping is contained by another by
// checking 1. whether global_mesh_name is identical
// 2. whether mapping of one device mesh is contained by the other one
if (global_mesh_name != other.global_mesh_name) return false;
for (int i = 0; i < n_nodes; i++) {
for (int j = 0; j < n_gpus_per_node; j++) {
if (mapping[i][j] == 0 && other.mapping[i][j] == 1) return false;
}
}
return true;
};
bool DeviceMesh::contained_by(const DeviceMesh &other) {
if (global_mesh_name != other.global_mesh_name) return false;
for (int i = 0; i < n_nodes; i++) {
for (int j = 0; j < n_gpus_per_node; j++) {
if (mapping[i][j] == 1 && other.mapping[i][j] == 0) return false;
}
}
return true;
};
bool DeviceMesh::overlap(const DeviceMesh &other) {
if (global_mesh_name != other.global_mesh_name) return false;
for (int i = 0; i < n_nodes; i++) {
for (int j = 0; j < n_gpus_per_node; j++) {
if (mapping[i][j] == 1 && other.mapping[i][j] == 1) return true;
}
}
return false;
};
ModelParallelStrategy::ModelParallelStrategy(int num_pp, int num_dp, int num_mp)
: num_pp(num_pp), num_dp(num_dp), num_mp(num_mp) {};
bool ModelParallelStrategy::operator==(const ModelParallelStrategy &other) const {
return num_pp == other.num_pp && num_dp == other.num_dp && num_mp == other.num_mp;
};
bool DeviceMesh::operator==(const DeviceMesh &other) const {
return name == other.name && global_mesh_name == other.global_mesh_name;
};
std::string ModelParallelStrategy::to_string() {
return "num_pp:" + std::to_string(num_pp) + ";" + "num_dp:" + std::to_string(num_dp) + ";"
+ "num_mp:" + std::to_string(num_mp);
};
std::string ModelParallelStrategy::to_key() {
return std::to_string(num_pp) + "," + std::to_string(num_mp) + "," + std::to_string(num_dp);
}

View File

@ -1,49 +0,0 @@
#ifndef DEVICE_MESH_HPP
#define DEVICE_MESH_HPP
#include <string>
#include <vector>
#include <unordered_map>
#include <unordered_set>
// #include <rpc.hpp>
class RPCInstance;
class DeviceMesh {
public:
int n_nodes;
int n_gpus_per_node;
std::vector<std::vector<int>> mapping;
std::string global_mesh_name;
std::string name;
RPCInstance *pre_task = nullptr;
// DeviceMesh();
DeviceMesh(int n_nodes, int n_gpus_per_node, std::vector<std::vector<int>> mapping,
std::string global_mesh_name, std::string name);
bool overlap(const DeviceMesh &other);
bool contain(const DeviceMesh &other);
bool contained_by(const DeviceMesh &other);
bool operator==(const DeviceMesh &other) const;
};
bool is_all_overlap(std::vector<DeviceMesh *> device_meshes, DeviceMesh device_mesh);
bool is_all_overlap(std::unordered_set<DeviceMesh *> device_meshes, DeviceMesh device_mesh);
class ModelParallelStrategy {
public:
int num_pp, num_dp, num_mp;
ModelParallelStrategy(int num_pp, int num_dp, int num_mp);
bool operator==(const ModelParallelStrategy &other) const;
std::string to_string();
std::string to_key();
};
class ModelDeviceMapping {};
#endif // DEVICE_MESH_HPP

View File

@ -1,233 +0,0 @@
#include <rpc.hpp>
#include <device_mesh.hpp>
#include <numeric>
#include <algorithm>
#include <iostream>
#include <iomanip>
RPC::RPC(std::string model_name, std::string rpc_name, std::string interface_type)
: model_name(model_name), rpc_name(rpc_name), interface_type(interface_type) {};
RPCExecution::RPCExecution(RPC *rpc_ptr, DeviceMesh &device_mesh,
ModelParallelStrategy &model_parallel_strategy, uint64_t time_cost,
uint64_t mem, uint64_t static_mem)
: rpc_ptr(rpc_ptr),
device_mesh(device_mesh),
model_parallel_strategy(model_parallel_strategy),
time_cost(time_cost),
mem(mem),
static_mem(static_mem) {};
bool OverlapGroup::maybe_add(RPCExecution *rpc_exe) {
if (rpc_executions.empty()) {
rpc_executions.insert(rpc_exe);
device_meshes.insert(&rpc_exe->device_mesh);
mem_static = rpc_exe->static_mem;
mem_active = rpc_exe->mem - rpc_exe->static_mem;
return true;
}
if (is_all_overlap(device_meshes, rpc_exe->device_mesh)) {
rpc_executions.insert(rpc_exe);
// bool dm_in_group = device_meshes.find(&rpc_exe -> device_mesh) != device_meshes.end();
device_meshes.insert(&rpc_exe->device_mesh);
mem_static += rpc_exe->static_mem;
mem_active = std::max(mem_active, rpc_exe->mem - rpc_exe->static_mem);
return true;
}
return false;
};
void DeviceMeshGroup::add_to_groups(RPCExecution *rpc_exe) {
if (overlap_groups.empty()) {
OverlapGroup *og = new OverlapGroup();
og->maybe_add(rpc_exe);
overlap_groups.push_back(og);
return;
}
std::vector<OverlapGroup *> tmp_new_ogs;
for (OverlapGroup *og : overlap_groups) {
// OverlapGroup og_copy = *og;
bool update = og->maybe_add(rpc_exe);
if (!update) {
tmp_new_ogs.push_back(new OverlapGroup());
tmp_new_ogs.back()->maybe_add(rpc_exe);
for (RPCExecution *rpc_exe : og->rpc_executions) { tmp_new_ogs.back()->maybe_add(rpc_exe); }
}
tmp_new_ogs.push_back(og);
}
overlap_groups.clear();
for (OverlapGroup *og : tmp_new_ogs) { overlap_groups.push_back(og); }
};
void GroupedRPCExecutions::resolve(RPCExecution *rpc_exe) {}
void GroupedRPCExecutions::add(RPCExecution *rpc_exe) { group.add_to_groups(rpc_exe); };
void GroupedRPCExecutions::offload(std::string model_name) {};
uint64_t GroupedRPCExecutions::total_mem_cost() {
uint64_t max_mem = 0;
for (auto &og : group.overlap_groups) {
// double og_ma = (og -> mem_active/(1024*1024))/1024.0;
// double og_ms = (og -> mem_static/(1024*1024))/1024.0;
// std::cout << "og size " << og -> rpc_executions.size()
// << " mem active " << og_ma << " GB"
// << " mem static " << og_ms << " GB" << std::endl;
if (og->mem_active + og->mem_static > max_mem) { max_mem = og->mem_active + og->mem_static; }
}
return max_mem;
};
RPCInstance::RPCInstance(RPC *rpc_ptr, int id, std::string name)
: rpc_ptr(rpc_ptr), id(id), name(name) {};
void RPCInstance::remove_parent(RPCInstance *parent) {
auto it = std::find(parents.begin(), parents.end(), parent);
if (it != parents.end()) { parents.erase(it); }
};
void RPCInstance::remove_child(RPCInstance *child) {
auto it = std::find(children.begin(), children.end(), child);
if (it != children.end()) { children.erase(it); }
};
void RPCInstance::add_parent(RPCInstance *parent) { parents.push_back(parent); };
void RPCInstance::add_child(RPCInstance *child) { children.push_back(child); };
void RPCInstance::remove_tmp_child(RPCInstance *child) {
auto it = std::find(tmp_children.begin(), tmp_children.end(), child);
if (it != tmp_children.end()) { tmp_children.erase(it); }
};
void RPCInstance::remove_tmp_parent(RPCInstance *parent) {
auto it = std::find(tmp_parents.begin(), tmp_parents.end(), parent);
if (it != tmp_parents.end()) { tmp_parents.erase(it); }
};
void RPCInstance::add_tmp_parent(RPCInstance *parent) { tmp_parents.push_back(parent); };
void RPCInstance::add_tmp_child(RPCInstance *child) { tmp_children.push_back(child); };
uint64_t parameter_sync_cost(uint64_t model_size, RPCExecution *src, RPCExecution *dst,
std::unordered_map<std::string, uint64_t> &cost_table) {
// 7b size 13738442752 Bytes
// double size_multiplier = double(model_size) / 13738442752.0;
std::string model_key = std::to_string(model_size);
std::string src_key = src->model_parallel_strategy.to_key();
std::string dst_key = dst->model_parallel_strategy.to_key();
if (src_key == dst_key) return 0;
std::string key = model_key + "," + src_key + "," + dst_key;
if (cost_table.find(key) == cost_table.end()) {
// std::cout << "key " << key << " not found" << std::endl;
return 0;
}
return cost_table[key];
}
void RPCInstance::resolve_parameter_sync(std::vector<RPCInstance *> tmp_graph,
std::unordered_map<std::string, uint64_t> &cost_table) {
// add parameter synchronization edges
if (!param_sync) return;
// dst to train
uint64_t from_cost =
parameter_sync_cost(param_sync_size, param_sync_rpc_exe_ptr, rpc_exe_ptr, cost_table);
uint64_t to_cost =
parameter_sync_cost(param_sync_size, rpc_exe_ptr, param_sync_rpc_exe_ptr, cost_table);
// if (param_sync_cost > 0)
// std::cout << "Param sync cost " << param_sync_cost << " from "
// << param_sync_rpc_exe_ptr -> rpc_ptr -> rpc_name << " to "
// << rpc_exe_ptr -> rpc_ptr -> rpc_name << std::endl;
// add param sync from src to dst
RPCExecution *from_src_exe =
new RPCExecution(rpc_ptr, param_sync_rpc_exe_ptr->device_mesh,
param_sync_rpc_exe_ptr->model_parallel_strategy, from_cost, 0, 0);
RPCInstance *from_src = new RPCInstance(rpc_ptr, id, name + ":from_src");
from_src->rpc_exe_ptr = from_src_exe;
RPCExecution *from_dst_exe = new RPCExecution(
rpc_ptr, rpc_exe_ptr->device_mesh, rpc_exe_ptr->model_parallel_strategy, from_cost, 0, 0);
RPCInstance *from_dst = new RPCInstance(rpc_ptr, id, name + ":from_dst");
from_dst->rpc_exe_ptr = from_dst_exe;
// bool overlap = src -> rpc_exe_ptr -> device_mesh.overlap(
// dst -> rpc_exe_ptr -> device_mesh);
for (RPCInstance *parent : parents) {
parent->remove_tmp_child(this);
from_src->add_tmp_parent(parent);
from_dst->add_tmp_parent(parent);
parent->add_tmp_child(from_src);
parent->add_tmp_child(from_dst);
}
this->tmp_parents.clear();
from_src->add_tmp_child(this);
from_dst->add_tmp_child(this);
this->add_tmp_parent(from_src);
this->add_tmp_parent(from_dst);
tmp_graph.push_back(from_src);
tmp_graph.push_back(from_dst);
tmp_ris.push_back(from_src);
tmp_ris.push_back(from_dst);
tmp_exes.push_back(from_src_exe);
tmp_exes.push_back(from_dst_exe);
// add param sync from dst to src
RPCExecution *to_src_exe = new RPCExecution(rpc_ptr, rpc_exe_ptr->device_mesh,
rpc_exe_ptr->model_parallel_strategy, to_cost, 0, 0);
RPCInstance *to_src = new RPCInstance(rpc_ptr, id, name + ":to_src");
to_src->rpc_exe_ptr = to_src_exe;
RPCExecution *to_dst_exe =
new RPCExecution(rpc_ptr, param_sync_rpc_exe_ptr->device_mesh,
param_sync_rpc_exe_ptr->model_parallel_strategy, to_cost, 0, 0);
RPCInstance *to_dst = new RPCInstance(rpc_ptr, id, name + ":to_dst");
to_dst->rpc_exe_ptr = to_dst_exe;
for (RPCInstance *child : children) {
child->remove_tmp_parent(this);
to_src->add_tmp_child(child);
to_dst->add_tmp_child(child);
child->add_tmp_parent(to_src);
child->add_tmp_parent(to_dst);
}
this->tmp_children.clear();
to_src->add_tmp_parent(this);
to_dst->add_tmp_parent(this);
this->add_tmp_child(to_src);
this->add_tmp_child(to_dst);
tmp_graph.push_back(to_src);
tmp_graph.push_back(to_dst);
tmp_ris.push_back(to_src);
tmp_ris.push_back(to_dst);
tmp_exes.push_back(to_src_exe);
tmp_exes.push_back(to_dst_exe);
}
CommStats::CommStats(uint64_t local_send, uint64_t local_recv, uint64_t remote_send,
uint64_t remote_recv, uint64_t offload_store, uint64_t offload_load)
: local_send(local_send),
local_recv(local_recv),
remote_send(remote_send),
remote_recv(remote_recv),
offload_store(offload_store),
offload_load(offload_load) {};
// ModelConfig::ModelConfig(std::string model_name,
// uint64_t param_size_bytes)
// : model_name(model_name), param_size_bytes(param_size_bytes) {
// };
std::string RPCExecution::to_string() {
return rpc_ptr->rpc_name + " on " + device_mesh.name
+ ", parallel strategy: " + model_parallel_strategy.to_string();
};

View File

@ -1,121 +0,0 @@
#ifndef RPC_HPP
#define RPC_HPP
#include <string>
#include <vector>
#include <unordered_map>
#include <device_mesh.hpp>
class CommStats {
public:
uint64_t local_send, local_recv, remote_send, remote_recv, offload_store, offload_load;
CommStats(uint64_t local_send, uint64_t local_recv, uint64_t remote_send, uint64_t remote_recv,
uint64_t offload_store, uint64_t offload_load);
};
class RPC {
public:
std::string model_name;
std::string rpc_name;
// interface_type: 0=generate, 1=train_step, 2=inference
std::string interface_type;
RPC(std::string model_name, std::string rpc_name, std::string interface_type);
};
class RPCExecution {
public:
RPC *rpc_ptr;
DeviceMesh &device_mesh;
ModelParallelStrategy &model_parallel_strategy;
uint64_t time_cost, mem, static_mem;
RPCExecution(RPC *rpc_ptr, DeviceMesh &device_mesh,
ModelParallelStrategy &model_parallel_strategy, uint64_t time_cost, uint64_t mem,
uint64_t static_mem);
std::string to_string();
};
class OverlapGroup {
public:
std::unordered_set<RPCExecution *> rpc_executions;
std::unordered_set<DeviceMesh *> device_meshes;
uint64_t mem_static;
uint64_t mem_active;
bool maybe_add(RPCExecution *rpc_exe);
};
class DeviceMeshGroup {
public:
// std::string device_mesh_name;
std::vector<OverlapGroup *> overlap_groups;
void add_to_groups(RPCExecution *rpc_exe);
};
class GroupedRPCExecutions {
public:
// std::unordered_map<std::string, DeviceMeshGroup> dn_to_group;
DeviceMeshGroup group;
void add(RPCExecution *rpc_exe);
void resolve(RPCExecution *rpc_exe);
void offload(std::string model_name);
uint64_t total_mem_cost();
};
class RPCInstance {
public:
RPC *rpc_ptr;
int id;
std::string name;
std::vector<RPCInstance *> children;
std::vector<RPCInstance *> parents;
std::vector<RPCInstance *> tmp_children;
std::vector<RPCInstance *> tmp_parents;
std::vector<RPCInstance *> tmp_ris; // pointers to tmp rpc instances
std::vector<RPCExecution *> tmp_exes; // pointers to tmp rpc executions
RPCExecution *rpc_exe_ptr = nullptr;
RPCExecution *param_sync_rpc_exe_ptr = nullptr;
bool param_sync = false;
uint64_t param_sync_size = 0;
bool offload = false;
uint64_t offload_size = 0;
RPCInstance(RPC *rpc_ptr, int id, std::string name);
uint64_t ready_time = 0, start_time = 0, end_time = 0;
void remove_parent(RPCInstance *parent);
void remove_child(RPCInstance *child);
void add_parent(RPCInstance *parent);
void add_child(RPCInstance *child);
void add_tmp_parent(RPCInstance *parent);
void add_tmp_child(RPCInstance *child);
void remove_tmp_parent(RPCInstance *parent);
void remove_tmp_child(RPCInstance *child);
void resolve_parameter_sync(std::vector<RPCInstance *> tmp_graph,
std::unordered_map<std::string, uint64_t> &cost_table);
// void resolve_offload(std::vector<RPCInstance*> tmp_graph,
// CommStats& comm_stats);
};
uint64_t parameter_sync_cost(uint64_t param_size_bytes, RPCExecution *src, RPCExecution *dst,
std::unordered_map<std::string, uint64_t> &cost_table);
uint64_t remote_param_sync_size(uint64_t size, RPCExecution *src, RPCExecution *dst);
// class ModelConfig {
// std::string model_name;
// uint64_t param_size_bytes;
// ModelConfig(std::string model_name, uint64_t param_size_bytes);
// };
#endif

View File

@ -1,827 +0,0 @@
#include <iostream>
#include <iomanip>
#include <algorithm>
#include <rpc.hpp>
#include <device_mesh.hpp>
#include <simulate.hpp>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <chrono>
#include <limits>
#include <fstream>
#include <cmath>
#include <random>
#include <functional>
#include <thread>
namespace py = pybind11;
uint64_t VALID_COUNT_CAP = 25000000; // 25000000
size_t MAX_EXE_PER_RPC = 1000;
// std::unordered_map<std::string, DeviceMesh*> device_mesh_map;
void print_int_vector(std::vector<int> &vec) {
std::cout << "[";
for (int i = 0; i < static_cast<int>(vec.size()); i++) { std::cout << vec[i] << ", "; }
std::cout << "] ";
};
std::size_t vector_hash(const std::vector<int> &vec) {
std::size_t seed = vec.size();
for (const auto &i : vec) {
seed ^= std::hash<int>{}(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}
long check_memory_bytes() {
std::ifstream statm_file("/proc/self/statm");
long memory_usage_bytes = -1;
if (!statm_file) {
std::cerr << "Failed to open /proc/self/statm\n";
} else {
long size, resident, share, text, lib, data, dt;
statm_file >> size >> resident >> share >> text >> lib >> data >> dt;
// size is in pages, to convert to bytes, multiply by the page size
long page_size = sysconf(_SC_PAGESIZE);
memory_usage_bytes = size * page_size;
}
return memory_usage_bytes;
};
void make_rpc_exe_table(std::unordered_map<std::string, std::vector<RPCExecution *>> &rpc_exe_table,
std::vector<RPCExecution *> &rpc_exes) {
for (auto &rpc_exe : rpc_exes) {
if (rpc_exe_table[rpc_exe->rpc_ptr->rpc_name].size() < MAX_EXE_PER_RPC)
rpc_exe_table[rpc_exe->rpc_ptr->rpc_name].push_back(rpc_exe);
}
for (auto &x : rpc_exe_table) {
std::string rpc_name = x.first;
std::vector<RPCExecution *> &rpc_exe_list = x.second;
// sort first
std::sort(rpc_exe_list.begin(), rpc_exe_list.end(),
[](const RPCExecution *a, const RPCExecution *b) {
if (a->time_cost == b->time_cost)
return a->device_mesh.name < b->device_mesh.name;
else
return a->time_cost < b->time_cost;
});
}
}
void make_sorted_rpc_names(
std::vector<std::string> &sorted_rpc_names,
std::unordered_map<std::string, std::vector<RPCExecution *>> &rpc_exe_table) {
std::vector<std::pair<std::string, uint64_t>> average_time_cost;
for (auto &x : rpc_exe_table) {
std::string rpc_name = x.first;
std::vector<RPCExecution *> &rpc_exe_list = x.second;
uint64_t total_time_cost = 0;
int c = 0;
for (auto &rpc_exe : rpc_exe_list) {
total_time_cost += rpc_exe->time_cost;
c += 1;
if (c > 10) break;
}
average_time_cost.push_back(std::make_pair(rpc_name, total_time_cost / 10));
}
std::sort(average_time_cost.begin(), average_time_cost.end(),
[](const std::pair<std::string, float> &a, const std::pair<std::string, float> &b) {
return a.second > b.second;
});
for (auto &x : average_time_cost) { sorted_rpc_names.push_back(x.first); }
}
void prepare(std::unordered_map<std::string, std::vector<RPCExecution *>> &rpc_exe_table,
std::unordered_map<std::string, RPC *> &rpc_table,
std::vector<std::string> &sorted_rpc_names,
std::unordered_map<std::string, std::vector<RPCInstance *>> &ri_table,
std::unordered_map<std::string, std::vector<RPCInstance *>> &model_name_ri_table,
std::vector<RPC *> rpcs, std::vector<RPCExecution *> rpc_exes,
std::vector<RPCInstance *> graph) {
std::vector<std::pair<std::string, uint64_t>> average_time_cost;
for (auto &rpc : rpcs) { rpc_table[rpc->rpc_name] = rpc; }
make_rpc_exe_table(rpc_exe_table, rpc_exes);
make_sorted_rpc_names(sorted_rpc_names, rpc_exe_table);
for (auto &rpc_instance : graph) {
ri_table[rpc_instance->rpc_ptr->rpc_name].push_back(rpc_instance);
}
for (auto &rpc_instance : graph) {
model_name_ri_table[rpc_instance->rpc_ptr->model_name].push_back(rpc_instance);
}
};
std::vector<SimulateResult> mcmc_search(std::vector<RPC *> rpcs,
std::vector<RPCExecution *> rpc_exes,
std::vector<RPCInstance *> graph,
std::unordered_map<std::string, uint64_t> &cost_table,
std::unordered_map<std::string, uint64_t> model_sizes,
double beta, double time_limit,
MinEndTimeQueue &top_k_queue) {
std::unordered_map<std::string, std::vector<RPCExecution *>> rpc_exe_table;
std::unordered_map<std::string, RPC *> rpc_table;
std::vector<std::string> sorted_rpc_names;
std::unordered_map<std::string, std::vector<RPCInstance *>> ri_table;
std::unordered_map<std::string, std::vector<RPCInstance *>> model_name_ri_table;
std::chrono::duration<double> time_limit_duration(time_limit);
std::vector<SimulateResult> time_cost_cache;
prepare(rpc_exe_table, rpc_table, sorted_rpc_names, ri_table, model_name_ri_table, rpcs, rpc_exes,
graph);
std::vector<int> index;
std::vector<int> min_index;
std::vector<int> max_index;
uint64_t min_index_mem = 0;
uint64_t valid_count = 0;
uint64_t oom_count = 0;
int num_rpcs = static_cast<int>(sorted_rpc_names.size());
uint64_t min_time_cost = std::numeric_limits<uint64_t>::max();
uint64_t max_time_cost = 0;
double avg = 0;
// [6, 2, 3, 2, 12, 15, ]
// [0, 0, 0, 7, 8, 13, ]
// index = { 23, 7, 30, 154, 173, 265 };
// SimulateResult sr1 = simulate(graph, cost_table, model_sizes,
// rpc_table, rpc_exe_table, ri_table,
// model_name_ri_table, sorted_rpc_names,
// index);
// std::cout << "index 1 end time " << sr1.end_time << " mem cost " << sr1.mem_cost << std::endl;
// std::cout << "************************" << std::endl;
// index = { 1, 0, 0, 11, 20, 3 };
// SimulateResult sr2 = simulate(graph, cost_table, model_sizes,
// rpc_table, rpc_exe_table, ri_table,
// model_name_ri_table, sorted_rpc_names,
// index);
// std::cout << "index 2 end time " << sr2.end_time << " mem cost " << sr2.mem_cost << std::endl;
// exit(0);
// initial value
index.resize(sorted_rpc_names.size(), 0);
// index 1
SimulateResult first_sr = simulate(graph, cost_table, model_sizes, rpc_table, rpc_exe_table,
ri_table, model_name_ri_table, sorted_rpc_names, index);
SimulateResult final_sr = first_sr;
uint64_t current_cost = first_sr.end_time;
time_cost_cache.push_back(first_sr);
if (first_sr.oom)
oom_count += 1;
else
top_k_queue.insert(first_sr);
// std::cout << "initial cost " << current_cost << " oom "
// << first_sr.oom << std::endl;
// index = {0, 1, 3, 34, 4, 10};
// std::unordered_map<std::size_t, uint64_t> time_cost_cache;
auto start = std::chrono::high_resolution_clock::now();
// bool outer_loop_break_flag = false;
while (valid_count < VALID_COUNT_CAP) {
// only change one model execution in each iteration
// std::vector<SimulateResult> sr_vector;
std::unordered_map<int, std::pair<int, int>> flatten_to_pair;
std::vector<double> weight;
// double beta = 0.0075;
int max_step_range = 10000;
int current = 0;
std::vector<int> new_index(index);
for (int i = 0; i < num_rpcs; i++) {
std::string rpc_name = sorted_rpc_names[i];
int c_i = index[i];
int min_i = std::max(0, c_i - max_step_range);
int max_i =
std::min(static_cast<int>(rpc_exe_table[rpc_name].size()), c_i + max_step_range + 1);
for (int j = min_i; j < max_i; j++) {
if (j == c_i) continue;
// int tmp = new_index[i];
// new_index[i] = j;
// SimulateResult sr = simulate(graph, cost_table, model_sizes,
// rpc_table, rpc_exe_table, ri_table,
// model_name_ri_table, sorted_rpc_names,
// new_index);
// sr_vector.push_back(sr);
// new_index[i] = tmp;
flatten_to_pair[current] = std::make_pair(i, j);
current++;
}
}
// if (time_cost_cache.size() > 10000000) {
// time_cost_cache.clear();
// }
// std::cout << "sr vector size " << sr_vector.size() << std::endl;
// for (int i = 0; i < static_cast<int>(sr_vector.size()); i++) {
// weight.push_back(std::exp(-beta * (sr_vector[i].end_time/100000)));
// }
// exit(0);
std::random_device rd;
std::mt19937 gen(rd());
// std::discrete_distribution<int> d(weight.begin(), weight.end());
std::uniform_int_distribution<int> d(0, static_cast<int>(flatten_to_pair.size() - 1));
int selected = d(gen);
// assign new index
int selected_i = flatten_to_pair[selected].first;
int selected_j = flatten_to_pair[selected].second;
new_index[selected_i] = selected_j;
SimulateResult selected_sr =
simulate(graph, cost_table, model_sizes, rpc_table, rpc_exe_table, ri_table,
model_name_ri_table, sorted_rpc_names, new_index);
uint64_t selected_cost = selected_sr.end_time;
// if (selected_sr.oom) {
// std::cout << "oom max end time " << selected_cost << std::endl;
// } else {
// std::cout << "max end time " << selected_cost << std::endl;
// }
if (selected_cost < std::numeric_limits<uint64_t>::max()) {
bool accepted = true;
if (current_cost < selected_cost) {
double accept_prob = std::exp(-beta * ((selected_cost - current_cost) / 100000));
// double accept_prob = ap * ap;
std::bernoulli_distribution accept_dist(accept_prob);
accepted = accept_dist(gen);
}
if (accepted) {
// std::cout << "accepted" << std::endl;
index = new_index;
current_cost = selected_cost;
valid_count++;
if (!selected_sr.oom) {
// min_time_cost = std::min(min_time_cost, selected_cost);
// max_time_cost = std::max(max_time_cost, selected_cost);
avg = (selected_cost + avg * (valid_count - 1)) / valid_count;
if (min_time_cost > selected_cost) {
min_time_cost = selected_cost;
min_index = index;
min_index_mem = selected_sr.mem_cost;
// final_sr = selected_sr;
auto now = std::chrono::high_resolution_clock::now();
double diff =
std::chrono::duration_cast<std::chrono::milliseconds>(now - start).count();
selected_sr.used_time = diff;
time_cost_cache.push_back(selected_sr);
}
if (min_time_cost == selected_cost) {
if (min_index_mem > selected_sr.mem_cost) {
min_index = index;
min_index_mem = selected_sr.mem_cost;
// final_sr = selected_sr;
}
}
top_k_queue.insert(selected_sr);
if (max_time_cost < selected_cost) {
max_time_cost = selected_cost;
max_index = index;
}
} else {
oom_count += 1;
}
// if (min_time_cost <= 100000000) break; // DEBUG
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
if (valid_count % 1000 == 0) {
std::cout << " valid_count " << valid_count << " oom count " << oom_count << " time "
<< diff.count() << " min time cost " << min_time_cost << " min index mem cost "
<< min_index_mem << " max time cost " << max_time_cost << " current cost "
<< current_cost << " avg " << avg << " mem usage " << check_memory_bytes()
<< std::endl;
std::cout << "min index : ";
print_int_vector(min_index);
std::cout << std::endl;
std::cout << "max index : ";
print_int_vector(max_index);
std::cout << std::endl;
std::cout << "current index : ";
print_int_vector(index);
std::cout << std::endl;
}
if (diff > time_limit_duration) break;
}
}
}
// std::cout << "break" << std::endl;
// int rpc_index = 0;
// for (int index : final_sr.index) {
// // std::cout << "index " << index << " rpc_index " << rpc_index << std::endl;
// // std::cout << "rpc name " << sorted_rpc_names[rpc_index] << std::endl;
// RPCExecution* re_ptr = rpc_exe_table[sorted_rpc_names[rpc_index]][index];
// final_sr.rpc_exe_list.push_back(re_ptr);
// rpc_index ++;
// }
// std::cout << "final_sr.rpc_exe_list size " << final_sr.rpc_exe_list.size() << std::endl;
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
std::cout << "MCMC Search finished, beta " << beta << " time limit " << time_limit << " seconds"
<< " valid_count " << valid_count << " oom_count " << oom_count << " time "
<< diff.count() << " min time cost " << min_time_cost << " max time cost "
<< max_time_cost << " avg " << avg << std::endl;
std::cout << "RPCExecutions: " << std::endl;
for (auto &re_ptr : final_sr.rpc_exe_list) { std::cout << re_ptr->to_string() << std::endl; }
return time_cost_cache;
// return final_sr;
}
void multi_mcmc_search(std::vector<RPC *> rpcs, std::vector<RPCExecution *> rpc_exes,
std::vector<RPCInstance *> graph,
std::unordered_map<std::string, uint64_t> &cost_table,
std::unordered_map<std::string, uint64_t> model_sizes, double beta_min,
double beta_max, double beta_step, MinEndTimeQueue &res_queue,
int top_k = 10, double time_limit = 60.0,
int repeat = 1 // Remove the trailing comma here
) {
SimulateResult sr;
std::vector<MinEndTimeQueue *> queues;
// std::vector<std::thread> ts;
for (int i = 0; i < repeat; i++) {
for (double beta = beta_min; beta < beta_max; beta += beta_step) {
MinEndTimeQueue *q = new MinEndTimeQueue(10);
queues.push_back(q);
// Create a new thread to run mcmc_search
std::vector<SimulateResult> r =
mcmc_search(rpcs, rpc_exes, graph, cost_table, model_sizes, beta, time_limit, *q);
}
}
for (auto &q : queues) {
mergeMinEndTimeQueues(res_queue, *q);
// delete q;
}
// std::cout << "Best result: " << sr.end_time << std::endl;
// for (auto& re_ptr : sr.rpc_exe_list) {
// std::cout << re_ptr -> to_string() << std::endl;
// }
// std::cout << "Index: ";
// for (int i : sr.index) {
// std::cout << i << " ";
// }
// std::cout << std::endl;
// std::cout << "Time cost: " << sr.end_time
// << ", mem cost: " << sr.mem_cost << std::endl;
// std::cout << "sr.rpc_exe_list size " << sr.rpc_exe_list.size() << std::endl;
// return sr;
}
void input_check(std::vector<RPC *> rpcs, std::vector<RPCExecution *> rpc_exes,
std::vector<RPCInstance *> graph, CommStats &comm_stats) {
std::cout << "rpcs: " << rpcs.size() << std::endl;
std::cout << "rpc_exes: " << rpc_exes.size() << std::endl;
std::cout << "graph: " << graph.size() << std::endl;
for (auto &rpc_instance : graph) {
std::cout << "==================" << std::endl;
std::cout << "rpc instance: " << rpc_instance->name << std::endl;
std::cout << "parents" << std::endl;
for (auto &parent : rpc_instance->parents) { std::cout << parent->name << " "; }
std::cout << std::endl;
std::cout << "children" << std::endl;
for (auto &child : rpc_instance->children) { std::cout << child->name << " "; }
std::cout << std::endl;
// std::cout << "parents: " << rpc_instance -> parents.size()
// << " children: " << rpc_instance -> children.size() << std::endl;
}
std::cout << "comm_stats: " << std::endl;
std::cout << "local_send: " << comm_stats.local_send << std::endl;
std::cout << "local_recv: " << comm_stats.local_recv << std::endl;
std::cout << "remote_send: " << comm_stats.remote_send << std::endl;
std::cout << "remote_recv: " << comm_stats.remote_recv << std::endl;
std::cout << "offload_store: " << comm_stats.offload_store << std::endl;
std::cout << "offload_load: " << comm_stats.offload_load << std::endl;
}
RPC *cast_rpc(py::handle rpc_py) {
return new RPC(py::str(rpc_py.attr("model_name")).cast<std::string>(),
rpc_py.attr("name").cast<std::string>(),
py::str(rpc_py.attr("interface_type")).cast<std::string>());
}
DeviceMesh *cast_device_mesh(py::handle device_mesh_py,
std::unordered_map<std::string, DeviceMesh *> &device_mesh_map) {
std::string name = device_mesh_py.attr("name").cast<std::string>();
if (device_mesh_map.find(name) == device_mesh_map.end()) {
py::array_t<int32_t> mapping_array =
device_mesh_py.attr("mapping").cast<pybind11::array_t<int32_t>>();
py::buffer_info buf_info = mapping_array.request();
auto rows = buf_info.shape[0];
auto cols = buf_info.shape[1];
std::vector<std::vector<int>> mapping(rows, std::vector<int>(cols));
// Get a pointer to the data
int32_t *data = static_cast<int32_t *>(buf_info.ptr);
// Fill the 2D vector with data from the numpy array
for (size_t i = 0; i < static_cast<size_t>(rows); ++i) {
for (size_t j = 0; j < static_cast<size_t>(cols); ++j) { mapping[i][j] = data[i * cols + j]; }
}
DeviceMesh *device_mesh =
new DeviceMesh(device_mesh_py.attr("n_nodes").cast<int>(),
device_mesh_py.attr("n_gpus_per_node").cast<int>(), mapping,
device_mesh_py.attr("global_mesh_name").cast<std::string>(),
device_mesh_py.attr("name").cast<std::string>());
device_mesh_map[name] = device_mesh;
return device_mesh;
} else {
return device_mesh_map[name];
}
}
ModelParallelStrategy *cast_model_parallel_strategy(py::handle model_parallel_strategy_py) {
return new ModelParallelStrategy(
model_parallel_strategy_py.attr("pipeline_parallel_size").cast<int>(),
model_parallel_strategy_py.attr("data_parallel_size").cast<int>(),
model_parallel_strategy_py.attr("model_parallel_size").cast<int>());
}
RPCExecution *cast_rpc_execution(py::handle rpc_exe_py, std::unordered_map<std::string, RPC *> &tmp,
std::unordered_map<std::string, DeviceMesh *> &device_mesh_map) {
DeviceMesh *device_mesh = cast_device_mesh(rpc_exe_py.attr("device_mesh"), device_mesh_map);
ModelParallelStrategy *model_parallel_strategy =
cast_model_parallel_strategy(rpc_exe_py.attr("parallel_strategy"));
// RPC* rpc = cast_rpc(rpc_exe_py.attr("rpc"));
return new RPCExecution(
tmp[rpc_exe_py.attr("rpc").attr("name").cast<std::string>()], *device_mesh,
*model_parallel_strategy, rpc_exe_py.attr("time_cost").cast<uint64_t>(),
rpc_exe_py.attr("mem").cast<uint64_t>(), rpc_exe_py.attr("static_mem").cast<uint64_t>());
}
RPCInstance *cast_rpc_instance_wo_dependency(py::handle rpc_instance_py,
std::unordered_map<std::string, RPC *> &tmp) {
return new RPCInstance(tmp[rpc_instance_py.attr("rpc").attr("name").cast<std::string>()],
rpc_instance_py.attr("iteration_id").cast<int>(),
rpc_instance_py.attr("name").cast<std::string>());
}
void cast_rpc_instance_dependency(py::handle rpc_instance_py, RPCInstance *ri_ptr,
std::unordered_map<std::string, RPCInstance *> &tmp_graph) {
for (py::handle parent_py : rpc_instance_py.attr("parents"))
ri_ptr->parents.push_back(tmp_graph[parent_py.attr("name").cast<std::string>()]);
for (py::handle child_py : rpc_instance_py.attr("children"))
ri_ptr->children.push_back(tmp_graph[child_py.attr("name").cast<std::string>()]);
}
py::list py_single_mcmc_search_time_profile(py::list rpcs_py, py::list rpc_exes_py,
py::list graph_py, py::dict cost_table_py,
py::dict model_sizes_py, py::object beta,
py::object time_limit) {
std::vector<RPC *> rpcs;
std::unordered_map<std::string, RPC *> tmp;
for (py::handle rpc_py : rpcs_py) {
RPC *rpc_ptr = cast_rpc(rpc_py);
rpcs.push_back(rpc_ptr);
tmp[rpc_ptr->rpc_name] = rpc_ptr;
}
std::vector<RPCExecution *> rpc_exes;
std::unordered_map<std::string, DeviceMesh *> tmp_device_mesh;
for (py::handle rpc_exe_py : rpc_exes_py) {
RPCExecution *rpc_exe_ptr = cast_rpc_execution(rpc_exe_py, tmp, tmp_device_mesh);
rpc_exes.push_back(rpc_exe_ptr);
}
std::vector<RPCInstance *> graph;
std::unordered_map<std::string, RPCInstance *> tmp_graph;
for (py::handle ri_py : graph_py) {
RPCInstance *ri_ptr = cast_rpc_instance_wo_dependency(ri_py, tmp);
// std::cout << "cast " << ri_ptr -> name << std::endl;
tmp_graph[ri_ptr->name] = ri_ptr;
}
// build dependecny
for (py::handle ri_py : graph_py) {
std::string ri_name = ri_py.attr("name").cast<std::string>();
cast_rpc_instance_dependency(ri_py, tmp_graph[ri_name], tmp_graph);
graph.push_back(tmp_graph[ri_name]);
}
std::unordered_map<std::string, uint64_t> cost_table =
cost_table_py.cast<std::unordered_map<std::string, uint64_t>>();
std::unordered_map<std::string, uint64_t> model_sizes =
model_sizes_py.cast<std::unordered_map<std::string, uint64_t>>();
MinEndTimeQueue res_queue(10);
std::vector<SimulateResult> rlist =
mcmc_search(rpcs, rpc_exes, graph, cost_table, model_sizes, beta.cast<double>(),
time_limit.cast<double>(), res_queue);
std::unordered_map<std::string, std::vector<RPCExecution *>> rpc_exe_table;
std::vector<std::string> sorted_rpc_names;
make_rpc_exe_table(rpc_exe_table, rpc_exes);
make_sorted_rpc_names(sorted_rpc_names, rpc_exe_table);
py::list result;
std::cout << "rlist.size " << rlist.size() << std::endl;
for (auto &r : rlist) {
// SimulateResult r = res_queue.getQueue().top();
// res_queue.getQueue().pop();
std::cout << "End time: " << r.end_time << std::endl;
for (auto &re_ptr : r.rpc_exe_list) { std::cout << re_ptr->to_string() << std::endl; }
std::cout << "Index: ";
for (int i : r.index) { std::cout << i << " "; }
std::cout << std::endl;
std::cout << "Time cost: " << r.end_time << ", mem cost: " << r.mem_cost << std::endl;
int rpc_index = 0;
for (int index : r.index) {
// std::cout << "index " << index << " rpc_index " << rpc_index << std::endl;
// std::cout << "rpc name " << sorted_rpc_names[rpc_index] << std::endl;
RPCExecution *re_ptr = rpc_exe_table[sorted_rpc_names[rpc_index]][index];
r.rpc_exe_list.push_back(re_ptr);
rpc_index++;
}
py::dict rdict;
for (auto &re_ptr : r.rpc_exe_list) {
py::dict rpc_exe_info;
std::string rpc_name = re_ptr->rpc_ptr->rpc_name;
py::object rpc_name_obj = py::str(rpc_name);
// rpc_exe_info.append(re_ptr -> device_mesh.device_mesh_name);
// rpc_exe_info.append(re_ptr -> model_parallel_strategy.num_dp);
// rpc_exe_info.append(re_ptr -> model_parallel_strategy.num_mp);
// rpc_exe_info.append(re_ptr -> model_parallel_strategy.num_pp);
rpc_exe_info["device_mesh"] = re_ptr->device_mesh.name;
rpc_exe_info["num_dp"] = re_ptr->model_parallel_strategy.num_dp;
rpc_exe_info["num_mp"] = re_ptr->model_parallel_strategy.num_mp;
rpc_exe_info["num_pp"] = re_ptr->model_parallel_strategy.num_pp;
rdict[rpc_name_obj] = rpc_exe_info;
// std::cout << "append key " << rpc_name_obj << std::endl;
}
rdict["end_time"] = r.end_time;
rdict["mem_cost"] = r.mem_cost;
rdict["used_time"] = r.used_time;
result.append(rdict);
}
return result;
};
py::list py_multi_mcmc_search(py::list rpcs_py, py::list rpc_exes_py, py::list graph_py,
py::dict cost_table_py, py::dict model_sizes_py,
py::object beta_min_py, py::object beta_max_py,
py::object beta_step_py, py::object time_limit_py,
py::object repeat) {
std::vector<RPC *> rpcs;
std::unordered_map<std::string, RPC *> tmp;
for (py::handle rpc_py : rpcs_py) {
RPC *rpc_ptr = cast_rpc(rpc_py);
rpcs.push_back(rpc_ptr);
tmp[rpc_ptr->rpc_name] = rpc_ptr;
}
std::vector<RPCExecution *> rpc_exes;
std::unordered_map<std::string, DeviceMesh *> tmp_device_mesh;
for (py::handle rpc_exe_py : rpc_exes_py) {
RPCExecution *rpc_exe_ptr = cast_rpc_execution(rpc_exe_py, tmp, tmp_device_mesh);
rpc_exes.push_back(rpc_exe_ptr);
}
std::vector<RPCInstance *> graph;
std::unordered_map<std::string, RPCInstance *> tmp_graph;
for (py::handle ri_py : graph_py) {
RPCInstance *ri_ptr = cast_rpc_instance_wo_dependency(ri_py, tmp);
// std::cout << "cast " << ri_ptr -> name << std::endl;
tmp_graph[ri_ptr->name] = ri_ptr;
}
// build dependecny
for (py::handle ri_py : graph_py) {
std::string ri_name = ri_py.attr("name").cast<std::string>();
cast_rpc_instance_dependency(ri_py, tmp_graph[ri_name], tmp_graph);
graph.push_back(tmp_graph[ri_name]);
}
std::unordered_map<std::string, uint64_t> cost_table =
cost_table_py.cast<std::unordered_map<std::string, uint64_t>>();
std::unordered_map<std::string, uint64_t> model_sizes =
model_sizes_py.cast<std::unordered_map<std::string, uint64_t>>();
double beta_min = beta_min_py.cast<double>();
double beta_max = beta_max_py.cast<double>();
double beta_step = beta_step_py.cast<double>();
double time_limit = time_limit_py.cast<double>();
int rp = repeat.cast<int>();
MinEndTimeQueue res_queue(10);
multi_mcmc_search(rpcs, rpc_exes, graph, cost_table, model_sizes, beta_min, beta_max, beta_step,
res_queue, 10, time_limit, rp);
std::unordered_map<std::string, std::vector<RPCExecution *>> rpc_exe_table;
std::vector<std::string> sorted_rpc_names;
make_rpc_exe_table(rpc_exe_table, rpc_exes);
make_sorted_rpc_names(sorted_rpc_names, rpc_exe_table);
// std::cout << "r.rpc_exe_list size " << r.rpc_exe_list.size() << std::endl;
// for (int rpc_index = 0; rpc_index < 6; rpc_index++) {
// std::cout << "rpc name " << sorted_rpc_names[rpc_index] << std::endl;
// int index = 0;
// for (auto& re_ptr : rpc_exe_table[sorted_rpc_names[rpc_index]]) {
// std::cout << "index " << index << " " << re_ptr -> to_string() << std::endl;
// index ++;
// }
// }
py::list result;
std::cout << "res_queue.getQueue().size " << res_queue.getQueue().size() << std::endl;
while (!res_queue.getQueue().empty()) {
SimulateResult r = res_queue.getQueue().top();
res_queue.getQueue().pop();
std::cout << "End time: " << r.end_time << std::endl;
for (auto &re_ptr : r.rpc_exe_list) { std::cout << re_ptr->to_string() << std::endl; }
std::cout << "Index: ";
for (int i : r.index) { std::cout << i << " "; }
std::cout << std::endl;
std::cout << "Time cost: " << r.end_time << ", mem cost: " << r.mem_cost << std::endl;
int rpc_index = 0;
for (int index : r.index) {
// std::cout << "index " << index << " rpc_index " << rpc_index << std::endl;
// std::cout << "rpc name " << sorted_rpc_names[rpc_index] << std::endl;
RPCExecution *re_ptr = rpc_exe_table[sorted_rpc_names[rpc_index]][index];
r.rpc_exe_list.push_back(re_ptr);
rpc_index++;
}
py::dict rdict;
for (auto &re_ptr : r.rpc_exe_list) {
py::dict rpc_exe_info;
std::string rpc_name = re_ptr->rpc_ptr->rpc_name;
py::object rpc_name_obj = py::str(rpc_name);
// convert device mesh mapping into py::array_t<int>
std::vector<std::vector<int>> mapping = re_ptr->device_mesh.mapping;
int rows = mapping.size();
int cols = mapping[0].size();
py::array_t<int> numpy_array({rows, cols});
for (int i = 0; i < rows; ++i) {
for (int j = 0; j < cols; ++j) {
*numpy_array.mutable_data(i, j) = mapping[i][j];
// std::cout << i << j << mapping[i][j] << std::endl;
}
}
// store in py::dict
rpc_exe_info["device_mesh_mapping"] = numpy_array;
rpc_exe_info["device_mesh_name"] = re_ptr->device_mesh.name;
rpc_exe_info["num_dp"] = re_ptr->model_parallel_strategy.num_dp;
rpc_exe_info["num_mp"] = re_ptr->model_parallel_strategy.num_mp;
rpc_exe_info["num_pp"] = re_ptr->model_parallel_strategy.num_pp;
rdict[rpc_name_obj] = rpc_exe_info;
// std::cout << "append key " << rpc_name_obj << std::endl;
}
rdict["end_time"] = r.end_time;
rdict["mem_cost"] = r.mem_cost;
result.append(rdict);
}
return result;
};
PYBIND11_MODULE(mdm_search, m) {
m.doc() = "model device mapping search module";
// for debug
// m.def("mcmc_search", [](py::list rpcs_py, py::list rpc_exes_py,
// py::list graph_py, py::object comm_stats_py,
// py::dict model_sizes_py, py::object beta_py,
// py::object time_limit_py) {
// std::vector<RPC*> rpcs;
// std::unordered_map<std::string, RPC*> tmp;
// for (py::handle rpc_py : rpcs_py) {
// RPC* rpc_ptr = cast_rpc(rpc_py);
// rpcs.push_back(rpc_ptr);
// tmp[rpc_ptr -> rpc_name] = rpc_ptr;
// }
// std::vector<RPCExecution*> rpc_exes;
// for (py::handle rpc_exe_py : rpc_exes_py) {
// RPCExecution* rpc_exe_ptr = cast_rpc_execution(rpc_exe_py, tmp);
// rpc_exes.push_back(rpc_exe_ptr);
// }
// std::vector<RPCInstance*> graph;
// std::unordered_map<std::string, RPCInstance*> tmp_graph;
// for (py::handle ri_py : graph_py) {
// RPCInstance* ri_ptr = cast_rpc_instance_wo_dependency(ri_py, tmp);
// std::cout << "cast " << ri_ptr -> name << std::endl;
// tmp_graph[ri_ptr -> name] = ri_ptr;
// }
// // build dependecny
// for (py::handle ri_py : graph_py) {
// std::string ri_name = ri_py.attr("name").cast<std::string>();
// cast_rpc_instance_dependency(ri_py, tmp_graph[ri_name], tmp_graph);
// graph.push_back(tmp_graph[ri_name]);
// }
// CommStats comm_stats(
// comm_stats_py.attr("local_send").cast<uint64_t>(),
// comm_stats_py.attr("local_recv").cast<uint64_t>(),
// comm_stats_py.attr("remote_send").cast<uint64_t>(),
// comm_stats_py.attr("remote_recv").cast<uint64_t>(),
// comm_stats_py.attr("offload_load").cast<uint64_t>(),
// comm_stats_py.attr("offload_store").cast<uint64_t>()
// );
// std::unordered_map<std::string, uint64_t> model_sizes
// = model_sizes_py.cast<std::unordered_map<std::string, uint64_t>>();
// double beta = beta_py.cast<double>();
// double time_limit = time_limit_py.cast<double>();
// mcmc_search(rpcs, rpc_exes, graph,
// comm_stats, model_sizes, beta,
// time_limit);
// });
m.def("input_check",
[](py::list rpcs_py, py::list rpc_exes_py, py::list graph_py, py::object comm_stats_py) {
std::vector<RPC *> rpcs;
std::unordered_map<std::string, RPC *> tmp;
for (py::handle rpc_py : rpcs_py) {
RPC *rpc_ptr = cast_rpc(rpc_py);
rpcs.push_back(rpc_ptr);
tmp[rpc_ptr->rpc_name] = rpc_ptr;
}
std::vector<RPCExecution *> rpc_exes;
std::unordered_map<std::string, DeviceMesh *> tmp_device_mesh;
for (py::handle rpc_exe_py : rpc_exes_py) {
RPCExecution *rpc_exe_ptr = cast_rpc_execution(rpc_exe_py, tmp, tmp_device_mesh);
rpc_exes.push_back(rpc_exe_ptr);
}
std::vector<RPCInstance *> graph;
std::unordered_map<std::string, RPCInstance *> tmp_graph;
for (py::handle ri_py : graph_py) {
RPCInstance *ri_ptr = cast_rpc_instance_wo_dependency(ri_py, tmp);
std::cout << "cast " << ri_ptr->name << std::endl;
tmp_graph[ri_ptr->name] = ri_ptr;
}
// build dependecny
for (py::handle ri_py : graph_py) {
std::string ri_name = ri_py.attr("name").cast<std::string>();
cast_rpc_instance_dependency(ri_py, tmp_graph[ri_name], tmp_graph);
graph.push_back(tmp_graph[ri_name]);
}
CommStats comm_stats(comm_stats_py.attr("local_send").cast<uint64_t>(),
comm_stats_py.attr("local_recv").cast<uint64_t>(),
comm_stats_py.attr("remote_send").cast<uint64_t>(),
comm_stats_py.attr("remote_recv").cast<uint64_t>(),
comm_stats_py.attr("offload_load").cast<uint64_t>(),
comm_stats_py.attr("offload_store").cast<uint64_t>());
input_check(rpcs, rpc_exes, graph, comm_stats);
});
// mcmc search to py result
m.def("multi_mcmc_search", &py_multi_mcmc_search);
m.def("parameter_sync_cost", [](py::object rpcs_py, py::object param_size_bytes_py,
py::dict cost_table_py, py::object src_py, py::object dst_py) {
uint64_t param_size_bytes = param_size_bytes_py.cast<uint64_t>();
std::unordered_map<std::string, uint64_t> cost_table =
cost_table_py.cast<std::unordered_map<std::string, uint64_t>>();
std::vector<RPC *> rpcs;
std::unordered_map<std::string, RPC *> tmp;
for (py::handle rpc_py : rpcs_py) {
RPC *rpc_ptr = cast_rpc(rpc_py);
rpcs.push_back(rpc_ptr);
tmp[rpc_ptr->rpc_name] = rpc_ptr;
}
std::unordered_map<std::string, DeviceMesh *> tmp_device_mesh;
RPCExecution *src = cast_rpc_execution(src_py, tmp, tmp_device_mesh);
RPCExecution *dst = cast_rpc_execution(dst_py, tmp, tmp_device_mesh);
return parameter_sync_cost(param_size_bytes, src, dst, cost_table);
});
m.def("mcmc_search_time_profile", &py_single_mcmc_search_time_profile);
};

View File

@ -1,244 +0,0 @@
#include <rpc.hpp>
#include <device_mesh.hpp>
#include <simulate.hpp>
#include <iostream>
#include <queue>
#include <algorithm>
#include <iomanip>
#include <limits>
#include <chrono>
#define MAX(a, b) ((a) > (b) ? (a) : (b))
uint64_t SOFT_GPU_MEM_CAP = 85899345920; // 80G
SimulateResult::SimulateResult()
: end_time(std::numeric_limits<uint64_t>::max()), oom(true), mem_cost(0) {}
SimulateResult::SimulateResult(uint64_t end_time, bool oom, uint64_t mem_cost,
std::vector<int> &index)
: end_time(end_time), oom(oom), mem_cost(mem_cost), index(index) {}
SimulateResult simulate(
std::vector<RPCInstance *> &graph, std::unordered_map<std::string, uint64_t> &cost_table,
std::unordered_map<std::string, uint64_t> &model_sizes,
std::unordered_map<std::string, RPC *> &rpc_table,
std::unordered_map<std::string, std::vector<RPCExecution *>> &rpc_exe_table,
std::unordered_map<std::string, std::vector<RPCInstance *>> &ri_table,
std::unordered_map<std::string, std::vector<RPCInstance *>> &model_name_ri_table,
std::vector<std::string> &sorted_rpc_names, std::vector<int> &index) {
auto start = std::chrono::high_resolution_clock::now();
GroupedRPCExecutions grouped_rpc_exe;
// std::unordered_map<std::string, RPCExecution*> param_dst; // model_name -> rpc_exe_ptr
std::unordered_set<std::string> offloaded;
uint64_t oom_penalty = 3;
int num_rpcs = static_cast<int>(sorted_rpc_names.size());
for (int i = 0; i < num_rpcs; i++) {
std::string rpc_name = sorted_rpc_names[i];
RPC *rpc = rpc_table[rpc_name];
RPCExecution *rpc_exe = rpc_exe_table[rpc_name][index[i]];
for (auto &ri : ri_table[rpc_name]) {
ri->rpc_exe_ptr = rpc_exe; // assign rpc_exe to rpc_instance
}
// dirty implementation, check whether instance is hooked with param syncs
for (auto &rpc_instance : model_name_ri_table[rpc->model_name]) {
if (rpc_instance->rpc_ptr->interface_type == "ModelInterfaceType.TRAIN_STEP"
&& rpc->interface_type != "ModelInterfaceType.TRAIN_STEP") {
// param_dst[rpc -> model_name] = rpc_exe;
rpc_instance->param_sync = true;
rpc_instance->param_sync_size = model_sizes[rpc_instance->rpc_ptr->model_name];
rpc_instance->param_sync_rpc_exe_ptr = rpc_exe;
}
}
grouped_rpc_exe.add(rpc_exe);
std::string model_name = rpc->model_name;
}
// rpc_instances: list of rpc instances, graph
std::priority_queue<RPCInstance *, std::vector<RPCInstance *>, CompareReadyTime> ready_queue;
// std::vector<RPCInstance*> executed; // for debug, remove later
std::unordered_map<std::string, size_t> parent_executed;
std::unordered_set<DeviceMesh *> device_meshes;
// for offload and parameter sync RPC instances
std::vector<RPCInstance *> tmp_graph;
// resolve parameter sync
for (RPCInstance *node : graph) {
tmp_graph.push_back(node);
node->tmp_children = node->children;
node->tmp_parents = node->parents;
// std::cout << "Resolve parameter sync: " << node -> name
// << " " << node -> param_sync << std::endl;
node->resolve_parameter_sync(tmp_graph, cost_table);
// node -> resolve_offload(tmp_graph, comm_stats);
}
uint64_t max_end_time = 0;
for (RPCInstance *node : tmp_graph) {
// std::cout << "Node: " << node -> name << " parents: "
// << node -> parents.size() << std::endl;
if (node->parents.size() == 0) ready_queue.push(node);
// init device meshes
RPCExecution *rpc_exe = node->rpc_exe_ptr;
device_meshes.insert(&rpc_exe->device_mesh);
}
std::vector<RPCInstance *> executed;
// simulate
while (!ready_queue.empty()) {
RPCInstance *t = ready_queue.top();
RPCExecution *rpc_exe = t->rpc_exe_ptr;
uint64_t exec_time = rpc_exe->time_cost;
DeviceMesh *device_mesh = &rpc_exe->device_mesh;
ready_queue.pop();
if (device_mesh->pre_task == nullptr) {
t->start_time = t->ready_time;
} else {
t->start_time = MAX(t->ready_time, device_mesh->pre_task->end_time);
}
t->end_time = t->start_time + exec_time;
max_end_time = MAX(t->end_time, max_end_time);
for (DeviceMesh *mesh : device_meshes) {
if (device_mesh->overlap(*mesh)) {
if (mesh->pre_task == nullptr || mesh->pre_task->end_time <= t->end_time) {
mesh->pre_task = t;
}
// mesh -> pre_task = t;
}
}
executed.push_back(t);
for (RPCInstance *child : t->tmp_children) {
child->ready_time = MAX(t->end_time, child->ready_time);
// std::cout << "parent: " << t -> name
// << " child: " << child -> name << std::endl;
parent_executed[child->name] += 1;
// child -> remove_parent(t);
if (child->tmp_parents.size() == parent_executed[child->name]) {
ready_queue.push(child);
// std::cout << "Ready: " << child -> name
// << " ready time " << child -> ready_time << std::endl;
}
}
// std::cout << "ready_queue size " << ready_queue.size()
// << " executed size " << executed.size() << std::endl;
}
// 110045999
// if (max_end_time < 100000000) { // DEBUG
// std::cout << "INDEX: [";
// for (int i : index) {
// std::cout << i << ", ";
// }
// std::cout << "]" << std::endl;
// for (auto& x : rpc_exe_table){
// std::string rpc_name = x.first;
// std::vector<RPCExecution*>& rpc_exe_list = x.second;
// int count = 0;
// for (RPCExecution* rpc_exe : rpc_exe_list) {
// std::cout << "RPC: " << rpc_name
// << " device mesh " << rpc_exe -> device_mesh.device_mesh_name
// << " time cost " << rpc_exe -> time_cost << std::endl;
// count ++;
// if (count > 10) break;
// }
// }
// for (RPCInstance* ri : executed) {
// for (RPCInstance* parent : ri -> tmp_parents) {
// std::cout << "Parent: " << parent -> name << " of " << ri -> name
// << " start time " << parent -> start_time
// << " end time " << parent -> end_time << std::endl;
// }
// std::cout << "Executed: " << ri -> name
// << " start time " << ri -> start_time
// << " end time " << ri -> end_time
// << " rpc name " << ri -> rpc_ptr -> rpc_name
// << " device mesh "
// << ri -> rpc_exe_ptr -> device_mesh.device_mesh_name
// << " rpc exe time cost " << ri -> rpc_exe_ptr -> time_cost
// << std::endl;
// }
// }
// clear device mesh pre tasks
for (DeviceMesh *mesh : device_meshes) { mesh->pre_task = nullptr; }
// clear rpc instance times
for (RPCInstance *node : graph) {
node->tmp_children.clear();
node->tmp_parents.clear();
node->ready_time = 0;
node->start_time = 0;
node->end_time = 0;
tmp_graph.clear();
for (RPCInstance *ptr : node->tmp_ris) { delete ptr; }
node->tmp_ris.clear();
for (RPCExecution *ptr : node->tmp_exes) { delete ptr; }
node->tmp_exes.clear();
}
uint64_t current_mem = grouped_rpc_exe.total_mem_cost();
if (current_mem > SOFT_GPU_MEM_CAP) { max_end_time *= oom_penalty; }
// std::cout << "Max end time: " << max_end_time
// << " executed size " << executed.size() << std::endl;
std::chrono::duration<double> elapsed = std::chrono::high_resolution_clock::now() - start;
// std::cout << "Elapsed time (micro seconds): "
// << std::chrono::duration_cast<std::chrono::microseconds>(elapsed).count()
// << std::endl;
for (OverlapGroup *ptr : grouped_rpc_exe.group.overlap_groups) { delete ptr; }
grouped_rpc_exe.group.overlap_groups.clear();
return SimulateResult(max_end_time, current_mem > SOFT_GPU_MEM_CAP, current_mem, index);
};
SimulateResult &SimulateResult::operator=(const SimulateResult &other) {
if (this != &other) {
end_time = other.end_time;
oom = other.oom;
mem_cost = other.mem_cost;
index = other.index;
rpc_exe_list = other.rpc_exe_list;
}
return *this;
};
bool isPresent(MinEndTimeQueue &q, SimulateResult element) {
std::priority_queue<SimulateResult, std::vector<SimulateResult>, CompareEndTime> pq =
q.getQueue();
std::queue<SimulateResult> tmp;
while (!pq.empty()) {
if (pq.top().end_time == element.end_time) { return true; }
tmp.push(pq.top());
pq.pop();
}
while (!tmp.empty()) {
pq.push(tmp.front());
tmp.pop();
}
return false;
}
void mergeMinEndTimeQueues(MinEndTimeQueue &target, MinEndTimeQueue &q1) {
// Get the underlying priority queues
std::priority_queue<SimulateResult, std::vector<SimulateResult>, CompareEndTime> pq1 =
q1.getQueue();
// Insert all elements from q1 into the merged queue
while (!pq1.empty()) {
if (!isPresent(target, pq1.top())) target.insert(pq1.top());
pq1.pop();
}
}

View File

@ -1,75 +0,0 @@
#ifndef SIMULATE_HPP
#define SIMULATE_HPP
#include <rpc.hpp>
#include <device_mesh.hpp>
#include <queue>
#include <iostream>
class SimulateResult {
public:
uint64_t end_time;
bool oom;
uint64_t mem_cost;
std::vector<int> index;
std::vector<RPCExecution *> rpc_exe_list;
double used_time = 0;
SimulateResult();
SimulateResult(uint64_t end_time, bool oom, uint64_t mem_cost, std::vector<int> &index);
SimulateResult &operator=(const SimulateResult &other);
};
SimulateResult simulate(
std::vector<RPCInstance *> &graph, std::unordered_map<std::string, uint64_t> &cost_table,
std::unordered_map<std::string, uint64_t> &model_sizes,
std::unordered_map<std::string, RPC *> &rpc_table,
std::unordered_map<std::string, std::vector<RPCExecution *>> &rpc_exe_table,
std::unordered_map<std::string, std::vector<RPCInstance *>> &ri_table,
std::unordered_map<std::string, std::vector<RPCInstance *>> &model_name_ri_table,
std::vector<std::string> &sorted_rpc_names, std::vector<int> &index);
// Comparator for priority queue
struct CompareEndTime {
bool operator()(SimulateResult const &r1, SimulateResult const &r2) {
// We want largest end_time at the top of the queue, so we reverse the comparison
return r1.end_time < r2.end_time;
}
};
class MinEndTimeQueue {
public:
MinEndTimeQueue(int capacity) : k(capacity) {}
void insert(SimulateResult r) {
if (queue.size() < k) {
// std::cout << "push " << "end_time: " << r.end_time << " qsize " << queue.size() <<
// std::endl;
queue.push(r);
} else if (r.end_time < queue.top().end_time) {
// std::cout << "push " << "end_time: " << r.end_time << " qsize " << queue.size() <<
// std::endl;
queue.pop();
queue.push(r);
}
}
std::priority_queue<SimulateResult, std::vector<SimulateResult>, CompareEndTime> &getQueue() {
return queue;
}
private:
std::priority_queue<SimulateResult, std::vector<SimulateResult>, CompareEndTime> queue;
int k;
};
void mergeMinEndTimeQueues(MinEndTimeQueue &target, MinEndTimeQueue &q1);
class CompareReadyTime {
public:
bool operator()(RPCInstance *r1, RPCInstance *r2) { return r1->ready_time > r2->ready_time; }
};
#endif // SIMULATE_HPP

View File

@ -83,7 +83,7 @@ async def async_invoke_function(
url: str, url: str,
timeout: aiohttp.ClientTimeout, timeout: aiohttp.ClientTimeout,
payload: Dict[str, Any] = None, payload: Dict[str, Any] = None,
max_retries: int = 100, max_retries: int = 2,
initial_retry_interval: float = 0.5, initial_retry_interval: float = 0.5,
max_retry_interval: float = 10.0, max_retry_interval: float = 10.0,
): ):
@ -137,7 +137,7 @@ async def async_invoke_function(
) )
retries += 1 retries += 1
if retries > max_retries: if retries >= max_retries:
return { return {
"uid": payload.get("uid", ""), "uid": payload.get("uid", ""),
"success": False, "success": False,
@ -189,12 +189,13 @@ async def batch_function_call_async(payload_list, url, timeout, concurrency=1500
data_list.append(data) data_list.append(data)
elapsed_times.append(elapsed) elapsed_times.append(elapsed)
p50 = median(elapsed_times) if len(elapsed_times) > 0:
p90 = calculate_percentile(elapsed_times, 90) p50 = median(elapsed_times)
p99 = calculate_percentile(elapsed_times, 99) p90 = calculate_percentile(elapsed_times, 90)
logger.info( p99 = calculate_percentile(elapsed_times, 99)
f"Longest functioncall took {max_elapsed:.4f} seconds, timeout: {timeout}, uid: {max_elapsed_uid}, Active connections: {len(connector._conns)}, p50: {p50}, p90: {p90}, p99: {p99}" logger.info(
) f"Longest functioncall took {max_elapsed:.4f} seconds, timeout: {timeout}, uid: {max_elapsed_uid}, Active connections: {len(connector._conns)}, p50: {p50}, p90: {p90}, p99: {p99}"
)
return data_list return data_list

View File

@ -1,6 +1,7 @@
import ast import ast
import faulthandler import faulthandler
import json import json
import os
import platform import platform
# to run the solution files we're using a timing based approach # to run the solution files we're using a timing based approach
@ -38,6 +39,14 @@ def truncatefn(s, length=300):
return s[: length // 2] + "...(truncated) ..." + s[-length // 2 :] return s[: length // 2] + "...(truncated) ..." + s[-length // 2 :]
def load_from_path(l):
outputs = []
for x in l:
with open(x) as f:
outputs.append(f.read())
return outputs
class CODE_TYPE(Enum): class CODE_TYPE(Enum):
call_based = 0 call_based = 0
standard_input = 1 standard_input = 1
@ -105,6 +114,13 @@ def run_test(sample, test=None, debug=False, timeout=6):
which_type = CODE_TYPE.call_based which_type = CODE_TYPE.call_based
if in_outs: if in_outs:
assert "inputs" in in_outs
assert "outputs" in in_outs
if os.path.isfile(in_outs["inputs"][0]):
assert os.path.isfile(in_outs["outputs"][0])
in_outs["inputs"] = load_from_path(in_outs["inputs"])
in_outs["outputs"] = load_from_path(in_outs["outputs"])
if in_outs.get("fn_name", "") == "": if in_outs.get("fn_name", "") == "":
which_type = CODE_TYPE.standard_input # Standard input which_type = CODE_TYPE.standard_input # Standard input
method_name = None method_name = None

View File

@ -59,9 +59,10 @@ def call_verify(problem, generation, debug, timeout=SINGLE_CASE_EXEC_TIMEOUT):
shell=True, shell=True,
preexec_fn=os.setsid, preexec_fn=os.setsid,
stdout=subprocess.DEVNULL, stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
) )
try: try:
pro.wait(600) pro.wait(200)
except Exception as e: except Exception as e:
pass pass
try: try:

View File

@ -9,7 +9,7 @@ from functioncall.base.utils import construct_uid, load_jsonl, logger
SINGLE_CASE_EXEC_TIMEOUT = 6 SINGLE_CASE_EXEC_TIMEOUT = 6
TEST_CASE_BATCH_SIZE = 1 TEST_CASE_BATCH_SIZE = 1
FUNCTIONCALL_TIMEOUT = 1000 FUNCTIONCALL_TIMEOUT = 100
def round_up_memory(memory): def round_up_memory(memory):

View File

@ -0,0 +1,144 @@
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
index 174656b2..33fe0a5f 100644
--- a/python/sglang/srt/managers/io_struct.py
+++ b/python/sglang/srt/managers/io_struct.py
@@ -687,10 +687,21 @@ class FlushCacheReqOutput:
success: bool
+@dataclass
+class InterruptAllReqInput:
+ pass
+
+
+@dataclass
+class InterruptAllReqOutput:
+ num_interrupted_requests: int
+
+
@dataclass
class UpdateWeightFromDiskReqInput:
# The model path with the new weights
model_path: str
+ allow_interrupt: bool = False
# The format to load the weights
load_format: Optional[str] = None
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 8891115c..843a8a82 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -70,6 +70,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ InterruptAllReqInput,
+ InterruptAllReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -419,6 +421,7 @@ class Scheduler(
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
+ (InterruptAllReqInput, self.interrupt_all_requests),
(TokenizedGenerateReqInput, self.handle_generate_request),
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
(FlushCacheReqInput, self.flush_cache_wrapped),
@@ -1938,6 +1941,15 @@ class Scheduler(
def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError()
+ def interrupt_all_requests(self, recv_req: InterruptAllReqInput):
+ num = len(self.waiting_queue) + len(self.running_batch.reqs)
+ for req in self.waiting_queue:
+ req.sampling_params.max_new_tokens = 0
+ for req in self.running_batch.reqs:
+ req.sampling_params.max_new_tokens = len(req.output_ids)
+ logger.info(f"Interrupt {num} requests.")
+ return InterruptAllReqOutput(num)
+
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req)
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index 82709b09..bfab3ce7 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -76,6 +76,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ InterruptAllReqInput,
+ InterruptAllReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -265,6 +267,9 @@ class TokenizerManager:
self.resume_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
+ self.interrupt_requests_communicator = _Communicator(
+ self.send_to_scheduler, server_args.dp_size
+ )
self.flush_cache_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
@@ -294,6 +299,10 @@ class TokenizerManager:
UpdateWeightFromDiskReqOutput,
self._handle_update_weights_from_disk_req_output,
),
+ (
+ InterruptAllReqOutput,
+ self.interrupt_requests_communicator.handle_recv,
+ ),
(
InitWeightsUpdateGroupReqOutput,
self.init_weights_update_group_communicator.handle_recv,
@@ -767,6 +776,13 @@ class TokenizerManager:
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
+ if obj.allow_interrupt:
+ num_interrupted_requests = await self.interrupt_all_requests(
+ InterruptAllReqInput()
+ )
+ # Set a break point to wait for the interrupt to finish
+ await asyncio.sleep(0.1)
+
# default the load format to the server_args
if obj.load_format is None:
obj.load_format = self.server_args.load_format
@@ -776,7 +792,12 @@ class TokenizerManager:
# Hold the lock if it is not async. This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
- return await self._wait_for_model_update_from_disk(obj)
+ success, message, n_paused = (
+ await self._wait_for_model_update_from_disk(obj)
+ )
+ if obj.allow_interrupt:
+ return success, message, num_interrupted_requests
+ return success, message, n_paused
async def _wait_for_model_update_from_disk(
self, obj: UpdateWeightFromDiskReqInput
@@ -849,6 +870,18 @@ class TokenizerManager:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message
+ async def interrupt_all_requests(
+ self,
+ obj: InterruptAllReqInput,
+ request: Optional[fastapi.Request] = None,
+ ) -> Tuple[bool, str]:
+ self.auto_create_handle_loop()
+ result = await self.interrupt_requests_communicator(obj)
+ if self.server_args.dp_size == 1:
+ return result[0].num_interrupted_requests
+ else:
+ return [r.num_interrupted_requests for r in result]
+
async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
):

View File

@ -0,0 +1,144 @@
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
index 5390668c..db370d19 100644
--- a/python/sglang/srt/managers/io_struct.py
+++ b/python/sglang/srt/managers/io_struct.py
@@ -687,10 +687,21 @@ class FlushCacheReqOutput:
success: bool
+@dataclass
+class InterruptAllReqInput:
+ pass
+
+
+@dataclass
+class InterruptAllReqOutput:
+ num_interrupted_requests: int
+
+
@dataclass
class UpdateWeightFromDiskReqInput:
# The model path with the new weights
model_path: str
+ allow_interrupt: bool = False
# The format to load the weights
load_format: Optional[str] = None
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 1178eec5..318dee33 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -73,6 +73,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ InterruptAllReqInput,
+ InterruptAllReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -427,6 +429,7 @@ class Scheduler(
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
+ (InterruptAllReqInput, self.interrupt_all_requests),
(TokenizedGenerateReqInput, self.handle_generate_request),
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
(FlushCacheReqInput, self.flush_cache_wrapped),
@@ -1971,6 +1974,15 @@ class Scheduler(
def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError()
+ def interrupt_all_requests(self, recv_req: InterruptAllReqInput):
+ num = len(self.waiting_queue) + len(self.running_batch.reqs)
+ for req in self.waiting_queue:
+ req.sampling_params.max_new_tokens = 0
+ for req in self.running_batch.reqs:
+ req.sampling_params.max_new_tokens = len(req.output_ids)
+ logger.info(f"Interrupt {num} requests.")
+ return InterruptAllReqOutput(num)
+
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req)
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index b646fae1..c668728b 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -80,6 +80,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ InterruptAllReqInput,
+ InterruptAllReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -279,6 +281,9 @@ class TokenizerManager:
self.slow_down_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
+ self.interrupt_requests_communicator = _Communicator(
+ self.send_to_scheduler, server_args.dp_size
+ )
self.flush_cache_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
@@ -309,6 +314,10 @@ class TokenizerManager:
UpdateWeightFromDiskReqOutput,
self._handle_update_weights_from_disk_req_output,
),
+ (
+ InterruptAllReqOutput,
+ self.interrupt_requests_communicator.handle_recv,
+ ),
(
InitWeightsUpdateGroupReqOutput,
self.init_weights_update_group_communicator.handle_recv,
@@ -799,6 +808,13 @@ class TokenizerManager:
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
+ if obj.allow_interrupt:
+ num_interrupted_requests = await self.interrupt_all_requests(
+ InterruptAllReqInput()
+ )
+ # Set a break point to wait for the interrupt to finish
+ await asyncio.sleep(0.1)
+
# default the load format to the server_args
if obj.load_format is None:
obj.load_format = self.server_args.load_format
@@ -808,7 +824,12 @@ class TokenizerManager:
# Hold the lock if it is not async. This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
- return await self._wait_for_model_update_from_disk(obj)
+ success, message, n_paused = (
+ await self._wait_for_model_update_from_disk(obj)
+ )
+ if obj.allow_interrupt:
+ return success, message, num_interrupted_requests
+ return success, message, n_paused
async def _wait_for_model_update_from_disk(
self, obj: UpdateWeightFromDiskReqInput
@@ -881,6 +902,18 @@ class TokenizerManager:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message
+ async def interrupt_all_requests(
+ self,
+ obj: InterruptAllReqInput,
+ request: Optional[fastapi.Request] = None,
+ ) -> Tuple[bool, str]:
+ self.auto_create_handle_loop()
+ result = await self.interrupt_requests_communicator(obj)
+ if self.server_args.dp_size == 1:
+ return result[0].num_interrupted_requests
+ else:
+ return [r.num_interrupted_requests for r in result]
+
async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
):

View File

@ -97,8 +97,7 @@ class PromptOnlyDatasetConfig:
class ModelFamily: class ModelFamily:
"""Identifier for HuggingFace model types (e.g., llama, gpt2). """Identifier for HuggingFace model types (e.g., llama, gpt2).
Used for model registration and allocation. The size parameter is specifically Used for model registration and allocation.
relevant for the 'search' allocation mode.
""" """
_class: str = field( _class: str = field(
@ -107,12 +106,6 @@ class ModelFamily:
"`realhf/api/from_hf` for supported models.", "`realhf/api/from_hf` for supported models.",
} }
) )
size: int = field(
default=0,
metadata={
"help": "Model size parameter. Only used by 'search' allocation mode, ignored otherwise",
},
)
is_critic: bool = field( is_critic: bool = field(
default=False, default=False,
metadata={ metadata={
@ -121,8 +114,8 @@ class ModelFamily:
) )
def __repr__(self): def __repr__(self):
"""Returns formatted string representation: '{class}-{size}[-critic]'.""" """Returns formatted string representation: '{class}[-critic]'."""
s = f"{self._class}-{self.size}" s = f"{self._class}"
if self.is_critic: if self.is_critic:
s += "-critic" s += "-critic"
return s return s
@ -136,7 +129,7 @@ class ParallelismConfig:
Sequence parallelism is only used in combination with tensor-model parallelism. Sequence parallelism is only used in combination with tensor-model parallelism.
""" """
model_parallel_size: int = field( tensor_parallel_size: int = field(
default=1, metadata={"help": "Size of tensor-model parallelism"} default=1, metadata={"help": "Size of tensor-model parallelism"}
) )
pipeline_parallel_size: int = field( pipeline_parallel_size: int = field(
@ -155,7 +148,7 @@ class ParallelismConfig:
def __str__(self): def __str__(self):
"""Returns compact string representation: 'Parallel(mp=X,pp=Y,dp=Z)'.""" """Returns compact string representation: 'Parallel(mp=X,pp=Y,dp=Z)'."""
return ( return (
f"Parallel(mp={self.model_parallel_size}," f"Parallel(mp={self.tensor_parallel_size},"
f"pp={self.pipeline_parallel_size}," f"pp={self.pipeline_parallel_size},"
f"dp={self.data_parallel_size})" f"dp={self.data_parallel_size})"
) )
@ -168,7 +161,7 @@ class ParallelismConfig:
Implemented as static method to avoid OmegaConf compatibility issues. Implemented as static method to avoid OmegaConf compatibility issues.
""" """
return ( return (
(this.model_parallel_size == other.model_parallel_size) (this.tensor_parallel_size == other.tensor_parallel_size)
and (this.pipeline_parallel_size == other.pipeline_parallel_size) and (this.pipeline_parallel_size == other.pipeline_parallel_size)
and (this.data_parallel_size == other.data_parallel_size) and (this.data_parallel_size == other.data_parallel_size)
) )
@ -186,7 +179,7 @@ class OptimizerConfig:
default="adam", default="adam",
metadata={"help": "Optimizer type", "choices": ["adam", "empty"]}, metadata={"help": "Optimizer type", "choices": ["adam", "empty"]},
) )
lr: float = field(default=1e-5, metadata={"help": "Learning rate"}) lr: float = field(default=2e-5, metadata={"help": "Learning rate"})
weight_decay: float = field(default=0.05, metadata={"help": "Weight decay"}) weight_decay: float = field(default=0.05, metadata={"help": "Weight decay"})
beta1: float = field(default=0.9, metadata={"help": "Adam beta1 parameter"}) beta1: float = field(default=0.9, metadata={"help": "Adam beta1 parameter"})
beta2: float = field(default=0.95, metadata={"help": "Adam beta2 parameter"}) beta2: float = field(default=0.95, metadata={"help": "Adam beta2 parameter"})
@ -198,14 +191,14 @@ class OptimizerConfig:
}, },
) )
lr_scheduler_type: str = field( lr_scheduler_type: str = field(
default="cosine", default="constant",
metadata={ metadata={
"help": "Learning rate scheduler type", "help": "Learning rate scheduler type",
"choices": ["linear", "cosine", "constant"], "choices": ["linear", "cosine", "constant"],
}, },
) )
warmup_steps_proportion: float = field( warmup_steps_proportion: float = field(
default=0.02, default=0.001,
metadata={ metadata={
"help": "Proportion of training steps for warmup", "help": "Proportion of training steps for warmup",
}, },
@ -237,6 +230,7 @@ class vLLMConfig:
""" """
max_num_seqs: int = 256 max_num_seqs: int = 256
dtype: str = "float16"
kv_cache_type: str = "auto" kv_cache_type: str = "auto"
num_scheduler_steps: int = 1 num_scheduler_steps: int = 1
multi_step_stream_outputs: bool = True multi_step_stream_outputs: bool = True
@ -278,7 +272,6 @@ class SGLangConfig:
enable_nccl_nvls: bool = False enable_nccl_nvls: bool = False
disable_outlines_disk_cache: bool = False disable_outlines_disk_cache: bool = False
disable_custom_all_reduce: bool = False disable_custom_all_reduce: bool = False
disable_mla: bool = False
disable_overlap_schedule: bool = False disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False enable_mixed_chunk: bool = False
enable_dp_attention: bool = False enable_dp_attention: bool = False
@ -296,7 +289,7 @@ class SGLangConfig:
enable_memory_saver: bool = False enable_memory_saver: bool = False
allow_auto_truncate: bool = False allow_auto_truncate: bool = False
# NOTE: to avoid the illegal memory access error # NOTE: to avoid the illegal memory access error
attention_backend: Optional[str] = "triton" attention_backend: Optional[str] = "flashinfer"
sampling_backend: Optional[str] = None sampling_backend: Optional[str] = None
context_length: Optional[int] = 32768 context_length: Optional[int] = 32768
mem_fraction_static: Optional[float] = 0.9 mem_fraction_static: Optional[float] = 0.9
@ -309,15 +302,19 @@ class SGLangConfig:
schedule_conservativeness: float = 1.0 schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0 cpu_offload_gb: int = 0
hybrid_train: bool = False hybrid_train: bool = False
dtype: str = "float16"
kv_cache_dtype: str = "auto"
# logging # logging
log_level: str = "info" log_level: str = "warning"
log_level_http: Optional[str] = "warning" log_level_http: Optional[str] = "warning"
log_requests: bool = False log_requests: bool = False
log_requests_level: int = 0 log_requests_level: int = 0
show_time_cost: bool = False show_time_cost: bool = False
enable_metrics: bool = True # Exports Prometheus-like metrics enable_metrics: bool = True # Exports Prometheus-like metrics
decode_log_interval: int = 1000 # How often (in tokens) to log decode progress. # The interval (in decoding iterations) to log throughput
# and update prometheus metrics
decode_log_interval: int = 1
# Use staticmethod to make OmegaConf happy. # Use staticmethod to make OmegaConf happy.
@staticmethod @staticmethod
@ -327,6 +324,7 @@ class SGLangConfig:
tp_size, tp_size,
server_index, server_index,
base_gpu_id, base_gpu_id,
dist_init_addr: Optional[str] = None,
): ):
from realhf.base import constants, network, pkg_version, seeding from realhf.base import constants, network, pkg_version, seeding
from realhf.experiments.common.utils import asdict as conf_as_dict from realhf.experiments.common.utils import asdict as conf_as_dict
@ -345,7 +343,6 @@ class SGLangConfig:
tokenizer_mode="auto", tokenizer_mode="auto",
load_format="auto", load_format="auto",
trust_remote_code=True, trust_remote_code=True,
kv_cache_dtype="auto",
device="cuda", device="cuda",
served_model_name=f"{constants.experiment_name()}/{constants.trial_name()}/{model_path}", served_model_name=f"{constants.experiment_name()}/{constants.trial_name()}/{model_path}",
is_embedding=False, is_embedding=False,
@ -365,6 +362,7 @@ class SGLangConfig:
ep_size=1, # TODO: check ep_size=1, # TODO: check
nnodes=1, nnodes=1,
node_rank=0, node_rank=0,
dist_init_addr=dist_init_addr,
**args, **args,
) )
@ -385,6 +383,10 @@ class SGLangConfig:
if v is True: if v is True:
flags.append(f"--{k.replace('_','-')} ") flags.append(f"--{k.replace('_','-')} ")
continue continue
if isinstance(v, list):
values = " ".join(map(str, v))
flags.append(f"--{k.replace('_','-')} {values}")
continue
flags.append(f"--{k.replace('_','-')} {v}") flags.append(f"--{k.replace('_','-')} {v}")
flags = " ".join(flags) flags = " ".join(flags)
return f"python3 -m sglang.launch_server {flags}" return f"python3 -m sglang.launch_server {flags}"
@ -444,7 +446,7 @@ class ModelTrainEvalConfig:
# Model Architecture Configuration # Model Architecture Configuration
type: ModelFamily = field( type: ModelFamily = field(
default=ModelFamily("llama", 7, False), default=ModelFamily("llama", False),
metadata={"help": "Model family specification"}, metadata={"help": "Model family specification"},
) )
path: str = field(default="", metadata={"help": "Path to HuggingFace checkpoint"}) path: str = field(default="", metadata={"help": "Path to HuggingFace checkpoint"})
@ -679,13 +681,13 @@ class PPOHyperparameters:
value_norm_eps: float = field( value_norm_eps: float = field(
default=1e-5, metadata={"help": "Epsilon term for numerical stability"} default=1e-5, metadata={"help": "Epsilon term for numerical stability"}
) )
# Experimental Features
recompute_logprob: bool = field( recompute_logprob: bool = field(
default=False, default=False,
metadata={ metadata={"help": "Recompute logp and replace the logp returned by inference."},
"help": "Recompute log probabilities after generation. Used mainly for debugging purposes" )
}, use_decoupled_loss: bool = field(
default=False,
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."},
) )
@ -772,6 +774,13 @@ class ExperimentSaveEvalControl:
"For benchmarking purposes only. None indicates normal training." "For benchmarking purposes only. None indicates normal training."
}, },
) )
benchmark_n_seqs: Optional[int] = field(
default=None,
metadata={
"help": "Terminate training after consuming this number of samples. "
"For benchmarking purposes only. None indicates normal training."
},
)
@dataclass @dataclass
@ -847,7 +856,7 @@ class BaseExperimentConfig:
Note: Note:
- Recovery modes: auto, fault, resume, disabled - Recovery modes: auto, fault, resume, disabled
- Allocation modes: manual, search, heuristic, or pattern-based - Allocation modes: manual, heuristic, or pattern-based
""" """
experiment_name: str = field( experiment_name: str = field(
@ -919,13 +928,9 @@ class BaseExperimentConfig:
default="", default="",
metadata={ metadata={
"help": "GPU parallel strategy allocation mode. " "help": "GPU parallel strategy allocation mode. "
"Options: manual/search/heuristic or pattern-based." "Options: manual/heuristic or pattern-based."
}, },
) )
allocation_use_cache: bool = field(
default=False,
metadata={"help": "Use allocation search cache (search mode only)."},
)
n_nodes: int = field( n_nodes: int = field(
default=1, metadata={"help": "Number of nodes for experiment."} default=1, metadata={"help": "Number of nodes for experiment."}
) )
@ -998,9 +1003,17 @@ class BaseExperimentConfig:
@dataclass @dataclass
class AsyncRLOptions: class AsyncRLOptions:
schedule_policy: str = field(
default="round_robin",
metadata={
"help": "The request schedule policy during generation. Available options: [round_robin]."
},
)
new_tokens_per_chunk: int = field( new_tokens_per_chunk: int = field(
default=1024, default=int(1e10),
metadata={"help": "The lenght of chunked generation."}, metadata={
"help": "The length of chunked generation. Only valid if inference can't be interrupted."
},
) )
max_head_offpolicyness: int = field( max_head_offpolicyness: int = field(
default=0, default=0,
@ -1013,9 +1026,11 @@ class AsyncRLOptions:
"help": "Number of rollout workers. None defaults to train world size." "help": "Number of rollout workers. None defaults to train world size."
}, },
) )
max_concurrent_rollouts: int = field( max_concurrent_rollouts: Optional[int] = field(
default=1024, default=None,
metadata={"help": "Max concurrent rollout jobs in each worker."}, metadata={
"help": "Max concurrent rollouts globally. Defaults to train batch size."
},
) )
flush_request_timeout: int = field( flush_request_timeout: int = field(
default=120, default=120,
@ -1225,6 +1240,12 @@ class PPOMATHExperimentOptions:
}, },
) )
# testing only
no_training: bool = field(
default=False,
metadata={"help": "Run without training. Test-only."},
)
@dataclass @dataclass
class MathCodeEvalOptions: class MathCodeEvalOptions:

View File

@ -100,8 +100,8 @@ class ModelShardID:
:type model_name: ModelName :type model_name: ModelName
:param dp_rank: The data parallel rank. :param dp_rank: The data parallel rank.
:type dp_rank: int :type dp_rank: int
:param mp_rank: The tensor-model parallel rank. :param tp_rank: The tensor-model parallel rank.
:type mp_rank: int :type tp_rank: int
:param pp_rank: The pipeline-model parallel rank. :param pp_rank: The pipeline-model parallel rank.
:type pp_rank: int :type pp_rank: int
:param topo: The 3D parallelism topology of this model. :param topo: The 3D parallelism topology of this model.
@ -110,22 +110,22 @@ class ModelShardID:
model_name: ModelName model_name: ModelName
dp_rank: int dp_rank: int
mp_rank: int tp_rank: int
pp_rank: int pp_rank: int
topo: topology.ProcessTopology topo: topology.ProcessTopology
def __post_init__(self): def __post_init__(self):
assert self.dp_rank >= 0 and self.mp_rank >= 0 and self.pp_rank >= 0 assert self.dp_rank >= 0 and self.tp_rank >= 0 and self.pp_rank >= 0
if "@" in self.model_name.role: if "@" in self.model_name.role:
raise ValueError("model_name cannot contain @") raise ValueError("model_name cannot contain @")
assert self.dp_rank < self.topo.get_dim("data") assert self.dp_rank < self.topo.get_dim("data")
assert self.mp_rank < self.topo.get_dim("model") assert self.tp_rank < self.topo.get_dim("tensor")
assert self.pp_rank < self.topo.get_dim("pipe") assert self.pp_rank < self.topo.get_dim("pipe")
@property @property
def parallelism_rank(self): def parallelism_rank(self):
return self.topo.get_rank( return self.topo.get_rank(
data=self.dp_rank, model=self.mp_rank, pipe=self.pp_rank data=self.dp_rank, tensor=self.tp_rank, pipe=self.pp_rank
) )
@classmethod @classmethod
@ -134,14 +134,14 @@ class ModelShardID:
return cls( return cls(
model_name=model_name, model_name=model_name,
dp_rank=c.data, dp_rank=c.data,
mp_rank=c.model, tp_rank=c.tensor,
pp_rank=c.pipe, pp_rank=c.pipe,
topo=topo, topo=topo,
) )
def __repr__(self): def __repr__(self):
n = cluster.spec.suffix_n_digits n = cluster.spec.suffix_n_digits
return f"{self.model_name}@pp{self.pp_rank:0{n}d}@mp{self.mp_rank:0{n}d}@dp{self.dp_rank:0{n}d}" return f"{self.model_name}@pp{self.pp_rank:0{n}d}@tp{self.tp_rank:0{n}d}@dp{self.dp_rank:0{n}d}"
def __hash__(self): def __hash__(self):
return hash(str(self)) return hash(str(self))
@ -152,7 +152,7 @@ class ModelShardID:
return ( return (
self.model_name == other.model_name self.model_name == other.model_name
and self.dp_rank == other.dp_rank and self.dp_rank == other.dp_rank
and self.mp_rank == other.mp_rank and self.tp_rank == other.tp_rank
and self.pp_rank == other.pp_rank and self.pp_rank == other.pp_rank
) )
return False return False

View File

@ -547,6 +547,7 @@ class SequenceSample:
return [[seqlen] for seqlen in seqlens] return [[seqlen] for seqlen in seqlens]
elif key in [ elif key in [
"packed_logprobs", "packed_logprobs",
"prox_logp",
"logprobs", "logprobs",
"packed_ref_logprobs", "packed_ref_logprobs",
"ref_logprobs", "ref_logprobs",

View File

@ -10,7 +10,6 @@ import matplotlib.pyplot as plt
import networkx as nx import networkx as nx
import realhf.base.logging as logging import realhf.base.logging as logging
from realhf.api.cli_args import ModelFamily
from realhf.api.core.config import ( from realhf.api.core.config import (
ModelInterfaceAbstraction, ModelInterfaceAbstraction,
ModelInterfaceType, ModelInterfaceType,
@ -94,13 +93,6 @@ class MFCDef:
:type min_n_seqs_per_pass: int :type min_n_seqs_per_pass: int
:param log_return_value: Whether to log the return value of the interface implementation. :param log_return_value: Whether to log the return value of the interface implementation.
:type log_return_value: bool :type log_return_value: bool
:param model_type: The specification of the LLM, e.g., LLaMA-7B. Used by the profiler and
search engine to produce an optimal execution plan. Can be omitted if the search engine
is not used.
:type model_type: Optional[ModelFamily]
:param model_path: The path to the model file. Used to get the config for the search engine.
Can be omitted if the search engine is not used.
:type model_path: Optional[str]
""" """
# The unique identifier of this model function call. # The unique identifier of this model function call.
@ -126,10 +118,6 @@ class MFCDef:
min_n_seqs_per_pass: int | float = 1 min_n_seqs_per_pass: int | float = 1
log_return_value: bool = False log_return_value: bool = False
# Only used by search.
model_type: Optional[Any | ModelFamily] = None
model_path: Optional[str] = None
# Reserved dataclasses.fields. Should not be set by the user. # Reserved dataclasses.fields. Should not be set by the user.
_G: nx.DiGraph = None _G: nx.DiGraph = None
_pre_hooks: List[RPCHook] = dataclasses.field(default_factory=lambda: []) _pre_hooks: List[RPCHook] = dataclasses.field(default_factory=lambda: [])

View File

@ -11,6 +11,7 @@ from typing import Any, Callable, Dict, Hashable, List, Literal, Optional, Tuple
import aiohttp import aiohttp
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist
import torch.utils.data import torch.utils.data
import transformers import transformers
@ -24,6 +25,7 @@ from realhf.api.core.config import (
ModelWrapperAbstraction, ModelWrapperAbstraction,
) )
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample, load_hf_tokenizer from realhf.api.core.data_api import MicroBatchSpec, SequenceSample, load_hf_tokenizer
from realhf.base.datapack import flat2d
from realhf.base.recover import StepInfo from realhf.base.recover import StepInfo
logger = logging.getLogger("model_api") logger = logging.getLogger("model_api")
@ -37,15 +39,19 @@ class ZeroTotalLossWeightException(Exception):
class GenRespMeta: class GenRespMeta:
qid: str qid: str
accepted: bool accepted: bool
n_tokens: int
@dataclasses.dataclass @dataclasses.dataclass
class GenReqMeta: class GenReqMeta:
## Meta info used to schedule the request. ## ## Meta info used to schedule the request. ##
qid: Hashable
prompt_len: int prompt_len: int
group_size: int group_size: int
new_token_budget: int new_token_budget: int
predicted_new_tokens: int | None predicted_new_tokens: int | None
previous_server_url: str = ""
previous_version: int = -1
@dataclasses.dataclass @dataclasses.dataclass
@ -120,6 +126,7 @@ class APIGenerateOutput:
@staticmethod @staticmethod
def concat(outputs: List["APIGenerateOutput"]): def concat(outputs: List["APIGenerateOutput"]):
assert len(set([o.qid for o in outputs])) == 1
return APIGenerateOutput( return APIGenerateOutput(
qid=outputs[0].qid, qid=outputs[0].qid,
prompt_ids=outputs[0].prompt_ids, prompt_ids=outputs[0].prompt_ids,
@ -436,6 +443,8 @@ class ReaLModelConfig:
rotary_special_impl: Optional[str] = None rotary_special_impl: Optional[str] = None
# for gemma # for gemma
normalize_embed: bool = False normalize_embed: bool = False
# for qwen3
qk_layernorm: bool = False
# for opt, it's 2 # for opt, it's 2
abs_position_embedding_offset: int = 0 abs_position_embedding_offset: int = 0
do_layernorm_before: bool = True do_layernorm_before: bool = True
@ -798,7 +807,7 @@ class ModelInterface(abc.ABC):
model: Model, model: Model,
data: SequenceSample, data: SequenceSample,
mb_spec: MicroBatchSpec, mb_spec: MicroBatchSpec,
) -> Dict: ) -> Dict | List[Dict]:
raise NotImplementedError() raise NotImplementedError()
# Mock methods for creating data and profiling an individual MFC. # Mock methods for creating data and profiling an individual MFC.
@ -860,7 +869,17 @@ class NullInterface(ModelInterface):
def train_step( def train_step(
self, model: Model, data: SequenceSample, mb_spec: MicroBatchSpec self, model: Model, data: SequenceSample, mb_spec: MicroBatchSpec
) -> Dict: ) -> Dict | List[Dict]:
from realhf.base import constants
n_tokens = sum(flat2d(data.seqlens[data._get_split_key()]))
n_tokens = torch.tensor(
n_tokens, dtype=torch.long, device=constants.current_device()
)
dist.all_reduce(n_tokens, group=constants.data_parallel_group())
if constants.parallelism_rank() == 0:
logger.info(f"Number of tokens in NullInterface training: {int(n_tokens)}")
model.inc_version()
return {} return {}
def save(self, model: Model, save_dir: str): def save(self, model: Model, save_dir: str):

View File

@ -462,8 +462,8 @@ class ExperimentConfig:
) )
self_topo = model_topos[rpc.model_name] self_topo = model_topos[rpc.model_name]
if ( if (
self_topo.get_dim("model") % other_topo.get_dim("model") != 0 self_topo.get_dim("tensor") % other_topo.get_dim("tensor") != 0
and other_topo.get_dim("model") % self_topo.get_dim("model") != 0 and other_topo.get_dim("tensor") % self_topo.get_dim("tensor") != 0
): ):
raise ValueError( raise ValueError(
"To synchronize parameters between two models, " "To synchronize parameters between two models, "

252
realhf/api/from_hf/qwen3.py Normal file
View File

@ -0,0 +1,252 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
from typing import *
from transformers.configuration_utils import PretrainedConfig
from realhf.api.core.model_api import ReaLModelConfig, register_hf_family
from realhf.base.testing import (
TESTING_MODEL_HEAD_DIM,
TESTING_MODEL_HIDDEN_SIZE,
TESTING_MODEL_INTERMEDIATE_SIZE,
TESTING_MODEL_N_HEADS,
TESTING_MODEL_N_LAYERS,
TESTING_MODEL_N_POSITIONS,
TESTING_MODEL_VOCAB_SIZE,
)
from .llama import (
convert_state_dict_llama,
llama_embedding_layer_names,
llama_output_head_param_name,
to_llama_state_dict,
)
class Qwen3Config(PretrainedConfig):
model_type = "qwen3"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen3`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
head_dim=128,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
attention_dropout=0.0,
**kwargs,
):
from transformers.modeling_rope_utils import rope_config_validation
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = (
sliding_window # we check `use_sliding_window` in the modeling code
)
self.max_window_layers = max_window_layers
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def convert_config_qwen3(
hf_config: Qwen3Config,
) -> ReaLModelConfig:
return ReaLModelConfig(
n_layers=hf_config.num_hidden_layers,
n_kv_heads=hf_config.num_key_value_heads,
hidden_dim=hf_config.hidden_size,
n_q_heads=hf_config.num_attention_heads,
head_dim=getattr(
hf_config,
"head_dim",
hf_config.hidden_size // hf_config.num_attention_heads,
),
intermediate_dim=hf_config.intermediate_size,
vocab_size=hf_config.vocab_size,
n_positions=hf_config.max_position_embeddings,
embd_pdrop=0.0,
attn_pdrop=(
hf_config.attention_dropout
if hasattr(hf_config, "attention_dropout")
else 0.1
),
layer_norm_epsilon=hf_config.rms_norm_eps,
activation_function=hf_config.hidden_act,
use_attention_bias=False,
use_attn_proj_bias=False,
scale_attn_by_inverse_layer_idx=False,
layer_norm_type="rms",
qk_layernorm=True,
mlp_type="llama",
apply_rotary=True,
rotary_base=hf_config.rope_theta,
rotary_interleaved=False,
tied_embedding=hf_config.tie_word_embeddings,
)
def convert_config_back_qwen3(
config: ReaLModelConfig,
) -> Qwen3Config:
return Qwen3Config(
vocab_size=config.vocab_size,
hidden_size=config.hidden_dim,
intermediate_size=config.intermediate_dim,
num_hidden_layers=config.n_layers,
num_key_value_heads=config.n_kv_heads,
num_attention_heads=config.n_q_heads,
head_dim=config.head_dim,
max_position_embeddings=config.n_positions,
rms_norm_eps=config.layer_norm_epsilon,
hidden_act=config.activation_function,
attention_dropout=config.attn_pdrop,
rope_theta=config.rotary_base,
architectures=["Qwen3ForCausalLM"], # ["Qwen3ForCausalLM"],
tie_word_embeddings=config.tied_embedding,
)
def qwen3_config_maker():
hf_config = Qwen3Config(
vocab_size=TESTING_MODEL_VOCAB_SIZE,
max_position_embeddings=TESTING_MODEL_N_POSITIONS,
hidden_size=TESTING_MODEL_HIDDEN_SIZE,
intermediate_size=TESTING_MODEL_INTERMEDIATE_SIZE,
num_hidden_layers=TESTING_MODEL_N_LAYERS,
num_attention_heads=TESTING_MODEL_N_HEADS,
head_dim=TESTING_MODEL_HEAD_DIM,
num_key_value_heads=8,
hidden_act="silu",
rms_norm_eps=1e-5,
)
return convert_config_qwen3(hf_config)
def convert_state_dict_qwen3(state_dict: Dict, config: ReaLModelConfig) -> Dict:
llama_state_dict = convert_state_dict_llama(state_dict, config)
# model.layers.0.self_attn.k_norm.weight -> 1.attn.k_ln.weight
new_state_dict = {}
for k, v in llama_state_dict.items():
if "k_norm" in k:
k = k.replace("k_norm", "k_ln")
if "q_norm" in k:
k = k.replace("q_norm", "q_ln")
new_state_dict[k] = v
return new_state_dict
def convert_state_dict_back_qwen3(state_dict: Dict, config: ReaLModelConfig) -> Dict:
new_sd = to_llama_state_dict(state_dict, config)
layer_indices = list(set([int(k.split(".")[0]) for k in state_dict.keys()]))
for i in layer_indices:
if i == 0 or i == config.n_layers + 1:
continue
new_sd[f"model.layers.{i - 1}.self_attn.k_norm.weight"] = state_dict[
f"{i}.attn.k_ln.weight"
]
new_sd[f"model.layers.{i - 1}.self_attn.q_norm.weight"] = state_dict[
f"{i}.attn.q_ln.weight"
]
return new_sd
def qwen3_transformer_block_param_name(config: ReaLModelConfig, idx: int) -> List[str]:
names = []
for k in ["weight", "bias"]:
names += [
f"model.layers.{idx}.input_layernorm.{k}",
f"model.layers.{idx}.mlp.down_proj.{k}",
f"model.layers.{idx}.mlp.gate_proj.{k}",
f"model.layers.{idx}.mlp.up_proj.{k}",
f"model.layers.{idx}.post_attention_layernorm.{k}",
f"model.layers.{idx}.self_attn.k_proj.{k}",
f"model.layers.{idx}.self_attn.o_proj.{k}",
f"model.layers.{idx}.self_attn.q_proj.{k}",
# f"model.layers.{idx}.self_attn.rotary_emb.inv_freq",
f"model.layers.{idx}.self_attn.v_proj.{k}",
]
if idx == config.n_layers - 1:
names += [f"model.norm.{k}"]
# Qwen3
if config.qk_layernorm:
names += [
f"model.layers.{idx}.self_attn.q_norm.weight",
f"model.layers.{idx}.self_attn.k_norm.weight",
]
return names
register_hf_family(
name="qwen3",
hf_cls_name="Qwen3ForCausalLM", # "Qwen3ForCausalLM"
config_from_hf_converter=convert_config_qwen3,
config_to_hf_converter=convert_config_back_qwen3,
sd_from_hf_converter=convert_state_dict_qwen3,
sd_to_hf_converter=convert_state_dict_back_qwen3,
embedding_param_names=llama_embedding_layer_names,
tblock_param_names=qwen3_transformer_block_param_name,
head_param_names=llama_output_head_param_name,
real_config_maker=qwen3_config_maker,
)

View File

@ -224,18 +224,18 @@ def find_parallel_strategies(
) -> List[ParallelismConfig]: ) -> List[ParallelismConfig]:
n_gpus = np.sum(device_mesh.mapping) n_gpus = np.sum(device_mesh.mapping)
res = [] res = []
for num_mp in [1, 2, 4, 8]: for num_tp in [1, 2, 4, 8]:
if n_gpus >= num_mp: if n_gpus >= num_tp:
assert n_gpus % num_mp == 0 assert n_gpus % num_tp == 0
num_dp_pp = n_gpus // num_mp num_dp_pp = n_gpus // num_tp
num_pp = 1 num_pp = 1
while num_pp <= num_dp_pp: while num_pp <= num_dp_pp:
num_dp_mp = n_gpus // num_pp num_dp_tp = n_gpus // num_pp
valid = ( valid = (
num_dp_mp in [1, 2, 4, 8] or num_dp_mp % 8 == 0 num_dp_tp in [1, 2, 4, 8] or num_dp_tp % 8 == 0
) and num_dp_pp % num_pp == 0 ) and num_dp_pp % num_pp == 0
if valid: if valid:
res.append(ParallelismConfig(num_pp, num_mp, num_dp_pp // num_pp)) res.append(ParallelismConfig(num_pp, num_tp, num_dp_pp // num_pp))
num_pp += 1 num_pp += 1
return res return res
@ -248,7 +248,7 @@ class RPCAllocation:
def __post_init__(self): def __post_init__(self):
world_size = ( world_size = (
self.parallel.model_parallel_size self.parallel.tensor_parallel_size
* self.parallel.pipeline_parallel_size * self.parallel.pipeline_parallel_size
* self.parallel.data_parallel_size * self.parallel.data_parallel_size
) )

View File

@ -8,12 +8,10 @@ import functools
import inspect import inspect
import json import json
import os import os
import pickle from typing import Callable
import subprocess
from typing import Callable, Optional
import hydra import hydra
import omegaconf import yaml
from hydra.core.config_store import ConfigStore from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, OmegaConf from omegaconf import MISSING, OmegaConf
@ -29,6 +27,9 @@ def kind_reminder(config_name, logger, args):
logger.info( logger.info(
f"Logs will be dumped to {os.path.join(LOG_ROOT, args.experiment_name, args.trial_name)}" f"Logs will be dumped to {os.path.join(LOG_ROOT, args.experiment_name, args.trial_name)}"
) )
logger.info(
f"Experiment configs will be dumped to {os.path.join(LOG_ROOT, args.experiment_name, args.trial_name, 'config.yaml')}"
)
logger.info( logger.info(
f"Model checkpoints will be saved to {os.path.join(MODEL_SAVE_ROOT, args.experiment_name, args.trial_name)}" f"Model checkpoints will be saved to {os.path.join(MODEL_SAVE_ROOT, args.experiment_name, args.trial_name)}"
) )
@ -69,7 +70,7 @@ def register_quickstart_exp(config_name: str, exp_cls: Callable):
logger = logging.getLogger("quickstart", "colored") logger = logging.getLogger("quickstart", "colored")
print_runtime_helper(OmegaConf.to_object(args)) # print_runtime_helper(OmegaConf.to_object(args))
exp_name = args.experiment_name exp_name = args.experiment_name
if args.trial_name == MISSING: if args.trial_name == MISSING:
@ -80,6 +81,17 @@ def register_quickstart_exp(config_name: str, exp_cls: Callable):
trial_name = args.trial_name trial_name = args.trial_name
from realhf.apps.main import main_start, main_stop from realhf.apps.main import main_start, main_stop
config_save_path = os.path.join(
LOG_ROOT, args.experiment_name, args.trial_name, "config.yaml"
)
with open(config_save_path, "w") as f:
yaml.dump(
dataclasses.asdict(OmegaConf.to_object(args)),
f,
default_flow_style=False,
sort_keys=False,
)
kind_reminder(config_name, logger, args) kind_reminder(config_name, logger, args)
exp_fn = functools.partial(exp_cls, **args) exp_fn = functools.partial(exp_cls, **args)

View File

@ -87,7 +87,7 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0):
job_group_id = str(uuid.uuid4()) job_group_id = str(uuid.uuid4())
logger.info(f"AReaL Version: {get_full_version_with_dirty_description()}") logger.info(f"AReaL Version: {get_full_version_with_dirty_description()}")
logger.info(f"AReaL Job Group ID: {job_group_id}") logger.info(f"AReaL Job Group ID: {job_group_id}")
logger.info(f"AReaL Job Group Index: {recover_count}") logger.info(f"AReaL Job Group Index (recover count): {recover_count}")
if recover_count == 0: if recover_count == 0:
constants.set_experiment_trial_names(args.experiment_name, args.trial_name) constants.set_experiment_trial_names(args.experiment_name, args.trial_name)
experiment = config_package.make_experiment(args.experiment_name) experiment = config_package.make_experiment(args.experiment_name)
@ -110,10 +110,6 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0):
assert ( assert (
args.recover_mode == "disabled" args.recover_mode == "disabled"
), "Recover mode is not supported for local runs!" ), "Recover mode is not supported for local runs!"
# Use search cache for recover runs
force_allocation_use_cache = (
recover_count > 1 or args.recover_mode == "resume"
) and args.allocation_mode == "search"
# handle args # handle args
args.ignore_worker_error = ( args.ignore_worker_error = (
args.ignore_worker_error and args.recover_mode == "disabled" args.ignore_worker_error and args.recover_mode == "disabled"
@ -174,12 +170,8 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0):
) )
for k, v in BASE_ENVIRONS.items(): for k, v in BASE_ENVIRONS.items():
os.environ[k] = v os.environ[k] = v
os.environ["REAL_IS_REMOTE"] = "0" if not force_allocation_use_cache else "1"
# setup experiments # setup experiments
if args.allocation_mode == "search":
experiment._search()
sched = sched_client.make( sched = sched_client.make(
mode=scheduler_mode(args.mode), mode=scheduler_mode(args.mode),
expr_name=expr_name, expr_name=expr_name,
@ -324,80 +316,6 @@ def main_find_config(args):
print(exp_name) print(exp_name)
def main_profile_layers(args):
from realhf.api.cli_args import ModelFamily
_main_profile_layers(
ModelFamily(args.model_class, args.model_size, args.is_critic),
args.model_path,
)
def _main_profile_layers(model_family, model_path):
from realhf.api.cli_args import ModelFamily
from realhf.base.slurm_utils import check_slurm_availability
from realhf.base.testing import clear_name_resolve
expr_name = trial_name = "profile"
cmd = (
f"python3 -m realhf.apps.profile_layers --expr_name {expr_name} --trial_name {trial_name} "
f"--model_path {model_path} --model_name {model_family} "
)
if check_slurm_availability():
if not os.environ.get("CLUSTER_SPEC_PATH", ""):
raise ValueError(
"Environment variable CLUSTER_SPEC_PATH must be set for slurm mode! "
"See example/cluster_config.json for a template."
)
BASE_ENVIRONS = constants.get_env_vars(
REAL_MODE="slurm",
CLUSTER_SPEC_PATH=os.environ.get("CLUSTER_SPEC_PATH", ""),
)
clear_name_resolve(expr_name, trial_name)
sched = sched_client.make(
mode="slurm", expr_name=expr_name, trial_name=trial_name
)
print(
f"Profiling {model_family} layers, model path {model_path}, " f"cmd {cmd}"
)
sched.submit_array(
worker_type="profile_layer",
cmd=cmd,
count=1,
cpu=64,
gpu=8,
gpu_type="tesla",
mem=500000,
env_vars=BASE_ENVIRONS,
container_image=config_package._LLM_GPU_IMAGE,
)
try:
sched.wait(timeout=None)
except (
KeyboardInterrupt,
sched_client.JobException,
TimeoutError,
) as e:
sched.stop_all()
raise e
else:
try:
print(
f"Profiling {model_family} layers, model path {model_path}, "
f"cmd {cmd}"
)
clear_name_resolve(expr_name, trial_name)
os.system(cmd)
except (
KeyboardInterrupt,
sched_client.JobException,
TimeoutError,
) as e:
raise e
def main(): def main():
parser = argparse.ArgumentParser(prog="ReaLHF") parser = argparse.ArgumentParser(prog="ReaLHF")
subparsers = parser.add_subparsers(dest="cmd", help="sub-command help") subparsers = parser.add_subparsers(dest="cmd", help="sub-command help")
@ -482,7 +400,7 @@ def main():
type=str, type=str,
required=False, required=False,
default="pipe_model", default="pipe_model",
choices=["manual", "search", "heuristic", "pipe_model", "pipe_data"], choices=["manual", "heuristic", "pipe_model", "pipe_data"],
help="Mode of GPU resource/model parallel strategy allocation.", help="Mode of GPU resource/model parallel strategy allocation.",
) )
subparser.set_defaults(ignore_worker_error=False) subparser.set_defaults(ignore_worker_error=False)
@ -514,15 +432,6 @@ def main():
subparser.add_argument("--regex", "-r", type=str, required=True) subparser.add_argument("--regex", "-r", type=str, required=True)
subparser.set_defaults(func=main_find_config) subparser.set_defaults(func=main_find_config)
subparser = subparsers.add_parser(
"profile_layers", help="profile layers of a model."
)
subparser.add_argument("--model_class", type=str, required=True)
subparser.add_argument("--model_size", type=int, required=True)
subparser.add_argument("--is_critic", action="store_true")
subparser.add_argument("--model_path", type=str, required=True)
subparser.set_defaults(func=main_profile_layers)
args = parser.parse_args() args = parser.parse_args()
args.func(args) args.func(args)

View File

@ -1,91 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import argparse
import itertools
import time
import realhf.base.testing as testing
BATCH_SIZE_RANGE = [1, 2, 4, 8, 16, 32, 64, 128]
SEQ_LEN_RANGE = [128, 256, 512]
def profile_layer_func(
world_size,
model_path,
model_name,
warm_up_rounds,
profile_rounds,
batch_size_range,
seq_len_range,
use_sequence_parallel=False,
use_gradient_checkpointing=False,
):
# FIXME: use_sequence_parallel=True and use_gradient_checkpointing=True will cause bugs
import torch
import realhf.base.constants as constants
testing.init_global_constants(
1, world_size, 1, sequence_parallel=False, gradient_checkpointing=False
)
device = torch.device("cuda")
with constants.model_scope(testing.MODEL_NAME):
from realhf.search_engine.layers import make_profile_layers
profile_layers = make_profile_layers(device, model_path, model_name)
st = time.monotonic_ns()
for i in range(warm_up_rounds + profile_rounds):
for bs, seq_len in itertools.product(batch_size_range, seq_len_range):
profile_layers.fwd_gen(bs, seq_len)
profile_layers.fwd_bwd_opt(bs, seq_len)
if i < warm_up_rounds:
profile_layers.reset_stats()
profile_layers.make_dataframe_and_print()
profile_layers.dump_stats(world_size)
t = (time.monotonic_ns() - st) / int(1e9)
print(f"profile world size {world_size} cost {t:4f} seconds")
if __name__ == "__main__":
st = time.monotonic_ns()
parser = argparse.ArgumentParser(prog="profile_layers")
parser.add_argument(
"--model_path",
type=str,
required=True,
)
parser.add_argument("--expr_name", type=str, default="profile")
parser.add_argument("--trial_name", type=str, default="profile")
parser.add_argument("--model_name", type=str, default="Llama-2-70b")
parser.add_argument("--warm_up_rounds", type=int, default=1)
parser.add_argument("--profile_rounds", type=int, default=3)
# parser.add_argument("--use_sequence_parallel", action="store_true")
# parser.add_argument("--use_gradient_checkpointing", action="store_true")
args = parser.parse_args()
world_sizes = [1, 2, 4, 8]
for world_size in world_sizes:
testing.clear_name_resolve(args.expr_name, args.trial_name)
mp = testing.LocalMultiProcessTest(
world_size,
profile_layer_func,
world_size,
args.model_path,
args.model_name,
args.warm_up_rounds,
args.profile_rounds,
BATCH_SIZE_RANGE,
SEQ_LEN_RANGE,
expr_name=args.expr_name,
trial_name=args.trial_name,
)
mp.launch()
t = (time.monotonic_ns() - st) / int(1e9)
print(f"profile model {args.model_name} time cost {t:4f} seconds")

View File

@ -72,6 +72,7 @@ MODEL_SAVE_ROOT = f"{cluster_spec.fileroot}/checkpoints/{getpass.getuser()}"
LOG_ROOT = f"{cluster_spec.fileroot}/logs/{getpass.getuser()}" LOG_ROOT = f"{cluster_spec.fileroot}/logs/{getpass.getuser()}"
RECOVER_ROOT = f"{cluster_spec.fileroot}/recover/{getpass.getuser()}" RECOVER_ROOT = f"{cluster_spec.fileroot}/recover/{getpass.getuser()}"
SLURM_LOCK_FILE_NAME = f"{cluster_spec.fileroot}/logs/slurm_scheduler.lock" SLURM_LOCK_FILE_NAME = f"{cluster_spec.fileroot}/logs/slurm_scheduler.lock"
PORT_LOCK_FILE_ROOT = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/ports"
PYTORCH_KERNEL_CACHE_PATH = ( PYTORCH_KERNEL_CACHE_PATH = (
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels" f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels"
) )
@ -120,6 +121,9 @@ BASE_ENVIRONS = {
"REAL_GPU_MEMORY_KILL_THRESHOLD": os.getenv( "REAL_GPU_MEMORY_KILL_THRESHOLD": os.getenv(
"REAL_GPU_MEMORY_KILL_THRESHOLD", "0.95" "REAL_GPU_MEMORY_KILL_THRESHOLD", "0.95"
), ),
"LC_ALL": "C",
"LANG": "C",
"NCCL_DEBUG": "WARN",
} }
# Set PPU-specific environment variables for stable training. # Set PPU-specific environment variables for stable training.
@ -146,7 +150,6 @@ elif cluster_spec.name == "na132":
"NCCL_IB_SL": "5", "NCCL_IB_SL": "5",
"NCCL_IB_TC": "136", "NCCL_IB_TC": "136",
"NCCL_IB_HCA": "mlx5_bond", "NCCL_IB_HCA": "mlx5_bond",
"NCCL_DEBUG": "WARN",
"NCCL_IB_QPS_PER_CONNECTION": "8", "NCCL_IB_QPS_PER_CONNECTION": "8",
"NCCL_SET_THREAD_NAME": "1", "NCCL_SET_THREAD_NAME": "1",
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH", "NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
@ -165,6 +168,7 @@ os.makedirs(DATASET_CACHE_PATH, exist_ok=True)
os.makedirs(PROFILER_CACHE_PATH, exist_ok=True) os.makedirs(PROFILER_CACHE_PATH, exist_ok=True)
os.makedirs(TORCH_EXTENSIONS_DIR, exist_ok=True) os.makedirs(TORCH_EXTENSIONS_DIR, exist_ok=True)
os.makedirs(QUICKSTART_EXPR_CACHE_PATH, exist_ok=True) os.makedirs(QUICKSTART_EXPR_CACHE_PATH, exist_ok=True)
os.makedirs(PORT_LOCK_FILE_ROOT, exist_ok=True)
os.makedirs(SGLANG_CACHE_PATH, exist_ok=True) os.makedirs(SGLANG_CACHE_PATH, exist_ok=True)
# _model_name will be changed in the model_scope context manager # _model_name will be changed in the model_scope context manager
@ -186,16 +190,12 @@ _self_group = None
_rank_mapping: Dict["ModelName", Dict["ModelShardID", int]] = {} _rank_mapping: Dict["ModelName", Dict["ModelShardID", int]] = {}
_global_memory_buffer: GlobalMemoryBuffer = GlobalMemoryBuffer() _global_memory_buffer: GlobalMemoryBuffer = GlobalMemoryBuffer()
# used only in scripts and tests
_fake_mp_world_size = None
_fake_mp_rank = None
# TODO: As in Megatron, we can set NCCL group options. Is it necessary? # TODO: As in Megatron, we can set NCCL group options. Is it necessary?
def reset_run(): def reset_run():
global _model_name, _grids, _pgroups, _pgroup_ranks, _self_group, _rank_mapping, _global_memory_buffer, _fake_mp_world_size, _fake_mp_rank global _model_name, _grids, _pgroups, _pgroup_ranks, _self_group, _rank_mapping, _global_memory_buffer
_model_name = None _model_name = None
_grids = {} _grids = {}
_pgroups = {} _pgroups = {}
@ -203,8 +203,6 @@ def reset_run():
_self_group = None _self_group = None
_rank_mapping = {} _rank_mapping = {}
_global_memory_buffer = GlobalMemoryBuffer() _global_memory_buffer = GlobalMemoryBuffer()
_fake_mp_world_size = None
_fake_mp_rank = None
@contextlib.contextmanager @contextlib.contextmanager
@ -284,7 +282,7 @@ def set_rank_mapping(
else: else:
msid2mwid = {k: v for k, v in msid2mwid.items() if k.model_name == model_name} msid2mwid = {k: v for k, v in msid2mwid.items() if k.model_name == model_name}
_rank_mapping[model_name] = { _rank_mapping[model_name] = {
topo.get_rank(data=s.dp_rank, model=s.mp_rank, pipe=s.pp_rank): mw_id topo.get_rank(data=s.dp_rank, tensor=s.tp_rank, pipe=s.pp_rank): mw_id
for s, mw_id in msid2mwid.items() for s, mw_id in msid2mwid.items()
} }
@ -408,7 +406,7 @@ def parallelism_group_ranks():
def parallelism_group_size() -> int: def parallelism_group_size() -> int:
"""The 3D parallelism group size of a specific model, normally dp_size * """The 3D parallelism group size of a specific model, normally dp_size *
pp_size * mp_size.""" pp_size * tp_size."""
import torch.distributed as dist import torch.distributed as dist
return dist.get_world_size(group=parallelism_group()) return dist.get_world_size(group=parallelism_group())
@ -470,37 +468,25 @@ def prev_pipe_stage():
def is_dp_head(): def is_dp_head():
return is_last_pipe_stage() and model_parallel_rank() == 0 return is_last_pipe_stage() and tensor_parallel_rank() == 0
def model_parallel_rank() -> int: def tensor_parallel_rank() -> int:
"""Return the rank inside the tensor parallelism group.""" """Return the rank inside the tensor parallelism group."""
try: return grid().get_tensor_model_parallel_rank()
return grid().get_tensor_model_parallel_rank()
except RuntimeError as e: # used only in scripts and tests
if _fake_mp_rank is not None:
return _fake_mp_rank
else:
raise e
def model_parallel_world_size() -> int: def tensor_parallel_world_size() -> int:
"""Return the world size of the tensor parallelism group.""" """Return the world size of the tensor parallelism group."""
try: return grid().get_tensor_model_parallel_world_size()
return grid().get_tensor_model_parallel_world_size()
except RuntimeError as e: # used only in scripts and tests
if _fake_mp_world_size is not None:
return _fake_mp_world_size
else:
raise e
def model_parallel_group(): def tensor_parallel_group():
"""Return the NCCL tensor parallelism process group.""" """Return the NCCL tensor parallelism process group."""
return grid().get_tensor_model_parallel_group() return grid().get_tensor_model_parallel_group()
def model_parallel_cpu_group(): def tensor_parallel_cpu_group():
"""Return the GLOO tensor parallelism process group.""" """Return the GLOO tensor parallelism process group."""
return grid().get_tensor_model_parallel_cpu_group() return grid().get_tensor_model_parallel_cpu_group()
@ -536,26 +522,6 @@ def data_parallel_group():
return grid().get_data_parallel_group() return grid().get_data_parallel_group()
def set_fake_mp_world_size(world_size):
# used only in scripts and tests
global _fake_mp_world_size
_fake_mp_world_size = world_size
def set_fake_mp_rank(rank):
# used only in scripts and tests
global _fake_mp_rank
_fake_mp_rank = rank
def set_fake_grid(model_name, rank, topo):
# used only in scripts and tests
from realhf.base.topology import FakeGrid
global _grids
_grids[model_name] = FakeGrid(rank=rank, topo=topo)
def get_global_memory_buffer(): def get_global_memory_buffer():
global _global_memory_buffer global _global_memory_buffer
assert _global_memory_buffer is not None, "global memory buffer is not set" assert _global_memory_buffer is not None, "global memory buffer is not set"

View File

@ -62,7 +62,7 @@ def reveal_pg_identity(expr_name, trial_name, worker_index):
master_group_name = names.distributed_peer( master_group_name = names.distributed_peer(
expr_name, trial_name, GLOBAL_PROCESS_GROUP_NAME expr_name, trial_name, GLOBAL_PROCESS_GROUP_NAME
) )
name_resolve.add_subentry(master_group_name, str(worker_index), keepalive_ttl=300) name_resolve.add_subentry(master_group_name, str(worker_index))
def isolate_cuda_device( def isolate_cuda_device(
@ -100,12 +100,10 @@ def isolate_cuda_device(
name_resolve_identifier, name_resolve_identifier,
), ),
rank, rank,
keepalive_ttl=60,
) )
name_resolve.add_subentry( name_resolve.add_subentry(
names.distributed_peer(experiment_name, trial_name, name_resolve_identifier), names.distributed_peer(experiment_name, trial_name, name_resolve_identifier),
rank, rank,
keepalive_ttl=30,
) )
logger.debug( logger.debug(
f"Worker type {worker_type} rank {rank} waiting for peers, world size {world_size}..." f"Worker type {worker_type} rank {rank} waiting for peers, world size {world_size}..."

View File

@ -141,9 +141,18 @@ def getLogger(
return logging.getLogger(name) return logging.getLogger(name)
_LATEST_WANDB_STEP = 0
def log_wandb_tensorboard(data, step=None, summary_writer=None): def log_wandb_tensorboard(data, step=None, summary_writer=None):
import wandb import wandb
global _LATEST_WANDB_STEP
if step is None:
step = _LATEST_WANDB_STEP
else:
_LATEST_WANDB_STEP = max(_LATEST_WANDB_STEP, step)
wandb.log(data, step=step) wandb.log(data, step=step)
if summary_writer is not None: if summary_writer is not None:
for key, val in data.items(): for key, val in data.items():

View File

@ -618,7 +618,7 @@ class Etcd3NameRecordRepository(NameRecordRepository):
self._to_delete = set() self._to_delete = set()
logger.info(f"Connected to etcd3 at {self._host}:{self._port}") logger.debug(f"Connected to etcd3 at {self._host}:{self._port}")
def __del__(self): def __del__(self):
"""Clean up resources when the object is deleted.""" """Clean up resources when the object is deleted."""
@ -945,12 +945,13 @@ def make_repository(type_="nfs", **kwargs):
# DEFAULT_REPOSITORY_TYPE = "redis" if socket.gethostname().startswith("frl") else "nfs" # DEFAULT_REPOSITORY_TYPE = "redis" if socket.gethostname().startswith("frl") else "nfs"
DEFAULT_REPOSITORY_TYPE = "nfs" DEFAULT_REPOSITORY_TYPE = "nfs"
if ( if etcd3 is not None and os.getenv("REAL_ETCD_ADDR", ""):
etcd3 is not None
and cluster.spec.name in ["wa180", "na132", "su18"]
and os.getenv("REAL_ETCD_ADDR", "")
):
DEFAULT_REPOSITORY_TYPE = "etcd3" DEFAULT_REPOSITORY_TYPE = "etcd3"
if os.getenv("REAL_ETCD_ADDR", "") and etcd3 is None:
logger.warning(
f"Detected REAL_ETCD_ADDR but etcd3 client is not available. "
"Please run `pip install -r requirements.txt` if you want to use etcd name resolve."
)
DEFAULT_REPOSITORY = make_repository(DEFAULT_REPOSITORY_TYPE) DEFAULT_REPOSITORY = make_repository(DEFAULT_REPOSITORY_TYPE)
add = DEFAULT_REPOSITORY.add add = DEFAULT_REPOSITORY.add
add_subentry = DEFAULT_REPOSITORY.add_subentry add_subentry = DEFAULT_REPOSITORY.add_subentry

View File

@ -93,5 +93,9 @@ def gen_servers(experiment_name, trial_name):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_servers" return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_servers"
def used_ports(experiment_name, trial_name, host_name):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/{host_name}/"
def gen_server_manager(experiment_name, trial_name): def gen_server_manager(experiment_name, trial_name):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_server_manager" return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_server_manager"

View File

@ -2,31 +2,16 @@
# Copyright 2024 Wei Fu & Zhiyu Mei # Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License"). # Licensed under the Apache License, Version 2.0 (the "License").
import fcntl
import os
import socket import socket
import time
from contextlib import closing from contextlib import closing
from functools import wraps
from realhf.base import constants, logging, name_resolve, names
def find_free_port(low=1, high=65536, exclude_ports=None): logger = logging.getLogger(__name__)
"""Find a free port within the specified range, excluding certain ports."""
if exclude_ports is None:
exclude_ports = set()
while True:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
port = s.getsockname()[1]
if low <= port <= high and port not in exclude_ports:
return port
def find_multiple_free_ports(count, low=1, high=65536):
"""Find multiple mutually exclusive free ports."""
free_ports = set()
for _ in range(count):
port = find_free_port(low, high, exclude_ports=free_ports)
free_ports.add(port)
return list(free_ports)
def gethostname(): def gethostname():
@ -35,3 +20,54 @@ def gethostname():
def gethostip(): def gethostip():
return socket.gethostbyname(socket.gethostname()) return socket.gethostbyname(socket.gethostname())
def find_free_port(
low=1, high=65536, exclude_ports=None, experiment_name="port", trial_name="port"
):
"""Find a free port within the specified range, excluding certain ports."""
ports_name = names.used_ports(experiment_name, trial_name, gethostip())
used_ports = list(map(int, name_resolve.get_subtree(ports_name)))
if exclude_ports is None:
exclude_ports = set(used_ports)
else:
exclude_ports = exclude_ports.union(set(used_ports))
free_port = None
lockfile = os.path.join(constants.PORT_LOCK_FILE_ROOT, gethostip())
while True:
with open(lockfile, "w") as fd:
# This will block until lock is acquired
fcntl.flock(fd, fcntl.LOCK_EX)
try:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
port = s.getsockname()[1]
if low <= port <= high and port not in exclude_ports:
name_resolve.add_subentry(ports_name, str(port))
logger.info(f"Found free port {port}")
free_port = port
break
finally:
fcntl.flock(fd, fcntl.LOCK_UN)
time.sleep(0.05)
return free_port
def find_multiple_free_ports(
count, low=1, high=65536, experiment_name="port", trial_name="port"
):
"""Find multiple mutually exclusive free ports."""
free_ports = set()
for _ in range(count):
port = find_free_port(
low=low,
high=high,
exclude_ports=free_ports,
experiment_name=experiment_name,
trial_name=trial_name,
)
free_ports.add(port)
return list(free_ports)

View File

@ -171,6 +171,9 @@ class DistributedStatsTracker:
else: else:
raise ValueError(f"Unknown reduce type: {reduce_type}") raise ValueError(f"Unknown reduce type: {reduce_type}")
keys_to_pop = [k for k, v in result.items() if v is None]
for k in keys_to_pop:
result.pop(k)
return result return result
def _sum_of(self, key, reduce_group): def _sum_of(self, key, reduce_group):
@ -209,7 +212,7 @@ class DistributedStatsTracker:
dist.all_reduce(x, group=reduce_group) dist.all_reduce(x, group=reduce_group)
dist.all_reduce(d, group=reduce_group) dist.all_reduce(d, group=reduce_group)
if d == 0: if d == 0:
return 0 return None
return x / d return x / d
def _min_of(self, key, reduce_group): def _min_of(self, key, reduce_group):
@ -224,7 +227,7 @@ class DistributedStatsTracker:
if reduce_group is not None: if reduce_group is not None:
dist.all_reduce(x, group=reduce_group, op=dist.ReduceOp.MIN) dist.all_reduce(x, group=reduce_group, op=dist.ReduceOp.MIN)
if torch.isinf(x): if torch.isinf(x):
return float("nan") return None
return float(x) return float(x)
def _max_of(self, key, reduce_group): def _max_of(self, key, reduce_group):
@ -239,7 +242,7 @@ class DistributedStatsTracker:
if reduce_group is not None: if reduce_group is not None:
dist.all_reduce(x, group=reduce_group, op=dist.ReduceOp.MAX) dist.all_reduce(x, group=reduce_group, op=dist.ReduceOp.MAX)
if torch.isinf(x): if torch.isinf(x):
return float("nan") return None
return float(x) return float(x)

View File

@ -22,9 +22,9 @@ import torch.utils.data
from realhf.api.core.data_api import SequenceSample from realhf.api.core.data_api import SequenceSample
from realhf.base import constants, gpu_utils, logging, name_resolve, names, topology from realhf.base import constants, gpu_utils, logging, name_resolve, names, topology
from realhf.base.topology import ( from realhf.base.topology import (
DataPipeModelParallelTopology, DataPipeTensorParallelTopology,
ParallelGrid, ParallelGrid,
PipeDataModelParallelTopology, PipeDataTensorParallelTopology,
) )
logger = logging.getLogger("testing") logger = logging.getLogger("testing")
@ -106,9 +106,6 @@ class StandaloneTestingProcess(mp.Process):
self.expr_name, self.trial_name, self.rank, backend=self.dist_backend self.expr_name, self.trial_name, self.rank, backend=self.dist_backend
) )
# setup some useful constants
constants.set_experiment_trial_names(self.expr_name, self.trial_name)
# misc setup # misc setup
if constants.use_cuda(): if constants.use_cuda():
pynvml.nvmlInit() pynvml.nvmlInit()
@ -207,7 +204,7 @@ class LocalMultiProcessTest:
def init_global_constants( def init_global_constants(
num_dp=1, num_dp=1,
num_mp=1, num_tp=1,
num_pp=1, num_pp=1,
topo=None, topo=None,
model_name=None, model_name=None,
@ -227,9 +224,9 @@ def init_global_constants(
if topo is None: if topo is None:
if is_train: if is_train:
topo = PipeDataModelParallelTopology( topo = PipeDataTensorParallelTopology(
num_dp=num_dp, num_dp=num_dp,
num_mp=num_mp, num_tp=num_tp,
num_pp=num_pp, num_pp=num_pp,
sequence_parallel=sequence_parallel, sequence_parallel=sequence_parallel,
gradient_checkpointing=gradient_checkpointing, gradient_checkpointing=gradient_checkpointing,
@ -237,13 +234,13 @@ def init_global_constants(
max_prompt_len=max_prompt_len, max_prompt_len=max_prompt_len,
) )
else: else:
topo = DataPipeModelParallelTopology( topo = DataPipeTensorParallelTopology(
num_dp=num_dp, num_dp=num_dp,
num_mp=num_mp, num_tp=num_tp,
num_pp=num_pp, num_pp=num_pp,
sequence_parallel=sequence_parallel, sequence_parallel=sequence_parallel,
) )
ws = num_dp * num_mp * num_pp ws = num_dp * num_tp * num_pp
else: else:
ws = topo.world_size() ws = topo.world_size()

View File

@ -65,22 +65,22 @@ def decompose_to_three_factors(n: int) -> List[Tuple[int, int, int]]:
return factors return factors
class PipeDataModelrocessCoord(NamedTuple): class PipeDataTensorProcessCoord(NamedTuple):
pipe: int pipe: int
data: int data: int
model: int tensor: int
class DataPipeModelrocessCoord(NamedTuple): class DataPipeTensorProcessCoord(NamedTuple):
data: int data: int
pipe: int pipe: int
model: int tensor: int
# Explicitly define these class to allow pickling. # Explicitly define these class to allow pickling.
PROCESS_COORD_REGISTRY = { PROCESS_COORD_REGISTRY = {
"pipe#data#model": PipeDataModelrocessCoord, "pipe#data#tensor": PipeDataTensorProcessCoord,
"data#pipe#model": DataPipeModelrocessCoord, "data#pipe#tensor": DataPipeTensorProcessCoord,
} }
@ -327,20 +327,20 @@ def _prime_factors(N):
return primes return primes
class PipeDataModelParallelTopology(ProcessTopology): class PipeDataTensorParallelTopology(ProcessTopology):
"""A topology for hybrid pipeline, model, and data parallelism.""" """A topology for hybrid pipeline, model, and data parallelism."""
def __init__( def __init__(
self, self,
num_pp: int, num_pp: int,
num_mp: int, num_tp: int,
num_dp: int, num_dp: int,
sequence_parallel: bool, sequence_parallel: bool,
gradient_checkpointing: bool, gradient_checkpointing: bool,
gradient_accumulation_fusion: bool, gradient_accumulation_fusion: bool,
max_prompt_len: Optional[int] = None, max_prompt_len: Optional[int] = None,
): ):
super().__init__(axes=["pipe", "data", "model"], dims=[num_pp, num_dp, num_mp]) super().__init__(axes=["pipe", "data", "tensor"], dims=[num_pp, num_dp, num_tp])
self.sequence_parallel = sequence_parallel self.sequence_parallel = sequence_parallel
self.gradient_checkpointing = gradient_checkpointing self.gradient_checkpointing = gradient_checkpointing
@ -348,7 +348,7 @@ class PipeDataModelParallelTopology(ProcessTopology):
self.gradient_accumulation_fusion = gradient_accumulation_fusion self.gradient_accumulation_fusion = gradient_accumulation_fusion
class DataPipeModelParallelTopology(ProcessTopology): class DataPipeTensorParallelTopology(ProcessTopology):
"""A topology for hybrid data, pipeline, and tensor parallelism. """A topology for hybrid data, pipeline, and tensor parallelism.
Note that DP is the most outer dimension. Used for inference only. Note that DP is the most outer dimension. Used for inference only.
@ -357,12 +357,12 @@ class DataPipeModelParallelTopology(ProcessTopology):
def __init__( def __init__(
self, self,
num_pp: int, num_pp: int,
num_mp: int, num_tp: int,
num_dp: int, num_dp: int,
sequence_parallel: bool, sequence_parallel: bool,
max_prompt_len: Optional[int] = None, max_prompt_len: Optional[int] = None,
): ):
super().__init__(axes=["data", "pipe", "model"], dims=[num_dp, num_pp, num_mp]) super().__init__(axes=["data", "pipe", "tensor"], dims=[num_dp, num_pp, num_tp])
self.sequence_parallel = sequence_parallel self.sequence_parallel = sequence_parallel
self.max_prompt_len = max_prompt_len self.max_prompt_len = max_prompt_len
@ -414,7 +414,7 @@ class ParallelGrid:
self.data_parallel_size = max(self._topo.get_dim("data"), 1) self.data_parallel_size = max(self._topo.get_dim("data"), 1)
self.pipe_parallel_size = max(self._topo.get_dim("pipe"), 1) self.pipe_parallel_size = max(self._topo.get_dim("pipe"), 1)
self.model_parallel_size = max(self._topo.get_dim("model"), 1) self.model_parallel_size = max(self._topo.get_dim("tensor"), 1)
self.slice_parallel_size = self.model_parallel_size self.slice_parallel_size = self.model_parallel_size
assert self._is_grid_valid(), ( assert self._is_grid_valid(), (
"Invalid Grid", "Invalid Grid",
@ -520,7 +520,7 @@ class ParallelGrid:
self.slice_group = None self.slice_group = None
self.slice_proc_group = self.slice_proc_group_gloo = None self.slice_proc_group = self.slice_proc_group_gloo = None
self.mp_group = [] self.mp_group = []
self.model_groups = self._topo.get_axis_comm_lists("model") self.model_groups = self._topo.get_axis_comm_lists("tensor")
for g in self.model_groups: for g in self.model_groups:
proc_group = new_or_get_group(ranks=[rank_mapping[x] for x in g]) proc_group = new_or_get_group(ranks=[rank_mapping[x] for x in g])
# NOTE: We must create the GLOO group for vLLM's usage. # NOTE: We must create the GLOO group for vLLM's usage.
@ -634,8 +634,8 @@ class ParallelGrid:
def get_tensor_model_parallel_rank(self): def get_tensor_model_parallel_rank(self):
if self.global_rank == -1: if self.global_rank == -1:
return -1 return -1
if "model" in self._topo.get_axis_names(): if "tensor" in self._topo.get_axis_names():
return self._topo.get_coord(rank=self.global_rank).model return self._topo.get_coord(rank=self.global_rank).tensor
else: else:
return 0 return 0
@ -662,12 +662,12 @@ class FakeGrid:
self.data_parallel_size = max(self._topo.get_dim("data"), 1) self.data_parallel_size = max(self._topo.get_dim("data"), 1)
self.pipe_parallel_size = max(self._topo.get_dim("pipe"), 1) self.pipe_parallel_size = max(self._topo.get_dim("pipe"), 1)
self.model_parallel_size = max(self._topo.get_dim("model"), 1) self.model_parallel_size = max(self._topo.get_dim("tensor"), 1)
self.coord: ProcessCoord = self._topo.get_coord(self.rank) self.coord = self._topo.get_coord(self.rank)
self.dp_id = self.coord.data self.dp_id = self.coord.data
self.pp_id = self.coord.pipe self.pp_id = self.coord.pipe
self.mp_id = self.coord.model self.mp_id = self.coord.tensor
self.world_size = ( self.world_size = (
self.data_parallel_size * self.pipe_parallel_size * self.model_parallel_size self.data_parallel_size * self.pipe_parallel_size * self.model_parallel_size

View File

@ -6,7 +6,11 @@ from typing import Any, Dict, List, Tuple
import realhf.base.logging as logging import realhf.base.logging as logging
from realhf.api.cli_args import ModelTrainEvalConfig, PPOMATHExperimentOptions from realhf.api.cli_args import ModelTrainEvalConfig, PPOMATHExperimentOptions
from realhf.api.core.config import AgentAbstraction, EnvServiceAbstraction from realhf.api.core.config import (
AgentAbstraction,
EnvServiceAbstraction,
ModelInterfaceAbstraction,
)
from realhf.api.core.model_api import GenerationHyperparameters from realhf.api.core.model_api import GenerationHyperparameters
from realhf.api.quickstart.entrypoint import register_quickstart_exp from realhf.api.quickstart.entrypoint import register_quickstart_exp
from realhf.experiments.async_exp.async_rl_exp import AsyncRLExperimentConfig from realhf.experiments.async_exp.async_rl_exp import AsyncRLExperimentConfig
@ -36,7 +40,7 @@ class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig):
@property @property
def env(self) -> EnvServiceAbstraction: def env(self) -> EnvServiceAbstraction:
return EnvServiceAbstraction( return EnvServiceAbstraction(
"math-single-step", args=dict(dataset_path=self.dataset.path) "math-code-single-step", args=dict(dataset_path=self.dataset.path)
) )
@property @property
@ -71,6 +75,11 @@ class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig):
rpcs["ref_inf"].output_keys = ("packed_ref_logprobs",) rpcs["ref_inf"].output_keys = ("packed_ref_logprobs",)
if "rew_inf" in rpcs: if "rew_inf" in rpcs:
rpcs.pop("rew_inf") rpcs.pop("rew_inf")
if self.no_training:
rpcs["actor_train"].interface_impl = ModelInterfaceAbstraction("null")
rpcs["actor_gen"].interface_impl = ModelInterfaceAbstraction("null")
if "actor_inf" in rpcs:
rpcs["actor_inf"].interface_impl = ModelInterfaceAbstraction("null")
return rpcs return rpcs
@property @property

17
realhf/experiments/async_exp/async_rl_exp.py Normal file → Executable file
View File

@ -60,7 +60,6 @@ GEN_WORKER_DEFAULT_CAPACITY = 512
@dataclasses.dataclass @dataclasses.dataclass
class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions): class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
@property @property
def generation_config(self) -> GenerationHyperparameters: def generation_config(self) -> GenerationHyperparameters:
raise NotImplementedError() raise NotImplementedError()
@ -203,16 +202,17 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
"config_from_hf_converter" "config_from_hf_converter"
](hf_config) ](hf_config)
if ( if (
model_config.n_kv_heads % rpc_alloc.parallel.model_parallel_size model_config.n_kv_heads % rpc_alloc.parallel.tensor_parallel_size
!= 0 != 0
) or ( ) or (
model_config.n_q_heads % rpc_alloc.parallel.model_parallel_size != 0 model_config.n_q_heads % rpc_alloc.parallel.tensor_parallel_size
!= 0
): ):
raise ValueError( raise ValueError(
f"The number of KV heads {model_config.n_kv_heads} or " f"The number of KV heads {model_config.n_kv_heads} or "
f"Q heads {model_config.n_q_heads} is not" f"Q heads {model_config.n_q_heads} is not"
f" divisible by the configured TP size " f" divisible by the configured TP size "
f"({rpc_alloc.parallel.model_parallel_size}). " f"({rpc_alloc.parallel.tensor_parallel_size}). "
f"Please decrease TP size." f"Please decrease TP size."
) )
mapping = rpc_alloc.device_mesh.mapping mapping = rpc_alloc.device_mesh.mapping
@ -250,7 +250,7 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
topo=topo, topo=topo,
dp_rank=topo.get_coord(shard_idx).data, dp_rank=topo.get_coord(shard_idx).data,
pp_rank=topo.get_coord(shard_idx).pipe, pp_rank=topo.get_coord(shard_idx).pipe,
mp_rank=topo.get_coord(shard_idx).model, tp_rank=topo.get_coord(shard_idx).tensor,
), ),
model=model, model=model,
backend=backend, backend=backend,
@ -308,15 +308,18 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
model_name = gen_rpc_alloc.rpc.model_name model_name = gen_rpc_alloc.rpc.model_name
train_rpcs = [alloc.rpc for alloc in rpc_allocs if alloc.rpc.is_train()] train_rpcs = [alloc.rpc for alloc in rpc_allocs if alloc.rpc.is_train()]
assert all(rpc.n_seqs == train_rpcs[0].n_seqs for rpc in train_rpcs) assert all(rpc.n_seqs == train_rpcs[0].n_seqs for rpc in train_rpcs)
max_concurrent_rollouts = self.max_concurrent_rollouts
if max_concurrent_rollouts is None:
max_concurrent_rollouts = train_rpcs[0].n_seqs
return [ return [
GserverManager( GserverManager(
model_name=model_name, model_name=model_name,
flush_request_timeout=self.flush_request_timeout, flush_request_timeout=self.flush_request_timeout,
n_servers=gen_world_size // gen_tp_size, n_servers=gen_world_size // gen_tp_size,
schedule_policy="round_robin", schedule_policy=self.schedule_policy,
max_head_offpolicyness=self.max_head_offpolicyness, max_head_offpolicyness=self.max_head_offpolicyness,
train_batch_size=train_rpcs[0].n_seqs, train_batch_size=train_rpcs[0].n_seqs,
max_concurrent_rollouts=self.max_concurrent_rollouts, max_concurrent_rollouts=max_concurrent_rollouts,
) )
] ]

View File

@ -37,21 +37,21 @@ def default_parallel_config(n_gpus: int) -> List[Dict[str, Any]]:
x = [ x = [
{ {
"data_parallel_size": dp, "data_parallel_size": dp,
"model_parallel_size": mp, "tensor_parallel_size": tp,
"pipeline_parallel_size": pp, "pipeline_parallel_size": pp,
"use_sequence_parallel": mp > 1, "use_sequence_parallel": tp > 1,
} }
for dp, mp, pp in factors for dp, tp, pp in factors
] ]
x += [ x += [
{ {
"data_parallel_size": dp, "data_parallel_size": dp,
"model_parallel_size": mp, "tensor_parallel_size": tp,
"pipeline_parallel_size": pp, "pipeline_parallel_size": pp,
"use_sequence_parallel": False, "use_sequence_parallel": False,
} }
for dp, mp, pp in factors for dp, tp, pp in factors
if mp > 1 if tp > 1
] ]
return x return x
@ -122,7 +122,7 @@ class ProfileConfig(CommonExperimentConfig):
k k
in [ in [
"data_parallel_size", "data_parallel_size",
"model_parallel_size", "tensor_parallel_size",
"pipeline_parallel_size", "pipeline_parallel_size",
"use_sequence_parallel", "use_sequence_parallel",
] ]
@ -130,7 +130,7 @@ class ProfileConfig(CommonExperimentConfig):
), pcfg.keys() ), pcfg.keys()
assert (self.n_nodes * self.n_gpus_per_node) == ( assert (self.n_nodes * self.n_gpus_per_node) == (
pcfg.get("data_parallel_size", 1) pcfg.get("data_parallel_size", 1)
* pcfg.get("model_parallel_size", 1) * pcfg.get("tensor_parallel_size", 1)
* pcfg.get("pipeline_parallel_size", 1) * pcfg.get("pipeline_parallel_size", 1)
) )
@ -246,8 +246,6 @@ class ProfileConfig(CommonExperimentConfig):
model_name="default", model_name="default",
input_keys=["packed_prompts"], input_keys=["packed_prompts"],
log_return_value=False, log_return_value=False,
model_type=self._tmp_model.type,
model_path=self._tmp_model.path,
balanced_dp=True, balanced_dp=True,
) )

View File

@ -70,7 +70,7 @@ def check_valid_parallel_batch_size(rpc_alloc: RPCAllocation):
mb_spec = rpc.mb_spec mb_spec = rpc.mb_spec
dp_size = rpc_alloc.parallel.data_parallel_size dp_size = rpc_alloc.parallel.data_parallel_size
tp_size = rpc_alloc.parallel.model_parallel_size tp_size = rpc_alloc.parallel.tensor_parallel_size
pp_size = rpc_alloc.parallel.pipeline_parallel_size pp_size = rpc_alloc.parallel.pipeline_parallel_size
factor = 1 factor = 1

View File

@ -44,7 +44,6 @@ from realhf.api.quickstart.device_mesh import (
) )
from realhf.base.cluster import spec as cluster_spec from realhf.base.cluster import spec as cluster_spec
from realhf.experiments.common.check import ( from realhf.experiments.common.check import (
check_is_realhf_native_model_interface,
check_valid_model_and_path, check_valid_model_and_path,
check_valid_optimizer, check_valid_optimizer,
check_valid_parallel_batch_size, check_valid_parallel_batch_size,
@ -61,7 +60,6 @@ from realhf.experiments.common.utils import (
resolve_replica_ids, resolve_replica_ids,
resolve_rpc_hooks, resolve_rpc_hooks,
) )
from realhf.search_engine.search import search_rpc_allocations
# Register all HF models # Register all HF models
import realhf.api.from_hf # isort:skip import realhf.api.from_hf # isort:skip
@ -144,10 +142,6 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
""" """
return None return None
@property
def search_kwargs(self) -> Dict[str, Any]:
return {}
@property @property
def global_device_mesh(self) -> DeviceMesh: def global_device_mesh(self) -> DeviceMesh:
return DeviceMesh( return DeviceMesh(
@ -161,20 +155,6 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
f"_heuristic_rpc_allocation is not implemented in {self.__class__}" f"_heuristic_rpc_allocation is not implemented in {self.__class__}"
) )
def _search(self):
# called in both api.main and controller
gradient_checkpointing = any(
model.gradient_checkpointing for model in self.models.values()
)
rpc_allocs: List[RPCAllocation] = search_rpc_allocations(
device_mesh=self.global_device_mesh,
rpcs=list(self.rpcs.values()),
gradient_checkpointing=gradient_checkpointing,
use_cache=self.allocation_use_cache,
**self.search_kwargs,
)
return rpc_allocs
def scheduling_setup(self) -> ExperimentScheduling: def scheduling_setup(self) -> ExperimentScheduling:
"""The resourced occupied by each worker. """The resourced occupied by each worker.
@ -221,24 +201,11 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
self._check_legal_allocation_options() self._check_legal_allocation_options()
rpcs = self.rpcs rpcs = self.rpcs
if self.allocation_mode == "search": if self._allocation_mode.is_decoupled():
# assert self.mode == "slurm"
# assumes gradient checkpointing for all training RPCs if one is enabled
# for the simplicity of search configurations
rpc_allocs = self._search()
for rpc_alloc in rpc_allocs:
assert isinstance(rpc_alloc.rpc, str)
for rpc in rpcs.values():
if rpc.name == rpc_alloc.rpc:
rpc_alloc.rpc = rpc
break
else:
raise ValueError(f"RPC {rpc_alloc.rpc} not found in rpcs.")
elif self._allocation_mode.is_decoupled():
paras = self._allocation_mode.parallel_strat paras = self._allocation_mode.parallel_strat
gdp, gpp, gmp = paras["gen"]["d"], paras["gen"]["p"], paras["gen"]["m"] gdp, gpp, gtp = paras["gen"]["d"], paras["gen"]["p"], paras["gen"]["m"]
gen_world_size = gdp * gpp * gmp gen_world_size = gdp * gpp * gtp
assert ( assert (
gen_world_size < self.n_gpus_per_node gen_world_size < self.n_gpus_per_node
or gen_world_size % self.n_gpus_per_node == 0 or gen_world_size % self.n_gpus_per_node == 0
@ -268,7 +235,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
parallel=ParallelismConfig( parallel=ParallelismConfig(
data_parallel_size=gdp, data_parallel_size=gdp,
pipeline_parallel_size=gpp, pipeline_parallel_size=gpp,
model_parallel_size=gmp, tensor_parallel_size=gtp,
use_sequence_parallel=False, use_sequence_parallel=False,
), ),
) )
@ -276,7 +243,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
else: else:
rpc_name = rpc.name rpc_name = rpc.name
if rpc_name in paras: if rpc_name in paras:
dp, pp, mp = ( dp, pp, tp = (
paras[rpc_name]["d"], paras[rpc_name]["d"],
paras[rpc_name]["p"], paras[rpc_name]["p"],
paras[rpc_name]["m"], paras[rpc_name]["m"],
@ -287,9 +254,9 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
f"RPC {rpc_name} parallel strategy not given, " f"RPC {rpc_name} parallel strategy not given, "
"expect a `*` to specify the default parallel strategy." "expect a `*` to specify the default parallel strategy."
) )
dp, pp, mp = paras["*"]["d"], paras["*"]["p"], paras["*"]["m"] dp, pp, tp = paras["*"]["d"], paras["*"]["p"], paras["*"]["m"]
if ( if (
dp * pp * mp + gdp * gpp * gmp dp * pp * tp + gdp * gpp * gtp
!= self.n_nodes * self.n_gpus_per_node != self.n_nodes * self.n_gpus_per_node
): ):
raise ValueError( raise ValueError(
@ -297,7 +264,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
"does not equal to the number of gpus. " "does not equal to the number of gpus. "
"Note that the device mesh of vLLM/SGLang should be disjoint from the device mesh of other MFCs, " "Note that the device mesh of vLLM/SGLang should be disjoint from the device mesh of other MFCs, "
"so their summation should be equal to the total number of gpus. " "so their summation should be equal to the total number of gpus. "
f"dp={dp}, pp={pp}, mp={mp}, gen.dp={gdp}, gen.pp={gpp}, gen.mp={gmp}, " f"dp={dp}, pp={pp}, mp={tp}, gen.dp={gdp}, gen.pp={gpp}, gen.mp={gtp}, "
f"n_nodes={self.n_nodes}, n_gpus_per_node={self.n_gpus_per_node}" f"n_nodes={self.n_nodes}, n_gpus_per_node={self.n_gpus_per_node}"
) )
alloc = RPCAllocation( alloc = RPCAllocation(
@ -306,10 +273,10 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
parallel=ParallelismConfig( parallel=ParallelismConfig(
data_parallel_size=dp, data_parallel_size=dp,
pipeline_parallel_size=pp, pipeline_parallel_size=pp,
model_parallel_size=mp, tensor_parallel_size=tp,
use_sequence_parallel=( use_sequence_parallel=(
rpc.interface_type == ModelInterfaceType.TRAIN_STEP rpc.interface_type == ModelInterfaceType.TRAIN_STEP
and mp > 1 and tp > 1
), ),
), ),
) )
@ -323,7 +290,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
rpc_allocs = [] rpc_allocs = []
for rpc_name, rpc in self.rpcs.items(): for rpc_name, rpc in self.rpcs.items():
if rpc_name in paras: if rpc_name in paras:
dp, pp, mp = ( dp, pp, tp = (
paras[rpc_name]["d"], paras[rpc_name]["d"],
paras[rpc_name]["p"], paras[rpc_name]["p"],
paras[rpc_name]["m"], paras[rpc_name]["m"],
@ -334,18 +301,18 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
f"RPC {rpc_name} parallel strategy not given, " f"RPC {rpc_name} parallel strategy not given, "
"expect a `*` to specify the default parallel strategy." "expect a `*` to specify the default parallel strategy."
) )
dp, pp, mp = paras["*"]["d"], paras["*"]["p"], paras["*"]["m"] dp, pp, tp = paras["*"]["d"], paras["*"]["p"], paras["*"]["m"]
assert dp * pp * mp == self.n_nodes * self.n_gpus_per_node assert dp * pp * tp == self.n_nodes * self.n_gpus_per_node
alloc = RPCAllocation( alloc = RPCAllocation(
rpc=rpc, rpc=rpc,
device_mesh=self.global_device_mesh, device_mesh=self.global_device_mesh,
parallel=ParallelismConfig( parallel=ParallelismConfig(
data_parallel_size=dp, data_parallel_size=dp,
pipeline_parallel_size=pp, pipeline_parallel_size=pp,
model_parallel_size=mp, tensor_parallel_size=tp,
use_sequence_parallel=( use_sequence_parallel=(
rpc.interface_type == ModelInterfaceType.TRAIN_STEP rpc.interface_type == ModelInterfaceType.TRAIN_STEP
and mp > 1 and tp > 1
), ),
), ),
) )
@ -455,7 +422,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
topo=topo, topo=topo,
dp_rank=topo.get_coord(shard_idx).data, dp_rank=topo.get_coord(shard_idx).data,
pp_rank=topo.get_coord(shard_idx).pipe, pp_rank=topo.get_coord(shard_idx).pipe,
mp_rank=topo.get_coord(shard_idx).model, tp_rank=topo.get_coord(shard_idx).tensor,
), ),
model=ModelAbstraction( model=ModelAbstraction(
"tokenizer", args=dict(tokenizer_path=model_cfg.path) "tokenizer", args=dict(tokenizer_path=model_cfg.path)
@ -464,7 +431,6 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
gen_backend_name, gen_backend_name,
args=dict( args=dict(
model_path=model_cfg.path, model_path=model_cfg.path,
dtype="bfloat16" if model_cfg.bf16 else "float16",
**dict_args, **dict_args,
), ),
), ),
@ -503,16 +469,17 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
"config_from_hf_converter" "config_from_hf_converter"
](hf_config) ](hf_config)
if ( if (
model_config.n_kv_heads % rpc_alloc.parallel.model_parallel_size model_config.n_kv_heads % rpc_alloc.parallel.tensor_parallel_size
!= 0 != 0
) or ( ) or (
model_config.n_q_heads % rpc_alloc.parallel.model_parallel_size != 0 model_config.n_q_heads % rpc_alloc.parallel.tensor_parallel_size
!= 0
): ):
raise ValueError( raise ValueError(
f"The number of KV heads {model_config.n_kv_heads} or " f"The number of KV heads {model_config.n_kv_heads} or "
f"Q heads {model_config.n_q_heads} is not" f"Q heads {model_config.n_q_heads} is not"
f" divisible by the configured TP size " f" divisible by the configured TP size "
f"({rpc_alloc.parallel.model_parallel_size}). " f"({rpc_alloc.parallel.tensor_parallel_size}). "
f"Please decrease TP size." f"Please decrease TP size."
) )
mapping = rpc_alloc.device_mesh.mapping mapping = rpc_alloc.device_mesh.mapping
@ -572,7 +539,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
topo=topo, topo=topo,
dp_rank=topo.get_coord(shard_idx).data, dp_rank=topo.get_coord(shard_idx).data,
pp_rank=topo.get_coord(shard_idx).pipe, pp_rank=topo.get_coord(shard_idx).pipe,
mp_rank=topo.get_coord(shard_idx).model, tp_rank=topo.get_coord(shard_idx).tensor,
), ),
model=model, model=model,
backend=backend, backend=backend,
@ -612,12 +579,9 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
"please setup slurm for distributed runs." "please setup slurm for distributed runs."
) )
if self.n_gpus_per_node != 8 and self.allocation_mode in [ if self.n_gpus_per_node != 8 and self.allocation_mode == "heuristic":
"search",
"heuristic",
]:
raise ValueError( raise ValueError(
f"Cannot run search or heuristic allocation with " f"Cannot run heuristic allocation with "
f"n_gpus_per_node {self.n_gpus_per_node}, " f"n_gpus_per_node {self.n_gpus_per_node}, "
"please set n_gpus_per_node to 8." "please set n_gpus_per_node to 8."
) )
@ -627,13 +591,6 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
raise KeyError( raise KeyError(
f"RPC name {rpc_name} does not match the name in the MFCDef object {rpc.name}." f"RPC name {rpc_name} does not match the name in the MFCDef object {rpc.name}."
) )
if not check_is_realhf_native_model_interface(
rpc.interface_impl.type_
) and self.allocation_mode in ["search"]:
raise ValueError(
f"RPC {rpc.name} interface is not a realhf native implementation. "
f"The search allocation mode are not available."
)
if self.allocation_mode == "manual" and rpc_name not in self.allocations: if self.allocation_mode == "manual" and rpc_name not in self.allocations:
if rpc_name not in self.allocations: if rpc_name not in self.allocations:
raise ValueError( raise ValueError(

View File

@ -65,8 +65,6 @@ class MathCodeEvalConfig(MathCodeEvalOptions, CommonExperimentConfig):
model_name="actor", model_name="actor",
mb_spec=self.actor_gen.mb_spec, mb_spec=self.actor_gen.mb_spec,
interface_type=ModelInterfaceType.GENERATE, interface_type=ModelInterfaceType.GENERATE,
model_type=self.actor.type,
model_path=self.actor.path,
interface_impl=actor_interface, interface_impl=actor_interface,
input_keys=("packed_prompts", "task_ids"), input_keys=("packed_prompts", "task_ids"),
output_keys=("packed_input_ids",), output_keys=("packed_input_ids",),
@ -79,8 +77,6 @@ class MathCodeEvalConfig(MathCodeEvalOptions, CommonExperimentConfig):
mb_spec=self.rew_inf.mb_spec, mb_spec=self.rew_inf.mb_spec,
interface_type=ModelInterfaceType.INFERENCE, interface_type=ModelInterfaceType.INFERENCE,
interface_impl=rw_interface, interface_impl=rw_interface,
model_type=self.rew.type,
model_path=self.rew.path,
min_n_seqs_per_pass=1 / self.group_size, min_n_seqs_per_pass=1 / self.group_size,
input_keys=("packed_input_ids", "packed_prompts", "task_ids"), input_keys=("packed_input_ids", "packed_prompts", "task_ids"),
output_keys=("rewards",), output_keys=("rewards",),

View File

@ -39,8 +39,6 @@ class NullSFTConfig(CommonExperimentConfig, SFTExperimentOptions):
model_name="default", model_name="default",
input_keys=("packed_input_ids", "prompt_mask"), input_keys=("packed_input_ids", "prompt_mask"),
log_return_value=True, log_return_value=True,
model_type=self.model.type,
model_path=self.model.path,
) )
return {"trainDefault": rpc} return {"trainDefault": rpc}
@ -88,8 +86,6 @@ class NullPPOConfig(CommonExperimentConfig, NullPPOExperimentOptions):
model_name="default", model_name="default",
input_keys=("packed_prompts",), input_keys=("packed_prompts",),
output_keys=("rewards",), output_keys=("rewards",),
model_type=self.model.type,
model_path=self.model.path,
) )
rpc = MFCDef( rpc = MFCDef(
n_seqs=self.dataset.train_bs_n_seqs, n_seqs=self.dataset.train_bs_n_seqs,
@ -100,8 +96,6 @@ class NullPPOConfig(CommonExperimentConfig, NullPPOExperimentOptions):
model_name="default", model_name="default",
input_keys=("packed_prompts", "rewards"), input_keys=("packed_prompts", "rewards"),
log_return_value=True, log_return_value=True,
model_type=self.model.type,
model_path=self.model.path,
) )
return {"trainDefault": rpc, "reward": rw} return {"trainDefault": rpc, "reward": rw}

View File

@ -148,32 +148,31 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
"packed_logprobs", "packed_logprobs",
"prompt_mask", "prompt_mask",
] ]
if self.ppo.recompute_logprob: if self.ppo.recompute_logprob and not self.ppo.use_decoupled_loss:
rollout_output_keys.remove("packed_logprobs") rollout_output_keys.remove("packed_logprobs")
rollout = MFCDef( rollout = MFCDef(
name="actor_gen", name="actor_gen",
model_name="actor", model_name="actor",
mb_spec=self.actor_gen.mb_spec, mb_spec=self.actor_gen.mb_spec,
interface_type=ModelInterfaceType.GENERATE, interface_type=ModelInterfaceType.GENERATE,
model_type=self.actor.type,
model_path=self.actor.path,
interface_impl=actor_interface, interface_impl=actor_interface,
input_keys=("packed_prompts", "task_ids"), input_keys=("packed_prompts", "task_ids"),
output_keys=tuple(rollout_output_keys), output_keys=tuple(rollout_output_keys),
n_seqs=self.dataset.train_bs_n_seqs, n_seqs=self.dataset.train_bs_n_seqs,
) )
actor_inf_outputs = ("packed_logprobs",)
if self.ppo.use_decoupled_loss:
actor_inf_outputs = ("proximal_logprobs",)
actor_inf = MFCDef( actor_inf = MFCDef(
name="actor_inf", name="actor_inf",
model_name="actor", model_name="actor",
mb_spec=self.actor_inf.mb_spec, mb_spec=self.actor_inf.mb_spec,
interface_type=ModelInterfaceType.INFERENCE, interface_type=ModelInterfaceType.INFERENCE,
model_type=self.actor.type,
model_path=self.actor.path,
interface_impl=actor_interface, interface_impl=actor_interface,
input_keys=("packed_input_ids",), input_keys=("packed_input_ids",),
output_keys=("packed_logprobs",), output_keys=actor_inf_outputs,
output_key_remap=dict(logprobs="packed_logprobs"), output_key_remap=dict(logprobs=actor_inf_outputs[0]),
n_seqs=self.dataset.train_bs_n_seqs, n_seqs=self.dataset.train_bs_n_seqs,
) )
@ -200,8 +199,6 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
model_name="ref", model_name="ref",
mb_spec=self.ref_inf.mb_spec, mb_spec=self.ref_inf.mb_spec,
interface_type=ModelInterfaceType.INFERENCE, interface_type=ModelInterfaceType.INFERENCE,
model_type=self.ref.type,
model_path=self.ref.path,
interface_impl=ref_interface, interface_impl=ref_interface,
min_n_seqs_per_pass=1 / self.group_size, min_n_seqs_per_pass=1 / self.group_size,
input_keys=tuple(inf_ref_inputs), input_keys=tuple(inf_ref_inputs),
@ -216,8 +213,6 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
mb_spec=self.critic_inf.mb_spec, mb_spec=self.critic_inf.mb_spec,
interface_type=ModelInterfaceType.INFERENCE, interface_type=ModelInterfaceType.INFERENCE,
interface_impl=critic_interface, interface_impl=critic_interface,
model_type=self.critic.type,
model_path=self.critic.path,
min_n_seqs_per_pass=1 / self.group_size, min_n_seqs_per_pass=1 / self.group_size,
input_keys=("packed_input_ids", "seq_no_eos_mask"), input_keys=("packed_input_ids", "seq_no_eos_mask"),
output_keys=("values",), output_keys=("values",),
@ -238,13 +233,13 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
train_actor_inputs.remove("values") train_actor_inputs.remove("values")
if self.ppo.kl_ctl == 0: if self.ppo.kl_ctl == 0:
train_actor_inputs.remove("packed_ref_logprobs") train_actor_inputs.remove("packed_ref_logprobs")
if self.ppo.use_decoupled_loss:
train_actor_inputs.append("proximal_logprobs")
train_actor = MFCDef( train_actor = MFCDef(
name="actor_train", name="actor_train",
model_name="actor", model_name="actor",
mb_spec=self.actor_train.mb_spec, mb_spec=self.actor_train.mb_spec,
interface_type=ModelInterfaceType.TRAIN_STEP, interface_type=ModelInterfaceType.TRAIN_STEP,
model_type=self.actor.type,
model_path=self.actor.path,
interface_impl=actor_interface, interface_impl=actor_interface,
input_keys=tuple(train_actor_inputs), input_keys=tuple(train_actor_inputs),
log_return_value=True, log_return_value=True,
@ -269,8 +264,6 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
mb_spec=self.critic_train.mb_spec, mb_spec=self.critic_train.mb_spec,
interface_type=ModelInterfaceType.TRAIN_STEP, interface_type=ModelInterfaceType.TRAIN_STEP,
interface_impl=critic_interface, interface_impl=critic_interface,
model_type=self.critic.type,
model_path=self.critic.path,
input_keys=tuple(train_critic_inputs), input_keys=tuple(train_critic_inputs),
log_return_value=True, log_return_value=True,
min_n_seqs_per_pass=self.ppo.ppo_n_minibatches / self.group_size, min_n_seqs_per_pass=self.ppo.ppo_n_minibatches / self.group_size,
@ -289,7 +282,7 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
if self.ppo.disable_value: if self.ppo.disable_value:
rpcs.pop("critic_inf") rpcs.pop("critic_inf")
rpcs.pop("critic_train") rpcs.pop("critic_train")
if not self.ppo.recompute_logprob: if not self.ppo.recompute_logprob and not self.ppo.use_decoupled_loss:
rpcs.pop("actor_inf") rpcs.pop("actor_inf")
if self.ppo.kl_ctl == 0: if self.ppo.kl_ctl == 0:
rpcs.pop("ref_inf") rpcs.pop("ref_inf")
@ -311,7 +304,7 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
if self.ppo.disable_value: if self.ppo.disable_value:
allocs.pop("critic_inf") allocs.pop("critic_inf")
allocs.pop("critic_train") allocs.pop("critic_train")
if not self.ppo.recompute_logprob: if not self.ppo.recompute_logprob and not self.ppo.use_decoupled_loss:
allocs.pop("actor_inf") allocs.pop("actor_inf")
if self.ppo.kl_ctl == 0: if self.ppo.kl_ctl == 0:
allocs.pop("ref_inf") allocs.pop("ref_inf")
@ -337,14 +330,6 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
def tokenizer_name_or_path(self) -> str: def tokenizer_name_or_path(self) -> str:
return self.actor.path return self.actor.path
@property
def search_kwargs(self):
return {
"num_gen_tokens": self.ppo.gen.max_new_tokens,
"n_ppo_minibatches": self.ppo.ppo_n_minibatches,
"seq_len": self.dataset.max_prompt_len,
}
@property @property
def max_prompt_len(self): def max_prompt_len(self):
return self.dataset.max_prompt_len return self.dataset.max_prompt_len

View File

@ -36,8 +36,6 @@ class SFTConfig(CommonExperimentConfig, SFTExperimentOptions):
model_name="default", model_name="default",
input_keys=("packed_input_ids", "prompt_mask"), input_keys=("packed_input_ids", "prompt_mask"),
log_return_value=True, log_return_value=True,
model_type=self.model.type,
model_path=self.model.path,
) )
return {"trainDefault": rpc} return {"trainDefault": rpc}

View File

@ -34,8 +34,8 @@ from realhf.api.core.dfg import OffloadHook, ParamReallocHook
from realhf.api.quickstart.device_mesh import RPCAllocation from realhf.api.quickstart.device_mesh import RPCAllocation
from realhf.base import logging from realhf.base import logging
from realhf.base.topology import ( from realhf.base.topology import (
DataPipeModelParallelTopology, DataPipeTensorParallelTopology,
PipeDataModelParallelTopology, PipeDataTensorParallelTopology,
ProcessTopology, ProcessTopology,
) )
@ -73,8 +73,8 @@ def get_topo(
max_prompt_len: Optional[int] = None, max_prompt_len: Optional[int] = None,
) -> ProcessTopology: ) -> ProcessTopology:
if is_train: if is_train:
return PipeDataModelParallelTopology( return PipeDataTensorParallelTopology(
num_mp=parallel.model_parallel_size, num_tp=parallel.tensor_parallel_size,
num_pp=parallel.pipeline_parallel_size, num_pp=parallel.pipeline_parallel_size,
num_dp=parallel.data_parallel_size, num_dp=parallel.data_parallel_size,
sequence_parallel=parallel.use_sequence_parallel, sequence_parallel=parallel.use_sequence_parallel,
@ -82,8 +82,8 @@ def get_topo(
max_prompt_len=max_prompt_len, max_prompt_len=max_prompt_len,
gradient_accumulation_fusion=gradient_accumulation_fusion, gradient_accumulation_fusion=gradient_accumulation_fusion,
) )
return DataPipeModelParallelTopology( return DataPipeTensorParallelTopology(
num_mp=parallel.model_parallel_size, num_tp=parallel.tensor_parallel_size,
num_pp=parallel.pipeline_parallel_size, num_pp=parallel.pipeline_parallel_size,
num_dp=parallel.data_parallel_size, num_dp=parallel.data_parallel_size,
sequence_parallel=parallel.use_sequence_parallel, sequence_parallel=parallel.use_sequence_parallel,
@ -93,7 +93,7 @@ def get_topo(
def get_world_size(parallel: ParallelismConfig) -> int: def get_world_size(parallel: ParallelismConfig) -> int:
return ( return (
parallel.model_parallel_size parallel.tensor_parallel_size
* parallel.pipeline_parallel_size * parallel.pipeline_parallel_size
* parallel.data_parallel_size * parallel.data_parallel_size
) )
@ -247,9 +247,8 @@ class AllocationType(enum.Enum):
GLOBAL_HYBRID = 2 GLOBAL_HYBRID = 2
MANUAL = 3 MANUAL = 3
HEURISTIC = 4 HEURISTIC = 4
SEARCH = 5 DECOUPLED_SGLANG = 5
DECOUPLED_SGLANG = 6 DECOUPLED_MOCK = 6
DECOUPLED_MOCK = 7
@dataclasses.dataclass @dataclasses.dataclass
@ -293,8 +292,6 @@ class AllocationMode:
return cls(AllocationType.MANUAL, None) return cls(AllocationType.MANUAL, None)
if allocation_mode == "heuristic": if allocation_mode == "heuristic":
return cls(AllocationType.HEURISTIC, None) return cls(AllocationType.HEURISTIC, None)
if allocation_mode == "search":
return cls(AllocationType.SEARCH, None)
alloc_3d = AllocationMode.extract_3d_alloc(allocation_mode) alloc_3d = AllocationMode.extract_3d_alloc(allocation_mode)
alloc_hybrid = AllocationMode.extract_key_value_alloc(allocation_mode) alloc_hybrid = AllocationMode.extract_key_value_alloc(allocation_mode)

View File

@ -1 +1 @@
import realhf.impl.environment.math_single_step_env import realhf.impl.environment.math_code_single_step_env

View File

@ -0,0 +1,75 @@
# Copyright 2025 Ant Group Inc.
import asyncio
import os
import re
from typing import List, Tuple
from functioncall.code.local_verify import code_verify as local_code_verify
from functioncall.code.verify import code_verify
from functioncall.math.verify import math_verify
from realhf.api.core.env_api import EnvironmentService, register_environment
from realhf.base import logging
from realhf.impl.dataset.math_code_dataset import load_metadata
from realhf.impl.dataset.math_parser import parse_lines_in_parallel
ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel
code_verify_call = code_verify if ENABLE_FUNCTION_CALL else local_code_verify
logger = logging.getLogger("Math Single Step Environment")
def extract_code(text, min_length=20):
code_pattern = r"(?i)```(?:python|py|cpp|CPP)?\s*\n?(.*?)\n?```"
code_blocks = re.findall(code_pattern, text, re.DOTALL)
valid_blocks = []
for block in code_blocks:
clean_block = block.strip()
if len(clean_block) < min_length:
continue
valid_blocks.append(clean_block)
if not valid_blocks:
# logger.warning(f"failed to extract python code from {text}")
return None
# return the last code block
return valid_blocks[-1]
class MathCodeSingleStepEnv(EnvironmentService):
def __init__(self, dataset_path: str):
self.id2info, _ = load_metadata(dataset_path)
async def reset(self, seed=None, options=None):
return None, {}
async def step(self, action: Tuple[str, List[str]]):
qid, answers = action
group_size = len(answers)
qid = qid.split("@")[0]
cur_task = self.id2info[qid]["task"]
if cur_task == "math":
format_rewards = await asyncio.to_thread(
math_verify_call,
self.id2info,
answers,
[qid for _ in range(group_size)],
)
elif cur_task == "code":
answers = [extract_code(x) for x in answers]
format_rewards = await asyncio.to_thread(
code_verify_call,
self.id2info,
answers,
[qid for _ in range(group_size)],
)
else:
raise NotImplementedError()
return None, format_rewards, True, False, {}
register_environment("math-code-single-step", MathCodeSingleStepEnv)

View File

@ -1,38 +0,0 @@
# Copyright 2025 Ant Group Inc.
import asyncio
import os
from typing import List, Tuple
from functioncall.math.verify import math_verify
from realhf.api.core.env_api import EnvironmentService, register_environment
from realhf.base import logging
from realhf.impl.dataset.math_code_dataset import load_metadata
from realhf.impl.dataset.math_parser import parse_lines_in_parallel
ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel
logger = logging.getLogger("Math Single Step Environment")
class MathSingleStepEnv(EnvironmentService):
def __init__(self, dataset_path: str):
self.id2info, _ = load_metadata(dataset_path)
async def reset(self, seed=None, options=None):
return None, {}
async def step(self, action: Tuple[str, List[str]]):
qid, answers = action
group_size = len(answers)
format_rewards = await asyncio.to_thread(
math_verify_call,
self.id2info,
answers,
[qid for _ in range(group_size)],
)
return None, format_rewards, True, False, {}
register_environment("math-single-step", MathSingleStepEnv)

View File

@ -13,7 +13,7 @@ import realhf.api.from_hf
import realhf.base.logging as logging import realhf.base.logging as logging
from realhf.api.core.model_api import HF_MODEL_FAMILY_REGISTRY from realhf.api.core.model_api import HF_MODEL_FAMILY_REGISTRY
from realhf.base.importing import import_module from realhf.base.importing import import_module
from realhf.base.pkg_version import is_version_less from realhf.base.pkg_version import is_available, is_version_less
from realhf.impl.model.conversion.hf_registry import HFModelRegistry from realhf.impl.model.conversion.hf_registry import HFModelRegistry
from realhf.impl.model.nn.real_llm_api import ReaLModel from realhf.impl.model.nn.real_llm_api import ReaLModel
@ -27,8 +27,9 @@ import_module(os.path.join(_filepath, "nn"), _p)
# NOTE: skip importing vLLM for now to avoid an # NOTE: skip importing vLLM for now to avoid an
# "invalid device context" issue for the 25.01 image # "invalid device context" issue for the 25.01 image
if is_version_less("vllm", "0.6.4"): if is_available("vllm") and is_version_less("vllm", "0.6.4"):
import realhf.impl.model.backend.vllm import realhf.impl.model.backend.vllm
import realhf.impl.model.backend.inference import realhf.impl.model.backend.inference
import realhf.impl.model.backend.megatron import realhf.impl.model.backend.megatron
import realhf.impl.model.backend.mock_train import realhf.impl.model.backend.mock_train

View File

@ -62,7 +62,7 @@ class PipelinableInferenceEngine(model_api.PipelinableEngine):
f"num_layers(this stage)={self.module.num_layers} " f"num_layers(this stage)={self.module.num_layers} "
f"pp_size={constants.pipe_parallel_world_size()} " f"pp_size={constants.pipe_parallel_world_size()} "
f"dp_size={constants.data_parallel_world_size()} " f"dp_size={constants.data_parallel_world_size()} "
f"mp_size={constants.model_parallel_world_size()} " f"tp_size={constants.tensor_parallel_world_size()} "
) )
if constants.data_parallel_rank() == 0: if constants.data_parallel_rank() == 0:
logger.info( logger.info(

View File

@ -126,7 +126,7 @@ def megatron_ctx():
# Build the tensor model-parallel groups. # Build the tensor model-parallel groups.
parallel_state._TENSOR_MODEL_PARALLEL_GROUP = g parallel_state._TENSOR_MODEL_PARALLEL_GROUP = g
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"): if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
g = constants.model_parallel_group() g = constants.tensor_parallel_group()
parallel_state._TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = ( parallel_state._TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = (
dist.get_process_group_ranks(g) dist.get_process_group_ranks(g)
) )
@ -155,7 +155,7 @@ def megatron_ctx():
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"): if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
# Build the tensor + context parallel groups # Build the tensor + context parallel groups
parallel_state._TENSOR_AND_CONTEXT_PARALLEL_GROUP = ( parallel_state._TENSOR_AND_CONTEXT_PARALLEL_GROUP = (
constants.model_parallel_group() constants.tensor_parallel_group()
) )
# Build expert parallel groups. # Build expert parallel groups.
@ -173,7 +173,7 @@ def megatron_ctx():
) )
else: else:
parallel_state._TENSOR_AND_EXPERT_PARALLEL_GROUP = ( parallel_state._TENSOR_AND_EXPERT_PARALLEL_GROUP = (
constants.model_parallel_group() constants.tensor_parallel_group()
) )
parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP = ( parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP = (
constants.data_parallel_group() constants.data_parallel_group()
@ -227,7 +227,7 @@ class MegatronEngine:
def _all_reduce_layernorm_grads(self): def _all_reduce_layernorm_grads(self):
if not ( if not (
constants.sequence_parallel() and constants.model_parallel_world_size() > 1 constants.sequence_parallel() and constants.tensor_parallel_world_size() > 1
): ):
return return
real_model: ReaLModel = self.ddp.module real_model: ReaLModel = self.ddp.module
@ -255,7 +255,7 @@ class MegatronEngine:
assert all(x is not None for x in grads) assert all(x is not None for x in grads)
coalesced = _flatten_dense_tensors(grads) coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, group=constants.model_parallel_group()) dist.all_reduce(coalesced, group=constants.tensor_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced) buf.copy_(synced)
@ -362,7 +362,10 @@ class PipeTrainInstrSetForMegatron(PipeTrainInstrSet):
) )
dist.all_reduce(grad_norm, group=constants.tp_and_pp_group()) dist.all_reduce(grad_norm, group=constants.tp_and_pp_group())
grad_norm /= constants.tp_and_pp_world_size() grad_norm /= constants.tp_and_pp_world_size()
if constants.data_parallel_rank() == 0 and constants.model_parallel_rank() == 0: if (
constants.data_parallel_rank() == 0
and constants.tensor_parallel_rank() == 0
):
logger.info( logger.info(
f"Model name {constants.model_name()}, " f"Model name {constants.model_name()}, "
f"Pipeline rank {constants.pipe_parallel_rank()}. " f"Pipeline rank {constants.pipe_parallel_rank()}. "
@ -539,7 +542,10 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
) )
dist.all_reduce(grad_norm, group=constants.tp_and_pp_group()) dist.all_reduce(grad_norm, group=constants.tp_and_pp_group())
grad_norm /= constants.tp_and_pp_world_size() grad_norm /= constants.tp_and_pp_world_size()
if constants.data_parallel_rank() == 0 and constants.model_parallel_rank() == 0: if (
constants.data_parallel_rank() == 0
and constants.tensor_parallel_rank() == 0
):
logger.info( logger.info(
f"Megatron backend update success? {update_successful}. " f"Megatron backend update success? {update_successful}. "
f"Grad Norm: {grad_norm}. " f"Grad Norm: {grad_norm}. "
@ -700,7 +706,8 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
# Deleting models directly will not release the memory. # Deleting models directly will not release the memory.
# We must disable hooks at first. # We must disable hooks at first.
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"): if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
model.module.engine.ddp.disable_forward_pre_hook() if self.ddp.use_distributed_optimizer and self.ddp.overlap_param_gather:
model.module.engine.ddp.disable_forward_pre_hook()
else: else:
optimizer = model.module.engine.optim optimizer = model.module.engine.optim
if self.ddp.use_distributed_optimizer and self.ddp.overlap_param_gather: if self.ddp.use_distributed_optimizer and self.ddp.overlap_param_gather:
@ -726,7 +733,7 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
sd = optimizer.state_dict() sd = optimizer.state_dict()
dp = constants.data_parallel_rank() dp = constants.data_parallel_rank()
pp = constants.pipe_parallel_rank() pp = constants.pipe_parallel_rank()
tp = constants.model_parallel_rank() tp = constants.tensor_parallel_rank()
# HACK: (bowei) I'm not sure whether there's duplicated information. # HACK: (bowei) I'm not sure whether there's duplicated information.
torch.save( torch.save(
sd, pathlib.Path(save_dir) / f"megatron_optim_sd_d{dp}p{pp}t{tp}.mckpt" sd, pathlib.Path(save_dir) / f"megatron_optim_sd_d{dp}p{pp}t{tp}.mckpt"
@ -742,7 +749,7 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
dp = constants.data_parallel_rank() dp = constants.data_parallel_rank()
pp = constants.pipe_parallel_rank() pp = constants.pipe_parallel_rank()
tp = constants.model_parallel_rank() tp = constants.tensor_parallel_rank()
sd = torch.load( sd = torch.load(
pathlib.Path(load_dir) / f"megatron_optim_sd_d{dp}p{pp}t{tp}.mckpt" pathlib.Path(load_dir) / f"megatron_optim_sd_d{dp}p{pp}t{tp}.mckpt"

View File

@ -82,7 +82,7 @@ def _split_and_prefill_pipe_input(
raise PipelineError( raise PipelineError(
"Partitioned seqlens are not equal across pipeline parallel ranks. " "Partitioned seqlens are not equal across pipeline parallel ranks. "
f"Current rank (dp={constants.data_parallel_rank()}," f"Current rank (dp={constants.data_parallel_rank()},"
f"tp={constants.model_parallel_rank()},pp={constants.pipe_parallel_rank()}), " f"tp={constants.tensor_parallel_rank()},pp={constants.pipe_parallel_rank()}), "
f"gathered batch seqlens={_batch_seqlen_all_gathered}, " f"gathered batch seqlens={_batch_seqlen_all_gathered}, "
f"Have you ensured that the order of dataset across ranks is the same?", f"Have you ensured that the order of dataset across ranks is the same?",
) )
@ -118,7 +118,7 @@ def _split_and_prefill_pipe_input(
total_len = ( total_len = (
packed_input_ids.shape[0] packed_input_ids.shape[0]
if not constants.sequence_parallel() if not constants.sequence_parallel()
else packed_input_ids.shape[0] // constants.model_parallel_world_size() else packed_input_ids.shape[0] // constants.tensor_parallel_world_size()
) )
mb_seq_lens.append(total_len) mb_seq_lens.append(total_len)
return (x, ys) return (x, ys)
@ -569,7 +569,7 @@ class PipeGenInstrSet:
"batch_lengths", micro_batch_id, remove=False "batch_lengths", micro_batch_id, remove=False
) )
batch_length = ( batch_length = (
batch_length // constants.model_parallel_world_size() batch_length // constants.tensor_parallel_world_size()
if constants.sequence_parallel() if constants.sequence_parallel()
else batch_length else batch_length
) )

View File

@ -198,8 +198,8 @@ class SGLangGenerationEngine(PipelinableEngine):
hybrid_train: bool, hybrid_train: bool,
request_timeout: int = 1800, request_timeout: int = 1800,
): ):
if constants.model_parallel_rank() != 0: if constants.tensor_parallel_rank() != 0:
dist.barrier(group=constants.model_parallel_cpu_group()) dist.barrier(group=constants.tensor_parallel_cpu_group())
return return
# Start the serving process # Start the serving process
self.server_proc = mp.Process( self.server_proc = mp.Process(
@ -224,8 +224,8 @@ class SGLangGenerationEngine(PipelinableEngine):
if server_args_dict["enable_metrics"]: if server_args_dict["enable_metrics"]:
dp_rank = constants.data_parallel_rank() dp_rank = constants.data_parallel_rank()
pp_rank = constants.pipe_parallel_rank() pp_rank = constants.pipe_parallel_rank()
mp_rank = constants.model_parallel_rank() tp_rank = constants.tensor_parallel_rank()
metric_server_name = f"d{dp_rank}p{pp_rank}m{mp_rank}" metric_server_name = f"d{dp_rank}p{pp_rank}t{tp_rank}"
key = names.metric_server( key = names.metric_server(
constants.experiment_name(), constants.experiment_name(),
constants.trial_name(), constants.trial_name(),
@ -243,7 +243,7 @@ class SGLangGenerationEngine(PipelinableEngine):
# offload weights/cache # offload weights/cache
self.hybrid_train = hybrid_train self.hybrid_train = hybrid_train
dist.barrier(group=constants.model_parallel_cpu_group()) dist.barrier(group=constants.tensor_parallel_cpu_group())
def __del__(self): def __del__(self):
if hasattr(self, "server_proc"): if hasattr(self, "server_proc"):
@ -381,8 +381,8 @@ class SGLangGenerationEngine(PipelinableEngine):
"NOTE: passing in an arbitrary `min_new_tokens` will lead to a bug for SGLang v0.4.3 " "NOTE: passing in an arbitrary `min_new_tokens` will lead to a bug for SGLang v0.4.3 "
"because we force to skip_tokenizer_init." "because we force to skip_tokenizer_init."
) )
if constants.model_parallel_rank() != 0: if constants.tensor_parallel_rank() != 0:
dist.barrier(group=constants.model_parallel_cpu_group()) dist.barrier(group=constants.tensor_parallel_cpu_group())
return None, None, None return None, None, None
results = asyncio.run( results = asyncio.run(
@ -393,12 +393,12 @@ class SGLangGenerationEngine(PipelinableEngine):
gconfig=gconfig, gconfig=gconfig,
) )
) )
dist.barrier(group=constants.model_parallel_cpu_group()) dist.barrier(group=constants.tensor_parallel_cpu_group())
return results return results
def update_weights_from_disk(self, path): def update_weights_from_disk(self, path):
if constants.model_parallel_rank() != 0: if constants.tensor_parallel_rank() != 0:
dist.barrier(group=constants.model_parallel_cpu_group()) dist.barrier(group=constants.tensor_parallel_cpu_group())
return return
async def _fn(): async def _fn():
@ -409,18 +409,17 @@ class SGLangGenerationEngine(PipelinableEngine):
await client.async_update_weights_from_disk(path) await client.async_update_weights_from_disk(path)
asyncio.run(_fn()) asyncio.run(_fn())
dist.barrier(group=constants.model_parallel_cpu_group()) dist.barrier(group=constants.tensor_parallel_cpu_group())
@dataclasses.dataclass @dataclasses.dataclass
class SGLangGenerationBackend(ModelBackend, SGLangConfig): class SGLangGenerationBackend(ModelBackend, SGLangConfig):
model_path: str = "" model_path: str = ""
dtype: str = "float16"
def _initialize(self, model: Model, spec: FinetuneSpec) -> Model: def _initialize(self, model: Model, spec: FinetuneSpec) -> Model:
if constants.pipe_parallel_world_size() != 1: if constants.pipe_parallel_world_size() != 1:
raise RuntimeError("SGLang does not support pipe parallel size > 1.") raise RuntimeError("SGLang does not support pipe parallel size > 1.")
if constants.model_parallel_world_size() > cluster.spec.n_gpus_per_node: if constants.tensor_parallel_world_size() > cluster.spec.n_gpus_per_node:
raise RuntimeError( raise RuntimeError(
"AReaL's SGLang integration does not support model parallel size > n_gpus_per_node." "AReaL's SGLang integration does not support model parallel size > n_gpus_per_node."
) )
@ -436,7 +435,13 @@ class SGLangGenerationBackend(ModelBackend, SGLangConfig):
) != len(datapack.flat2d(ports)): ) != len(datapack.flat2d(ports)):
dist.all_gather_object( dist.all_gather_object(
ports, ports,
network.find_multiple_free_ports(2, low=20000, high=40000), network.find_multiple_free_ports(
2,
low=10000,
high=60000,
experiment_name=constants.experiment_name(),
trial_name=constants.trial_name(),
),
group=constants.data_parallel_group(), group=constants.data_parallel_group(),
) )
api_server_port, dist_port = ports[constants.data_parallel_rank()] api_server_port, dist_port = ports[constants.data_parallel_rank()]
@ -450,13 +455,12 @@ class SGLangGenerationBackend(ModelBackend, SGLangConfig):
tokenizer_mode="auto", tokenizer_mode="auto",
load_format="auto", load_format="auto",
trust_remote_code=True, trust_remote_code=True,
kv_cache_dtype="auto",
device="cuda", device="cuda",
served_model_name=f"{constants.experiment_name()}/{constants.trial_name()}/{constants.model_name().role}", served_model_name=f"{constants.experiment_name()}/{constants.trial_name()}/{constants.model_name().role}",
is_embedding=False, is_embedding=False,
skip_tokenizer_init=True, skip_tokenizer_init=True,
# Other runtime options # Other runtime options
tp_size=constants.model_parallel_world_size(), tp_size=constants.tensor_parallel_world_size(),
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process # Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
base_gpu_id=int(os.environ["CUDA_VISIBLE_DEVICES"]), base_gpu_id=int(os.environ["CUDA_VISIBLE_DEVICES"]),
file_storage_path=os.path.join( file_storage_path=os.path.join(

View File

@ -36,7 +36,7 @@ def _vllm_group_rank(group_type: _vLLMGroupType):
if group_type == _vLLMGroupType.WORLD: if group_type == _vLLMGroupType.WORLD:
return constants.tp_and_pp_rank() return constants.tp_and_pp_rank()
elif group_type == _vLLMGroupType.TP: elif group_type == _vLLMGroupType.TP:
return constants.model_parallel_rank() return constants.tensor_parallel_rank()
elif group_type == _vLLMGroupType.PP: elif group_type == _vLLMGroupType.PP:
return constants.pipe_parallel_rank() return constants.pipe_parallel_rank()
@ -45,7 +45,7 @@ def _vllm_group_size(group_type: _vLLMGroupType):
if group_type == _vLLMGroupType.WORLD: if group_type == _vLLMGroupType.WORLD:
return constants.tp_and_pp_world_size() return constants.tp_and_pp_world_size()
elif group_type == _vLLMGroupType.TP: elif group_type == _vLLMGroupType.TP:
return constants.model_parallel_world_size() return constants.tensor_parallel_world_size()
elif group_type == _vLLMGroupType.PP: elif group_type == _vLLMGroupType.PP:
return constants.pipe_parallel_world_size() return constants.pipe_parallel_world_size()
@ -54,7 +54,7 @@ def _vllm_parallel_group(group_type: _vLLMGroupType):
if group_type == _vLLMGroupType.WORLD: if group_type == _vLLMGroupType.WORLD:
return constants.tp_and_pp_group() return constants.tp_and_pp_group()
elif group_type == _vLLMGroupType.TP: elif group_type == _vLLMGroupType.TP:
return constants.model_parallel_group() return constants.tensor_parallel_group()
elif group_type == _vLLMGroupType.PP: elif group_type == _vLLMGroupType.PP:
return constants.pipe_parallel_group() return constants.pipe_parallel_group()

View File

@ -213,7 +213,7 @@ class GPUExecutor_(GPUExecutor):
tok = time.perf_counter() tok = time.perf_counter()
after_mem = float(pynvml.nvmlDeviceGetMemoryInfo(handle).used) after_mem = float(pynvml.nvmlDeviceGetMemoryInfo(handle).used)
is_dp_head = ( is_dp_head = (
constants.is_last_pipe_stage() and constants.model_parallel_rank() == 0 constants.is_last_pipe_stage() and constants.tensor_parallel_rank() == 0
) )
if is_dp_head: if is_dp_head:
logger.info( logger.info(
@ -241,7 +241,7 @@ class GPUExecutor_(GPUExecutor):
tok = time.perf_counter() tok = time.perf_counter()
after_mem = float(pynvml.nvmlDeviceGetMemoryInfo(handle).used) after_mem = float(pynvml.nvmlDeviceGetMemoryInfo(handle).used)
is_dp_head = ( is_dp_head = (
constants.is_last_pipe_stage() and constants.model_parallel_rank() == 0 constants.is_last_pipe_stage() and constants.tensor_parallel_rank() == 0
) )
if is_dp_head: if is_dp_head:
logger.info( logger.info(

View File

@ -166,7 +166,6 @@ class vLLMGenerationEngine(model_api.PipelinableEngine, LLM):
@dataclasses.dataclass @dataclasses.dataclass
class vLLMGenerationBackend(vLLMConfig, model_api.ModelBackend): class vLLMGenerationBackend(vLLMConfig, model_api.ModelBackend):
model_path: str = "" model_path: str = ""
dtype: str = "bfloat16"
def _initialize( def _initialize(
self, model: model_api.Model, spec: model_api.FinetuneSpec self, model: model_api.Model, spec: model_api.FinetuneSpec
@ -192,7 +191,7 @@ class vLLMGenerationBackend(vLLMConfig, model_api.ModelBackend):
kv_cache_dtype=self.kv_cache_type, kv_cache_dtype=self.kv_cache_type,
device=constants.current_device(), device=constants.current_device(),
# Parallelism. # Parallelism.
tensor_parallel_size=constants.model_parallel_world_size(), tensor_parallel_size=constants.tensor_parallel_world_size(),
pipeline_parallel_size=constants.pipe_parallel_world_size(), pipeline_parallel_size=constants.pipe_parallel_world_size(),
# KV cahce and scheduling. # KV cahce and scheduling.
num_scheduler_steps=self.num_scheduler_steps, num_scheduler_steps=self.num_scheduler_steps,

View File

@ -100,7 +100,7 @@ def setup_global_comm(
if worker_index == 0: if worker_index == 0:
host_ip = socket.gethostbyname(socket.gethostname()) host_ip = socket.gethostbyname(socket.gethostname())
port = network.find_free_port() port = network.find_free_port(experiment_name=expr_name, trial_name=trial_name)
pg_init_addr = f"tcp://{host_ip}:{port}" pg_init_addr = f"tcp://{host_ip}:{port}"
name_resolve.add(pg_master_name, pg_init_addr, keepalive_ttl=300) name_resolve.add(pg_master_name, pg_init_addr, keepalive_ttl=300)
else: else:

View File

@ -43,10 +43,10 @@ def is_trainable(model_name: ModelName) -> bool:
class ParamReallocPair: class ParamReallocPair:
src: ModelName src: ModelName
src_dp_rank: int src_dp_rank: int
src_mp_rank: int src_tp_rank: int
src_pp_rank: int src_pp_rank: int
dst: ModelName dst: ModelName
dst_mp_rank: int dst_tp_rank: int
dst_pp_rank: int dst_pp_rank: int
@ -171,32 +171,34 @@ def _create_param_realloc_groups(
range(from_topo.get_dim("pipe")), range(to_topo.get_dim("pipe")) range(from_topo.get_dim("pipe")), range(to_topo.get_dim("pipe"))
): ):
# create tensor reshard groups # create tensor reshard groups
src_mp_size = from_topo.get_dim("model") src_tp_size = from_topo.get_dim("tensor")
dst_mp_size = to_topo.get_dim("model") dst_tp_size = to_topo.get_dim("tensor")
for mp_j in range(dst_mp_size): for tp_j in range(dst_tp_size):
_all_dst_ranks = filter_match_mwids( _all_dst_ranks = filter_match_mwids(
dst, to_topo, msid2mwid, pipe=pp_j, model=mp_j dst, to_topo, msid2mwid, pipe=pp_j, tensor=tp_j
) )
if src_mp_size > dst_mp_size: if src_tp_size > dst_tp_size:
factor = src_mp_size // dst_mp_size factor = src_tp_size // dst_tp_size
mp_is = list(range(factor * mp_j, factor * (mp_j + 1))) tp_is = list(range(factor * tp_j, factor * (tp_j + 1)))
_all_src_ranks = [ _all_src_ranks = [
filter_match_mwids(src, from_topo, msid2mwid, model=mp_i, pipe=pp_i) filter_match_mwids(
for mp_i in mp_is src, from_topo, msid2mwid, tensor=tp_i, pipe=pp_i
)
for tp_i in tp_is
] ]
else: else:
factor = dst_mp_size // src_mp_size factor = dst_tp_size // src_tp_size
_all_src_ranks = [ _all_src_ranks = [
filter_match_mwids( filter_match_mwids(
src, src,
from_topo, from_topo,
msid2mwid, msid2mwid,
model=mp_j // factor, tensor=tp_j // factor,
pipe=pp_i, pipe=pp_i,
) )
] ]
# All GPUs in _src_ranks have the data required by (pp_j, mp_j) # All GPUs in _src_ranks have the data required by (pp_j, tp_j)
for _src_ranks in _all_src_ranks: for _src_ranks in _all_src_ranks:
# NOTE: inter-node communication cost is significantly larger than intra-node communication cost. # NOTE: inter-node communication cost is significantly larger than intra-node communication cost.
# We only select one sender per host/node to prevent multiple senders occupying the same network bandwidth. # We only select one sender per host/node to prevent multiple senders occupying the same network bandwidth.
@ -209,42 +211,42 @@ def _create_param_realloc_groups(
) )
_idle_src_ranks = [r for r in _src_ranks if r not in assignment] _idle_src_ranks = [r for r in _src_ranks if r not in assignment]
for _src_rank in _idle_src_ranks: for _src_rank in _idle_src_ranks:
dp_i, mp_i = ( dp_i, tp_i = (
from_topo.get_coord( from_topo.get_coord(
mwid2msid[_src_rank][src].parallelism_rank mwid2msid[_src_rank][src].parallelism_rank
).data, ).data,
from_topo.get_coord( from_topo.get_coord(
mwid2msid[_src_rank][src].parallelism_rank mwid2msid[_src_rank][src].parallelism_rank
).model, ).tensor,
) )
key = ParamReallocPair( key = ParamReallocPair(
src=src, src=src,
src_dp_rank=dp_i, src_dp_rank=dp_i,
src_mp_rank=mp_i, src_tp_rank=tp_i,
src_pp_rank=pp_i, src_pp_rank=pp_i,
dst=dst, dst=dst,
dst_mp_rank=mp_j, dst_tp_rank=tp_j,
dst_pp_rank=pp_j, dst_pp_rank=pp_j,
) )
param_realloc_dst_ranks[key] = [] param_realloc_dst_ranks[key] = []
param_realloc_groups[key] = None param_realloc_groups[key] = None
param_realloc_src_ranks[key] = _src_rank param_realloc_src_ranks[key] = _src_rank
for _src_rank, _dst_ranks in assignment.items(): for _src_rank, _dst_ranks in assignment.items():
dp_i, mp_i = ( dp_i, tp_i = (
from_topo.get_coord( from_topo.get_coord(
mwid2msid[_src_rank][src].parallelism_rank mwid2msid[_src_rank][src].parallelism_rank
).data, ).data,
from_topo.get_coord( from_topo.get_coord(
mwid2msid[_src_rank][src].parallelism_rank mwid2msid[_src_rank][src].parallelism_rank
).model, ).tensor,
) )
key = ParamReallocPair( key = ParamReallocPair(
src=src, src=src,
src_dp_rank=dp_i, src_dp_rank=dp_i,
src_mp_rank=mp_i, src_tp_rank=tp_i,
src_pp_rank=pp_i, src_pp_rank=pp_i,
dst=dst, dst=dst,
dst_mp_rank=mp_j, dst_tp_rank=tp_j,
dst_pp_rank=pp_j, dst_pp_rank=pp_j,
) )
param_realloc_dst_ranks[key] = _dst_ranks param_realloc_dst_ranks[key] = _dst_ranks
@ -315,8 +317,8 @@ def setup_param_realloc(
@dataclasses.dataclass @dataclasses.dataclass
class ReparallelizeSenderStep: class ReparallelizeSenderStep:
rank: int rank: int
sender_mp_portion_id: int sender_tp_portion_id: int
receiver_mp_portion_id: int receiver_tp_portion_id: int
param_keys: List[str] param_keys: List[str]
param_intervals_cpu: List[Tuple[int, int]] param_intervals_cpu: List[Tuple[int, int]]
param_intervals_cuda: torch.Tensor param_intervals_cuda: torch.Tensor
@ -330,8 +332,8 @@ class ReparallelizeSenderStep:
@dataclasses.dataclass @dataclasses.dataclass
class ReparallelizeReceiverStep: class ReparallelizeReceiverStep:
rank: int rank: int
sender_mp_portion_id: int sender_tp_portion_id: int
receiver_mp_portion_id: int receiver_tp_portion_id: int
sender_param_intervals_cpu: List[Tuple[int, int]] sender_param_intervals_cpu: List[Tuple[int, int]]
sender_param_intervals_cuda: torch.Tensor sender_param_intervals_cuda: torch.Tensor
sender_max_interval_size: int sender_max_interval_size: int
@ -356,9 +358,9 @@ def _derive_reparallelize_comm_plan(
pg_info: ParamReallocInfo, pg_info: ParamReallocInfo,
dtype: Optional[torch.dtype] = torch.float16, dtype: Optional[torch.dtype] = torch.float16,
) -> List[ReparallelizeReceiverStep | ReparallelizeSenderStep]: ) -> List[ReparallelizeReceiverStep | ReparallelizeSenderStep]:
src_mp_size = from_topo.get_dim("model") src_tp_size = from_topo.get_dim("tensor")
dst_mp_size = to_topo.get_dim("model") dst_tp_size = to_topo.get_dim("tensor")
assert src_mp_size % dst_mp_size == 0 or dst_mp_size % src_mp_size == 0 assert src_tp_size % dst_tp_size == 0 or dst_tp_size % src_tp_size == 0
for k, v in dataclasses.asdict(to_model_config).items(): for k, v in dataclasses.asdict(to_model_config).items():
if k not in ["is_critic"] and v != getattr(from_model_config, k): if k not in ["is_critic"] and v != getattr(from_model_config, k):
raise ValueError( raise ValueError(
@ -366,8 +368,8 @@ def _derive_reparallelize_comm_plan(
f"value in checkpoint is `{v}`, current value is `{getattr(from_model_config, k)}`)." f"value in checkpoint is `{v}`, current value is `{getattr(from_model_config, k)}`)."
) )
if from_model_config.n_kv_heads > 1 and ( if from_model_config.n_kv_heads > 1 and (
from_model_config.n_kv_heads % src_mp_size == 0 from_model_config.n_kv_heads % src_tp_size == 0
) != (from_model_config.n_kv_heads % dst_mp_size == 0): ) != (from_model_config.n_kv_heads % dst_tp_size == 0):
raise ValueError("Whether to partition kv heads should remain the same.") raise ValueError("Whether to partition kv heads should remain the same.")
from_layer_mapping = partition_pipeline_layers( from_layer_mapping = partition_pipeline_layers(
@ -400,7 +402,7 @@ def _derive_reparallelize_comm_plan(
from_model_param_specs, _ = build_param_spec( from_model_param_specs, _ = build_param_spec(
from_layer_indices, from_layer_indices,
from_model_config, from_model_config,
mp_size=from_topo.get_dim("model"), tp_size=from_topo.get_dim("tensor"),
dp_size=from_topo.get_dim("data"), dp_size=from_topo.get_dim("data"),
pp_size=from_topo.get_dim("pipe"), pp_size=from_topo.get_dim("pipe"),
head_param_point_to_embedding=from_model_head_param_point_to_embedding, head_param_point_to_embedding=from_model_head_param_point_to_embedding,
@ -411,7 +413,7 @@ def _derive_reparallelize_comm_plan(
to_model_param_specs, _ = build_param_spec( to_model_param_specs, _ = build_param_spec(
to_layer_indices, to_layer_indices,
to_model_config, to_model_config,
mp_size=to_topo.get_dim("model"), tp_size=to_topo.get_dim("tensor"),
pp_size=to_topo.get_dim("pipe"), pp_size=to_topo.get_dim("pipe"),
dp_size=to_topo.get_dim("data"), dp_size=to_topo.get_dim("data"),
head_param_point_to_embedding=to_model_head_param_point_to_embedding, head_param_point_to_embedding=to_model_head_param_point_to_embedding,
@ -428,25 +430,25 @@ def _derive_reparallelize_comm_plan(
if len(layer_indices) == 0: if len(layer_indices) == 0:
continue continue
for mp_i in range(src_mp_size): for tp_i in range(src_tp_size):
if dst_mp_size > src_mp_size: if dst_tp_size > src_tp_size:
factor = dst_mp_size // src_mp_size factor = dst_tp_size // src_tp_size
mp_js = [i + factor * mp_i for i in range(factor)] tp_js = [i + factor * tp_i for i in range(factor)]
receiver_mp_portion_id = 0 receiver_tp_portion_id = 0
else: else:
factor = src_mp_size // dst_mp_size factor = src_tp_size // dst_tp_size
mp_js = [mp_i // factor] tp_js = [tp_i // factor]
receiver_mp_portion_id = mp_i % factor receiver_tp_portion_id = tp_i % factor
for sender_mp_portion_id, mp_j in enumerate(mp_js): for sender_tp_portion_id, tp_j in enumerate(tp_js):
for dp_i in range(src_dp_size): for dp_i in range(src_dp_size):
key = ParamReallocPair( key = ParamReallocPair(
src=from_model_name, src=from_model_name,
src_dp_rank=dp_i, src_dp_rank=dp_i,
src_mp_rank=mp_i, src_tp_rank=tp_i,
src_pp_rank=pp_i, src_pp_rank=pp_i,
dst=to_model_name, dst=to_model_name,
dst_mp_rank=mp_j, dst_tp_rank=tp_j,
dst_pp_rank=pp_j, dst_pp_rank=pp_j,
) )
src = pg_info.param_realloc_src_ranks[key] src = pg_info.param_realloc_src_ranks[key]
@ -462,10 +464,10 @@ def _derive_reparallelize_comm_plan(
) )
param_size = param_size_from_keys( param_size = param_size_from_keys(
config=from_model_config, config=from_model_config,
src_mp_size=src_mp_size, src_tp_size=src_tp_size,
sd_keys=param_keys, sd_keys=param_keys,
src2dst_tp_size=max(dst_mp_size // src_mp_size, 1), src2dst_tp_size=max(dst_tp_size // src_tp_size, 1),
src2dst_tp_rank=sender_mp_portion_id, src2dst_tp_rank=sender_tp_portion_id,
head_param_point_to_embedding=from_model_head_param_point_to_embedding, head_param_point_to_embedding=from_model_head_param_point_to_embedding,
) )
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
@ -474,11 +476,11 @@ def _derive_reparallelize_comm_plan(
param_intervals_cpu = param_intervals_from_keys( param_intervals_cpu = param_intervals_from_keys(
model_name=from_model_name, model_name=from_model_name,
config=from_model_config, config=from_model_config,
mp_size=src_mp_size, tp_size=src_tp_size,
param_spec=from_model_param_specs, param_spec=from_model_param_specs,
sd_keys=param_keys, sd_keys=param_keys,
portion_size=max(dst_mp_size // src_mp_size, 1), portion_size=max(dst_tp_size // src_tp_size, 1),
portion_rank=sender_mp_portion_id, portion_rank=sender_tp_portion_id,
head_param_point_to_embedding=from_model_head_param_point_to_embedding, head_param_point_to_embedding=from_model_head_param_point_to_embedding,
) )
param_intervals_cuda = torch.tensor( param_intervals_cuda = torch.tensor(
@ -493,11 +495,11 @@ def _derive_reparallelize_comm_plan(
receiver_param_intervals_cpu = param_intervals_from_keys( receiver_param_intervals_cpu = param_intervals_from_keys(
model_name=to_model_name, model_name=to_model_name,
config=to_model_config, config=to_model_config,
mp_size=dst_mp_size, tp_size=dst_tp_size,
param_spec=to_model_param_specs, param_spec=to_model_param_specs,
sd_keys=param_keys, sd_keys=param_keys,
portion_size=max(src_mp_size // dst_mp_size, 1), portion_size=max(src_tp_size // dst_tp_size, 1),
portion_rank=receiver_mp_portion_id, portion_rank=receiver_tp_portion_id,
head_param_point_to_embedding=to_model_head_param_point_to_embedding, head_param_point_to_embedding=to_model_head_param_point_to_embedding,
) )
receiver_param_intervals_cuda = torch.tensor( receiver_param_intervals_cuda = torch.tensor(
@ -513,8 +515,8 @@ def _derive_reparallelize_comm_plan(
comm_plan.append( comm_plan.append(
ReparallelizeReceiverStep( ReparallelizeReceiverStep(
rank=dst_rank, rank=dst_rank,
sender_mp_portion_id=sender_mp_portion_id, sender_tp_portion_id=sender_tp_portion_id,
receiver_mp_portion_id=receiver_mp_portion_id, receiver_tp_portion_id=receiver_tp_portion_id,
param_keys=param_keys, param_keys=param_keys,
sender_param_intervals_cpu=param_intervals_cpu, sender_param_intervals_cpu=param_intervals_cpu,
sender_param_intervals_cuda=param_intervals_cuda, sender_param_intervals_cuda=param_intervals_cuda,
@ -532,8 +534,8 @@ def _derive_reparallelize_comm_plan(
comm_plan.append( comm_plan.append(
ReparallelizeSenderStep( ReparallelizeSenderStep(
rank=src, rank=src,
sender_mp_portion_id=sender_mp_portion_id, sender_tp_portion_id=sender_tp_portion_id,
receiver_mp_portion_id=receiver_mp_portion_id, receiver_tp_portion_id=receiver_tp_portion_id,
param_keys=param_keys, param_keys=param_keys,
param_intervals_cpu=param_intervals_cpu, param_intervals_cpu=param_intervals_cpu,
param_intervals_cuda=param_intervals_cuda, param_intervals_cuda=param_intervals_cuda,

View File

@ -22,8 +22,8 @@ from realhf.base.saveload_utils import (
) )
from realhf.impl.model.nn.real_llm_api import ReaLModel from realhf.impl.model.nn.real_llm_api import ReaLModel
from realhf.impl.model.nn.real_llm_parallel import ( from realhf.impl.model.nn.real_llm_parallel import (
mp_merge_key, tp_merge_key,
mp_partition_real_model_state_dict, tp_partition_real_model_state_dict,
) )
logger = logging.getLogger("HF Registry") logger = logging.getLogger("HF Registry")
@ -141,11 +141,11 @@ class HFModelRegistry:
partition_tik = time.perf_counter() partition_tik = time.perf_counter()
sd = {k: v for k, v in sd.items() if k in required_hf_sd_names} sd = {k: v for k, v in sd.items() if k in required_hf_sd_names}
sd = self.sd_from_hf_converter(sd, model.config) sd = self.sd_from_hf_converter(sd, model.config)
psd = mp_partition_real_model_state_dict( psd = tp_partition_real_model_state_dict(
sd, sd,
model.config, model.config,
constants.model_parallel_world_size(), constants.tensor_parallel_world_size(),
constants.model_parallel_rank(), constants.tensor_parallel_rank(),
) )
return psd, partition_tik - load_tik, time.perf_counter() - partition_tik return psd, partition_tik - load_tik, time.perf_counter() - partition_tik
@ -222,8 +222,8 @@ class HFModelRegistry:
dp_rank = constants.data_parallel_rank() dp_rank = constants.data_parallel_rank()
pp_rank = constants.pipe_parallel_rank() pp_rank = constants.pipe_parallel_rank()
mp_rank = constants.model_parallel_rank() tp_rank = constants.tensor_parallel_rank()
mp_size = constants.model_parallel_world_size() tp_size = constants.tensor_parallel_world_size()
pp_size = constants.pipe_parallel_world_size() pp_size = constants.pipe_parallel_world_size()
dp_size = constants.data_parallel_world_size() dp_size = constants.data_parallel_world_size()
@ -234,7 +234,7 @@ class HFModelRegistry:
# of each pipeline stage into smaller shards. # of each pipeline stage into smaller shards.
approx_param_size = ( approx_param_size = (
sum(v.numel() * v.element_size() for v in model.state_dict().values()) sum(v.numel() * v.element_size() for v in model.state_dict().values())
* mp_size * tp_size
) )
# By default a shard is at most 1GB. A small size enables parallel saving during training. # By default a shard is at most 1GB. A small size enables parallel saving during training.
@ -274,9 +274,9 @@ class HFModelRegistry:
and k == f"{model.config.n_layers + 1}.weight" and k == f"{model.config.n_layers + 1}.weight"
): ):
continue continue
gather_list = [torch.zeros_like(v) for _ in range(mp_size)] gather_list = [torch.zeros_like(v) for _ in range(tp_size)]
dist.all_gather(gather_list, v, group=constants.model_parallel_group()) dist.all_gather(gather_list, v, group=constants.tensor_parallel_group())
gathered = mp_merge_key(k, gather_list, model.config) gathered = tp_merge_key(k, gather_list, model.config)
cpu_sd[k] = gathered.cpu() cpu_sd[k] = gathered.cpu()
t2 = time.perf_counter() t2 = time.perf_counter()
@ -299,7 +299,7 @@ class HFModelRegistry:
param_size = param_size.item() param_size = param_size.item()
# Save tokenizer and huggingface model config. # Save tokenizer and huggingface model config.
if pp_rank == 0 and dp_rank == 0 and mp_rank == 0: if pp_rank == 0 and dp_rank == 0 and tp_rank == 0:
hf_config.save_pretrained(save_dir) hf_config.save_pretrained(save_dir)
if tokenizer is not None: if tokenizer is not None:
tokenizer.save_pretrained(save_dir) tokenizer.save_pretrained(save_dir)
@ -307,7 +307,7 @@ class HFModelRegistry:
# Dump parameters to disk. # Dump parameters to disk.
if len(pp_stage_n_shards) == 1 and pp_stage_n_shards[0] == 1: if len(pp_stage_n_shards) == 1 and pp_stage_n_shards[0] == 1:
fn = "pytorch_model.bin" fn = "pytorch_model.bin"
if pp_rank == 0 and dp_rank == 0 and mp_rank == 0: if pp_rank == 0 and dp_rank == 0 and tp_rank == 0:
torch.save(hf_sd, os.path.join(save_dir, fn)) torch.save(hf_sd, os.path.join(save_dir, fn))
else: else:
output_fn = ( output_fn = (
@ -326,8 +326,8 @@ class HFModelRegistry:
bin_index["weight_map"] = {} bin_index["weight_map"] = {}
weight_map = {} weight_map = {}
mesh_size = dp_size * mp_size mesh_size = dp_size * tp_size
mesh_idx = dp_rank * mp_size + mp_rank mesh_idx = dp_rank * tp_size + tp_rank
n_shards_per_gpu = (n_shards + mesh_size - 1) // mesh_size n_shards_per_gpu = (n_shards + mesh_size - 1) // mesh_size
if mesh_idx < len(range(0, n_shards, n_shards_per_gpu)): if mesh_idx < len(range(0, n_shards, n_shards_per_gpu)):
s = list(range(0, n_shards, n_shards_per_gpu))[mesh_idx] s = list(range(0, n_shards, n_shards_per_gpu))[mesh_idx]
@ -357,7 +357,7 @@ class HFModelRegistry:
for wm in weight_map_list: for wm in weight_map_list:
bin_index["weight_map"].update(wm) bin_index["weight_map"].update(wm)
if pp_rank == 0 and dp_rank == 0 and mp_rank == 0: if pp_rank == 0 and dp_rank == 0 and tp_rank == 0:
with open( with open(
os.path.join(save_dir, "pytorch_model.bin.index.json"), "w" os.path.join(save_dir, "pytorch_model.bin.index.json"), "w"
) as f: ) as f:

View File

@ -240,7 +240,7 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
return data return data
local_rank = constants.grid().topo.get_rank( local_rank = constants.grid().topo.get_rank(
data=constants.data_parallel_rank(), data=constants.data_parallel_rank(),
model=0, tensor=0,
pipe=constants.pipe_parallel_world_size() - 1, pipe=constants.pipe_parallel_world_size() - 1,
) )
dst = constants.to_global_pg_rank(local_rank) dst = constants.to_global_pg_rank(local_rank)

View File

@ -3,7 +3,7 @@
# Licensed under the Apache License, Version 2.0 (the "License"). # Licensed under the Apache License, Version 2.0 (the "License").
import dataclasses import dataclasses
from typing import Dict, Literal, Optional from typing import Dict, List, Literal, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -86,6 +86,7 @@ def _ppo_actor_loss_from_model_outputs(
eps_clip=eps_clip, eps_clip=eps_clip,
loss_mask=ppo_loss_mask, loss_mask=ppo_loss_mask,
c_clip=c_clip, c_clip=c_clip,
proximal_logprobs=input_.data.get("prox_logp", None),
) )
# Log training statistics # Log training statistics
@ -106,13 +107,20 @@ def _ppo_actor_loss_from_model_outputs(
dual_clip_ratio=ppo_stat["dual_clip_mask"].float(), dual_clip_ratio=ppo_stat["dual_clip_mask"].float(),
denominator="n_valid_tokens", denominator="n_valid_tokens",
) )
if "behave_imp_weight" in ppo_stat:
stats_tracker.denominator(unclipped_behave_tokens=ppo_stat["behave_mask"])
stats_tracker.stat(
behave_imp_weight=ppo_stat["behave_imp_weight"],
behave_approx_kl=ppo_stat["behave_approx_kl"],
denominator="unclipped_behave_tokens",
)
vocab_min_logits = logits.detach().min(-1).values.float() vocab_min_logits = logits.detach().min(-1).values.float()
vocab_max_logits = logits.detach().max(-1).values.float() vocab_max_logits = logits.detach().max(-1).values.float()
dist.all_reduce( dist.all_reduce(
vocab_min_logits, group=constants.model_parallel_group(), op=dist.ReduceOp.MIN vocab_min_logits, group=constants.tensor_parallel_group(), op=dist.ReduceOp.MIN
) )
dist.all_reduce( dist.all_reduce(
vocab_max_logits, group=constants.model_parallel_group(), op=dist.ReduceOp.MAX vocab_max_logits, group=constants.tensor_parallel_group(), op=dist.ReduceOp.MAX
) )
stats_tracker.stat( stats_tracker.stat(
vocab_min_logits=vocab_min_logits, vocab_min_logits=vocab_min_logits,
@ -505,7 +513,7 @@ class PPOActorInterface(model_api.ModelInterface):
model: model_api.Model, model: model_api.Model,
input_: SequenceSample, input_: SequenceSample,
mb_spec: MicroBatchSpec, mb_spec: MicroBatchSpec,
) -> Dict: ) -> Dict | List[Dict]:
module = model.module module = model.module
# We call module.eval() because dropout causes the computation of incorrect of log probs. # We call module.eval() because dropout causes the computation of incorrect of log probs.
module.eval() module.eval()
@ -656,15 +664,20 @@ class PPOActorInterface(model_api.ModelInterface):
advantages = torch.cat(adv_list, 0) advantages = torch.cat(adv_list, 0)
# Prepare data to be splitted into mini-batches. # Prepare data to be splitted into mini-batches.
flat_data = dict(
advantages=advantages,
old_logp=old_logp,
ppo_loss_mask=loss_mask,
packed_input_ids=input_.data["packed_input_ids"],
kl_rewards=kl_rewards,
)
use_prox_logp = "proximal_logprobs" in input_.data
if use_prox_logp:
flat_data["prox_logp"] = input_.data["proximal_logprobs"].float()
flat_input = SequenceSample.from_default( flat_input = SequenceSample.from_default(
ids=list(range(input_.bs * self.group_size)), ids=list(range(input_.bs * self.group_size)),
data=dict( data=flat_data,
advantages=advantages,
old_logp=old_logp,
ppo_loss_mask=loss_mask,
packed_input_ids=input_.data["packed_input_ids"],
kl_rewards=kl_rewards,
),
seqlens=[int(x) for x in input_lens.cpu().numpy().tolist()], seqlens=[int(x) for x in input_lens.cpu().numpy().tolist()],
) )
@ -672,6 +685,7 @@ class PPOActorInterface(model_api.ModelInterface):
dense_reward_score = dense_reward_score[shift_one_indices] dense_reward_score = dense_reward_score[shift_one_indices]
### Logging code starts. ### ### Logging code starts. ###
all_stats = []
with stats_tracker.scope("ppo_actor"): with stats_tracker.scope("ppo_actor"):
assert ( assert (
task_ids.shape == reward_score.shape task_ids.shape == reward_score.shape
@ -682,12 +696,13 @@ class PPOActorInterface(model_api.ModelInterface):
for idx, task in enumerate(RL_TASKS) for idx, task in enumerate(RL_TASKS)
} }
stats_tracker.denominator( global_denominators = dict(
n_seqs=torch.ones_like(reward_score, dtype=torch.bool), n_seqs=torch.ones_like(reward_score, dtype=torch.bool),
n_tokens=torch.ones_like(prompt_mask, dtype=torch.bool), n_tokens=torch.ones_like(prompt_mask, dtype=torch.bool),
n_valid_tokens=loss_mask.bool(), n_valid_tokens=loss_mask.bool(),
**task_denominators, **task_denominators,
) )
stats_tracker.denominator(**global_denominators)
for task in RL_TASKS: for task in RL_TASKS:
stats_tracker.stat( stats_tracker.stat(
@ -721,6 +736,22 @@ class PPOActorInterface(model_api.ModelInterface):
**seq_stats, **seq_stats,
denominator="n_seqs", denominator="n_seqs",
) )
scalars = dict(
disable_value=self.disable_value,
mask_no_eos_with_zero=self.mask_no_eos_with_zero,
eps_clip=self.eps_clip,
use_prox_logp=use_prox_logp,
)
if self.c_clip is not None:
scalars["c_clip"] = self.c_clip
scalars["use_dual_clip"] = 1
else:
scalars["use_dual_clip"] = 0
stats_tracker.scalar(**scalars)
global_stats = stats_tracker.export()
for k in global_denominators:
global_stats.pop(f"ppo_actor/{k}")
# Run mini-batched PPO training! # Run mini-batched PPO training!
def _loss_fn(logits, input_): def _loss_fn(logits, input_):
@ -736,43 +767,37 @@ class PPOActorInterface(model_api.ModelInterface):
) )
for reuse in range(self.sample_reuse): for reuse in range(self.sample_reuse):
with stats_tracker.scope(f"reuse{reuse}"): # NOTE: We split PPO minibatches in terms of #seqs instead of #tokens.
# NOTE: We split PPO minibatches in terms of #seqs instead of #tokens. flat_input = SequenceSample.shuffled(flat_input)
flat_input = SequenceSample.shuffled(flat_input) bs = flat_input.bs
bs = flat_input.bs sizes = [0 for _ in range(self.n_minibatches)]
sizes = [0 for _ in range(self.n_minibatches)] for idx in range(bs):
for idx in range(bs): sizes[idx % self.n_minibatches] += 1
sizes[idx % self.n_minibatches] += 1 spec = SequenceSplitSpec(sizes=sizes)
spec = SequenceSplitSpec(sizes=sizes) datas = flat_input.split_with_spec(spec)
datas = flat_input.split_with_spec(spec) logger.info(
logger.info( f"PPO minibatch split (size {self.n_minibatches}): "
f"PPO minibatch split (size {self.n_minibatches}): " f"#seqs: {[s.bs for s in datas]}, "
f"#seqs: {[s.bs for s in datas]}, " f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in datas]}"
f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in datas]}" )
for mb_i, data in enumerate(datas):
train_stat = module.train_batch(
input_=data,
mb_spec=mb_spec,
version_steps=model.version.global_step,
loss_fn=_loss_fn,
loss_weight_fn=lambda x: x.data[
"ppo_loss_mask"
].count_nonzero(),
token_normalize_scope=self.token_normalize_scope,
) )
for mb_i, data in enumerate(datas): stats_tracker.scalar(**train_stat)
with stats_tracker.scope(f"mb{mb_i}"): all_stats.append(stats_tracker.export())
train_stat = module.train_batch(
input_=data,
mb_spec=mb_spec,
version_steps=model.version.global_step,
loss_fn=_loss_fn,
loss_weight_fn=lambda x: x.data[
"ppo_loss_mask"
].count_nonzero(),
token_normalize_scope=self.token_normalize_scope,
)
stats_tracker.scalar(**train_stat)
stats_tracker.scalar(
disable_value=self.disable_value,
mask_no_eos_with_zero=self.mask_no_eos_with_zero,
c_clip=self.c_clip if self.c_clip is not None else float("nan"),
eps_clip=self.eps_clip,
)
model.inc_version() model.inc_version()
all_stats[0].update(global_stats)
return stats_tracker.export() return all_stats
# Mock methods for profiling only. # Mock methods for profiling only.
def _mock_inference( def _mock_inference(
@ -1033,7 +1058,7 @@ class PPOCriticInterface(model_api.ModelInterface):
model: model_api.Model, model: model_api.Model,
input_: SequenceSample, input_: SequenceSample,
mb_spec: MicroBatchSpec, mb_spec: MicroBatchSpec,
) -> Dict: ) -> Dict | List[Dict]:
assert model.module.module.config.is_critic assert model.module.module.config.is_critic
if self.disable_value: if self.disable_value:

View File

@ -3,7 +3,7 @@
# Licensed under the Apache License, Version 2.0 (the "License"). # Licensed under the Apache License, Version 2.0 (the "License").
import dataclasses import dataclasses
from typing import Dict, Literal from typing import Dict, List, Literal
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -68,10 +68,10 @@ def compute_packed_sft_loss(
vocab_min_logits = logits.detach().min(-1).values.float() vocab_min_logits = logits.detach().min(-1).values.float()
vocab_max_logits = logits.detach().max(-1).values.float() vocab_max_logits = logits.detach().max(-1).values.float()
dist.all_reduce( dist.all_reduce(
vocab_min_logits, group=constants.model_parallel_group(), op=dist.ReduceOp.MIN vocab_min_logits, group=constants.tensor_parallel_group(), op=dist.ReduceOp.MIN
) )
dist.all_reduce( dist.all_reduce(
vocab_max_logits, group=constants.model_parallel_group(), op=dist.ReduceOp.MAX vocab_max_logits, group=constants.tensor_parallel_group(), op=dist.ReduceOp.MAX
) )
stats_tracker.stat( stats_tracker.stat(
vocab_min_logits=vocab_min_logits, vocab_min_logits=vocab_min_logits,
@ -88,7 +88,7 @@ class SFTInterface(model_api.ModelInterface):
def train_step( def train_step(
self, model: model_api.Model, data: SequenceSample, mb_spec: MicroBatchSpec self, model: model_api.Model, data: SequenceSample, mb_spec: MicroBatchSpec
) -> Dict: ) -> Dict | List[Dict]:
module = model.module module = model.module
module.train() module.train()

View File

@ -10,14 +10,14 @@ import torch.utils.checkpoint
import realhf.base.constants as constants import realhf.base.constants as constants
import realhf.base.logging as logging import realhf.base.logging as logging
from realhf.impl.model.parallelism.model_parallel.modules import RowParallelLinear from realhf.impl.model.parallelism.tensor_parallel.modules import RowParallelLinear
from realhf.impl.model.utils.functional import ( from realhf.impl.model.utils.functional import (
apply_rotary_varlen, apply_rotary_varlen,
compute_varlen_position_indices, compute_varlen_position_indices,
torch_attn_func, torch_attn_func,
) )
from .mlp import LayerNormQKVLinear from .mlp import GemmaRMSNorm, LayerNormQKVLinear, LlamaRMSNorm
from .rotary import RotaryEmbedding from .rotary import RotaryEmbedding
try: try:
@ -53,6 +53,8 @@ class CausalSelfAttentionLayer(nn.Module):
layer_norm_type: Optional[str] = None, layer_norm_type: Optional[str] = None,
# opt applies layer norm after attn # opt applies layer norm after attn
do_layernorm_before: bool = True, do_layernorm_before: bool = True,
# qk layer norm (Qwen3)
qk_layernorm: bool = False,
# rotary embedding # rotary embedding
apply_rotary: bool = False, apply_rotary: bool = False,
rotary_base: float = 10000.0, rotary_base: float = 10000.0,
@ -67,7 +69,7 @@ class CausalSelfAttentionLayer(nn.Module):
super().__init__() super().__init__()
if dtype is None: if dtype is None:
dtype = torch.float16 dtype = torch.float16
assert hidden_dim % head_dim == 0 assert hidden_dim % head_dim == 0, (hidden_dim, head_dim)
self.c_attn = LayerNormQKVLinear( self.c_attn = LayerNormQKVLinear(
input_dim=hidden_dim, input_dim=hidden_dim,
head_dim=head_dim, head_dim=head_dim,
@ -82,7 +84,7 @@ class CausalSelfAttentionLayer(nn.Module):
layer_index=layer_index, layer_index=layer_index,
) )
if constants.model_parallel_world_size() > 1: if constants.tensor_parallel_world_size() > 1:
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
n_q_heads * head_dim, n_q_heads * head_dim,
hidden_dim, hidden_dim,
@ -100,6 +102,21 @@ class CausalSelfAttentionLayer(nn.Module):
device=device, device=device,
) )
self.qk_layernorm = qk_layernorm
if qk_layernorm:
if layer_norm_type is None:
layer_norm_fn = nn.LayerNorm
elif layer_norm_type == "rms":
layer_norm_fn = LlamaRMSNorm
elif layer_norm_type == "gemma":
layer_norm_fn = GemmaRMSNorm
self.q_ln = layer_norm_fn(
head_dim, eps=layer_norm_epsilon, dtype=dtype, device=device
)
self.k_ln = layer_norm_fn(
head_dim, eps=layer_norm_epsilon, dtype=dtype, device=device
)
self.resid_dropout = nn.Dropout(resid_pdrop) self.resid_dropout = nn.Dropout(resid_pdrop)
self.attn_pdrop = attn_pdrop self.attn_pdrop = attn_pdrop
@ -173,6 +190,10 @@ class CausalSelfAttentionLayer(nn.Module):
q, k, v = self.c_attn(hidden_states) q, k, v = self.c_attn(hidden_states)
if self.qk_layernorm:
q = self.q_ln(q)
k = self.k_ln(k)
if self.apply_rotary and (k_cache is None or str(q.device) == "cpu"): if self.apply_rotary and (k_cache is None or str(q.device) == "cpu"):
# otherwise, we input rotary cos/sin directly into flash_attn_with_kvcache # otherwise, we input rotary cos/sin directly into flash_attn_with_kvcache
rotary_cache_len = max_seqlen rotary_cache_len = max_seqlen

View File

@ -8,7 +8,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import init from torch.nn import init
from realhf.impl.model.parallelism.model_parallel.modules import ParallelEmbedding from realhf.impl.model.parallelism.tensor_parallel.modules import ParallelEmbedding
class OffsetPositionalEmbedding(nn.Embedding): class OffsetPositionalEmbedding(nn.Embedding):

View File

@ -15,7 +15,7 @@ from transformers.activations import ACT2FN
import realhf.base.constants as constants import realhf.base.constants as constants
import realhf.base.logging as logging import realhf.base.logging as logging
from realhf.impl.model.parallelism.model_parallel.modules import ( from realhf.impl.model.parallelism.tensor_parallel.modules import (
ColumnParallelLinear, ColumnParallelLinear,
RowParallelLinear, RowParallelLinear,
merged_linear_with_grad_accumulation_and_async_allreduce, merged_linear_with_grad_accumulation_and_async_allreduce,
@ -49,10 +49,10 @@ class LayerNormQKVLinear(nn.Module):
layer_index=None, layer_index=None,
): ):
super().__init__() super().__init__()
model_parallel = constants.model_parallel_world_size() > 1 tensor_parallel = constants.tensor_parallel_world_size() > 1
sequence_parallel = constants.sequence_parallel() sequence_parallel = constants.sequence_parallel()
gradient_accumulation_fusion = constants.gradient_accumulation_fusion() gradient_accumulation_fusion = constants.gradient_accumulation_fusion()
if not model_parallel and (sequence_parallel or gradient_accumulation_fusion): if not tensor_parallel and (sequence_parallel or gradient_accumulation_fusion):
global SEQUENCE_PARALLEL_WARNED global SEQUENCE_PARALLEL_WARNED
if not SEQUENCE_PARALLEL_WARNED: if not SEQUENCE_PARALLEL_WARNED:
logger.warning( logger.warning(
@ -73,16 +73,16 @@ class LayerNormQKVLinear(nn.Module):
input_dim, eps=layer_norm_epsilon, dtype=dtype, device=device input_dim, eps=layer_norm_epsilon, dtype=dtype, device=device
) )
self.model_parallel = model_parallel self.tensor_parallel = tensor_parallel
self.layer_index = layer_index self.layer_index = layer_index
self.mp_worldsize = constants.model_parallel_world_size() self.tp_worldsize = constants.tensor_parallel_world_size()
assert n_q_heads % self.mp_worldsize == 0, ( assert n_q_heads % self.tp_worldsize == 0, (
f"n_q_heads {n_q_heads} must be divisible by " f"n_q_heads {n_q_heads} must be divisible by "
f"mp_worldsize {self.mp_worldsize}" f"tp_worldsize {self.tp_worldsize}"
) )
assert n_kv_heads % self.mp_worldsize == 0, ( assert n_kv_heads % self.tp_worldsize == 0, (
f"n_kv_heads {n_kv_heads} must be divisible by " f"n_kv_heads {n_kv_heads} must be divisible by "
f"mp_worldsize {self.mp_worldsize}" f"tp_worldsize {self.tp_worldsize}"
) )
hidden_dim = input_dim hidden_dim = input_dim
# TODO: we can fuse the forward of qkv attention # TODO: we can fuse the forward of qkv attention
@ -141,9 +141,9 @@ class LayerNormQKVLinear(nn.Module):
self.v_attn.weight, self.v_attn.weight,
self.v_attn.bias, self.v_attn.bias,
) )
q = q.view(*q.shape[:-1], self.nq // self.mp_worldsize, self.d) q = q.view(*q.shape[:-1], self.nq // self.tp_worldsize, self.d)
k = k.view(*k.shape[:-1], self.nkv // self.mp_worldsize, self.d) k = k.view(*k.shape[:-1], self.nkv // self.tp_worldsize, self.d)
v = v.view(*v.shape[:-1], self.nkv // self.mp_worldsize, self.d) v = v.view(*v.shape[:-1], self.nkv // self.tp_worldsize, self.d)
return q, k, v return q, k, v
@ -163,10 +163,10 @@ class LayerNormMLP(nn.Module):
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
): ):
super().__init__() super().__init__()
model_parallel = constants.model_parallel_world_size() > 1 tensor_parallel = constants.tensor_parallel_world_size() > 1
sequence_parallel = constants.sequence_parallel() sequence_parallel = constants.sequence_parallel()
gradient_accumulation_fusion = constants.gradient_accumulation_fusion() gradient_accumulation_fusion = constants.gradient_accumulation_fusion()
if not model_parallel and (sequence_parallel or gradient_accumulation_fusion): if not tensor_parallel and (sequence_parallel or gradient_accumulation_fusion):
global SEQUENCE_PARALLEL_WARNED global SEQUENCE_PARALLEL_WARNED
if not SEQUENCE_PARALLEL_WARNED: if not SEQUENCE_PARALLEL_WARNED:
logger.warning( logger.warning(
@ -228,12 +228,12 @@ class LlamaLayerNormMLP(nn.Module):
device: Optional[Union[str, torch.device]] = None, device: Optional[Union[str, torch.device]] = None,
): ):
super().__init__() super().__init__()
self.model_parallel = constants.model_parallel_world_size() > 1 self.tensor_parallel = constants.tensor_parallel_world_size() > 1
gradient_accumulation_fusion = constants.gradient_accumulation_fusion() gradient_accumulation_fusion = constants.gradient_accumulation_fusion()
self.is_expert = is_expert self.is_expert = is_expert
# when used as experts the MLP always compute without sequence parallel # when used as experts the MLP always compute without sequence parallel
sequence_parallel = constants.sequence_parallel() and not is_expert sequence_parallel = constants.sequence_parallel() and not is_expert
if not self.model_parallel and ( if not self.tensor_parallel and (
sequence_parallel or gradient_accumulation_fusion sequence_parallel or gradient_accumulation_fusion
): ):
global SEQUENCE_PARALLEL_WARNED global SEQUENCE_PARALLEL_WARNED
@ -418,13 +418,13 @@ if constants.use_te_impl():
eps=layer_norm_epsilon, eps=layer_norm_epsilon,
sequence_parallel=constants.sequence_parallel(), sequence_parallel=constants.sequence_parallel(),
return_bias=False, return_bias=False,
tp_group=constants.model_parallel_group(), tp_group=constants.tensor_parallel_group(),
tp_size=constants.model_parallel_world_size(), tp_size=constants.tensor_parallel_world_size(),
bias=False, bias=False,
normalization="RMSNorm", normalization="RMSNorm",
activation="swiglu", activation="swiglu",
fuse_wgrad_accumulation=constants.gradient_accumulation_fusion(), fuse_wgrad_accumulation=constants.gradient_accumulation_fusion(),
params_dtype=dtype, params_dtype=dtype,
set_parallel_mode=constants.model_parallel_world_size() > 1, set_parallel_mode=constants.tensor_parallel_world_size() > 1,
device=device, device=device,
) )

View File

@ -10,11 +10,11 @@ from torch.nn.parameter import Parameter
import realhf.base.constants as constants import realhf.base.constants as constants
from realhf.api.core.model_api import ReaLModelConfig from realhf.api.core.model_api import ReaLModelConfig
from realhf.impl.model.modules.mlp import LlamaLayerNormMLP, get_activation_fn from realhf.impl.model.modules.mlp import LlamaLayerNormMLP, get_activation_fn
from realhf.impl.model.parallelism.model_parallel.mappings import ( from realhf.impl.model.parallelism.tensor_parallel.mappings import (
copy_to_tensor_model_parallel_region, copy_to_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region,
) )
from realhf.impl.model.parallelism.model_parallel.utils import divide from realhf.impl.model.parallelism.tensor_parallel.utils import divide
from realhf.impl.model.utils.random import _initialize_affine_weight_gpu from realhf.impl.model.utils.random import _initialize_affine_weight_gpu
try: try:
@ -125,7 +125,7 @@ class GroupedMLP(torch.nn.Module):
self.activation_func = get_activation_fn(self.config.activation_function) self.activation_func = get_activation_fn(self.config.activation_function)
# How many feature each rank holds for fc1 and fc2, respectively. # How many feature each rank holds for fc1 and fc2, respectively.
tp_size = constants.model_parallel_world_size() tp_size = constants.tensor_parallel_world_size()
intermediate_dim_per_partition = divide(self.config.intermediate_dim, tp_size) intermediate_dim_per_partition = divide(self.config.intermediate_dim, tp_size)
# Note: The current kernel implementations of grouped_gemm # Note: The current kernel implementations of grouped_gemm
@ -186,7 +186,7 @@ class GroupedMLP(torch.nn.Module):
): ):
tokens_per_expert = tokens_per_expert.cpu() tokens_per_expert = tokens_per_expert.cpu()
if permuted_local_hidden_states.nelement() != 0: if permuted_local_hidden_states.nelement() != 0:
if constants.model_parallel_world_size() > 1: if constants.tensor_parallel_world_size() > 1:
permuted_local_hidden_states = copy_to_tensor_model_parallel_region( permuted_local_hidden_states = copy_to_tensor_model_parallel_region(
permuted_local_hidden_states permuted_local_hidden_states
) )
@ -208,7 +208,7 @@ class GroupedMLP(torch.nn.Module):
output = grouped_gemm.ops.gmm( output = grouped_gemm.ops.gmm(
inter, self.grouped_down_proj, tokens_per_expert, trans_b=False inter, self.grouped_down_proj, tokens_per_expert, trans_b=False
) )
if constants.model_parallel_world_size() > 1: if constants.tensor_parallel_world_size() > 1:
output = reduce_from_tensor_model_parallel_region(output) output = reduce_from_tensor_model_parallel_region(output)
else: else:
# No token is allocated for local experts. # No token is allocated for local experts.

View File

@ -8,7 +8,7 @@ import torch.nn.init as init
import realhf.base.constants as constants import realhf.base.constants as constants
from realhf.api.core.model_api import ReaLModelConfig from realhf.api.core.model_api import ReaLModelConfig
from realhf.impl.model.parallelism.model_parallel.mappings import ( from realhf.impl.model.parallelism.tensor_parallel.mappings import (
gather_from_sequence_parallel_region, gather_from_sequence_parallel_region,
) )
from realhf.impl.model.utils.moe import ( from realhf.impl.model.utils.moe import (
@ -117,10 +117,10 @@ class TopKRouter(torch.nn.Module):
torch.Tensor: The activation tensor with the attached gradient function. torch.Tensor: The activation tensor with the attached gradient function.
""" """
moe_aux_loss_coeff = self.config.moe.aux_loss_coeff moe_aux_loss_coeff = self.config.moe.aux_loss_coeff
moe_aux_loss_coeff /= constants.model_parallel_world_size() moe_aux_loss_coeff /= constants.tensor_parallel_world_size()
scale_for_logging = 1.0 scale_for_logging = 1.0
if constants.sequence_parallel(): if constants.sequence_parallel():
scale_for_logging *= constants.model_parallel_world_size() scale_for_logging *= constants.tensor_parallel_world_size()
aux_loss = switch_load_balancing_loss_func( aux_loss = switch_load_balancing_loss_func(
probs, probs,
@ -128,7 +128,7 @@ class TopKRouter(torch.nn.Module):
self.config.moe.top_k, self.config.moe.top_k,
moe_aux_loss_coeff, moe_aux_loss_coeff,
sequence_partition_group=( sequence_partition_group=(
constants.model_parallel_group() constants.tensor_parallel_group()
if constants.sequence_parallel() if constants.sequence_parallel()
else None else None
), ),
@ -155,7 +155,7 @@ class TopKRouter(torch.nn.Module):
""" """
if self.config.moe.z_loss_coeff > 0: if self.config.moe.z_loss_coeff > 0:
moe_z_loss_coeff = ( moe_z_loss_coeff = (
self.config.moe.z_loss_coeff / constants.model_parallel_world_size() self.config.moe.z_loss_coeff / constants.tensor_parallel_world_size()
) )
z_loss = z_loss_func(logits, moe_z_loss_coeff) z_loss = z_loss_func(logits, moe_z_loss_coeff)
logits = MoEAuxLossAutoScaler.apply(logits, z_loss) logits = MoEAuxLossAutoScaler.apply(logits, z_loss)

View File

@ -7,7 +7,7 @@ import torch
import realhf.base.constants as constants import realhf.base.constants as constants
from realhf.api.core.model_api import ReaLModelConfig from realhf.api.core.model_api import ReaLModelConfig
from realhf.impl.model.parallelism.model_parallel.mappings import ( from realhf.impl.model.parallelism.tensor_parallel.mappings import (
gather_from_sequence_parallel_region, gather_from_sequence_parallel_region,
scatter_to_sequence_parallel_region, scatter_to_sequence_parallel_region,
) )

View File

@ -23,8 +23,8 @@ from .real_llm_base import ReaLModelParamKeys
from .real_llm_parallel import ( from .real_llm_parallel import (
get_real_model_param_shape, get_real_model_param_shape,
intervals_partition_fn, intervals_partition_fn,
mp_partition_key,
shape_partition_fn, shape_partition_fn,
tp_partition_key,
) )
try: try:
@ -188,7 +188,7 @@ def set_intervals(
def param_size_from_keys( def param_size_from_keys(
config: model_api.ReaLModelConfig, config: model_api.ReaLModelConfig,
src_mp_size: int, src_tp_size: int,
sd_keys: List[str], sd_keys: List[str],
src2dst_tp_size: int, src2dst_tp_size: int,
src2dst_tp_rank: int, src2dst_tp_rank: int,
@ -202,9 +202,9 @@ def param_size_from_keys(
and "0.wte.weight" in sd_keys and "0.wte.weight" in sd_keys
): ):
continue continue
new_shape = mp_partition_key( new_shape = tp_partition_key(
k, k,
get_real_model_param_shape(k, config, src_mp_size), get_real_model_param_shape(k, config, src_tp_size),
src2dst_tp_rank, src2dst_tp_rank,
src2dst_tp_size, src2dst_tp_size,
config, config,
@ -218,7 +218,7 @@ def build_param_spec(
layer_indices: List[int], layer_indices: List[int],
config: model_api.ReaLModelConfig, config: model_api.ReaLModelConfig,
dp_size: int, dp_size: int,
mp_size: int, tp_size: int,
pp_size: int, pp_size: int,
head_param_point_to_embedding: bool, head_param_point_to_embedding: bool,
bucket_size: int = 40000000, bucket_size: int = 40000000,
@ -273,7 +273,7 @@ def build_param_spec(
if head_param_point_to_embedding and k == f"{config.n_layers + 1}.weight": if head_param_point_to_embedding and k == f"{config.n_layers + 1}.weight":
continue continue
shape = get_real_model_param_shape(k, config, mp_size) shape = get_real_model_param_shape(k, config, tp_size)
numel = int(np.prod(shape)) numel = int(np.prod(shape))
data_end_index = data_start_index + numel data_end_index = data_start_index + numel
@ -307,14 +307,14 @@ def param_intervals_from_keys(
config: model_api.ReaLModelConfig, config: model_api.ReaLModelConfig,
head_param_point_to_embedding: bool, head_param_point_to_embedding: bool,
param_spec: Dict[str, ContiguousParamSpec], param_spec: Dict[str, ContiguousParamSpec],
mp_size: int, tp_size: int,
sd_keys: List[str], sd_keys: List[str],
portion_size: int, portion_size: int,
portion_rank: int, portion_rank: int,
) -> List[int]: ) -> List[int]:
param_size = param_size_from_keys( param_size = param_size_from_keys(
config=config, config=config,
src_mp_size=mp_size, src_tp_size=tp_size,
sd_keys=sd_keys, sd_keys=sd_keys,
src2dst_tp_size=portion_size, src2dst_tp_size=portion_size,
src2dst_tp_rank=portion_rank, src2dst_tp_rank=portion_rank,
@ -333,13 +333,13 @@ def param_intervals_from_keys(
if ( if (
model_name, model_name,
k.split(".", 1)[1], k.split(".", 1)[1],
mp_size, tp_size,
portion_rank, portion_rank,
portion_size, portion_size,
) not in _FLAT_PARAM_INDICES_CACHE: ) not in _FLAT_PARAM_INDICES_CACHE:
zero_start_intervals = mp_partition_key( zero_start_intervals = tp_partition_key(
k, k,
get_real_model_param_shape(k, config, mp_size), get_real_model_param_shape(k, config, tp_size),
portion_rank, portion_rank,
portion_size, portion_size,
config, config,
@ -349,7 +349,7 @@ def param_intervals_from_keys(
( (
model_name, model_name,
k.split(".", 1)[1], k.split(".", 1)[1],
mp_size, tp_size,
portion_rank, portion_rank,
portion_size, portion_size,
) )
@ -359,7 +359,7 @@ def param_intervals_from_keys(
( (
model_name, model_name,
k.split(".", 1)[1], k.split(".", 1)[1],
mp_size, tp_size,
portion_rank, portion_rank,
portion_size, portion_size,
) )

View File

@ -167,7 +167,7 @@ class ReaLModel(nn.Module):
self._param_spec, self._param_size = build_param_spec( self._param_spec, self._param_size = build_param_spec(
list(range(self.layer_idx_start, self.layer_idx_end)), list(range(self.layer_idx_start, self.layer_idx_end)),
self.config, self.config,
mp_size=constants.model_parallel_world_size(), tp_size=constants.tensor_parallel_world_size(),
pp_size=constants.pipe_parallel_world_size(), pp_size=constants.pipe_parallel_world_size(),
dp_size=constants.data_parallel_world_size(), dp_size=constants.data_parallel_world_size(),
head_param_point_to_embedding=self.head_param_point_to_embedding, head_param_point_to_embedding=self.head_param_point_to_embedding,
@ -282,7 +282,7 @@ class ReaLModel(nn.Module):
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
elif not config.is_critic and constants.model_parallel_world_size() > 1: elif not config.is_critic and constants.tensor_parallel_world_size() > 1:
l = ParallelActorHead( l = ParallelActorHead(
config.hidden_dim, config.hidden_dim,
config.vocab_size, config.vocab_size,
@ -428,14 +428,14 @@ class ReaLModel(nn.Module):
x.cu_seqlens = x.cu_seqlens.int() x.cu_seqlens = x.cu_seqlens.int()
# Copy input tensor to a pinned buffer. # Copy input tensor to a pinned buffer.
mp_size = constants.model_parallel_world_size() tp_size = constants.tensor_parallel_world_size()
batch_length = None batch_length = None
if ys[0].packed_input_ids is not None: if ys[0].packed_input_ids is not None:
batch_length = ys[0].packed_input_ids.shape[0] batch_length = ys[0].packed_input_ids.shape[0]
if x.pp_input is not None: if x.pp_input is not None:
batch_length = x.pp_input.shape[0] batch_length = x.pp_input.shape[0]
assert batch_length is not None assert batch_length is not None
padded_batch_length = (batch_length + mp_size - 1) // mp_size * mp_size padded_batch_length = (batch_length + tp_size - 1) // tp_size * tp_size
pad_size = padded_batch_length - batch_length pad_size = padded_batch_length - batch_length
if ( if (
@ -609,7 +609,7 @@ class ReaLModel(nn.Module):
to_param_spec, to_param_size = build_param_spec( to_param_spec, to_param_size = build_param_spec(
to_layer_indices, to_layer_indices,
to_model_config, to_model_config,
mp_size=to_topo.get_dim("model"), tp_size=to_topo.get_dim("tensor"),
dp_size=to_topo.get_dim("data"), dp_size=to_topo.get_dim("data"),
pp_size=to_topo.get_dim("pipe"), pp_size=to_topo.get_dim("pipe"),
head_param_point_to_embedding=to_model_head_param_point_to_embedding, head_param_point_to_embedding=to_model_head_param_point_to_embedding,

View File

@ -17,7 +17,7 @@ import transformers
import realhf.base.constants as constants import realhf.base.constants as constants
import realhf.base.logging as logging import realhf.base.logging as logging
import realhf.impl.model.parallelism.model_parallel.mappings as tensor_parallel import realhf.impl.model.parallelism.tensor_parallel.mappings as tensor_parallel
from realhf.api.core import model_api from realhf.api.core import model_api
from realhf.impl.model.modules import ( from realhf.impl.model.modules import (
CausalSelfAttentionLayer, CausalSelfAttentionLayer,
@ -28,9 +28,8 @@ from realhf.impl.model.modules import (
LlamaRMSNorm, LlamaRMSNorm,
OffsetParallelPositionalEmbedding, OffsetParallelPositionalEmbedding,
OffsetPositionalEmbedding, OffsetPositionalEmbedding,
scatter_to_sequence_parallel_region,
) )
from realhf.impl.model.parallelism.model_parallel.modules import ( from realhf.impl.model.parallelism.tensor_parallel.modules import (
ColumnParallelLinear, ColumnParallelLinear,
ParallelEmbedding, ParallelEmbedding,
gather_from_sequence_parallel_region, gather_from_sequence_parallel_region,
@ -139,6 +138,7 @@ class ReaLModelBlock(nn.Module):
use_attention_bias=config.use_attention_bias, use_attention_bias=config.use_attention_bias,
use_attn_proj_bias=config.use_attn_proj_bias, use_attn_proj_bias=config.use_attn_proj_bias,
do_layernorm_before=config.do_layernorm_before, do_layernorm_before=config.do_layernorm_before,
qk_layernorm=config.qk_layernorm,
apply_rotary=config.apply_rotary, apply_rotary=config.apply_rotary,
rotary_base=config.rotary_base, rotary_base=config.rotary_base,
rotary_interleaved=config.rotary_interleaved, rotary_interleaved=config.rotary_interleaved,
@ -281,8 +281,8 @@ class VocabPositionEmbedding(nn.Module):
self.n_positions = config.n_positions self.n_positions = config.n_positions
self.hidden_dim = config.hidden_dim self.hidden_dim = config.hidden_dim
model_parallel = constants.model_parallel_world_size() > 1 tensor_parallel = constants.tensor_parallel_world_size() > 1
if model_parallel: if tensor_parallel:
embed_cls = ParallelEmbedding embed_cls = ParallelEmbedding
else: else:
embed_cls = nn.Embedding embed_cls = nn.Embedding
@ -295,7 +295,7 @@ class VocabPositionEmbedding(nn.Module):
if self.apply_abs_pos_embed: if self.apply_abs_pos_embed:
p_embed_cls = ( p_embed_cls = (
OffsetParallelPositionalEmbedding OffsetParallelPositionalEmbedding
if model_parallel if tensor_parallel
else OffsetPositionalEmbedding else OffsetPositionalEmbedding
) )
self.wpe = p_embed_cls( self.wpe = p_embed_cls(
@ -416,7 +416,7 @@ class ParallelActorHead(ColumnParallelLinear):
def _forward(self, x: torch.Tensor): def _forward(self, x: torch.Tensor):
weight = self.weight weight = self.weight
if self._norm_head: if self._norm_head:
from realhf.impl.model.parallelism.model_parallel.mappings import ( from realhf.impl.model.parallelism.tensor_parallel.mappings import (
gather_from_sequence_parallel_region, gather_from_sequence_parallel_region,
) )
@ -431,7 +431,7 @@ class ParallelActorHead(ColumnParallelLinear):
).transpose(1, 0) ).transpose(1, 0)
head_norm = unnormed_head.norm(dim=0, keepdim=True, p=2) head_norm = unnormed_head.norm(dim=0, keepdim=True, p=2)
normed_head = unnormed_head / (head_norm + 1e-7) normed_head = unnormed_head / (head_norm + 1e-7)
weight = scatter_to_sequence_parallel_region(normed_head) weight = tensor_parallel.scatter_to_sequence_parallel_region(normed_head)
output = parallel_lm_logits( output = parallel_lm_logits(
x, x,
@ -486,6 +486,12 @@ class ReaLModelParamKeys:
keys += [f"{idx + 1}.attn.c_proj.weight"] keys += [f"{idx + 1}.attn.c_proj.weight"]
if config.use_attn_proj_bias: if config.use_attn_proj_bias:
keys += [f"{idx + 1}.attn.c_proj.bias"] keys += [f"{idx + 1}.attn.c_proj.bias"]
if config.qk_layernorm:
keys += [f"{idx + 1}.attn.q_ln.weight"]
keys += [f"{idx + 1}.attn.k_ln.weight"]
if config.layer_norm_type is None:
keys += [f"{idx + 1}.attn.q_ln.bias"]
keys += [f"{idx + 1}.attn.k_ln.bias"]
keys += [f"{idx + 1}.mlp.ln.weight"] keys += [f"{idx + 1}.mlp.ln.weight"]
if config.layer_norm_type is None: if config.layer_norm_type is None:
keys += [f"{idx + 1}.mlp.ln.bias"] keys += [f"{idx + 1}.mlp.ln.bias"]

View File

@ -55,8 +55,8 @@ def genstep(
unfinished_sequences: Bool tensor indicator of whether a sequence is finished. unfinished_sequences: Bool tensor indicator of whether a sequence is finished.
Shape [bs]. Shape [bs].
""" """
if constants.model_parallel_world_size() > 1: if constants.tensor_parallel_world_size() > 1:
from realhf.impl.model.parallelism.model_parallel.mappings import ( from realhf.impl.model.parallelism.tensor_parallel.mappings import (
gather_from_tensor_model_parallel_region, gather_from_tensor_model_parallel_region,
) )
@ -95,20 +95,20 @@ def genstep(
next_tokens = distrb.mode if gconfig.greedy else distrb.sample() next_tokens = distrb.mode if gconfig.greedy else distrb.sample()
logprob = distrb.log_prob(next_tokens) logprob = distrb.log_prob(next_tokens)
if constants.model_parallel_world_size() > 1: if constants.tensor_parallel_world_size() > 1:
if constants.model_parallel_rank() > 0: if constants.tensor_parallel_rank() > 0:
logprob[:] = 0 logprob[:] = 0
next_tokens[:] = 0 next_tokens[:] = 0
handle = torch.distributed.all_reduce( handle = torch.distributed.all_reduce(
logprob, logprob,
torch.distributed.ReduceOp.SUM, torch.distributed.ReduceOp.SUM,
async_op=True, async_op=True,
group=constants.model_parallel_group(), group=constants.tensor_parallel_group(),
) )
torch.distributed.all_reduce( torch.distributed.all_reduce(
next_tokens, next_tokens,
torch.distributed.ReduceOp.SUM, torch.distributed.ReduceOp.SUM,
group=constants.model_parallel_group(), group=constants.tensor_parallel_group(),
) )
if tokenizer.eos_token_id is not None: if tokenizer.eos_token_id is not None:
@ -139,7 +139,7 @@ def genstep(
if not logits_mask.any(): if not logits_mask.any():
logits_mask = None logits_mask = None
if constants.model_parallel_world_size() > 1: if constants.tensor_parallel_world_size() > 1:
handle.wait() handle.wait()
return next_tokens, logprob, logits_mask, terminate, unfinished_sequences return next_tokens, logprob, logits_mask, terminate, unfinished_sequences

View File

@ -42,27 +42,27 @@ if constants.use_te_impl():
def tensor_slice_partition_fn( def tensor_slice_partition_fn(
tensor: torch.Tensor, tensor: torch.Tensor,
mp_rank: Optional[int], tp_rank: Optional[int],
mp_world_size: int, tp_world_size: int,
dim: Optional[int], dim: Optional[int],
) -> Union[List[torch.Tensor], torch.Tensor]: ) -> Union[List[torch.Tensor], torch.Tensor]:
"""Partition a tensor by slicing along a dimension for tensor-model """Partition a tensor by slicing along a dimension for tensor-model
parallelism.""" parallelism."""
if dim is None: if dim is None:
splits = [tensor for _ in range(mp_world_size)] splits = [tensor for _ in range(tp_world_size)]
else: else:
assert tensor.shape[dim] % mp_world_size == 0 assert tensor.shape[dim] % tp_world_size == 0
splits = torch.split(tensor, tensor.shape[dim] // mp_world_size, dim=dim) splits = torch.split(tensor, tensor.shape[dim] // tp_world_size, dim=dim)
if mp_rank is None: if tp_rank is None:
return [s.contiguous() for s in splits] return [s.contiguous() for s in splits]
else: else:
return splits[mp_rank].contiguous() return splits[tp_rank].contiguous()
def intervals_partition_fn( def intervals_partition_fn(
shape: torch.Size, shape: torch.Size,
mp_rank: Optional[int], tp_rank: Optional[int],
mp_world_size: int, tp_world_size: int,
dim: Optional[int], dim: Optional[int],
) -> Union[List[torch.Tensor], torch.Tensor]: ) -> Union[List[torch.Tensor], torch.Tensor]:
"""Get the intervals of a MP-partitioned tensor in the flatten view. """Get the intervals of a MP-partitioned tensor in the flatten view.
@ -74,34 +74,34 @@ def intervals_partition_fn(
Used by parameter reallocation. Return a numpy array of shape [N, Used by parameter reallocation. Return a numpy array of shape [N,
2], where N is the number of intervals. 2], where N is the number of intervals.
""" """
assert mp_rank is not None assert tp_rank is not None
param_size = int(np.prod(shape)) param_size = int(np.prod(shape))
if dim is None: if dim is None:
return np.array([(0, param_size)], dtype=np.int64) return np.array([(0, param_size)], dtype=np.int64)
if dim < 0: if dim < 0:
dim = len(shape) + dim dim = len(shape) + dim
assert shape[dim] % mp_world_size == 0 assert shape[dim] % tp_world_size == 0
if len(shape) == 1: if len(shape) == 1:
assert dim == 0 assert dim == 0
partition_size = shape[0] // mp_world_size partition_size = shape[0] // tp_world_size
return np.array( return np.array(
[(partition_size * mp_rank, partition_size * (mp_rank + 1))], [(partition_size * tp_rank, partition_size * (tp_rank + 1))],
dtype=np.int64, dtype=np.int64,
) )
else: else:
assert len(shape) == 2, shape assert len(shape) == 2, shape
if dim == 0: if dim == 0:
row_start = mp_rank * shape[0] // mp_world_size row_start = tp_rank * shape[0] // tp_world_size
row_end = (mp_rank + 1) * shape[0] // mp_world_size row_end = (tp_rank + 1) * shape[0] // tp_world_size
return np.array( return np.array(
[(row_start * shape[1], row_end * shape[1])], dtype=np.int64 [(row_start * shape[1], row_end * shape[1])], dtype=np.int64
) )
else: else:
assert dim == 1 assert dim == 1
col_start = mp_rank * shape[1] // mp_world_size col_start = tp_rank * shape[1] // tp_world_size
col_end = (mp_rank + 1) * shape[1] // mp_world_size col_end = (tp_rank + 1) * shape[1] // tp_world_size
return np.arange(shape[0], dtype=np.int64)[:, None] * shape[1] + np.array( return np.arange(shape[0], dtype=np.int64)[:, None] * shape[1] + np.array(
[(col_start, col_end)], dtype=np.int64 [(col_start, col_end)], dtype=np.int64
) )
@ -109,32 +109,32 @@ def intervals_partition_fn(
def shape_partition_fn( def shape_partition_fn(
shape: torch.Size, shape: torch.Size,
mp_rank: Optional[int], tp_rank: Optional[int],
mp_world_size: int, tp_world_size: int,
dim: Optional[int], dim: Optional[int],
): ):
"""Get the partitioned shape of a tensor for tensor-model parallelism.""" """Get the partitioned shape of a tensor for tensor-model parallelism."""
if dim is None: if dim is None:
splits = [shape for _ in range(mp_world_size)] splits = [shape for _ in range(tp_world_size)]
else: else:
if dim < 0: if dim < 0:
dim = len(shape) + dim dim = len(shape) + dim
assert shape[dim] % mp_world_size == 0 assert shape[dim] % tp_world_size == 0
splits = [ splits = [
(*shape[:dim], shape[dim] // mp_world_size, *shape[dim + 1 :]) (*shape[:dim], shape[dim] // tp_world_size, *shape[dim + 1 :])
for _ in range(mp_world_size) for _ in range(tp_world_size)
] ]
if mp_rank is None: if tp_rank is None:
return [s for s in splits] return [s for s in splits]
else: else:
return splits[mp_rank] return splits[tp_rank]
def mp_partition_key( def tp_partition_key(
key: str, key: str,
tensor_or_shape: torch.Tensor | torch.Size, tensor_or_shape: torch.Tensor | torch.Size,
mp_rank: Optional[int], tp_rank: Optional[int],
mp_size: Optional[int], tp_size: Optional[int],
config: model_api.ReaLModelConfig, config: model_api.ReaLModelConfig,
partition_fn: Callable[ partition_fn: Callable[
[torch.Tensor, Optional[int], int, Optional[int]], [torch.Tensor, Optional[int], int, Optional[int]],
@ -149,7 +149,7 @@ def mp_partition_key(
if any([ek in key for ek in EMBEDDING_KEYS]): if any([ek in key for ek in EMBEDDING_KEYS]):
assert "weight" in key assert "weight" in key
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=0) return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=0)
elif key == f"{config.n_layers + 1}.weight": # output head elif key == f"{config.n_layers + 1}.weight": # output head
if ( if (
isinstance(tensor_or_shape, torch.Tensor) and tensor_or_shape.shape[0] == 1 isinstance(tensor_or_shape, torch.Tensor) and tensor_or_shape.shape[0] == 1
@ -157,88 +157,90 @@ def mp_partition_key(
not isinstance(tensor_or_shape, torch.Tensor) and tensor_or_shape[0] == 1 not isinstance(tensor_or_shape, torch.Tensor) and tensor_or_shape[0] == 1
): ):
assert config.is_critic assert config.is_critic
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=None) return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=None)
else: else:
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=0) return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=0)
elif any([ck in key for ck in COLUMN_LINEAR_KEYS]): elif any([ck in key for ck in COLUMN_LINEAR_KEYS]):
if ( if (
("k_attn" in key) or ("v_attn" in key) ("k_attn" in key) or ("v_attn" in key)
) and config.n_kv_heads % mp_size != 0: ) and config.n_kv_heads % tp_size != 0:
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=None) return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=None)
# partition both weight and bias # partition both weight and bias
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=0) return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=0)
elif any([rk in key for rk in ROW_LINEAR_KEYS]): elif any([rk in key for rk in ROW_LINEAR_KEYS]):
# only partition weight # only partition weight
if "weight" in key: if "weight" in key:
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=1) return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=1)
else: else:
assert "bias" in key, key assert "bias" in key, key
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=None) return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=None)
else: else:
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=None) return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=None)
def mp_partition_real_model_state_dict( def tp_partition_real_model_state_dict(
state_dict: Dict[str, torch.Tensor], state_dict: Dict[str, torch.Tensor],
config: model_api.ReaLModelConfig, config: model_api.ReaLModelConfig,
mp_size: int, tp_size: int,
mp_rank: Optional[int] = None, tp_rank: Optional[int] = None,
) -> Union[Dict, List[Dict]]: ) -> Union[Dict, List[Dict]]:
"""A helper function to partition a state dict using `mp_partition_key`.""" """A helper function to partition a state dict using `tp_partition_key`."""
if mp_size == 1: if tp_size == 1:
if mp_rank is None: if tp_rank is None:
return [state_dict] return [state_dict]
else: else:
return state_dict return state_dict
new_state_dict = {} new_state_dict = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
new_state_dict[k] = mp_partition_key(k, v, mp_rank, mp_size, config) new_state_dict[k] = tp_partition_key(k, v, tp_rank, tp_size, config)
if mp_rank is None: if tp_rank is None:
return [ return [
{k: v[mp_rank] for k, v in new_state_dict.items()} {k: v[tp_rank] for k, v in new_state_dict.items()}
for mp_rank in range(mp_size) for tp_rank in range(tp_size)
] ]
else: else:
return new_state_dict return new_state_dict
def get_real_model_param_shape( def get_real_model_param_shape(
k: str, config: model_api.ReaLModelConfig, mp_size: int k: str, config: model_api.ReaLModelConfig, tp_size: int
) -> Tuple: ) -> Tuple:
if "wte.weight" in k: if "wte.weight" in k:
assert config.vocab_size % mp_size == 0 assert config.vocab_size % tp_size == 0
return (config.vocab_size // mp_size, config.hidden_dim) return (config.vocab_size // tp_size, config.hidden_dim)
elif "wpe.weight" in k: elif "wpe.weight" in k:
assert config.n_positions % mp_size == 0 assert config.n_positions % tp_size == 0
if (config.n_positions + config.abs_position_embedding_offset) % mp_size != 0: if (config.n_positions + config.abs_position_embedding_offset) % tp_size != 0:
raise ValueError( raise ValueError(
f"The dimenstion of position embedding " f"The dimenstion of position embedding "
f"({config.n_positions} + offset {config.abs_position_embedding_offset}) " f"({config.n_positions} + offset {config.abs_position_embedding_offset}) "
f"is not divisible by mp_size ({mp_size}). " f"is not divisible by tp_size ({tp_size}). "
"Models like this (e.g. OPT-350m) inherently do not support tensor parallelism." "Models like this (e.g. OPT-350m) inherently do not support tensor parallelism."
) )
return ( return (
(config.n_positions + config.abs_position_embedding_offset) // mp_size, (config.n_positions + config.abs_position_embedding_offset) // tp_size,
config.hidden_dim, config.hidden_dim,
) )
elif ".ln." in k or ".ln_f." in k: elif ".ln." in k or ".ln_f." in k:
return (config.hidden_dim,) return (config.hidden_dim,)
elif ".q_ln." in k or ".k_ln." in k:
return (config.head_dim,)
elif k == f"{config.n_layers + 1}.weight": # output head elif k == f"{config.n_layers + 1}.weight": # output head
if config.is_critic: if config.is_critic:
return (1, config.hidden_dim) return (1, config.hidden_dim)
elif mp_size > 1: elif tp_size > 1:
assert config.vocab_size % mp_size == 0 assert config.vocab_size % tp_size == 0
return (config.vocab_size // mp_size, config.hidden_dim) return (config.vocab_size // tp_size, config.hidden_dim)
else: else:
return (config.vocab_size, config.hidden_dim) return (config.vocab_size, config.hidden_dim)
elif any([ck in k for ck in COLUMN_LINEAR_KEYS]): elif any([ck in k for ck in COLUMN_LINEAR_KEYS]):
if "k_attn" in k or "v_attn" in k: if "k_attn" in k or "v_attn" in k:
if "weight" in k: if "weight" in k:
if config.n_kv_heads % mp_size == 0: if config.n_kv_heads % tp_size == 0:
return ( return (
config.head_dim * config.n_kv_heads // mp_size, config.head_dim * config.n_kv_heads // tp_size,
config.hidden_dim, config.hidden_dim,
) )
else: else:
@ -248,27 +250,27 @@ def get_real_model_param_shape(
) )
else: else:
assert "bias" in k assert "bias" in k
if config.n_kv_heads % mp_size == 0: if config.n_kv_heads % tp_size == 0:
return (config.head_dim * config.n_kv_heads // mp_size,) return (config.head_dim * config.n_kv_heads // tp_size,)
else: else:
return (config.head_dim * config.n_kv_heads,) return (config.head_dim * config.n_kv_heads,)
if "mlp" in k: if "mlp" in k:
if "weight" in k: if "weight" in k:
return (config.intermediate_dim // mp_size, config.hidden_dim) return (config.intermediate_dim // tp_size, config.hidden_dim)
else: else:
assert "bias" in k assert "bias" in k
return (config.intermediate_dim // mp_size,) return (config.intermediate_dim // tp_size,)
if "weight" in k: if "weight" in k:
assert config.n_q_heads % mp_size == 0 assert config.n_q_heads % tp_size == 0
return (config.n_q_heads * config.head_dim // mp_size, config.hidden_dim) return (config.n_q_heads * config.head_dim // tp_size, config.hidden_dim)
else: else:
assert "bias" in k assert "bias" in k
return (config.n_q_heads * config.head_dim // mp_size,) return (config.n_q_heads * config.head_dim // tp_size,)
elif any([rk in k for rk in ROW_LINEAR_KEYS]): elif any([rk in k for rk in ROW_LINEAR_KEYS]):
if "mlp" in k and "weight" in k: if "mlp" in k and "weight" in k:
return (config.hidden_dim, config.intermediate_dim // mp_size) return (config.hidden_dim, config.intermediate_dim // tp_size)
elif "attn" in k and "weight" in k: elif "attn" in k and "weight" in k:
return (config.hidden_dim, config.n_q_heads * config.head_dim // mp_size) return (config.hidden_dim, config.n_q_heads * config.head_dim // tp_size)
elif "bias" in k: elif "bias" in k:
return (config.hidden_dim,) return (config.hidden_dim,)
else: else:
@ -280,7 +282,7 @@ def get_real_model_param_shape(
raise NotImplementedError(f"unkown shape of key {k}.") raise NotImplementedError(f"unkown shape of key {k}.")
def mp_merge_key( def tp_merge_key(
k: str, k: str,
tensors: List[torch.Tensor], tensors: List[torch.Tensor],
config: model_api.ReaLModelConfig, config: model_api.ReaLModelConfig,
@ -297,17 +299,17 @@ def mp_merge_key(
return tensors[0] return tensors[0]
def mp_merge_real_model_state_dict( def tp_merge_real_model_state_dict(
state_dicts: List[Dict[str, torch.Tensor]], state_dicts: List[Dict[str, torch.Tensor]],
config: model_api.ReaLModelConfig, config: model_api.ReaLModelConfig,
) -> Dict: ) -> Dict:
mp_size = len(state_dicts) tp_size = len(state_dicts)
if mp_size == 1: if tp_size == 1:
return state_dicts[0] return state_dicts[0]
new_state_dict = {} new_state_dict = {}
for k in state_dicts[0].keys(): for k in state_dicts[0].keys():
new_state_dict[k] = mp_merge_key(k, [sd[k] for sd in state_dicts], config) new_state_dict[k] = tp_merge_key(k, [sd[k] for sd in state_dicts], config)
return new_state_dict return new_state_dict
@ -317,37 +319,37 @@ class ReaLModelParamCount:
@staticmethod @staticmethod
def _derive_count_from_keys( def _derive_count_from_keys(
keys: List[str], config: model_api.ReaLModelConfig, mp_size: int keys: List[str], config: model_api.ReaLModelConfig, tp_size: int
) -> int: ) -> int:
count = 0 count = 0
for k in keys: for k in keys:
count += np.prod(get_real_model_param_shape(k, config, mp_size)) count += np.prod(get_real_model_param_shape(k, config, tp_size))
return int(count) return int(count)
@staticmethod @staticmethod
def embed(config: model_api.ReaLModelConfig, mp_size: int) -> int: def embed(config: model_api.ReaLModelConfig, tp_size: int) -> int:
return ReaLModelParamCount._derive_count_from_keys( return ReaLModelParamCount._derive_count_from_keys(
ReaLModelParamKeys.embed(config), config, mp_size ReaLModelParamKeys.embed(config), config, tp_size
) )
@staticmethod @staticmethod
def tblock(config: model_api.ReaLModelConfig, idx: int, mp_size: int) -> int: def tblock(config: model_api.ReaLModelConfig, idx: int, tp_size: int) -> int:
return ReaLModelParamCount._derive_count_from_keys( return ReaLModelParamCount._derive_count_from_keys(
ReaLModelParamKeys.tblock(config, idx), config, mp_size ReaLModelParamKeys.tblock(config, idx), config, tp_size
) )
@staticmethod @staticmethod
def head(config: model_api.ReaLModelConfig, mp_size: int) -> int: def head(config: model_api.ReaLModelConfig, tp_size: int) -> int:
return ReaLModelParamCount._derive_count_from_keys( return ReaLModelParamCount._derive_count_from_keys(
ReaLModelParamKeys.head(config), config, mp_size ReaLModelParamKeys.head(config), config, tp_size
) )
@staticmethod @staticmethod
def total(config: model_api.ReaLModelConfig, idx: int, mp_size: int) -> int: def total(config: model_api.ReaLModelConfig, idx: int, tp_size: int) -> int:
return ( return (
config.n_layers * ReaLModelParamCount.tblock(config, idx, mp_size) config.n_layers * ReaLModelParamCount.tblock(config, idx, tp_size)
+ ReaLModelParamCount.head(config, mp_size) + ReaLModelParamCount.head(config, tp_size)
+ ReaLModelParamCount.embed(config, mp_size) + ReaLModelParamCount.embed(config, tp_size)
) )
@ -356,7 +358,7 @@ def partition_pipeline_layers(
num_stages: int, num_stages: int,
method: str = "parameters", method: str = "parameters",
) -> Dict[int, Tuple[int, int]]: ) -> Dict[int, Tuple[int, int]]:
# Ignoring mp_size in param count because tensor parallel equally partitions parameters. # Ignoring tp_size in param count because tensor parallel equally partitions parameters.
# It is irrelevant to how we partition pipeline stages. # It is irrelevant to how we partition pipeline stages.
param_counts = ( param_counts = (
[ReaLModelParamCount.embed(config, 1)] [ReaLModelParamCount.embed(config, 1)]

View File

@ -13,11 +13,11 @@ def _reduce(input_):
"""All-reduce the input tensor across model parallel group.""" """All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if constants.model_parallel_world_size() == 1: if constants.tensor_parallel_world_size() == 1:
return input_ return input_
# All-reduce. # All-reduce.
torch.distributed.all_reduce(input_, group=constants.model_parallel_group()) torch.distributed.all_reduce(input_, group=constants.tensor_parallel_group())
return input_ return input_
@ -25,7 +25,7 @@ def _split_along_last_dim(input_):
"""Split the tensor along its last dimension and keep the corresponding """Split the tensor along its last dimension and keep the corresponding
slice.""" slice."""
world_size = constants.model_parallel_world_size() world_size = constants.tensor_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size == 1: if world_size == 1:
return input_ return input_
@ -34,7 +34,7 @@ def _split_along_last_dim(input_):
input_list = split_tensor_along_last_dim(input_, world_size) input_list = split_tensor_along_last_dim(input_, world_size)
# Note: torch.split does not create contiguous tensors by default. # Note: torch.split does not create contiguous tensors by default.
rank = constants.model_parallel_rank() rank = constants.tensor_parallel_rank()
output = input_list[rank].contiguous() output = input_list[rank].contiguous()
return output return output
@ -44,7 +44,7 @@ def _split_along_first_dim(input_):
"""Split the tensor along its first dimension and keep the corresponding """Split the tensor along its first dimension and keep the corresponding
slice.""" slice."""
world_size = constants.model_parallel_world_size() world_size = constants.tensor_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size == 1: if world_size == 1:
return input_ return input_
@ -55,7 +55,7 @@ def _split_along_first_dim(input_):
dim_size % world_size == 0 dim_size % world_size == 0
), "First dimension of the tensor should be divisible by tensor parallel size" ), "First dimension of the tensor should be divisible by tensor parallel size"
local_dim_size = dim_size // world_size local_dim_size = dim_size // world_size
rank = constants.model_parallel_rank() rank = constants.tensor_parallel_rank()
dim_offset = rank * local_dim_size dim_offset = rank * local_dim_size
output = input_[dim_offset : dim_offset + local_dim_size].contiguous() output = input_[dim_offset : dim_offset + local_dim_size].contiguous()
@ -66,19 +66,19 @@ def _split_along_first_dim(input_):
def _gather_along_last_dim(input_): def _gather_along_last_dim(input_):
"""Gather tensors and concatinate along the last dimension.""" """Gather tensors and concatinate along the last dimension."""
world_size = constants.model_parallel_world_size() world_size = constants.tensor_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size == 1: if world_size == 1:
return input_ return input_
# Size and dimension. # Size and dimension.
last_dim = input_.dim() - 1 last_dim = input_.dim() - 1
rank = constants.model_parallel_rank() rank = constants.tensor_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_ tensor_list[rank] = input_
torch.distributed.all_gather( torch.distributed.all_gather(
tensor_list, input_, group=constants.model_parallel_group() tensor_list, input_, group=constants.tensor_parallel_group()
) )
# Note: torch.cat already creates a contiguous tensor. # Note: torch.cat already creates a contiguous tensor.
@ -90,7 +90,7 @@ def _gather_along_last_dim(input_):
def _gather_along_first_dim(input_): def _gather_along_first_dim(input_):
"""Gather tensors and concatinate along the first dimension.""" """Gather tensors and concatinate along the first dimension."""
world_size = constants.model_parallel_world_size() world_size = constants.tensor_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size == 1: if world_size == 1:
return input_ return input_
@ -102,7 +102,7 @@ def _gather_along_first_dim(input_):
dim_size, dtype=input_.dtype, device=constants.current_device() dim_size, dtype=input_.dtype, device=constants.current_device()
) )
torch.distributed._all_gather_base( torch.distributed._all_gather_base(
output, input_.contiguous(), group=constants.model_parallel_group() output, input_.contiguous(), group=constants.tensor_parallel_group()
) )
return output return output
@ -110,7 +110,7 @@ def _gather_along_first_dim(input_):
def _reduce_scatter_along_first_dim(input_): def _reduce_scatter_along_first_dim(input_):
"""Reduce-scatter the input tensor across model parallel group.""" """Reduce-scatter the input tensor across model parallel group."""
world_size = constants.model_parallel_world_size() world_size = constants.tensor_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size == 1: if world_size == 1:
return input_ return input_
@ -128,7 +128,7 @@ def _reduce_scatter_along_first_dim(input_):
dim_size, dtype=input_.dtype, device=constants.current_device() dim_size, dtype=input_.dtype, device=constants.current_device()
) )
torch.distributed._reduce_scatter_base( torch.distributed._reduce_scatter_base(
output, input_.contiguous(), group=constants.model_parallel_group() output, input_.contiguous(), group=constants.tensor_parallel_group()
) )
return output return output

View File

@ -44,7 +44,7 @@ except ImportError:
import realhf.base.logging as logging import realhf.base.logging as logging
logger = logging.getLogger("model_parallel.modules") logger = logging.getLogger("tensor_parallel.modules")
def get_activation_fn(activation_function: str) -> Callable: def get_activation_fn(activation_function: str) -> Callable:
@ -95,12 +95,12 @@ class ParallelEmbedding(torch.nn.Module):
self.scale_grad_by_freq = False self.scale_grad_by_freq = False
self.sparse = False self.sparse = False
self._weight = None self._weight = None
self.tensor_model_parallel_size = constants.model_parallel_world_size() self.tensor_model_parallel_size = constants.tensor_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension. # Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = ( self.vocab_start_index, self.vocab_end_index = (
VocabUtility.vocab_range_from_global_vocab_size( VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, self.num_embeddings,
constants.model_parallel_rank(), constants.tensor_parallel_rank(),
self.tensor_model_parallel_size, self.tensor_model_parallel_size,
) )
) )
@ -110,7 +110,7 @@ class ParallelEmbedding(torch.nn.Module):
logger.debug( logger.debug(
f"ParallelEmbedding: num_embeddings={num_embeddings}, per_partition={self.num_embeddings_per_partition}, embedding_dim={embedding_dim}," f"ParallelEmbedding: num_embeddings={num_embeddings}, per_partition={self.num_embeddings_per_partition}, embedding_dim={embedding_dim},"
f"tp_rank={constants.model_parallel_rank()},tp_world_size={constants.model_parallel_world_size()}" f"tp_rank={constants.tensor_parallel_rank()},tp_world_size={constants.tensor_parallel_world_size()}"
) )
# Allocate weights and initialize. # Allocate weights and initialize.
self.weight = Parameter( self.weight = Parameter(
@ -264,7 +264,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
assert ( assert (
not ctx.async_grad_allreduce not ctx.async_grad_allreduce
), "async_grad_allreduce and sequence_parallel can not be both True" ), "async_grad_allreduce and sequence_parallel can not be both True"
world_size = constants.model_parallel_world_size() world_size = constants.tensor_parallel_world_size()
dim_size = list(input.size()) dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size dim_size[0] = dim_size[0] * world_size
@ -272,7 +272,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size, input.dtype, "mpu" dim_size, input.dtype, "mpu"
) )
torch.distributed._all_gather_base( torch.distributed._all_gather_base(
all_gather_buffer, input, group=constants.model_parallel_group() all_gather_buffer, input, group=constants.tensor_parallel_group()
) )
total_input = all_gather_buffer total_input = all_gather_buffer
else: else:
@ -290,7 +290,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
use_bias = ctx.use_bias use_bias = ctx.use_bias
if ctx.sequence_parallel: if ctx.sequence_parallel:
world_size = constants.model_parallel_world_size() world_size = constants.tensor_parallel_world_size()
dim_size = list(input.size()) dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size dim_size[0] = dim_size[0] * world_size
@ -300,7 +300,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
handle = torch.distributed._all_gather_base( handle = torch.distributed._all_gather_base(
all_gather_buffer, all_gather_buffer,
input, input,
group=constants.model_parallel_group(), group=constants.tensor_parallel_group(),
async_op=True, async_op=True,
) )
@ -327,7 +327,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
# Asynchronous all-reduce # Asynchronous all-reduce
handle = torch.distributed.all_reduce( handle = torch.distributed.all_reduce(
grad_input, grad_input,
group=constants.model_parallel_group(), group=constants.tensor_parallel_group(),
async_op=True, async_op=True,
) )
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
@ -346,7 +346,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
handle = torch.distributed._reduce_scatter_base( handle = torch.distributed._reduce_scatter_base(
sub_grad_input, sub_grad_input,
grad_input, grad_input,
group=constants.model_parallel_group(), group=constants.tensor_parallel_group(),
async_op=True, async_op=True,
) )
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
@ -525,7 +525,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct
assert ( assert (
not ctx.async_grad_allreduce not ctx.async_grad_allreduce
), "async_grad_allreduce and sequence_parallel can not be both True" ), "async_grad_allreduce and sequence_parallel can not be both True"
world_size = constants.model_parallel_world_size() world_size = constants.tensor_parallel_world_size()
dim_size = list(input.size()) dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size dim_size[0] = dim_size[0] * world_size
@ -533,7 +533,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct
dim_size, input.dtype, "mpu" dim_size, input.dtype, "mpu"
) )
torch.distributed._all_gather_base( torch.distributed._all_gather_base(
all_gather_buffer, input, group=constants.model_parallel_group() all_gather_buffer, input, group=constants.tensor_parallel_group()
) )
total_input = all_gather_buffer total_input = all_gather_buffer
else: else:
@ -557,7 +557,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct
is_w_parallel = ctx.is_w_parallel is_w_parallel = ctx.is_w_parallel
if ctx.sequence_parallel: if ctx.sequence_parallel:
world_size = constants.model_parallel_world_size() world_size = constants.tensor_parallel_world_size()
dim_size = list(input.size()) dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size dim_size[0] = dim_size[0] * world_size
@ -567,7 +567,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct
handle = torch.distributed._all_gather_base( handle = torch.distributed._all_gather_base(
all_gather_buffer, all_gather_buffer,
input, input,
group=constants.model_parallel_group(), group=constants.tensor_parallel_group(),
async_op=True, async_op=True,
) )
@ -578,7 +578,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct
total_input = input total_input = input
grad_input = 0 grad_input = 0
for w, is_parallel, grad in zip(weights, is_w_parallel, grads): for w, is_parallel, grad in zip(weights, is_w_parallel, grads):
if is_parallel or constants.model_parallel_rank() == 0: if is_parallel or constants.tensor_parallel_rank() == 0:
grad_input = grad_input + grad.matmul(w) grad_input = grad_input + grad.matmul(w)
if ctx.sequence_parallel: if ctx.sequence_parallel:
@ -597,7 +597,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct
# Asynchronous all-reduce # Asynchronous all-reduce
handle = torch.distributed.all_reduce( handle = torch.distributed.all_reduce(
grad_input, grad_input,
group=constants.model_parallel_group(), group=constants.tensor_parallel_group(),
async_op=True, async_op=True,
) )
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
@ -616,7 +616,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct
handle = torch.distributed._reduce_scatter_base( handle = torch.distributed._reduce_scatter_base(
sub_grad_input, sub_grad_input,
grad_input, grad_input,
group=constants.model_parallel_group(), group=constants.tensor_parallel_group(),
async_op=True, async_op=True,
) )
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
@ -785,7 +785,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size = output_size self.output_size = output_size
self.gather_output = gather_output self.gather_output = gather_output
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
world_size = constants.model_parallel_world_size() world_size = constants.tensor_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size) self.output_size_per_partition = divide(output_size, world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.is_expert = is_expert self.is_expert = is_expert
@ -852,7 +852,7 @@ class ColumnParallelLinear(torch.nn.Module):
# in expert MLPs always behave as sequence parallel is not enabled. # in expert MLPs always behave as sequence parallel is not enabled.
sequence_parallel = constants.sequence_parallel() and not self.is_expert sequence_parallel = constants.sequence_parallel() and not self.is_expert
async_tensor_model_parallel_allreduce = ( async_tensor_model_parallel_allreduce = (
constants.model_parallel_world_size() > 1 and not sequence_parallel constants.tensor_parallel_world_size() > 1 and not sequence_parallel
) )
if sequence_parallel: if sequence_parallel:
@ -942,7 +942,7 @@ class RowParallelLinear(torch.nn.Module):
self.output_size = output_size self.output_size = output_size
self.input_is_parallel = input_is_parallel self.input_is_parallel = input_is_parallel
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
world_size = constants.model_parallel_world_size() world_size = constants.tensor_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size) self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.gradient_accumulation_fusion = gradient_accumulation_fusion self.gradient_accumulation_fusion = gradient_accumulation_fusion
@ -1030,9 +1030,9 @@ def parallel_lm_logits(
bias=None, bias=None,
): ):
"""LM logits using word embedding weights.""" """LM logits using word embedding weights."""
model_parallel = constants.model_parallel_world_size() > 1 tensor_parallel = constants.tensor_parallel_world_size() > 1
sequence_parallel = constants.sequence_parallel() sequence_parallel = constants.sequence_parallel()
async_grad_allreduce = not sequence_parallel and model_parallel async_grad_allreduce = not sequence_parallel and tensor_parallel
# Parallel logits. # Parallel logits.
if sequence_parallel: if sequence_parallel:
input_parallel = input_ input_parallel = input_
@ -1066,7 +1066,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
torch.distributed.all_reduce( torch.distributed.all_reduce(
logits_max, logits_max,
op=torch.distributed.ReduceOp.MAX, op=torch.distributed.ReduceOp.MAX,
group=constants.model_parallel_group(), group=constants.tensor_parallel_group(),
) )
# Subtract the maximum value. # Subtract the maximum value.
vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1) vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1)
@ -1074,8 +1074,8 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# Get the partition's vocab indecies # Get the partition's vocab indecies
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
partition_vocab_size = vocab_parallel_logits.size()[-1] partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = constants.model_parallel_rank() rank = constants.tensor_parallel_rank()
world_size = constants.model_parallel_world_size() world_size = constants.tensor_parallel_world_size()
vocab_start_index, vocab_end_index = get_vocab_range( vocab_start_index, vocab_end_index = get_vocab_range(
partition_vocab_size, rank, world_size partition_vocab_size, rank, world_size
) )
@ -1101,7 +1101,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
torch.distributed.all_reduce( torch.distributed.all_reduce(
predicted_logits, predicted_logits,
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=constants.model_parallel_group(), group=constants.tensor_parallel_group(),
) )
# Sum of exponential of logits along vocab dimension across all GPUs. # Sum of exponential of logits along vocab dimension across all GPUs.
@ -1111,7 +1111,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
torch.distributed.all_reduce( torch.distributed.all_reduce(
sum_exp_logits, sum_exp_logits,
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
group=constants.model_parallel_group(), group=constants.tensor_parallel_group(),
) )
# Loss = log(sum(exp(logits))) - predicted-logit. # Loss = log(sum(exp(logits))) - predicted-logit.

View File

@ -6,11 +6,7 @@ from typing import List, Sequence
import numpy as np import numpy as np
import torch import torch
from realhf.base.constants import ( import realhf.base.constants as constants
model_parallel_group,
model_parallel_rank,
model_parallel_world_size,
)
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = { _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
"tensor_model_parallel": False, "tensor_model_parallel": False,
@ -22,7 +18,7 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
def param_is_not_model_parallel_duplicate(param): def param_is_not_model_parallel_duplicate(param):
return ( return (
hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel
) or (model_parallel_rank() == 0) ) or (constants.tensor_parallel_rank() == 0)
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride): def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
@ -110,8 +106,8 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
If False, returns a view into the existing Tensor. If False, returns a view into the existing Tensor.
Default is False Default is False
""" """
partition_size = torch.numel(tensor) // model_parallel_world_size() partition_size = torch.numel(tensor) // constants.tensor_parallel_world_size()
start_index = partition_size * model_parallel_rank() start_index = partition_size * constants.tensor_parallel_rank()
end_index = start_index + partition_size end_index = start_index + partition_size
if new_buffer: if new_buffer:
data = torch.empty( data = torch.empty(
@ -135,7 +131,7 @@ def gather_split_1d_tensor(tensor):
Arguments: Arguments:
tensor: A Tensor or view of this rank's portion of the data. tensor: A Tensor or view of this rank's portion of the data.
""" """
numel_gathered = torch.numel(tensor) * model_parallel_world_size() numel_gathered = torch.numel(tensor) * constants.tensor_parallel_world_size()
gathered = torch.empty( gathered = torch.empty(
numel_gathered, numel_gathered,
dtype=tensor.dtype, dtype=tensor.dtype,
@ -147,7 +143,9 @@ def gather_split_1d_tensor(tensor):
# as opposed to torch.distributed.all_gather for efficiency reasons. # as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does # This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down. # internal copies and can potentially cause slow down.
torch.distributed._all_gather_base(gathered, tensor, group=model_parallel_group()) torch.distributed._all_gather_base(
gathered, tensor, group=constants.tensor_parallel_group()
)
return gathered return gathered

View File

@ -210,11 +210,11 @@ def gather_packed_shifted_log_probs(
""" """
labels = torch.nn.functional.pad(labels[1:], (0, 1), value=0) labels = torch.nn.functional.pad(labels[1:], (0, 1), value=0)
leave_one_indices = build_leave_one_indices(logits, cu_seqlens) leave_one_indices = build_leave_one_indices(logits, cu_seqlens)
if constants.model_parallel_world_size() > 1: if constants.tensor_parallel_world_size() > 1:
# NOTE: logprobs is freaking sensitive to input_ids. If the input sequence is a natural sequence, everything will be fine. # NOTE: logprobs is freaking sensitive to input_ids. If the input sequence is a natural sequence, everything will be fine.
# However, if we input random token IDs, parallel cross entropy can produce VERY different results than the normal # However, if we input random token IDs, parallel cross entropy can produce VERY different results than the normal
# torch.gather based version (e.g., the maximum absolute different can reach ~50). # torch.gather based version (e.g., the maximum absolute different can reach ~50).
from realhf.impl.model.parallelism.model_parallel.modules import ( from realhf.impl.model.parallelism.tensor_parallel.modules import (
vocab_parallel_cross_entropy, vocab_parallel_cross_entropy,
) )
@ -239,14 +239,16 @@ def gather_packed_shifted_log_probs(
def apply_logits_mask(logits: torch.HalfTensor, mask: torch.BoolTensor): def apply_logits_mask(logits: torch.HalfTensor, mask: torch.BoolTensor):
assert mask.shape[-1] == logits.shape[-1] * constants.model_parallel_world_size(), ( assert (
constants.model_parallel_world_size(), mask.shape[-1] == logits.shape[-1] * constants.tensor_parallel_world_size()
), (
constants.tensor_parallel_world_size(),
logits.shape, logits.shape,
mask.shape, mask.shape,
) )
parallel_vocab_size = logits.shape[-1] parallel_vocab_size = logits.shape[-1]
mp_rank = constants.model_parallel_rank() tp_rank = constants.tensor_parallel_rank()
mask = mask[:, mp_rank * parallel_vocab_size : (mp_rank + 1) * parallel_vocab_size] mask = mask[:, tp_rank * parallel_vocab_size : (tp_rank + 1) * parallel_vocab_size]
logits.masked_fill_(mask, torch.finfo(logits.dtype).min) logits.masked_fill_(mask, torch.finfo(logits.dtype).min)

View File

@ -250,7 +250,7 @@ def pad_sequence_parallel_input(
): ):
"""Sequence parallel requires packed_input_ids has a shape of 1 dimension """Sequence parallel requires packed_input_ids has a shape of 1 dimension
[total_seq_len], and total_seq_len should be divisible by [total_seq_len], and total_seq_len should be divisible by
model_parallel_world_size. This function is used to pad packed_input_ids to tensor_parallel_world_size. This function is used to pad packed_input_ids to
suitable length with an empty sequence, and return new packed_input_ids, suitable length with an empty sequence, and return new packed_input_ids,
cu_seqlens and max_seqlen. cu_seqlens and max_seqlen.
@ -262,10 +262,10 @@ def pad_sequence_parallel_input(
Returns: Returns:
(torch.Tensor, torch.Tensor, int, int): padded (packed_input_ids, cu_seqlens, max_seqlen, pad_size) (torch.Tensor, torch.Tensor, int, int): padded (packed_input_ids, cu_seqlens, max_seqlen, pad_size)
""" """
mp_world_size = constants.model_parallel_world_size() tp_world_size = constants.tensor_parallel_world_size()
pad_size = 0 pad_size = 0
if len(packed_input_ids) % mp_world_size != 0: if len(packed_input_ids) % tp_world_size != 0:
pad_size = mp_world_size - len(packed_input_ids) % mp_world_size pad_size = tp_world_size - len(packed_input_ids) % tp_world_size
packed_input_ids = torch.nn.functional.pad( packed_input_ids = torch.nn.functional.pad(
packed_input_ids, (0, pad_size), value=1 packed_input_ids, (0, pad_size), value=1
) )
@ -281,9 +281,9 @@ def pad_sequence_parallel_generate_input(
): ):
"""Only for pipeline generate input when model+seq parallel is enabled. To """Only for pipeline generate input when model+seq parallel is enabled. To
make sure inputs for seq parallel model have a shape with first dimension make sure inputs for seq parallel model have a shape with first dimension
divisible by model_parallel_world_size, the packed_input_ids should have divisible by tensor_parallel_world_size, the packed_input_ids should have
length divisible by model_parallel_world_size, and contains number of length divisible by tensor_parallel_world_size, and contains number of
sequences divisible by model_parallel_world_size. sequences divisible by tensor_parallel_world_size.
Args: Args:
packed_input_ids (torch.Tensor): unpadded packed_input_ids packed_input_ids (torch.Tensor): unpadded packed_input_ids
@ -293,16 +293,16 @@ def pad_sequence_parallel_generate_input(
Returns: Returns:
(torch.Tensor, torch.Tensor, int, int, int): padded (packed_input_ids, cu_seqlens, max_seqlen, pad_size, pad_seq_size) (torch.Tensor, torch.Tensor, int, int, int): padded (packed_input_ids, cu_seqlens, max_seqlen, pad_size, pad_seq_size)
""" """
mp_world_size = constants.model_parallel_world_size() tp_world_size = constants.tensor_parallel_world_size()
pad_size, pad_seq_size = 0, 0 pad_size, pad_seq_size = 0, 0
if ( if (
len(packed_input_ids) % mp_world_size != 0 len(packed_input_ids) % tp_world_size != 0
or (len(cu_seqlens) - 1) % mp_world_size != 0 or (len(cu_seqlens) - 1) % tp_world_size != 0
): ):
pad_size = mp_world_size - len(packed_input_ids) % mp_world_size pad_size = tp_world_size - len(packed_input_ids) % tp_world_size
pad_seq_size = mp_world_size - (len(cu_seqlens) - 1) % mp_world_size pad_seq_size = tp_world_size - (len(cu_seqlens) - 1) % tp_world_size
if pad_size < pad_seq_size: if pad_size < pad_seq_size:
pad_size += mp_world_size pad_size += tp_world_size
pad_cu_seqlens = torch.tensor(list(range(1, pad_seq_size)) + [pad_size]) + len( pad_cu_seqlens = torch.tensor(list(range(1, pad_seq_size)) + [pad_size]) + len(
packed_input_ids packed_input_ids
) )

View File

@ -55,6 +55,7 @@ def actor_loss_fn(
eps_clip: float, eps_clip: float,
loss_mask: Optional[torch.BoolTensor] = None, loss_mask: Optional[torch.BoolTensor] = None,
c_clip: Optional[float] = None, c_clip: Optional[float] = None,
proximal_logprobs: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.Tensor, Dict]: ) -> Tuple[torch.Tensor, Dict]:
"""Compute PPO actor loss function. """Compute PPO actor loss function.
@ -83,13 +84,22 @@ def actor_loss_fn(
old_logprobs = old_logprobs.clone() old_logprobs = old_logprobs.clone()
if advantages.is_inference(): if advantages.is_inference():
advantages = advantages.clone() advantages = advantages.clone()
if proximal_logprobs is not None:
if loss_mask is not None: assert proximal_logprobs.dtype == torch.float32
loss_mask_count = loss_mask.count_nonzero() or 1 if proximal_logprobs.is_inference():
# For numerical stability. proximal_logprobs = proximal_logprobs.clone()
ratio = torch.where(loss_mask, torch.exp(logprobs - old_logprobs), 0) denorm_logprobs = proximal_logprobs
else: else:
ratio = torch.exp(logprobs - old_logprobs) denorm_logprobs = old_logprobs
# create mask
if loss_mask is None:
loss_mask = torch.ones_like(logprobs, dtype=torch.bool)
loss_mask: torch.BoolTensor
loss_mask_count = loss_mask.count_nonzero() or 1
# For numerical stability.
ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0)
clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip) clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip)
pg_loss1 = -advantages * ratio pg_loss1 = -advantages * ratio
@ -104,24 +114,34 @@ def actor_loss_fn(
pg_loss = torch.min(pg_loss, pg_loss3) pg_loss = torch.min(pg_loss, pg_loss3)
else: else:
dual_clip_mask = torch.zeros_like(clip_mask) dual_clip_mask = torch.zeros_like(clip_mask)
if proximal_logprobs is not None:
behav_kl = proximal_logprobs - old_logprobs
behav_imp_weight = behav_kl.exp()
if c_clip is not None:
behav_mask = (behav_imp_weight <= c_clip).logical_and(loss_mask)
else:
behav_mask = loss_mask
behav_kl = torch.where(behav_mask, behav_kl, 0.0)
behav_imp_weight = torch.where(behav_mask, behav_imp_weight, 0.0)
pg_loss = pg_loss * behav_imp_weight
logging_loss = pg_loss.detach() logging_loss = pg_loss.detach()
if loss_mask is not None: pg_loss = torch.where(loss_mask, pg_loss, 0).sum() / loss_mask_count
pg_loss = torch.where(loss_mask, pg_loss, 0).sum() / loss_mask_count
else:
pg_loss = pg_loss.mean()
if loss_mask is not None: clip_mask.logical_and_(loss_mask)
clip_mask.logical_and_(loss_mask) dual_clip_mask.logical_and_(loss_mask)
dual_clip_mask.logical_and_(loss_mask)
# Remain torch.CudaTensor here for all-reduce after train step. # Remain torch.CudaTensor here for all-reduce after train step.
stat = dict( stat = dict(
loss=logging_loss, loss=logging_loss,
importance_weight=ratio.detach(), importance_weight=ratio.detach(),
approx_kl=(logprobs - old_logprobs).detach(), approx_kl=(logprobs - denorm_logprobs).detach(),
clip_mask=clip_mask, clip_mask=clip_mask,
dual_clip_mask=dual_clip_mask, dual_clip_mask=dual_clip_mask,
) )
if proximal_logprobs is not None:
stat["behave_imp_weight"] = behav_imp_weight
stat["behave_approx_kl"] = behav_kl
stat["behave_mask"] = behav_mask
return pg_loss, stat return pg_loss, stat

View File

@ -13,7 +13,7 @@ from torch.cuda import device as device_ctx_manager
from torch.utils.checkpoint import detach_variable from torch.utils.checkpoint import detach_variable
import realhf.base.constants as constants import realhf.base.constants as constants
from realhf.impl.model.parallelism.model_parallel.utils import ( from realhf.impl.model.parallelism.tensor_parallel.utils import (
divide, divide,
gather_split_1d_tensor, gather_split_1d_tensor,
safely_set_viewless_tensor_data, safely_set_viewless_tensor_data,
@ -169,10 +169,10 @@ def model_parallel_cuda_manual_seed(seed):
tensor-model-parallel state: This state is different among a set of model parallel GPUs, but the same across data parallel groups. This is used for example for dropout in model parallel regions. tensor-model-parallel state: This state is different among a set of model parallel GPUs, but the same across data parallel groups. This is used for example for dropout in model parallel regions.
""" """
# 2718 is just for fun and any POSITIVE value will work. # 2718 is just for fun and any POSITIVE value will work.
model_parallel_rank = constants.model_parallel_rank() tensor_parallel_rank = constants.tensor_parallel_rank()
expert_parallel_rank = 0 expert_parallel_rank = 0
offset = seed + 2718 offset = seed + 2718
tensor_model_parallel_seed = offset + model_parallel_rank tensor_model_parallel_seed = offset + tensor_parallel_rank
# Data parallel gets the original seed. # Data parallel gets the original seed.
data_parallel_seed = seed data_parallel_seed = seed
@ -187,7 +187,7 @@ def model_parallel_cuda_manual_seed(seed):
) )
expert_parallel_seed = ( expert_parallel_seed = (
seed + 1024 + 100 * expert_parallel_rank + model_parallel_rank seed + 1024 + 100 * expert_parallel_rank + tensor_parallel_rank
) )
_CUDA_RNG_STATE_TRACKER.add(_EXPERT_PARALLEL_RNG_TRACKER_NAME, expert_parallel_seed) _CUDA_RNG_STATE_TRACKER.add(_EXPERT_PARALLEL_RNG_TRACKER_NAME, expert_parallel_seed)
@ -331,8 +331,8 @@ def _initialize_affine_weight_cpu(
weight_list = torch.split( weight_list = torch.split(
master_weight, per_partition_per_stride_size, dim=partition_dim master_weight, per_partition_per_stride_size, dim=partition_dim
) )
rank = constants.model_parallel_rank() rank = constants.tensor_parallel_rank()
world_size = constants.model_parallel_world_size() world_size = constants.tensor_parallel_world_size()
my_weight_list = weight_list[rank::world_size] my_weight_list = weight_list[rank::world_size]
with torch.no_grad(): with torch.no_grad():

View File

@ -5,13 +5,18 @@
import fcntl import fcntl
import os import os
import re import re
import select
import subprocess import subprocess
import threading
import time import time
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Literal, Optional, Tuple from typing import Dict, List, Literal, Optional, Tuple
import colorama
import realhf.base.logging as logging import realhf.base.logging as logging
from realhf.base.cluster import spec as cluster_spec from realhf.base.cluster import spec as cluster_spec
from realhf.base.constants import LOG_ROOT
from realhf.base.constants import SLURM_LOCK_FILE_NAME as LOCK_FILE_NAME from realhf.base.constants import SLURM_LOCK_FILE_NAME as LOCK_FILE_NAME
from realhf.scheduler.client import JobException, JobInfo, JobState, SchedulerClient from realhf.scheduler.client import JobException, JobInfo, JobState, SchedulerClient
from realhf.scheduler.evaluator import AutomaticEvaluator from realhf.scheduler.evaluator import AutomaticEvaluator
@ -29,6 +34,49 @@ SCHEDULING_TIMEOUT_MAX_SECONDS = 3600 * 24
SCHEDULER_WAIT_CHECK_TIME_INTERVAL = 5 SCHEDULER_WAIT_CHECK_TIME_INTERVAL = 5
def monitor_log(
job_name: str, log_path: str, output_file: str, stop_event: threading.Event
):
"""Monitor a log file and write its contents to the output file with job name prefix."""
# Wait for log file to be created
while not os.path.exists(log_path) and not stop_event.is_set():
time.sleep(0.1)
if stop_event.is_set():
return
# Open the log file and follow it
with open(log_path, "r") as log_file, open(output_file, "a") as out_file:
# Store last position
position = 0
line_pos = 0
while not stop_event.is_set():
log_file.seek(position)
try:
new_lines = log_file.readlines()
except UnicodeDecodeError:
time.sleep(0.5)
continue
if new_lines:
# Update position
position = log_file.tell()
worker_type = job_name.split(":")[1]
# Write new lines to output file with job name prefix
for line in new_lines:
if line.strip(): # Skip empty lines
out_file.write(
f"{colorama.Fore.YELLOW + colorama.Style.DIM}({worker_type} Line {line_pos}){colorama.Style.RESET_ALL} {line}"
)
line_pos += 1
out_file.flush()
# Sleep briefly to avoid CPU spinning
time.sleep(0.1)
class SlurmSchedulerClient(SchedulerClient): class SlurmSchedulerClient(SchedulerClient):
"""Uses Slurm (https://slurm.schedmd.com/overview.html).""" """Uses Slurm (https://slurm.schedmd.com/overview.html)."""
@ -248,6 +296,26 @@ class SlurmSchedulerClient(SchedulerClient):
# before wait, commit all remaining pending jobs # before wait, commit all remaining pending jobs
# TODO: grab global file lock to avoid multi-experiment deadlocks # TODO: grab global file lock to avoid multi-experiment deadlocks
self.__allocate_and_commit_pending_jobs() self.__allocate_and_commit_pending_jobs()
# Start monitoring threads
threads = []
stop_events = []
merged_log_path = os.path.join(
LOG_ROOT, self.expr_name, self.trial_name, "main.log"
)
for job_name, launch_info in self.__committed_jobs.items():
stop_event = threading.Event()
stop_events.append(stop_event)
# Thread for monitoring the log file
log_thread = threading.Thread(
target=monitor_log,
args=(job_name, launch_info.log_path, merged_log_path, stop_event),
)
threads.append(log_thread)
log_thread.start()
# begin wait # begin wait
deadline = None if timeout is None else time.time() + timeout deadline = None if timeout is None else time.time() + timeout
left = set(self.__committed_jobs.keys()) left = set(self.__committed_jobs.keys())
@ -256,44 +324,52 @@ class SlurmSchedulerClient(SchedulerClient):
f"Waiting for {num_jobs_left} jobs. Jobs IDs: " f"Waiting for {num_jobs_left} jobs. Jobs IDs: "
f"{','.join(sorted([x.job_info.slurm_id for x in self.__committed_jobs.values()]))}." f"{','.join(sorted([x.job_info.slurm_id for x in self.__committed_jobs.values()]))}."
) )
while len(left) > 0: logger.info(
if len(left) < num_jobs_left: f"All slurm logs will be merged. To check the real-time output, "
num_jobs_left = len(left) f"run\n\t`tail -f {merged_log_path}`."
logger.info(f"Waiting for {num_jobs_left} jobs.") )
if self.__evaluator is not None: try:
self.__evaluator.step() while len(left) > 0:
if deadline is not None and time.time() > deadline: if len(left) < num_jobs_left:
raise TimeoutError( num_jobs_left = len(left)
f"Timeout waiting for {self.run_name}: {', '.join(sorted(left))}" logger.info(f"Waiting for {num_jobs_left} jobs.")
) if self.__evaluator is not None:
try: self.__evaluator.step()
self.__update_all() if deadline is not None and time.time() > deadline:
except subprocess.CalledProcessError: raise TimeoutError(
logger.warning( f"Timeout waiting for {self.run_name}: {', '.join(sorted(left))}"
"Calling squeue failed. Check slurm manually if you continue to see this warning." )
) try:
time.sleep(30) self.__update_all()
continue except subprocess.CalledProcessError:
for job_slurm_name in list(left): logger.warning(
launch_info = self.__committed_jobs[job_slurm_name] "Calling squeue failed. Check slurm manually if you continue to see this warning."
if launch_info.slurm_id is None: )
time.sleep(30)
continue continue
if launch_info.job_info.state in check_status: for job_slurm_name in list(left):
launch_info.show_log() launch_info = self.__committed_jobs[job_slurm_name]
raise JobException( if launch_info.slurm_id is None:
run_name=self.run_name, continue
worker_type=launch_info.worker_type, if launch_info.job_info.state in check_status:
host=launch_info.job_info.host, launch_info.show_log()
reason=launch_info.job_info.state, raise JobException(
) run_name=self.run_name,
if launch_info.job_info.state in remove_status: worker_type=launch_info.worker_type,
logger.info( host=launch_info.job_info.host,
f"Job {launch_info.slurm_name} is {launch_info.job_info.state}.(Removed)" reason=launch_info.job_info.state,
) )
left.remove(job_slurm_name) if launch_info.job_info.state in remove_status:
if update: logger.info(
self.__committed_jobs.pop(job_slurm_name) f"Job {launch_info.slurm_name} is {launch_info.job_info.state}.(Removed)"
time.sleep(SCHEDULER_WAIT_CHECK_TIME_INTERVAL) )
left.remove(job_slurm_name)
if update:
self.__committed_jobs.pop(job_slurm_name)
time.sleep(SCHEDULER_WAIT_CHECK_TIME_INTERVAL)
finally:
[s.set() for s in stop_events]
[t.join() for t in threads]
def __update_all(self): def __update_all(self):
states = [] states = []

View File

@ -294,21 +294,29 @@ class SlurmLaunchInfo:
@property @property
def multiprog_path(self) -> str: def multiprog_path(self) -> str:
return os.path.join( path = os.path.join(
LOG_ROOT, LOG_ROOT,
self.exper_name, self.exper_name,
self.trial_name, self.trial_name,
"slurm",
"multiprog",
f"{self.worker_type}-{self.worker_submission_idx}.multiprog", f"{self.worker_type}-{self.worker_submission_idx}.multiprog",
) )
os.makedirs(os.path.dirname(path), exist_ok=True)
return path
@property @property
def hostfile_path(self) -> str: def hostfile_path(self) -> str:
return os.path.join( path = os.path.join(
LOG_ROOT, LOG_ROOT,
self.exper_name, self.exper_name,
self.trial_name, self.trial_name,
"slurm",
"hostfile",
f"{self.worker_type}-{self.worker_submission_idx}.hostfile", f"{self.worker_type}-{self.worker_submission_idx}.hostfile",
) )
os.makedirs(os.path.dirname(path), exist_ok=True)
return path
def show_log(self): def show_log(self):
try: try:

View File

@ -1,12 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
def import_profiler_registers():
import realhf.search_engine.enumerate
import realhf.search_engine.estimate
import realhf.search_engine.layers
import realhf.search_engine.param_realloc
import realhf.search_engine.search
import realhf.search_engine.utils

View File

@ -1,158 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
from typing import List
from realhf.api.core.dfg import MFCDef, ModelInterfaceType
from realhf.api.core.dfg import build_graph as build_dfg
from realhf.api.quickstart.device_mesh import DeviceMesh, find_parallel_strategies
from realhf.api.quickstart.search import MFCDef, RPCExecution, RPCInstance
from realhf.search_engine.estimate import (
estimate_rpc_memory_cost,
estimate_rpc_time_cost,
)
MEM_INDEX = 1.0 # heuristic value to scale estimated memory
def enumerate_rpc_executions(
rpc: MFCDef,
device_mesh: DeviceMesh,
seq_len: int,
num_gen_tokens: int,
n_ppo_minibatches: int,
gradient_checkpointing: bool,
) -> List[RPCExecution]:
sub_device_meshes = device_mesh.sub_device_meshes()
import pprint
feasible = []
for sub_device_mesh in sub_device_meshes:
ps = find_parallel_strategies(sub_device_mesh)
for parallel in ps:
num_dp = parallel.data_parallel_size
num_pp = parallel.pipeline_parallel_size
num_mp = parallel.model_parallel_size
bs = rpc.n_seqs
# seq_len = seq_len
min_bs = (
2 * num_dp * num_pp * n_ppo_minibatches
if rpc.interface_type == ModelInterfaceType.TRAIN_STEP
else num_dp * num_pp
)
if min_bs > bs:
# batch size too small
continue
# heuristic to filter out inherent slow configurations
if (
num_mp * num_dp > device_mesh.n_gpus_per_node
and rpc.interface_type == ModelInterfaceType.TRAIN_STEP
):
continue
if num_mp > 8:
continue
if num_pp > max(device_mesh.n_nodes, 8):
continue
# memory and time estimation
mem_cost, static_mem = estimate_rpc_memory_cost(
rpc,
parallel,
bs,
seq_len,
gradient_checkpointing=gradient_checkpointing,
n_ppo_minibatches=n_ppo_minibatches,
num_gen_tokens=num_gen_tokens,
offload=rpc.model_name.role in ["ref", "reward"],
)
mem_cost = int(mem_cost * MEM_INDEX)
static_mem = int(static_mem * MEM_INDEX)
time_cost = estimate_rpc_time_cost(
rpc,
parallel,
bs=bs,
seq_len=seq_len,
num_gen_tokens=num_gen_tokens,
gradient_checkpointing=gradient_checkpointing,
n_ppo_minibatches=n_ppo_minibatches,
)
time_cost = int(time_cost)
if mem_cost < device_mesh.gpu_memory_capacity:
feasible.append(
RPCExecution(
rpc,
sub_device_mesh,
parallel,
time_cost,
mem_cost,
static_mem,
)
)
return feasible
def build_graph(
rpcs: List[MFCDef],
num_epoch: int = 5,
epoch_dependency_interval: int = 1,
if_print=False,
) -> List[RPCInstance]:
"""Build model function call graph of multiple training epochs,
args:
exp: ProfileExperiment, the experiment object
num_epoch: int, number of training epochs
epoch_dependency_interval: int, the interval of epoch dependency,
e.g. if epoch_dependency_interval = 2, then the graph will have
edges between epoch i and epoch i+2, i+4, ...
returns:
rpc_instances: List[RPCInstance], the list of RPCInstance objects
"""
# one epoch dependency graph
rpcs, edges = build_dfg(rpcs)
rpc_names_mapping = {rpc.name: rpc for rpc in rpcs}
rpc_instances = []
# multi epoch graph
for epoch_id in range(num_epoch):
for rpc in rpcs:
children = []
parents = []
if rpc.is_src and epoch_id >= epoch_dependency_interval:
for other in rpcs:
if other.is_dst and other.model_name.role == rpc.model_name.role:
parents.append(
RPCInstance(
rpc,
epoch_id - epoch_dependency_interval,
[],
[],
)
)
if rpc.is_dst and rpc.model_name.role == rpc.model_name.role:
for other in rpcs:
if (
other.is_src
and epoch_id + epoch_dependency_interval < num_epoch
):
children.append(
RPCInstance(
rpc,
epoch_id + epoch_dependency_interval,
[],
[],
)
)
for parent in rpc.parents:
p = rpc_names_mapping[parent]
parents.append(RPCInstance(p, epoch_id, [], []))
for child in rpc.children:
c = rpc_names_mapping[child]
children.append(RPCInstance(c, epoch_id, [], []))
rpc_instance = RPCInstance(rpc, epoch_id, parents, children)
rpc_instances.append(rpc_instance)
if if_print:
for ri in rpc_instances:
print(ri)
return rpc_instances

View File

@ -1,499 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
# Estimate a fucntion-call level execution time for device mesh enumerate pruning
# assume one batch of data passes through all rpcs once
import argparse
import getpass
import itertools
import os
import pickle
from collections import defaultdict
from typing import Optional
import numpy as np
import pandas as pd
import realhf.base.cluster
import realhf.base.constants as constants
import realhf.base.logging as logging
from realhf.api.cli_args import ParallelismConfig
from realhf.api.core.dfg import MFCDef, ModelFamily, ModelInterfaceType
from realhf.api.core.model_api import ReaLModelConfig
from realhf.search_engine.param_realloc import estimate_param_realloc_time_cost
from realhf.search_engine.utils import load_model_config
logger = logging.getLogger("estimate", "benchmark")
PROFILE_RESULT_PATH = os.path.join(
realhf.base.cluster.spec.fileroot,
"logs",
getpass.getuser(),
"profile",
"profile",
"layer_stats",
)
def get_param_realloc_stats(
model_family: ModelFamily,
model_path: str,
n_nodes: int,
use_cache: bool = True,
):
non_critic = ModelFamily(model_family._class, model_family.size, False)
table_path = os.path.join(
constants.PROFILER_CACHE_PATH,
"param_realloc",
f"prtc_{non_critic}_n{n_nodes}.pkl",
)
if not os.path.exists(table_path):
print(
f"Calculating estimation of param realloc time cost for {model_family} at {model_path}"
)
estimate_param_realloc_time_cost(n_nodes, {non_critic: model_path})
print(f"Loading param realloc stats from {table_path}")
return pickle.load(open(table_path, "rb"))
def get_organized_op_stats(
model_family: ModelFamily, model_path: str, use_cache: bool = True
):
non_critic = ModelFamily(model_family._class, model_family.size, False)
# parse raw stats into list of OpInfo used for estimation
cache_path = os.path.join(
constants.PROFILER_CACHE_PATH, "organized_stats", f"{non_critic}.pkl"
)
if use_cache and os.path.exists(cache_path):
with open(cache_path, "rb") as f:
return pickle.load(f)
raw_result_path = os.path.join(
constants.PROFILER_CACHE_PATH, "layer_stats", str(non_critic)
)
if not os.path.exists(raw_result_path):
from realhf.apps.main import _main_profile_layers
_main_profile_layers(non_critic, model_path)
raw_stats_list = []
for fn in os.listdir(raw_result_path):
if not fn.startswith("layer-stats"):
continue
num_mp = int(fn.replace(".pkl", "").split("_")[1])
rank = int(fn.replace(".pkl", "").split("_")[2])
with open(os.path.join(raw_result_path, fn), "rb") as f:
stats = pickle.load(f)
if isinstance(stats, dict):
stats = pd.DataFrame(stats)
elif isinstance(stats, pd.DataFrame):
pass
else:
raise ValueError(f"Unsupported stats type {type(stats)}")
stats["num_mp"] = num_mp
stats["rank"] = rank
raw_stats_list.append(stats)
raw_stats = pd.concat(raw_stats_list)
bs_list = raw_stats["bs"].unique()
seq_len_list = raw_stats["seq_len"].unique()
op_name_list = raw_stats["op_name"].unique()
layer_name_list = raw_stats["layer_name"].unique()
num_mp_list = raw_stats["num_mp"].unique()
organized_stats = defaultdict(list)
for op_name, bs, seq_len, layer_name, num_mp in itertools.product(
op_name_list, bs_list, seq_len_list, layer_name_list, num_mp_list
):
filter_cond = (
(raw_stats["op_name"] == op_name)
& (raw_stats["bs"] == bs)
& (raw_stats["seq_len"] == seq_len)
& (raw_stats["layer_name"] == layer_name)
& (raw_stats["num_mp"] == num_mp)
)
avg_time_ns = raw_stats[filter_cond]["time_ns"].mean()
x = int(bs) if op_name == "fwd_gen_1" else int(bs * seq_len)
organized_stats["op_name"].append(op_name)
organized_stats["layer_name"].append(layer_name)
organized_stats["bs"].append(bs)
organized_stats["seq_len"].append(seq_len)
organized_stats["num_mp"].append(num_mp)
organized_stats["avg_time_ns"].append(avg_time_ns)
organized_stats["x"].append(x)
df = pd.DataFrame(organized_stats)
if use_cache:
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
with open(cache_path, "wb") as f:
pickle.dump(df, f)
return df
def computation_instruction_time_cost(
op_stats: pd.DataFrame,
op_name: str,
num_layers: int,
parallel_strategy: ParallelismConfig,
bs: int,
seqlen: int,
):
# inst cost unit: ns
layer_names = ["embedding_layer", "block_0", "head"]
num_pp = parallel_strategy.pipeline_parallel_size
num_mp = parallel_strategy.model_parallel_size
op_stats = op_stats[
(op_stats["op_name"] == op_name) & (op_stats["num_mp"] == num_mp)
]
op_cost = {}
embed_stats = op_stats[op_stats["layer_name"] == "embedding_layer"]
if embed_stats[
(embed_stats["bs"] == bs) & (embed_stats["seq_len"] == seqlen)
].empty:
# do linear interpolation for data points that does not exist
for layer_name in layer_names:
layer_stats = op_stats[op_stats["layer_name"] == layer_name]
assert layer_stats[
(layer_stats["bs"] == bs) & (layer_stats["seq_len"] == seqlen)
].empty
assert not layer_stats.empty, (
layer_name,
op_name,
num_mp,
op_stats,
)
xs = layer_stats["x"]
ys = layer_stats["avg_time_ns"]
x = int(bs) if op_name == "fwd_gen_1" else int(bs * seqlen)
y = np.interp(x, xs, ys)
if max(xs) < x or min(xs) > x:
logger.warning(
f"Interpolated value outside profiling range, "
f"parallel strategy {parallel_strategy}: "
f"{x} in {sorted(list(set(xs)))}"
)
# estimate using largest or smallest value
if max(xs) < x:
y = ys.max() * (x / xs[ys.idxmax()])
else:
y = ys.min() * (x / xs[ys.idxmin()])
op_cost[layer_name] = y
else:
for layer_name in layer_names:
assert not op_stats[op_stats["layer_name"] == layer_name].empty
required_stats = op_stats[
(op_stats["layer_name"] == layer_name)
& (op_stats["bs"] == bs)
& (op_stats["seq_len"] == seqlen)
]
assert required_stats.shape[0] == 1
op_cost[layer_name] = required_stats["avg_time_ns"].values[0]
embedding_layer_cost = op_cost["embedding_layer"]
block_0_cost = op_cost["block_0"]
head_cost = op_cost["head"]
cost = (embedding_layer_cost + num_layers * block_0_cost + head_cost) / num_pp
return cost
def communication_instruction_time_cost(comm_stats, size, comm_type):
return size / comm_stats[comm_type] # unit: ns
def estimate_instruction_time_costs(
model_family: ModelFamily,
model_path: str,
num_layers: int, # model configuration, num transformers layer
parallel_strategy: ParallelismConfig,
hidden_dim: int,
batch_size: int,
seq_len: int,
n_ppo_minibatches: int = 1,
):
comm_stats = default_communication_stats()
op_stats = get_organized_op_stats(model_family, model_path, use_cache=True)
num_mp = parallel_strategy.model_parallel_size
num_pp = parallel_strategy.pipeline_parallel_size
num_dp = parallel_strategy.data_parallel_size
num_gpus = num_dp * num_mp * num_pp
train_mbs = (
batch_size / (2 * num_pp * num_dp * n_ppo_minibatches)
if num_pp > 1
else batch_size / (num_dp * n_ppo_minibatches)
)
gen_mbs = batch_size / (num_pp * num_dp)
# pprint.pprint(op_cost, indent=4)
inst_keys = [
"gen_fwd_0",
"gen_fwd_1",
"inf_fwd",
"train_fwd",
"train_bwd",
"train_opt",
]
op_names = ["fwd_gen_0", "fwd_gen_1", "fwd", "fwd", "bwd", "opt"]
inst_stats = {}
for inst_key, op_name in zip(inst_keys, op_names):
mbs = train_mbs if "train" in inst_key else gen_mbs
inst_stats[inst_key] = computation_instruction_time_cost(
op_stats, op_name, num_layers, parallel_strategy, mbs, seq_len
)
comm_type = "remote_send" if num_gpus // num_pp >= 8 else "local_send"
inst_stats["act_p2p"] = communication_instruction_time_cost(
comm_stats, 2 * hidden_dim * train_mbs * seq_len, comm_type
)
inst_stats["grad_p2p"] = communication_instruction_time_cost(
comm_stats, 2 * hidden_dim * train_mbs * seq_len, comm_type
)
inst_stats["gen_act_p2p"] = communication_instruction_time_cost(
comm_stats, 2 * hidden_dim * gen_mbs, comm_type
)
return inst_stats
def _estimate_rpc_time_cost(
inst_stats,
parallel_strategy: ParallelismConfig,
model_interface_type: ModelInterfaceType,
# model function call args
num_gen_tokens: int,
gradient_checkpointing: bool,
n_ppo_minibatches: int = 1,
):
# TODO: improve/remove heuristic
num_pp = parallel_strategy.pipeline_parallel_size
num_dp = parallel_strategy.data_parallel_size
num_mp = parallel_strategy.model_parallel_size
if model_interface_type == ModelInterfaceType.INFERENCE:
num_micro_batches = num_pp
compute_cost = inst_stats["inf_fwd"] * (num_pp + num_micro_batches - 1)
comm_cost = inst_stats["act_p2p"] * (num_pp + num_micro_batches - 2) * 2
if num_mp > 1: # and num_pp * num_dp > 1:
compute_cost = compute_cost * (1 - num_mp * 0.03)
elif model_interface_type == ModelInterfaceType.TRAIN_STEP:
# TODO: add reduce grads, add ppo micro batches
num_micro_batches = num_pp * 2 if num_pp > 1 else 1
compute_cost = (inst_stats["train_fwd"] + inst_stats["train_bwd"]) * (
num_pp + num_micro_batches - 1
) + inst_stats["train_opt"]
if gradient_checkpointing:
compute_cost += inst_stats["train_fwd"] * (num_pp + num_micro_batches - 1)
comm_cost = (
(inst_stats["grad_p2p"] + inst_stats["act_p2p"])
* (num_pp + num_micro_batches - 2)
* 2
)
compute_cost = compute_cost * n_ppo_minibatches
comm_cost = comm_cost * n_ppo_minibatches
if num_pp * num_dp <= 1:
compute_cost = compute_cost * (1 - num_mp * 0.04)
if num_mp > 1: # and num_pp * num_dp > 1:
compute_cost = compute_cost * (1 - num_mp * 0.03)
elif model_interface_type == ModelInterfaceType.GENERATE:
num_micro_batches = num_pp
num_gen_tokens = num_gen_tokens
compute_cost = (
inst_stats["gen_fwd_0"] * (num_pp + num_micro_batches - 1)
+ inst_stats["gen_fwd_1"] * (num_gen_tokens - 1) * num_micro_batches
)
if num_dp * num_mp > 1:
compute_cost = compute_cost * (1 - min(num_dp * num_mp, 8) * 0.03)
comm_cost = 0
# dirty heuristic
if num_pp > 8:
compute_cost = compute_cost * (1 + num_pp * 0.01)
# FIXME: disable comm cost for its not accurate
comm_cost = 0
return compute_cost + comm_cost
def estimate_rpc_time_cost(
rpc: MFCDef,
parallel_strategy: ParallelismConfig,
bs: int,
seq_len: int,
num_gen_tokens: int = 256,
gradient_checkpointing: bool = False,
n_ppo_minibatches: int = 1,
):
# time unit: miliseconds
# FIXME: n_ppo_minibatches > 1 will result in bad estimation
# when batch size is large enough, n_ppo_minibatches > 1 will not affect the estimation
n_ppo_minibatches = 1
model_type = rpc.model_type
model_path = rpc.model_path
model_config = load_model_config(rpc.model_type._class, rpc.model_path)
inst_cost = estimate_instruction_time_costs(
model_type,
model_path,
model_config.n_layers,
parallel_strategy,
model_config.hidden_dim,
bs,
seq_len,
n_ppo_minibatches=n_ppo_minibatches,
)
return (
_estimate_rpc_time_cost(
inst_cost,
parallel_strategy,
rpc.interface_type,
num_gen_tokens=num_gen_tokens,
gradient_checkpointing=gradient_checkpointing,
n_ppo_minibatches=n_ppo_minibatches,
)
) / 1e6
def default_communication_stats(if_print=False):
# use default comm stats of cluster
r = dict(
# between GPUs on the same node
local_send=170, # unit: GB/s
local_recv=170,
# IB between GPUs on different nodes
remote_send=20, # unit: GB/s
remote_recv=20,
)
if if_print:
print(r)
return r
def estimate_model_size(model_config: ReaLModelConfig):
h = model_config.hidden_dim
i = model_config.intermediate_dim
v = model_config.vocab_size
L = model_config.n_layers
# for llama actor only
n_params = 3 * v * h + (3 * h * i + 4 * h * h) * L
return 2 * n_params
def estimate_rpc_memory_cost(
rpc: MFCDef,
parallel_strategy: ParallelismConfig,
batch_size: int,
seq_len: int,
offload: bool = False,
gradient_checkpointing: bool = False,
offload_optimizer: bool = False,
n_ppo_minibatches: int = 1,
num_gen_tokens: int = 128,
):
# TODO: improve heuristic
interface_type = rpc.interface_type
model_config = load_model_config(rpc.model_type._class, rpc.model_path)
h = model_config.hidden_dim
i = model_config.intermediate_dim
v = model_config.vocab_size
s = seq_len
gs = num_gen_tokens
b = batch_size
L = model_config.n_layers
# for llama actor only
n_params = 3 * v * h + (3 * h * i + 4 * h * h) * L
param_mem = 2 * n_params
grad_mem = 2 * n_params
optimizer_mem = 20 * n_params if not offload_optimizer else 0
num_pp = parallel_strategy.pipeline_parallel_size
num_mp = parallel_strategy.model_parallel_size
num_dp = parallel_strategy.data_parallel_size
# zero1, pp and mp divide evenly
# enable sequence parallel
if interface_type == ModelInterfaceType.TRAIN_STEP:
# gradient checkpointing is always enabled for flash attn
static_mem = (param_mem + grad_mem) // (num_pp * num_mp) + optimizer_mem // (
num_pp * num_dp * num_mp
)
micro_bs = b // (2 * num_pp * num_dp) if num_pp > 0 else b // (num_dp)
if gradient_checkpointing:
active_mem = (micro_bs * s * h * num_pp * 2) // (num_pp * num_mp)
else:
# FIXME: calculate other memory entries
active_mem = (micro_bs * s * h * num_pp * 2) * 2 * L // (num_pp * num_mp)
return static_mem + active_mem, static_mem
elif interface_type == ModelInterfaceType.INFERENCE:
static_mem = int(2 * param_mem // (num_pp * num_mp))
# if num_dp > 4:
# static_mem = static_mem * 1.25
if offload:
return static_mem, 0 # assume offload
else:
return static_mem, static_mem
elif interface_type == ModelInterfaceType.GENERATE:
static_mem = int(2 * param_mem // (num_pp * num_mp))
if num_dp > 4 and num_dp * num_mp * num_pp <= 16:
static_mem = static_mem * 1.25
if num_mp == 0 and num_pp == 0:
static_mem = static_mem * 1.25
active_mem = (
2 * (2 * b * (gs + s) * h) * L // (num_pp * num_mp * num_dp)
) # kv cache
return static_mem + active_mem, static_mem
def example(rpcs):
if args.model_size == 7:
n_nodes = 1
elif args.model_size == 13:
n_nodes = 2
elif args.model_size == 34:
n_nodes = 4
elif args.model_size == 70:
n_nodes = 8
expr_name = f"profile-s{args.model_size}p{n_nodes}m1d8"
rollout, inf, train = rpcs
bs = 128
seq_len = 1024
p1 = ParallelismConfig(
pipeline_parallel_size=1, model_parallel_size=4, data_parallel_size=8
)
rpc_cost = estimate_rpc_time_cost(
train,
p1,
gradient_checkpointing=True,
num_gen_tokens=896,
bs=bs,
seq_len=seq_len,
n_ppo_minibatches=4,
)
mem_cost, static_mem = estimate_rpc_memory_cost(rollout, p1, bs, seq_len)
print(f"{p1} rpc cost {rpc_cost:.2f} seconds mem cost {mem_cost/(1024**3):.2f} GB")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run a profiling experiment.")
parser.add_argument(
"-s",
"--model_size",
type=int,
default=7,
)
args = parser.parse_args()
example(args)

View File

@ -1,310 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import os
import pickle
import time
from collections import defaultdict
from typing import List, Optional, Union
import pandas as pd
import torch
import torch.distributed as dist
import transformers
import realhf.api.core.model_api as model_api
import realhf.api.core.system_api as config_package
import realhf.base.constants as constants
import realhf.base.logging as logging
from realhf.api.core.model_api import ReaLModelConfig
from realhf.impl.model.utils.padding import unpad_input
logger = logging.getLogger("profile layers", "system")
def make_layers(config: ReaLModelConfig, dtype, device):
from realhf.impl.model.nn.real_llm_base import (
OutputHead,
ReaLModelBlock,
VocabPositionEmbedding,
)
embedding_layer = VocabPositionEmbedding(
config,
dtype=dtype,
device=device,
)
real_model_blocks = [
ReaLModelBlock(
config,
layer_index=i,
output_layernorm=(i == 1),
dtype=dtype,
device=device,
)
for i in range(1)
]
head = OutputHead(
config.hidden_dim,
1 if config.is_critic else config.vocab_size,
bias=False,
device=device,
dtype=dtype,
)
layer_names = ["embedding_layer", "block_0", "head"]
return [embedding_layer] + real_model_blocks + [head], layer_names
class ProfileLayers:
def __init__(
self,
model_name: str,
config: ReaLModelConfig,
tokenizer: transformers.PreTrainedTokenizerFast = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[str, torch.device]] = None,
):
self.model_name = model_name
self.config = config
self.backend_config = config_package.ModelBackend(
type_="deepspeed",
args=dict(
optimizer_name="adam",
optimizer_config=dict(lr=1e-5, weight_decay=0.0, betas=(0.9, 0.95)),
warmup_steps_proportion=0.0,
min_lr_ratio=0.0,
zero_stage=1,
bf16=False,
),
)
self.dtype = dtype
self.device = device
self.layers, self.layer_names = make_layers(config, dtype, device)
self.hidden_dim = config.hidden_dim
self.head_dim = config.head_dim
self.max_new_tokens = 128 # only useful in kv cache memory alloc
self.min_new_tokens = 128
self.stats = defaultdict(list)
self.num_layers = len(self.layers)
self.layers = [
model_api.Model(name, layer, tokenizer, device=device, dtype=dtype)
for layer, name in zip(self.layers, self.layer_names)
]
self.backend = model_api.make_backend(self.backend_config)
ft_spec = model_api.FinetuneSpec(10, 100, 10)
self.layers = [self.backend.initialize(layer, ft_spec) for layer in self.layers]
self.stats = defaultdict(list)
def reset_stats(self):
self.stats = defaultdict(list)
def insert_data_point(self, layer_name, name, bs, seq_len, time_ns):
self.stats["layer_name"].append(layer_name)
self.stats["op_name"].append(name)
self.stats["bs"].append(bs)
self.stats["seq_len"].append(seq_len)
self.stats["time_ns"].append(time_ns)
@torch.no_grad()
def fwd_gen(self, bs, seq_len):
from realhf.impl.model.nn.real_llm_base import PipeCacheData, PipeTransferData
input_ids = torch.randint(
0,
self.config.vocab_size,
(bs, seq_len),
dtype=torch.long,
device=self.device,
)
attention_mask = torch.ones_like(input_ids, device=self.device)
# fwd_gen_0
packed_input_ids, _, cu_seqlens, max_seqlen = unpad_input(
input_ids, attention_mask
)
cu_seqlens = cu_seqlens.to(device=self.device)
packed_input_ids = packed_input_ids.to(device=self.device)
x = PipeTransferData(
cu_seqlens=cu_seqlens,
max_seqlen=int(max_seqlen),
store_kv_cache=True,
)
ys = [PipeCacheData() for _ in range(self.num_layers)]
ys[0].packed_input_ids = packed_input_ids
for layer_name, layer, y in zip(self.layer_names, self.layers, ys):
st = time.monotonic_ns()
x: PipeTransferData = layer.module(x, y)
x.pp_input = x.pp_output
torch.cuda.synchronize()
self.insert_data_point(
layer_name, "fwd_gen_0", bs, seq_len, time.monotonic_ns() - st
)
prompt_logits = x.pp_output
logits = prompt_logits[cu_seqlens[1:] - 1]
input_lens = cu_seqlens[1:] - cu_seqlens[:-1]
cache_seqlens = input_lens.clone().to(dtype=torch.int32)
layer_indices = range(len(ys))
for y, layer_idx in zip(ys[1:-1], layer_indices[1:-1]):
assert (
y.k_cache is not None
and y.v_cache is not None
and y.cache_seqlens is not None
)
kvcache_seqlen = max(
max_seqlen + self.max_new_tokens,
self.hidden_dim // self.head_dim + 10,
)
# fix of a flash attention bug
k_cache = torch.zeros(
(bs, kvcache_seqlen, *y.k_cache.shape[1:]),
dtype=y.k_cache.dtype,
device=self.device,
)
v_cache = torch.zeros_like(k_cache)
indices = (
torch.arange(
kvcache_seqlen,
device=constants.current_device(),
dtype=torch.long,
)[None, :]
< input_lens[:, None]
)
k_cache[indices] = y.k_cache
v_cache[indices] = y.v_cache
y.k_cache = k_cache
y.v_cache = v_cache
y.cache_seqlens = cache_seqlens
x = PipeTransferData(store_kv_cache=True)
ys[0].cache_seqlens = cache_seqlens
# fwd_gen_1
new_tokens = torch.randint(
0,
self.config.vocab_size,
(bs,),
dtype=torch.long,
device=self.device,
)
ys[0].packed_input_ids = new_tokens
ys[0].packed_position_ids = None
x.cu_seqlens = torch.arange(bs + 1, dtype=torch.int32, device=self.device)
x.max_seqlen = 1
for layer_name, layer, y in zip(self.layer_names, self.layers, ys):
st = time.monotonic_ns()
x = layer.module(x, y)
x.pp_input = x.pp_output
torch.cuda.synchronize()
self.insert_data_point(
layer_name, "fwd_gen_1", bs, seq_len, time.monotonic_ns() - st
)
def fwd_bwd_opt(self, bs, seq_len):
from realhf.impl.model.nn.real_llm_base import PipeCacheData, PipeTransferData
input_ids = torch.randint(
0,
self.config.vocab_size,
(bs, seq_len),
dtype=torch.long,
device=self.device,
)
attention_mask = torch.ones_like(input_ids, device=self.device)
packed_input_ids, _, cu_seqlens, max_seqlen = unpad_input(
input_ids, attention_mask
)
cu_seqlens = cu_seqlens.to(device=self.device)
packed_input_ids = packed_input_ids.to(device=self.device)
x = PipeTransferData(
cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, store_kv_cache=False
)
ys = [PipeCacheData() for _ in range(self.num_layers)]
ys[0].packed_input_ids = packed_input_ids
for layer_name, layer, y in zip(self.layer_names, self.layers, ys):
# fwd
st = time.monotonic_ns()
x: PipeTransferData = layer.module(x, y)
torch.cuda.synchronize()
self.insert_data_point(
layer_name, "fwd", bs, seq_len, time.monotonic_ns() - st
)
# bwd
r = torch.rand(
*x.pp_output.shape,
device=x.pp_output.device,
dtype=x.pp_output.dtype,
)
loss = torch.max(x.pp_output * r)
st = time.monotonic_ns()
layer.module.backward(loss)
torch.cuda.synchronize()
self.insert_data_point(
layer_name, "bwd", bs, seq_len, time.monotonic_ns() - st
)
# opt
st = time.monotonic_ns()
layer.module.step()
torch.cuda.synchronize()
self.insert_data_point(
layer_name, "opt", bs, seq_len, time.monotonic_ns() - st
)
x.pp_input = x.pp_output.clone().detach()
def make_dataframe_and_print(self):
df = pd.DataFrame(self.stats)
logger.info(f"Current Stats: \nstr{df}")
def dump_stats(self, world_size):
rank = dist.get_rank()
# dump full stats
dump_dir = os.path.join(
constants.PROFILER_CACHE_PATH,
"layer_stats",
)
dump_path = os.path.join(
dump_dir, self.model_name, f"layer-stats_{world_size}_{rank}.pkl"
)
os.makedirs(os.path.dirname(dump_path), exist_ok=True)
with open(dump_path, "wb") as f:
df = pd.DataFrame(self.stats)
pickle.dump(df, f)
def make_profile_layers(
device: torch.device,
model_path: str,
model_name: str,
dtype: Optional[str] = None,
hf_model_type: str = "llama",
):
from realhf.impl.model.nn.real_llm_api import ReaLModel
if dtype == "fp16" or dtype == None:
dtype = torch.float16
elif dtype == "bf16":
dtype = torch.bfloat16
elif dtype == "fp32":
dtype == torch.float32
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
tokenizer = None
config: ReaLModelConfig = getattr(ReaLModel, f"config_from_{hf_model_type}")(
model_path=model_path,
)
if tokenizer is None:
tokenizer = model_api.load_hf_tokenizer(model_path)
profile_layers = ProfileLayers(
model_name, config, tokenizer=tokenizer, dtype=dtype, device=device
)
return profile_layers

View File

@ -1,272 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import dataclasses
import itertools
import json
import time
from collections import defaultdict
from typing import *
import torch
import torch.distributed
import realhf.api.core.system_api as system_api
import realhf.base.constants as constants
import realhf.base.topology as topology
from realhf.api.core.config import ModelFamily, ModelName
from realhf.api.core.model_api import ReaLModelConfig
from realhf.base.topology import decompose_to_three_factors
def bcast_cost(
param_size: float, bw: float, src: int, dsts: List[int], n_nodes_per_gpu=8
):
src_node = src // n_nodes_per_gpu
dst_nodes = [dst // n_nodes_per_gpu for dst in dsts]
if src_node == dst_nodes[0] and all(
dst_node == dst_nodes[0] for dst_node in dst_nodes
):
return param_size * 2 * 8 / (1800 * 1024**3)
# return 0.0
else:
# param size is in float16, bw is in Gbps
return param_size * 2 * 8 / (bw * 1024**3) * len(set(dst_nodes))
def compute_cost(
world_size: int,
from_model_name: ModelName,
to_model_name: ModelName,
from_topo: topology.ProcessTopology,
to_topo: topology.ProcessTopology,
model_config: ReaLModelConfig,
bw: float, # Gbps
set_interval_cost: float,
) -> int:
from realhf.impl.model.comm.param_realloc import (
ParamReallocInfo,
ReparallelizeReceiverStep,
ReparallelizeSenderStep,
_create_param_realloc_groups,
_derive_reparallelize_comm_plan,
)
param_sync_groups = {}
param_sync_src_ranks = {}
param_sync_dst_ranks = {}
msid2mwid = {}
for i in range(from_topo.world_size()):
msid2mwid[
system_api.ModelShardID.from_parallelism_rank(from_model_name, from_topo, i)
] = i
for i in range(to_topo.world_size()):
msid2mwid[
system_api.ModelShardID.from_parallelism_rank(to_model_name, to_topo, i)
] = (i + world_size - to_topo.world_size())
_create_param_realloc_groups(
from_topo,
to_topo,
from_model_name,
to_model_name,
msid2mwid,
param_sync_groups,
param_sync_src_ranks,
param_sync_dst_ranks,
)
pg_info = ParamReallocInfo(
param_sync_groups,
param_sync_src_ranks,
param_sync_dst_ranks,
)
comm_plan = _derive_reparallelize_comm_plan(
from_model_name,
to_model_name,
from_topo,
to_topo,
model_config,
model_config,
pg_info,
)
# Run boradcast!
max_cost = max_comm_volume = max_bcast_cnt = 0
for _rank in range(world_size):
cost = comm_volume = bcast_cnt = 0
for step in comm_plan:
if isinstance(step, ReparallelizeReceiverStep) and step.rank == _rank:
if step.rank != step.src:
cost += bcast_cost(step.param_size, bw, step.src, step.dst_ranks)
comm_volume += step.param_size
bcast_cnt += 1
cost += set_interval_cost
if isinstance(step, ReparallelizeSenderStep) and step.rank == _rank:
if step.group is not None:
cost += bcast_cost(step.param_size, bw, step.rank, step.dst_ranks)
bcast_cnt += 1
max_cost = max(max_cost, cost)
max_comm_volume = max(max_comm_volume, comm_volume)
max_bcast_cnt = max(max_bcast_cnt, bcast_cnt)
return max_cost
def dump_table(
n_nodes: int,
model_family: ModelFamily,
model_path: str,
rank: int = 0,
parallel: int = 1,
):
from_model_name = ModelName("actor", 0)
to_model_name = ModelName("actor", 1)
def hash_tuple_into_str(t) -> str:
return ",".join([str(i) for i in t])
import tqdm
res = {}
device_mesh_sizes = [4] + [8 * i for i in range(1, n_nodes + 1)]
space = list(itertools.product(device_mesh_sizes, device_mesh_sizes))
sub_space = space[rank::parallel]
# for a, b in set(small_spaces + large_spaces):
for a, b in sub_space:
mtik = time.perf_counter()
all_configs = list(
itertools.product(
decompose_to_three_factors(a), decompose_to_three_factors(b)
)
)
all_configs = list(filter(lambda x: x[0][1] <= 8 and x[1][1] <= 8, all_configs))
all_configs = list(filter(lambda x: x[0][2] <= 8 and x[1][2] <= 8, all_configs))
all_configs = list(
filter(
lambda x: x[0][1] in [1, 2, 4, 8] and x[1][1] in [1, 2, 4, 8],
all_configs,
)
)
all_configs = list(
filter(lambda x: x[0][0] <= 16 and x[1][0] <= 16, all_configs)
)
all_configs = list(
filter(
lambda x: x[0][1] % x[1][1] == 0 or x[1][1] % x[0][1] == 0,
all_configs,
)
)
for config_id, (from_pp_mp_dp, to_pp_mp_dp) in tqdm.tqdm(
enumerate(all_configs)
):
world_size = max(a, b)
from_topo = topology.PipeDataModelParallelTopology(
*from_pp_mp_dp, False, False
)
to_topo = topology.PipeDataModelParallelTopology(*to_pp_mp_dp, False, False)
assert world_size >= from_topo.world_size()
assert world_size >= to_topo.world_size()
from realhf.search_engine.utils import load_model_config
mconfig = load_model_config(model_family._class, model_path)
cost = compute_cost(
world_size,
from_model_name,
to_model_name,
from_topo,
to_topo,
mconfig,
bw=200.0,
set_interval_cost=0.03,
)
res[
hash_tuple_into_str((model_family.size, *from_pp_mp_dp, *to_pp_mp_dp))
] = int(cost * 1000 * 1000)
print(
f"Time for model size {model_family.size} {a} -> {b} {rank}/{parallel}: "
f"{time.perf_counter() - mtik:.4f}, num res entries {len(res)}"
)
print(f"Rank {rank} of model {model_family} finished, res size {len(res)}.")
import os
import pickle
dump_path = os.path.join(constants.PROFILER_CACHE_PATH, "param_realloc")
fn = f"prtc_{model_family}_n{n_nodes}_{rank}_{parallel}.pkl"
if not os.path.exists(dump_path):
os.makedirs(dump_path, exist_ok=True)
with open(os.path.join(dump_path, fn), "wb") as f:
pickle.dump(res, f)
print(f"dumped table with {len(res)} entries to {model_family}-{rank}-{parallel}.")
def dump_table_parallel(
n_nodes: int,
model_family_to_path: Dict[ModelFamily, str],
parallel: int = 4,
):
import torch.multiprocessing as mp
mp.set_start_method("spawn", force=True)
rq = mp.Queue()
ps = []
for model_family, model_path in model_family_to_path.items():
for rank in range(parallel):
ps.append(
mp.Process(
target=dump_table,
args=(n_nodes, model_family, model_path, rank, parallel),
)
)
for p in ps:
p.start()
for p in ps:
p.join()
def merge_tables(
n_nodes: int,
model_family_to_path: Dict[ModelFamily, str],
parallel: int = 4,
):
import os
import pickle
res_path = os.path.join(constants.PROFILER_CACHE_PATH, "param_realloc")
for model_family in model_family_to_path.keys():
prefix = f"prtc_{model_family}_n{n_nodes}"
r = {}
counter = 0
for fn in os.listdir(res_path):
if fn.endswith(".pkl") and fn.startswith(prefix):
counter += 1
path = os.path.join(res_path, fn)
with open(path, "rb") as f:
r.update(pickle.load(f))
os.remove(path)
if counter < parallel:
raise RuntimeError(
"missing sub-tables, probably some sub-processes failed "
"during param realloc time cost estimation."
)
with open(os.path.join(res_path, f"{prefix}.pkl"), "wb") as f:
pickle.dump(r, f)
print(f"merged parallel tables into {prefix}.pkl, total entries {len(r)}")
def estimate_param_realloc_time_cost(
n_nodes: int,
model_family_to_path: Dict[ModelFamily, str],
parallel: int = 4,
):
dump_table_parallel(n_nodes, model_family_to_path, parallel)
merge_tables(n_nodes, model_family_to_path, parallel)

View File

@ -1,232 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import argparse
import functools
import json
import os
import pickle
import pprint
import re
from typing import Any, Dict, List, Literal, Optional
import numpy as np
try:
import realhf._C.mdm_search as mdm_search
except ModuleNotFoundError:
mdm_search = None
import realhf.base.constants as constants
from realhf.api.cli_args import ModelTrainEvalConfig, ParallelismConfig
from realhf.api.core.config import ModelInterfaceType
from realhf.api.core.dfg import MFCDef
from realhf.api.quickstart.device_mesh import DeviceMesh, RPCAllocation
from realhf.api.quickstart.search import RPCExecution
def search_rpc_allocations(
device_mesh: DeviceMesh,
rpcs: List[MFCDef],
num_gen_tokens: int = 256,
n_ppo_minibatches: int = 1,
seq_len: int = 256,
gradient_checkpointing: bool = True,
use_cache: bool = False,
) -> List[RPCAllocation]:
from realhf.search_engine.enumerate import build_graph
from realhf.search_engine.estimate import get_param_realloc_stats
from_file = os.environ.get("REAL_IS_REMOTE", "0") == "1"
dump_dir = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"device_mapping.pkl",
)
log_dir = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"device_mapping",
)
rs_dir = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"raw_search_result",
)
rpc_exe_dir = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"rpc_exe_info",
)
if from_file or (use_cache and os.path.exists(dump_dir)):
with open(dump_dir, "r") as f:
s = json.load(f)
rpc_allocs = [RPCAllocation.from_dict(d) for d in s]
return rpc_allocs
else:
os.makedirs(os.path.dirname(dump_dir), exist_ok=True)
n_nodes = device_mesh.n_nodes
table = {}
for rpc in rpcs:
print(f"Getting param realloc stats for {rpc.model_type} at {rpc.model_path}")
t = get_param_realloc_stats(rpc.model_type, rpc.model_path, n_nodes, True)
table.update(t)
rpc_exe_list = make_rpc_exe_list(
rpcs,
device_mesh,
num_gen_tokens=num_gen_tokens,
n_ppo_minibatches=n_ppo_minibatches,
seq_len=seq_len,
gradient_checkpointing=gradient_checkpointing,
log_dir=rpc_exe_dir,
if_print=False,
)
graph = build_graph(rpcs, 5, 1, if_print=False)
model_size_dict = make_model_size_dict(rpcs, if_print=False)
n_nodes = device_mesh.n_nodes
search_time = 120
rs: List[Dict[str, List]] = mdm_search.multi_mcmc_search(
rpcs,
rpc_exe_list,
graph,
table,
model_size_dict,
0.001, # beta min
0.002, # beta max
0.001, # beta step
search_time, # time limit for each search
1, # repeat
)
if not from_file:
with open(rs_dir, "w") as f:
pprint.pprint(rs, stream=f)
r: Dict[str, Dict[str, Any]] = rs[-1]
pprint.pprint(r)
rpc_name_to_rpcs = {rpc.name: rpc for rpc in rpcs}
rpc_allocs = []
for rpc_name, alloc_info in r.items():
if rpc_name in ["end_time", "mem_cost"]:
continue
# rpc = rpc_dict[rpc_name]
rpc = rpc_name_to_rpcs[rpc_name]
parallel = ParallelismConfig(
pipeline_parallel_size=alloc_info["num_pp"],
data_parallel_size=alloc_info["num_dp"],
model_parallel_size=alloc_info["num_mp"],
use_sequence_parallel=(
alloc_info["num_mp"] > 1
and rpc.interface_type == ModelInterfaceType.TRAIN_STEP
),
)
sub_device_mesh = DeviceMesh(
n_nodes=device_mesh.n_nodes,
n_gpus_per_node=device_mesh.n_gpus_per_node,
mapping=alloc_info["device_mesh_mapping"],
name=alloc_info["device_mesh_name"],
global_mesh_name=device_mesh.global_mesh_name,
)
rpc_alloc = RPCAllocation(
rpc=rpc,
device_mesh=sub_device_mesh,
parallel=parallel,
)
rpc_allocs.append(rpc_alloc)
if not from_file:
with open(dump_dir, "w") as f:
json.dump([rpc_alloc.to_dict() for rpc_alloc in rpc_allocs], f, indent=4)
with open(log_dir, "w") as f:
pprint.pprint(rpc_allocs, stream=f)
return rpc_allocs
def make_rpc_exe_list(
rpcs: List[MFCDef],
device_mesh: DeviceMesh,
num_gen_tokens: int,
n_ppo_minibatches: int,
seq_len: int,
gradient_checkpointing: bool,
if_print: bool = False,
log_dir: Optional[str] = None,
) -> List[RPCExecution]:
from realhf.search_engine.enumerate import enumerate_rpc_executions
rpc_exe_list = []
log_flag = False
for rpc in rpcs:
# real_model_config = load_model_config(rpc)
feasible = enumerate_rpc_executions(
rpc,
device_mesh,
seq_len=seq_len,
num_gen_tokens=num_gen_tokens,
n_ppo_minibatches=n_ppo_minibatches,
gradient_checkpointing=gradient_checkpointing,
)
rpc_exe_list.extend(feasible)
if log_dir is not None:
mode = "w" if not log_flag else "a"
with open(log_dir, mode) as f:
f.write(f"{rpc.name} feasible: {len(feasible)}\n")
feasible.sort(key=lambda x: x.time_cost)
# feasible = feasible[:30]
for i, rpc_exe in enumerate(feasible):
f.write(
f"{i}: time_cost: {rpc_exe.time_cost} ms, {rpc_exe.time_cost} "
f"sub_device_mesh: {rpc_exe.device_mesh}, "
f"parallel_strategy: {rpc_exe.parallel_strategy}, "
f"mem_cost: {rpc_exe.mem/(1024*1024*1024):02f} GB, "
f"static_mem_cost: {rpc_exe.static_mem/(1024*1024*1024):02f} GB\n"
)
f.write("\n")
log_flag = True
if if_print:
print(f"{rpc.name} feasible: {len(feasible)}")
feasible.sort(key=lambda x: x.time_cost)
# feasible = feasible[:10]
for i, rpc_exe in enumerate(feasible):
print(
f"{i}: time_cost: {rpc_exe.time_cost} ms, "
f"sub_device_mesh: {rpc_exe.device_mesh}, "
f"parallel_strategy: {rpc_exe.parallel_strategy}, "
f"mem_cost: {rpc_exe.mem/(1024*1024*1024):02f} GB, "
f"static_mem_cost: {rpc_exe.static_mem/(1024*1024*1024):02f} GB"
)
return rpc_exe_list
def make_model_size_dict(rpcs: List[MFCDef], if_print: bool = False) -> Dict[str, int]:
model_size_dict = {}
for rpc in rpcs:
if rpc.model_name.role in model_size_dict:
continue
# model_configs = load_model_config(rpc)
# model_size_dict[rpc.model_name.role] = estimate_model_size(real_model_config)
model_size_dict[rpc.model_name.role] = rpc.model_type.size
if if_print:
print(
f"model_name: {rpc.model_name.role}, "
f"model_size: {rpc.model_type.size}"
)
return model_size_dict

View File

@ -1,28 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
from realhf.api.core.model_api import ReaLModelConfig
def find_factors(n):
factors = []
for i in range(1, n + 1):
if n % i == 0:
factors.append(i)
return factors
def make_stats_key(rpc_name, bs, seq_len):
return f"{rpc_name}|{bs}|{seq_len}"
def parse_stats_key(key):
rpc_name, bs, seq_len = key.split("|")
return rpc_name, int(bs), int(seq_len)
def load_model_config(model_class: str, model_path: str) -> ReaLModelConfig:
from realhf.impl.model.nn.real_llm_api import ReaLModel
return getattr(ReaLModel, f"config_from_{model_class}")(model_path=model_path)

View File

@ -478,9 +478,11 @@ def run_ray_worker(
# NOTE: Importing these will initialize DeepSpeed/CUDA devices. # NOTE: Importing these will initialize DeepSpeed/CUDA devices.
# profiler.import_profiler_registers() # profiler.import_profiler_registers()
import realhf.impl.dataset if worker_type != "master_worker":
import realhf.impl.model # For master_worker, there could be errors while importing and it is not necessary.
import realhf.system import realhf.impl.dataset
import realhf.impl.model
import realhf.system
worker_name = f"{worker_type}/{idx}" worker_name = f"{worker_type}/{idx}"
server = worker_control.make_server( server = worker_control.make_server(

View File

@ -24,6 +24,17 @@ SCATTER_GROUPS = {}
logger = logging.getLogger("data_manager", "system") logger = logging.getLogger("data_manager", "system")
def find_minimal_superset(A: List[Set[int]], B: Set[int]) -> Set[int] | None:
min_size = float("inf")
result = None
for S in A:
if B.issubset(S):
if len(S) < min_size:
min_size = len(S)
result = S
return result
class DataManager: class DataManager:
def __init__( def __init__(
@ -52,7 +63,7 @@ class DataManager:
mw_ranks: Dict[ModelName, List[int]] = {} mw_ranks: Dict[ModelName, List[int]] = {}
# Stores the dp_head (i.e., mp_rank=0, pp_rank=-1) ranks given a model_name. # Stores the dp_head (i.e., tp_rank=0, pp_rank=-1) ranks given a model_name.
mw_dp_head_ranks: Dict[ModelName, List[int]] = defaultdict(list) mw_dp_head_ranks: Dict[ModelName, List[int]] = defaultdict(list)
assert msid2mwid is not None assert msid2mwid is not None
@ -67,7 +78,7 @@ class DataManager:
topo, topo,
msid2mwid, msid2mwid,
pipe=topo.get_dim("pipe") - 1, pipe=topo.get_dim("pipe") - 1,
model=0, tensor=0,
) )
dp_size = topo.get_dim("data") dp_size = topo.get_dim("data")
for dp_i in range(dp_size): for dp_i in range(dp_size):
@ -87,11 +98,12 @@ class DataManager:
list(ranks), backend="nccl" if constants.use_cuda() else "gloo" list(ranks), backend="nccl" if constants.use_cuda() else "gloo"
) )
scatter_ranks = tuple(sorted(set([ranks[0]] + mw_ranks[dst]))) for rank in ranks:
SCATTER_GROUPS[scatter_ranks] = new_or_get_group( scatter_ranks = tuple(sorted(set([rank] + mw_ranks[dst])))
list(scatter_ranks), SCATTER_GROUPS[scatter_ranks] = new_or_get_group(
backend="nccl" if constants.use_cuda() else "gloo", list(scatter_ranks),
) backend="nccl" if constants.use_cuda() else "gloo",
)
# Construct all src-dst pairs, from any src dp rank to any dst dp rank. # Construct all src-dst pairs, from any src dp rank to any dst dp rank.
# Note that a dp rank corresponds to multiple parameter shards (TP+PP), # Note that a dp rank corresponds to multiple parameter shards (TP+PP),
@ -228,7 +240,20 @@ class DataManager:
def _run_gather( def _run_gather(
self, step: RedistribStep, data_infos: Dict[Hashable, SequenceSample] self, step: RedistribStep, data_infos: Dict[Hashable, SequenceSample]
): ):
if dist.get_rank() not in step.srcs: # It's possible that some DP rank is not involved.
# Create dummpy data to make the gather happy.
gather_ranks = find_minimal_superset(
[set(k) for k in GATHER_GROUPS.keys()], set(step.srcs)
)
assert gather_ranks is not None, (
set(step.srcs),
[set(k) for k in GATHER_GROUPS.keys()],
)
gather_ranks = sorted(list(gather_ranks))
pgroup = GATHER_GROUPS[tuple(gather_ranks)]
if dist.get_rank() not in gather_ranks:
return return
maxlen = 0 maxlen = 0
@ -249,36 +274,47 @@ class DataManager:
torch.empty( torch.empty(
maxlen, device=constants.current_device(), dtype=torch.float32 maxlen, device=constants.current_device(), dtype=torch.float32
) )
for _ in range(len(step.srcs)) for _ in gather_ranks
] ]
is_valid_gather = [i in step.srcs for i in gather_ranks]
else: else:
gather_list = None gather_list = None
local_gather_idx = step.srcs.index(dist.get_rank()) if dist.get_rank() in step.srcs:
ids = step.ids[local_gather_idx] local_gather_idx = step.srcs.index(dist.get_rank())
for i in ids: ids = step.ids[local_gather_idx]
self.storage[i].to_device(constants.current_device()) for i in ids:
samples = [self.storage[i] for i in ids] self.storage[i].to_device(constants.current_device())
data = torch.cat( samples = [self.storage[i] for i in ids]
[ data = torch.cat(
sample.data[key].float().flatten() [
for sample in samples sample.data[key].float().flatten()
for key in step.keys for sample in samples
] for key in step.keys
) ]
data = self._pad_data(data, maxlen) )
data = self._pad_data(data, maxlen)
else:
data = torch.empty(
maxlen, device=constants.current_device(), dtype=torch.float32
)
dist.gather( dist.gather(
data, data,
gather_list, gather_list,
dst=step.root, dst=step.root,
group=GATHER_GROUPS[tuple(sorted(step.srcs))], group=pgroup,
) )
if dist.get_rank() != step.root: if dist.get_rank() != step.root:
del data
return return
for ids, buf in zip(step.ids, gather_list): cnt = 0
for is_valid, buf in zip(is_valid_gather, gather_list):
if not is_valid:
continue
ids = step.ids[cnt]
offset = 0 offset = 0
for i in ids: for i in ids:
for key in step.keys: for key in step.keys:
@ -302,6 +338,9 @@ class DataManager:
self.storage[i].update_(s) self.storage[i].update_(s)
else: else:
self.storage[i] = s self.storage[i] = s
cnt += 1
assert cnt == len(step.srcs) == len(step.ids)
del data
def _run_scatter( def _run_scatter(
self, step: RedistribStep, data_infos: Dict[Hashable, SequenceSample] self, step: RedistribStep, data_infos: Dict[Hashable, SequenceSample]

View File

@ -27,7 +27,7 @@ class FunctionExecutor:
rpcs: List[MFCDef], rpcs: List[MFCDef],
msid2mwid: Dict[ModelShardID, int], msid2mwid: Dict[ModelShardID, int],
stream: NameResolvingRequestClient, stream: NameResolvingRequestClient,
buffer: AsyncIOSequenceBuffer, buffers: List[AsyncIOSequenceBuffer],
model_topos: Dict[str, ProcessTopology], model_topos: Dict[str, ProcessTopology],
model_configs: Dict[str, None | ReaLModelConfig], model_configs: Dict[str, None | ReaLModelConfig],
ctrl: RPCCorountineControl, ctrl: RPCCorountineControl,
@ -58,14 +58,15 @@ class FunctionExecutor:
model_topos=model_topos, model_topos=model_topos,
model_configs=model_configs, model_configs=model_configs,
ctrl=ctrl, ctrl=ctrl,
buffer=buffer, buffers=buffers,
redistrib_planner=self.redistrib_planner, redistrib_planner=self.redistrib_planner,
summary_writer=summary_writer, summary_writer=summary_writer,
) )
self.func_calls[rpc.name] = func_call self.func_calls[rpc.name] = func_call
self.stream = stream self.stream = stream
self.buffer = buffer self.buffers = buffers
self.buffer_id = 0
self.data_loading_dp_idx = -1 self.data_loading_dp_idx = -1
self.shuffle_dataset = shuffle_dataset self.shuffle_dataset = shuffle_dataset
@ -111,18 +112,17 @@ class FunctionExecutor:
self.ctrl.ids_to_clear.clear() self.ctrl.ids_to_clear.clear()
async def load_data(self): async def load_data(self, buffer_id: int):
buffer = self.buffer buffer = self.buffers[buffer_id]
ctrl = self.ctrl ctrl = self.ctrl
received_ids = set() received_ids = set()
while self.buffer.size < max(rpc.n_seqs for rpc in self.rpcs): while buffer.size < max(rpc.n_seqs for rpc in self.rpcs):
resps = await self.stream.call_async( resps = await self.stream.call_async(
handlers=[f"__data{dp_idx}__" for dp_idx in range(self.src_dp_size)], handlers=[f"__data{dp_idx}__" for dp_idx in range(self.src_dp_size)],
handle_type="fetch", handle_type="fetch",
datas=[None for _ in range(self.src_dp_size)], datas=[buffer_id for _ in range(self.src_dp_size)],
verbose=False, verbose=False,
) )
@ -182,10 +182,13 @@ class FunctionExecutor:
logger.info("Waiting for the finish of the execution graph.") logger.info("Waiting for the finish of the execution graph.")
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
tasks = [loop.create_task(fc.run()) for fc in self.func_calls.values()] + [ tasks = [
loop.create_task(fc.run(self.buffer_id)) for fc in self.func_calls.values()
] + [
loop.create_task(self.flush_calls()), loop.create_task(self.flush_calls()),
loop.create_task(self.load_data()), loop.create_task(self.load_data(self.buffer_id)),
loop.create_task(self.finish_traverse()), loop.create_task(self.finish_traverse()),
] ]
loop.run_until_complete(asyncio.gather(*tasks)) loop.run_until_complete(asyncio.gather(*tasks))
self.buffer_id = (self.buffer_id + 1) % len(self.buffers)

View File

@ -1,14 +1,111 @@
import os import os
import subprocess
import sys
import time import time
from pathlib import Path
import requests
from realhf.api.cli_args import SGLangConfig from realhf.api.cli_args import SGLangConfig
from realhf.api.core.system_api import GenerationServer as GenerationServerConfig from realhf.api.core.system_api import GenerationServer as GenerationServerConfig
from realhf.base import gpu_utils, logging, name_resolve, names, network, seeding from realhf.base import (
gpu_utils,
logging,
name_resolve,
names,
network,
pkg_version,
seeding,
)
from realhf.base.cluster import spec as cluster_spec
from realhf.system.worker_base import PollResult, Worker from realhf.system.worker_base import PollResult, Worker
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def execute_shell_command(command: str) -> subprocess.Popen:
"""
Execute a shell command and return its process handle.
"""
# Replace newline continuations and split the command string.
command = command.replace("\\\n", " ").replace("\\", " ")
parts = command.split()
return subprocess.Popen(parts, text=True, stderr=subprocess.STDOUT)
def launch_server_cmd(command: str, port: int = 30000):
"""
Launch the server using the given command.
If no port is specified, a free port is reserved.
"""
p = Path(os.path.dirname(__file__))
patch_path = str(
p.parent.parent
/ "patch"
/ "sglang"
/ f"v{pkg_version.get_version('sglang')}.patch"
)
target_path = ""
sglang_meta = subprocess.check_output(
"python3 -m pip show sglang", shell=True
).decode("ascii")
for line in sglang_meta.split("\n"):
line = line.strip()
if line.startswith("Editable project location: "):
target_path = str(Path(line.split(": ")[1]).parent)
if target_path:
proc = subprocess.Popen(
["git", "apply", patch_path],
cwd=target_path,
stderr=sys.stdout,
stdout=sys.stdout,
)
proc.wait()
logger.info(f"Applied SGLang patch at {target_path}")
assert port is not None
full_command = f"{command} --port {port}"
process = execute_shell_command(full_command)
return process, port
def terminate_process(process, port=None):
"""
Terminate the process and, if a port was reserved, release it.
"""
from sglang.srt.utils import kill_process_tree
kill_process_tree(process.pid)
def wait_for_server(base_url: str, timeout: int = None) -> None:
"""Wait for the server to be ready by polling the /v1/models endpoint.
Args:
base_url: The base URL of the server
timeout: Maximum time to wait in seconds. None means wait forever.
"""
start_time = time.time()
while True:
try:
response = requests.get(
f"{base_url}/v1/models",
headers={"Authorization": "Bearer None"},
)
if response.status_code == 200:
time.sleep(5)
break
if timeout and time.time() - start_time > timeout:
raise TimeoutError("Server did not become ready within timeout period")
except requests.exceptions.RequestException:
time.sleep(1)
PORT_CLEARANCE_PERIOD = 90
class GenerationServer(Worker): class GenerationServer(Worker):
def _configure(self, config: GenerationServerConfig): def _configure(self, config: GenerationServerConfig):
self.config = config self.config = config
@ -36,20 +133,37 @@ class GenerationServer(Worker):
config = self.config config = self.config
assert config.backend_type == "sglang" assert config.backend_type == "sglang"
host_ip = network.gethostip()
host = "localhost" if not config.backend_args.enable_metrics else host_ip
# NOTE: Ports returned by `find_multiple_free_ports` are unique,
# but SGLang servers still encounter conflicts.
# Use a clearance period to hack over this issue.
servers_per_node = cluster_spec.n_gpus_per_node // self.config.tp_size
idx_on_this_node = self.worker_index % servers_per_node
time.sleep(idx_on_this_node * PORT_CLEARANCE_PERIOD / servers_per_node)
ports = network.find_multiple_free_ports(
2,
low=10000,
high=60000,
experiment_name=self.experiment_name,
trial_name=self.trial_name,
)
server_port = ports[0]
nccl_port = ports[1]
cmd = SGLangConfig.build_cmd( cmd = SGLangConfig.build_cmd(
config.backend_args, config.backend_args,
config.model_path, config.model_path,
tp_size=config.tp_size, tp_size=config.tp_size,
server_index=self.worker_index, server_index=self.worker_index,
base_gpu_id=self.base_gpu_id, base_gpu_id=self.base_gpu_id,
dist_init_addr=f"{host}:{nccl_port}",
) )
from sglang.utils import launch_server_cmd, wait_for_server
host_ip = network.gethostip() self.server_process, self.server_port = launch_server_cmd(cmd, port=server_port)
host = "localhost" if not config.backend_args.enable_metrics else host_ip
# TODO: handle launching error and retry
self.server_process, self.server_port = launch_server_cmd(cmd)
self.server_addr = f"http://{host}:{self.server_port}" self.server_addr = f"http://{host}:{self.server_port}"
wait_for_server(self.server_addr) wait_for_server(self.server_addr)
@ -80,6 +194,5 @@ class GenerationServer(Worker):
def _exit_hook(self, exit_status): def _exit_hook(self, exit_status):
if self.server_process is not None and self.config.backend_type == "sglang": if self.server_process is not None and self.config.backend_type == "sglang":
from sglang.utils import terminate_process
terminate_process(self.server_process) terminate_process(self.server_process)

View File

@ -6,20 +6,46 @@ import shutil
import threading import threading
import time import time
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass
from typing import List from typing import List
import aiohttp import aiohttp
import numpy as np
from realhf.api.core.model_api import GenReqMeta, GenRespMeta, ModelVersionReq from realhf.api.core.model_api import GenReqMeta, GenRespMeta, ModelVersionReq
from realhf.api.core.system_api import GserverManager as GserverManagerConfig from realhf.api.core.system_api import GserverManager as GserverManagerConfig
from realhf.base import constants, logging, name_resolve, names, network, recover from realhf.base import constants, logging, name_resolve, names, network, recover
from realhf.system.worker_base import AsyncWorker, PollResult, Worker from realhf.system.worker_base import PollResult, Worker
logger = logging.getLogger("Generation Manager", "colored") logger = logging.getLogger("Generation Manager", "system")
STALENESS_WARNED = defaultdict(lambda: False) STALENESS_WARNED = defaultdict(lambda: False)
@dataclass
class RolloutStat:
submit: int = 0
accepted: int = 0
running: int = 0
def inc(self):
self.submit += 1
self.accepted += 1
self.running += 1
def accept(self):
self.running -= 1
def reject(self):
self.running -= 1
self.accepted -= 1
@dataclass
class AllocateRolloutInput:
qid: str
class GserverManager(Worker): class GserverManager(Worker):
"""This worker has the following functionalities: """This worker has the following functionalities:
1. As a router, it schedules generation requests and returns the 1. As a router, it schedules generation requests and returns the
@ -37,23 +63,30 @@ class GserverManager(Worker):
assert self.config.worker_info.worker_count == 1 assert self.config.worker_info.worker_count == 1
self.async_lock = asyncio.Lock()
self.threading_lock = threading.Lock() self.threading_lock = threading.Lock()
self.n_total_rollouts = 0 self.rollout_stat = RolloutStat()
self.n_running_rollouts = 0
self.accepted_rollouts = 0
self.schedule_policy = config.schedule_policy self.schedule_policy = config.schedule_policy
self._last_param_realloc_step = 0 self._last_param_realloc_step = 0
self._qid_to_server_url = {}
self._server_token_usage = defaultdict(float)
self._server_request_counts = defaultdict(int)
self._last_thpt_output_time = time.time()
self._gen_tokens = 0
self.experiment_name = config.worker_info.experiment_name self.experiment_name = config.worker_info.experiment_name
self.trial_name = config.worker_info.trial_name self.trial_name = config.worker_info.trial_name
# manager server # manager server
self.server = None self.manager_http_server = None
self.thread = None self.thread = None
self.server_urls = []
# recover info # recover info
self.__recover_run, self.__recover_info = recover.load_recover_info() self.__recover_run, self.__recover_info = recover.load_recover_info()
if self.__recover_run: if self.__recover_run:
@ -67,10 +100,12 @@ class GserverManager(Worker):
name_resolve.add(name, self.__recover_info.last_step_info.global_step) name_resolve.add(name, self.__recover_info.last_step_info.global_step)
self._loaded_recover_weights = False self._loaded_recover_weights = False
self.n_total_rollouts = self.accepted_rollouts = ( hist_rollouts = (
self.config.train_batch_size self.config.train_batch_size
* self.__recover_info.last_step_info.global_step * self.__recover_info.last_step_info.global_step
) )
self.rollout_stat.submit = hist_rollouts
self.rollout_stat.accepted = hist_rollouts
return config.worker_info return config.worker_info
@ -84,7 +119,7 @@ class GserverManager(Worker):
if cnt >= timeout: if cnt >= timeout:
raise TimeoutError("Waiting generation servers timeout.") raise TimeoutError("Waiting generation servers timeout.")
urls = name_resolve.get_subtree(name) urls = name_resolve.get_subtree(name)
assert len(set(urls)) == len(urls), urls assert len(set(urls)) == len(urls), (len(urls), len(set(urls)), urls)
return urls return urls
def _get_recover_ckpt_path(self, role: str): def _get_recover_ckpt_path(self, role: str):
@ -146,49 +181,34 @@ class GserverManager(Worker):
async def flush_requests_and_update_weights( async def flush_requests_and_update_weights(
self, server_url, new_param_path, update_weights_retries=5 self, server_url, new_param_path, update_weights_retries=5
): ):
# HACK: urls are designed for SGLang
server_index = self.server_urls.index(server_url) server_index = self.server_urls.index(server_url)
async with aiohttp.ClientSession(server_url) as session: success = False
running_requests = None for _ in range(update_weights_retries):
tik = time.perf_counter() async with aiohttp.ClientSession(
while running_requests is None or running_requests > 0: server_url,
if time.perf_counter() - tik > self.config.flush_request_timeout: timeout=aiohttp.ClientTimeout(
raise RuntimeError( total=self.config.flush_request_timeout, sock_connect=30
f"Waiting for flush requests failed. {running_requests} requests " ),
f"remain after {self.config.flush_request_timeout} secs waiting. " ) as session:
f"Please try to reduce `new_tokens_per_chunk`."
)
if running_requests is not None and running_requests > 0:
logger.info(
f"Waiting for {running_requests} requests on gen server {server_index}... "
f"Time taken so far: {time.perf_counter() - tik:.4f}s"
)
await asyncio.sleep(0.5)
async with session.get(f"/metrics") as resp:
resp.raise_for_status()
text = await resp.text()
for line in text.split("\n"):
if line.startswith("sglang:num_running_reqs"):
running_requests = float(line.split(" ")[1])
break
success = False
for _ in range(update_weights_retries):
async with session.post( async with session.post(
f"/update_weights_from_disk", f"/update_weights_from_disk",
json=dict(model_path=new_param_path), json=dict(model_path=new_param_path, allow_interrupt=True),
) as resp: ) as resp:
if resp.status == 200: if resp.status == 200:
res = await resp.json() res = await resp.json()
success = res["success"] success = res["success"]
if success: if success:
logger.info(
f"{res['num_paused_requests']} requests are interrupted "
f"during updateing weights for server {server_index}: {server_url}"
)
return return
logger.warning( logger.warning(
f"Update weights failed: {res['message']}. Retrying." f"Update weights failed: {res['message']}. Retrying."
) )
logger.warning(f"Update weights failed: {resp.reason}. Retrying.") logger.warning(f"Update weights failed: {resp.reason}. Retrying.")
time.sleep(0.1) time.sleep(0.1)
raise RuntimeError("Update weights failed.") raise RuntimeError("Update weights failed.")
def _round_robin_schedule(self, req_meta: GenReqMeta) -> int: def _round_robin_schedule(self, req_meta: GenReqMeta) -> int:
if not hasattr(self, "round_robin_idx"): if not hasattr(self, "round_robin_idx"):
@ -198,6 +218,16 @@ class GserverManager(Worker):
self.round_robin_idx %= self.config.n_servers self.round_robin_idx %= self.config.n_servers
return r return r
def _least_requests_schedule(self, req_meta: GenReqMeta) -> int:
counts = [
self._server_request_counts[server_url] for server_url in self.server_urls
]
return int(np.argmin(counts))
def _least_token_usage_schedule(self, req_meta: GenReqMeta) -> int:
url = min(self.server_urls, key=lambda k: self._server_token_usage[k])
return self.server_urls.index(url)
def _poll(self): def _poll(self):
if not self.thread: if not self.thread:
# Find addresses of generation servers # Find addresses of generation servers
@ -228,6 +258,23 @@ class GserverManager(Worker):
loop.run_until_complete(asyncio.gather(*tasks)) loop.run_until_complete(asyncio.gather(*tasks))
logger.info(f"Generaion server updated weights from: {new_param_path}") logger.info(f"Generaion server updated weights from: {new_param_path}")
tasks = [
self._get_server_token_usage(server_url) for server_url in self.server_urls
]
loop = asyncio.get_event_loop()
token_usages = loop.run_until_complete(asyncio.gather(*tasks))
with self.threading_lock:
for server_url, token_usage in zip(self.server_urls, token_usages):
self._server_token_usage[server_url] = token_usage
if time.time() - self._last_thpt_output_time > 30:
interval = time.time() - self._last_thpt_output_time
logger.info(
f"Generation throughput: {self._gen_tokens / interval:.2f} tokens/s"
)
self._last_thpt_output_time = time.time()
self._gen_tokens = 0
# clear old weights # clear old weights
realloc_root = os.path.join( realloc_root = os.path.join(
constants.PARAM_REALLOC_PATH, constants.PARAM_REALLOC_PATH,
@ -237,9 +284,11 @@ class GserverManager(Worker):
) )
if os.path.exists(realloc_root): if os.path.exists(realloc_root):
for realloc_version in os.listdir(realloc_root): for realloc_version in os.listdir(realloc_root):
# Lock-free is safe here.
# Remain one checkpoint for recover.
if ( if (
os.path.isdir(os.path.join(realloc_root, realloc_version)) os.path.isdir(os.path.join(realloc_root, realloc_version))
and int(realloc_version) < self._last_param_realloc_step and int(realloc_version) < self._last_param_realloc_step - 1
): ):
shutil.rmtree(os.path.join(realloc_root, realloc_version)) shutil.rmtree(os.path.join(realloc_root, realloc_version))
logger.info( logger.info(
@ -247,29 +296,56 @@ class GserverManager(Worker):
f"checkpoint: {os.path.join(realloc_root, realloc_version)}" f"checkpoint: {os.path.join(realloc_root, realloc_version)}"
) )
# TODO: we may want to update server status time.sleep(5)
# in the main thread.
time.sleep(1)
return PollResult(0, 0) return PollResult(0, 0)
async def is_staled(self): async def _get_server_token_usage(self, server_url):
global_sample_cnt = self.n_total_rollouts async with aiohttp.ClientSession(
expected_version = global_sample_cnt // self.config.train_batch_size server_url,
staled = ( timeout=aiohttp.ClientTimeout(
expected_version total=self.config.flush_request_timeout, sock_connect=30
> self.config.max_head_offpolicyness + self._last_param_realloc_step ),
) as session:
async with session.get("/metrics") as resp:
resp.raise_for_status()
text = await resp.text()
for l in text.split("\n"):
if l.startswith("sglang:num_used_tokens"):
return float(l.split(" ")[1])
raise RuntimeError(f"Failed to get token usage metrics from {server_url}")
async def _get_server_num_running_requests(self, server_url):
async with aiohttp.ClientSession(
server_url,
timeout=aiohttp.ClientTimeout(
total=self.config.flush_request_timeout, sock_connect=30
),
) as session:
async with session.get(f"/metrics") as resp:
resp.raise_for_status()
text = await resp.text()
for line in text.split("\n"):
if line.startswith("sglang:num_running_reqs"):
return float(line.split(" ")[1])
raise RuntimeError(
f"Failed to get num running requests metrics from {server_url}"
) )
def is_staled(self):
global_sample_cnt = self.rollout_stat.accepted
expected_version = global_sample_cnt // self.config.train_batch_size
version = self._last_param_realloc_step
staled = expected_version > self.config.max_head_offpolicyness + version
global STALENESS_WARNED global STALENESS_WARNED
if staled and not STALENESS_WARNED[self._last_param_realloc_step]: if staled and not STALENESS_WARNED[version]:
logger.warning( logger.warning(
f"expected version ({expected_version}) = " f"expected version ({expected_version}) = "
f"global sample cnt ({global_sample_cnt}) // batch size ({self.config.train_batch_size}), " f"global sample cnt ({global_sample_cnt}) // batch size ({self.config.train_batch_size}), "
f"current version {self._last_param_realloc_step}, " f"current latest version {version}, "
f"offpolicyness {self.config.max_head_offpolicyness}. Staled? {staled}" f"offpolicyness {self.config.max_head_offpolicyness}. Staled? {staled}"
) )
STALENESS_WARNED[self._last_param_realloc_step] = True STALENESS_WARNED[version] = True
return staled return staled
def _run_routing_service(self): def _run_routing_service(self):
@ -282,60 +358,94 @@ class GserverManager(Worker):
@self.app.post("/schedule_request") @self.app.post("/schedule_request")
async def schedule_request(req_meta: GenReqMeta): async def schedule_request(req_meta: GenReqMeta):
with self.threading_lock: with self.threading_lock:
async with self.async_lock: if (
version = self._last_param_realloc_step req_meta.previous_server_url
# FIXME: We only implement a round-robin scheduler that and req_meta.previous_version == self._last_param_realloc_step
# ignores server status and request metadata ):
return dict(
url=req_meta.previous_server_url,
version=req_meta.previous_version,
)
if self.schedule_policy == "round_robin":
server_idx = self._round_robin_schedule(req_meta) server_idx = self._round_robin_schedule(req_meta)
return dict(url=self.server_urls[server_idx], version=max(0, version)) elif self.schedule_policy == "least_token_usage":
server_idx = self._least_token_usage_schedule(req_meta)
elif self.schedule_policy == "least_requests":
server_idx = self._least_requests_schedule(req_meta)
else:
raise NotImplementedError(
f"Unknown schedule policy {self.schedule_policy}"
)
server_url = self.server_urls[server_idx]
# qid prompt (n samples) use the same dst server
self._qid_to_server_url[req_meta.qid] = server_url
self._server_request_counts[server_url] += 1
self._server_token_usage[server_url] += (
req_meta.prompt_len
+ req_meta.new_token_budget * req_meta.group_size * 0.4
)
version = self._last_param_realloc_step
return dict(url=server_url, version=version)
@self.app.post("/get_model_version") @self.app.post("/get_model_version")
async def get_model_version(req: ModelVersionReq): async def get_model_version(req: ModelVersionReq):
with self.threading_lock: with self.threading_lock:
async with self.async_lock: # FIXME: we may have different versions for different servers
# FIXME: we may have different versions for different servers version = self._last_param_realloc_step
version = self._last_param_realloc_step
return dict(version=version) return dict(version=version)
@self.app.get("/allocate_rollout") @self.app.post("/allocate_rollout")
async def allocate_rollout(): async def allocate_rollout(req: AllocateRolloutInput):
with self.threading_lock: with self.threading_lock:
async with self.async_lock: has_capacity = (
has_capacity = ( self.rollout_stat.running < self.config.max_concurrent_rollouts
self.n_running_rollouts < self.config.max_concurrent_rollouts )
) is_staled = self.is_staled()
is_staled = await self.is_staled() reason = ""
reason = "" if has_capacity and not is_staled:
if has_capacity and not is_staled: self.rollout_stat.inc()
self.n_running_rollouts += 1 return dict(success=True, reason=reason)
self.n_total_rollouts += 1 else:
return dict(success=True, reason=reason) if not has_capacity:
else: reason += f"capacity: {self.rollout_stat.running} >= {self.config.max_concurrent_rollouts}"
if not has_capacity: if is_staled:
reason += f"capacity: {self.n_running_rollouts} >= {self.config.max_concurrent_rollouts}" global_sample_cnt = self.rollout_stat.accepted
if is_staled: expected_version = (
global_sample_cnt = self.n_total_rollouts global_sample_cnt // self.config.train_batch_size
expected_version = ( )
global_sample_cnt // self.config.train_batch_size version = self._last_param_realloc_step
) reason += (
reason += ( f" and staled: expected version ({expected_version}) = "
f" and staled: expected version ({expected_version}) = " f"global sample cnt ({global_sample_cnt}) // batch size ({self.config.train_batch_size}), "
f"global sample cnt ({global_sample_cnt}) // batch size ({self.config.train_batch_size}), " f"current latest version {version}, "
f"current version {self._last_param_realloc_step}, " f"offpolicyness {self.config.max_head_offpolicyness}."
f"offpolicyness {self.config.max_head_offpolicyness}." )
) return dict(success=False, reason=reason)
return dict(success=False, reason=reason)
@self.app.post("/finish_rollout") @self.app.post("/finish_rollout")
async def finish_rollout(resp_meta: GenRespMeta): async def finish_rollout(resp_meta: GenRespMeta):
with self.threading_lock: with self.threading_lock:
async with self.async_lock: server_url = self._qid_to_server_url[resp_meta.qid]
self.n_running_rollouts -= 1 self._server_request_counts[server_url] -= 1
if resp_meta.accepted: assert (
self.accepted_rollouts += 1 self._server_request_counts[server_url] >= 0
return dict(success=True) ), "server request count < 0"
self._qid_to_server_url.pop(resp_meta.qid)
self._gen_tokens += resp_meta.n_tokens
if resp_meta.accepted:
self.rollout_stat.accept()
else:
self.rollout_stat.reject()
return dict(success=True)
self.manager_addr = f"{network.gethostip()}:{network.find_free_port()}" port = network.find_free_port(
experiment_name=self.experiment_name,
trial_name=self.trial_name,
)
self.manager_addr = f"{network.gethostip()}:{port}"
config = uvicorn.Config( config = uvicorn.Config(
self.app, self.app,
@ -343,12 +453,12 @@ class GserverManager(Worker):
port=int(self.manager_addr.split(":")[1]), port=int(self.manager_addr.split(":")[1]),
log_level="warning", log_level="warning",
) )
self.server = uvicorn.Server(config) self.manager_http_server = uvicorn.Server(config)
self.server.run() self.manager_http_server.run()
def _exit_hook(self, exit_status): def _exit_hook(self, exit_status):
if self.server: if self.manager_http_server:
self.server.should_exit = True self.manager_http_server.should_exit = True
if self.thread: if self.thread:
self.thread.join(timeout=3) self.thread.join(timeout=3)
logger.info("Server stopped") logger.info("Server stopped")

View File

@ -145,6 +145,7 @@ class MasterWorker(worker_base.Worker):
# for benchmark # for benchmark
self.e2e_time_history = [] self.e2e_time_history = []
self.__benchmark_steps = config.exp_ctrl.benchmark_steps self.__benchmark_steps = config.exp_ctrl.benchmark_steps
self.__benchmark_n_seqs = config.exp_ctrl.benchmark_n_seqs
return config.worker_info return config.worker_info
@ -210,13 +211,14 @@ class MasterWorker(worker_base.Worker):
src_rpc_dp_size = src_rpc_topo.get_dim("data") src_rpc_dp_size = src_rpc_topo.get_dim("data")
# Request training specification from data workers. # Request training specification from data workers.
self._dataset_size = sum( specs = self.__stream.call(
self.__stream.call( handlers=[f"__data{i}__" for i in range(src_rpc_dp_size)],
handlers=[f"__data{i}__" for i in range(src_rpc_dp_size)], datas=[None for i in range(src_rpc_dp_size)],
datas=[None for i in range(src_rpc_dp_size)], handle_type="spec",
handle_type="spec",
),
) )
assert all(x["n_datasets"] == specs[0]["n_datasets"] for x in specs), specs
self._dataset_size = sum(x["dataset_size"] for x in specs)
self._n_datasets = specs[0]["n_datasets"]
self._steps_per_epoch = self._dataset_size // src_rpc.n_seqs self._steps_per_epoch = self._dataset_size // src_rpc.n_seqs
@ -239,7 +241,7 @@ class MasterWorker(worker_base.Worker):
src_rpc_dp_size = src_rpc_topo.get_dim("data") src_rpc_dp_size = src_rpc_topo.get_dim("data")
src_rpc_pp_size = src_rpc_topo.get_dim("pipe") src_rpc_pp_size = src_rpc_topo.get_dim("pipe")
for i in range(src_rpc_dp_size): for i in range(src_rpc_dp_size):
rank = src_rpc_topo.get_rank(data=i, pipe=src_rpc_pp_size - 1, model=0) rank = src_rpc_topo.get_rank(data=i, pipe=src_rpc_pp_size - 1, tensor=0)
handler_routing[f"__data{i}__"] = self.config.msid2mwid[ handler_routing[f"__data{i}__"] = self.config.msid2mwid[
config_pkg.ModelShardID.from_parallelism_rank( config_pkg.ModelShardID.from_parallelism_rank(
model_name=src_rpc.model_name, model_name=src_rpc.model_name,
@ -263,10 +265,13 @@ class MasterWorker(worker_base.Worker):
self.initialize_models() self.initialize_models()
self.__seqbuffer = AsyncIOSequenceBuffer( self.__seqbuffers = [
self.__model_rpcs, AsyncIOSequenceBuffer(
max_size=int(os.getenv("REAL_MASTER_BUFFER_SIZE", str(int(1e7)))), self.__model_rpcs,
) max_size=int(os.getenv("REAL_MASTER_BUFFER_SIZE", str(int(1e7)))),
)
for _ in range(self._n_datasets)
]
# wandb init, connect to remote wandb host # wandb init, connect to remote wandb host
wandb.login() wandb.login()
@ -300,7 +305,7 @@ class MasterWorker(worker_base.Worker):
rpcs=self.__model_rpcs, rpcs=self.__model_rpcs,
msid2mwid=self.config.msid2mwid, msid2mwid=self.config.msid2mwid,
stream=self.__stream, stream=self.__stream,
buffer=self.__seqbuffer, buffers=self.__seqbuffers,
model_topos=self.__model_topos, model_topos=self.__model_topos,
model_configs=self.__model_configs, model_configs=self.__model_configs,
ctrl=self.__rpc_ctrl, ctrl=self.__rpc_ctrl,
@ -395,20 +400,33 @@ class MasterWorker(worker_base.Worker):
# Pause the worker if experiment or system-wise benchmark completes. # Pause the worker if experiment or system-wise benchmark completes.
if ( if (
self.__benchmark_steps is not None (
and self.__rpc_ctrl.step_info.global_step >= self.__benchmark_steps self.__benchmark_steps is not None
) or ( and self.__rpc_ctrl.step_info.global_step >= self.__benchmark_steps
self.__rpc_ctrl.step_info.global_step * self.__src_rpc.n_seqs )
>= self.__total_train_epochs * self._dataset_size or (
self.__rpc_ctrl.step_info.global_step * self.__src_rpc.n_seqs
>= self.__total_train_epochs * self._dataset_size
)
or (
self.__benchmark_n_seqs is not None
and self.__rpc_ctrl.step_info.global_step
* self._ft_spec.train_batch_size
>= self.__benchmark_n_seqs
)
): ):
# We don't know whether it is the last step of the current epoch, # We don't know whether it is the last step of the current epoch,
# so we exit at the first step of the next epoch. # so we exit at the first step of the next epoch.
if self.__benchmark_steps is not None: if (
self.__benchmark_steps is not None
or self.__benchmark_n_seqs is not None
):
logger.info( logger.info(
f"Finished benchmark {self.__benchmark_steps}. " f"Finished benchmark {self.__benchmark_steps}. "
f"Time consumption of this setup: {time_since_configure:.3f}" f"Time consumption of this setup: {time_since_configure:.3f}"
) )
logger.info(f"avg #e2e# time *{np.mean(self.e2e_time_history):.3f}*") logger.info(f"avg #e2e# time *{np.mean(self.e2e_time_history):.3f}*")
# TODO: inform generation workers to exit
return self.experiment_complete_exit() return self.experiment_complete_exit()
return worker_base.PollResult(sample_count=1, batch_count=1) return worker_base.PollResult(sample_count=1, batch_count=1)
@ -439,9 +457,7 @@ class MasterWorker(worker_base.Worker):
s += f"(global step {global_step}) finishes. " s += f"(global step {global_step}) finishes. "
s += f"#End to end# execution time: *{e2e_time:.3f}*s. " s += f"#End to end# execution time: *{e2e_time:.3f}*s. "
s += f"Total time consumption: {time_since_configure:.3f}s. " s += f"Total time consumption: {time_since_configure:.3f}s. "
logging.log_wandb_tensorboard( logging.log_wandb_tensorboard({"timeperf/e2e": e2e_time})
{"timeperf/e2e": e2e_time}, step=self.__rpc_ctrl.step_info.global_step
)
if len(self.e2e_time_history) > 2: if len(self.e2e_time_history) > 2:
remaining_steps = self._steps_per_epoch - epoch_step remaining_steps = self._steps_per_epoch - epoch_step
remaining_epochs = self.__total_train_epochs - epoch remaining_epochs = self.__total_train_epochs - epoch

View File

@ -63,7 +63,7 @@ class ModelFunctionCall:
model_topos: Dict[str, topology.ProcessTopology], model_topos: Dict[str, topology.ProcessTopology],
model_configs: Dict[str, None | ReaLModelConfig], model_configs: Dict[str, None | ReaLModelConfig],
ctrl: RPCCorountineControl, ctrl: RPCCorountineControl,
buffer: AsyncIOSequenceBuffer, buffers: List[AsyncIOSequenceBuffer],
redistrib_planner: RedistribPlanner, redistrib_planner: RedistribPlanner,
summary_writer: SummaryWriter | None, summary_writer: SummaryWriter | None,
): ):
@ -89,7 +89,7 @@ class ModelFunctionCall:
) )
self.rpc_ctrl = ctrl self.rpc_ctrl = ctrl
self.buffer = buffer self.buffers = buffers
self.redistrib_planner = redistrib_planner self.redistrib_planner = redistrib_planner
self.summary_writer = summary_writer self.summary_writer = summary_writer
@ -306,7 +306,7 @@ class ModelFunctionCall:
).partitions ).partitions
return buf_indices, sample, partitions return buf_indices, sample, partitions
async def run_step(self, buf_indices, sample): async def run_step(self, buf_indices, sample, buffer_id: int):
rpc = self.rpc rpc = self.rpc
topo = self.model_topos[rpc.model_name] topo = self.model_topos[rpc.model_name]
ctrl = self.rpc_ctrl ctrl = self.rpc_ctrl
@ -317,7 +317,7 @@ class ModelFunctionCall:
] ]
dp_head_indices = [ dp_head_indices = [
topo.get_rank(data=i, pipe=topo.get_dim("pipe") - 1, model=0) topo.get_rank(data=i, pipe=topo.get_dim("pipe") - 1, tensor=0)
for i in range(self.dp_size) for i in range(self.dp_size)
] ]
@ -348,12 +348,7 @@ class ModelFunctionCall:
if i not in dests: if i not in dests:
dests[i] = [] dests[i] = []
# NOTE: The data loaded from the dataset may be unevenly distributed across DP ranks. pattern = "gather-scatter"
# Only bcast works in this case.
if rpc.is_src:
pattern = "bcast"
else:
pattern = "gather-scatter"
data_transfer_plan = self.redistrib_planner.derive_plan( data_transfer_plan = self.redistrib_planner.derive_plan(
dests, dests,
keys=rpc.input_keys, keys=rpc.input_keys,
@ -362,14 +357,14 @@ class ModelFunctionCall:
blogger.info(f"Data tranfer plan for `{rpc.name}`: {data_transfer_plan}.") blogger.info(f"Data tranfer plan for `{rpc.name}`: {data_transfer_plan}.")
# Update storage tracker for transferred data. # Update storage tracker for transferred data.
if rpc.is_src: if pattern == "bcast":
# NOTE: since the data we loaded may be unevenly distributed across DP ranks, # NOTE: since the data we loaded may be unevenly distributed across DP ranks,
# we should change the owner of the data to the src RPC. # we should change the owner of the data to the src RPC.
for i in range(topo.world_size()): for i in range(topo.world_size()):
h = ModelShardID.from_parallelism_rank( h = ModelShardID.from_parallelism_rank(
model_name=rpc.model_name, topo=topo, parallelism_rank=i model_name=rpc.model_name, topo=topo, parallelism_rank=i
) )
is_dp_head = h.mp_rank == 0 and h.pp_rank == topo.get_dim("pipe") - 1 is_dp_head = h.tp_rank == 0 and h.pp_rank == topo.get_dim("pipe") - 1
gpu_id = self.msid2mwid[h] gpu_id = self.msid2mwid[h]
for key in rpc.input_keys: for key in rpc.input_keys:
await self.redistrib_planner.storage_tracker.add_data( await self.redistrib_planner.storage_tracker.add_data(
@ -414,13 +409,13 @@ class ModelFunctionCall:
responses, time_records = list(zip(*[responses[i] for i in dp_head_indices])) responses, time_records = list(zip(*[responses[i] for i in dp_head_indices]))
# If the returned data is a SequenceSample, it is the data returned by # If the returned data is a SequenceSample, it is the data returned by
# model function calls. The data shoulbe be amended into buffer. # model function calls. The data should be amended into buffer.
# Otherwise, it's the train statistics and should be reduced and logged. # Otherwise, it's the train statistics and should be reduced and logged.
if isinstance(responses[-1], data_api.SequenceSample): if isinstance(responses[-1], data_api.SequenceSample):
# Update storage tracker for generated data. # Update storage tracker for generated data.
for dp_rank, x in enumerate(responses): for dp_rank, x in enumerate(responses):
pp_size = topo.get_dim("pipe") pp_size = topo.get_dim("pipe")
ranks = topo.filter_match(data=dp_rank, pipe=pp_size - 1, model=0) ranks = topo.filter_match(data=dp_rank, pipe=pp_size - 1, tensor=0)
for rank in ranks: for rank in ranks:
h = config_pkg.ModelShardID.from_parallelism_rank( h = config_pkg.ModelShardID.from_parallelism_rank(
model_name=rpc.model_name, topo=topo, parallelism_rank=rank model_name=rpc.model_name, topo=topo, parallelism_rank=rank
@ -434,8 +429,14 @@ class ModelFunctionCall:
is_owner=True, is_owner=True,
) )
res = data_api.SequenceSample.gather(responses) res = data_api.SequenceSample.gather(responses)
else: elif isinstance(responses[0], dict):
res = data_api.gather_stat(responses) res = data_api.gather_stat(responses)
else:
assert isinstance(responses[0], list)
res = [
data_api.gather_stat([r[i] for r in responses])
for i in range(len(responses[0]))
]
if rpc.log_return_value: if rpc.log_return_value:
if isinstance(res, dict): if isinstance(res, dict):
@ -447,6 +448,17 @@ class ModelFunctionCall:
step=ctrl.step_info.global_step, step=ctrl.step_info.global_step,
summary_writer=self.summary_writer, summary_writer=self.summary_writer,
) )
elif isinstance(res, list):
for j, r in enumerate(res):
logger.info(
f"RPC name {rpc.name} returns ({j}/{len(res)})\n{data_api.tabulate_stats(r)}"
)
offset = len(res) * ctrl.step_info.global_step
logging.log_wandb_tensorboard(
r,
step=offset + j,
summary_writer=self.summary_writer,
)
else: else:
logger.info(f"RPC name {rpc.name} returns\n{res}") logger.info(f"RPC name {rpc.name} returns\n{res}")
@ -456,7 +468,6 @@ class ModelFunctionCall:
time_stats = stats_tracker.export() time_stats = stats_tracker.export()
logging.log_wandb_tensorboard( logging.log_wandb_tensorboard(
time_stats, time_stats,
step=ctrl.step_info.global_step,
summary_writer=self.summary_writer, summary_writer=self.summary_writer,
) )
@ -475,7 +486,7 @@ class ModelFunctionCall:
await ctrl.train_count.put(1) await ctrl.train_count.put(1)
else: else:
logger.info(f"Amending RPC {rpc.name} output keys: {res.keys}") logger.info(f"Amending RPC {rpc.name} output keys: {res.keys}")
await self.buffer.amend_batch(buf_indices, res.unpack()) await self.buffers[buffer_id].amend_batch(buf_indices, res.unpack())
# Wait for all side-effect requests to finish. # Wait for all side-effect requests to finish.
# Side-effect or empty requests are required for data transfer # Side-effect or empty requests are required for data transfer
@ -483,20 +494,20 @@ class ModelFunctionCall:
# Wait them after the main request to log the oorrect MFC time. # Wait them after the main request to log the oorrect MFC time.
await self.stream.gather_async(other_req_ids) await self.stream.gather_async(other_req_ids)
async def run(self): async def run(self, buffer_id: int):
rpc = self.rpc rpc = self.rpc
topo = self.model_topos[rpc.model_name] topo = self.model_topos[rpc.model_name]
logger.info( logger.info(
f"Running Model RPC, interface_type=#{rpc.interface_type}# " f"Running Model RPC, interface_type=#{rpc.interface_type}# "
f"(dp,mp,pp) = *({topo.get_dim('data')},{topo.get_dim('model')},{topo.get_dim('pipe')})*" f"(dp,tp,pp) = *({topo.get_dim('data')},{topo.get_dim('tensor')},{topo.get_dim('pipe')})*"
) )
consumed = 0 consumed = 0
while True: while True:
buf_indices, sample = await self.buffer.get_batch_for_rpc(rpc) buf_indices, sample = await self.buffers[buffer_id].get_batch_for_rpc(rpc)
await self.run_step(buf_indices, sample) await self.run_step(buf_indices, sample, buffer_id)
consumed += sample.bs consumed += sample.bs
# Ensure that parent RPCs will not be over-consumed. # Ensure that parent RPCs will not be over-consumed.

View File

@ -153,7 +153,7 @@ class ModelWorker(worker_base.Worker):
] ]
for s in self.config.shards: for s in self.config.shards:
_pp_size = s.id.topo.get_dim("pipe") _pp_size = s.id.topo.get_dim("pipe")
if not (s.id.mp_rank == 0 and s.id.pp_rank == _pp_size - 1): if not (s.id.tp_rank == 0 and s.id.pp_rank == _pp_size - 1):
continue continue
if src_rpc.model_name == s.id.model_name: if src_rpc.model_name == s.id.model_name:
self.__has_dataset = True self.__has_dataset = True
@ -195,8 +195,8 @@ class ModelWorker(worker_base.Worker):
return None return None
@property @property
def _mp_rank(self) -> int: def _tp_rank(self) -> int:
return constants.model_parallel_rank() return constants.tensor_parallel_rank()
@property @property
def _pp_rank(self) -> int: def _pp_rank(self) -> int:
@ -211,8 +211,8 @@ class ModelWorker(worker_base.Worker):
return constants.pipe_parallel_world_size() return constants.pipe_parallel_world_size()
@property @property
def _mp_size(self) -> int: def _tp_size(self) -> int:
return constants.model_parallel_world_size() return constants.tensor_parallel_world_size()
@property @property
def _dp_size(self) -> int: def _dp_size(self) -> int:
@ -220,7 +220,7 @@ class ModelWorker(worker_base.Worker):
@property @property
def _is_dp_head(self) -> bool: def _is_dp_head(self) -> bool:
return self._mp_rank == 0 and self._pp_rank == self._pp_size - 1 return self._tp_rank == 0 and self._pp_rank == self._pp_size - 1
@property @property
def _model(self) -> model_api.Model: def _model(self) -> model_api.Model:
@ -302,6 +302,7 @@ class ModelWorker(worker_base.Worker):
constants.set_grid(model_name_, grid) constants.set_grid(model_name_, grid)
# Set up training dataset for source RPCs. # Set up training dataset for source RPCs.
self.__datasets = []
if self.__has_dataset: if self.__has_dataset:
datasets = [ datasets = [
data_api.make_dataset( data_api.make_dataset(
@ -321,31 +322,34 @@ class ModelWorker(worker_base.Worker):
) )
for d in self.config.datasets for d in self.config.datasets
] ]
if len(self.config.datasets) == 1: self.__datasets = datasets
self.__dataset = datasets[0]
else:
self.__dataset = torch.utils.data.ConcatDataset(datasets)
g = torch.Generator() self.__dataloaders: List[
g.manual_seed(seeding.get_seed()) torch.utils.data.DataLoader[data_api.SequenceSample]
dataloader_kwargs = dict( ] = []
shuffle=self.config.shuffle_dataset, for i, d in enumerate(self.__datasets):
generator=g, g = torch.Generator()
) g.manual_seed(
if not isinstance(self.__dataset, PullerStreamDataset): self.config.base_seed + seeding._seed_from_key(f"__dataloader{i}__")
dataloader_kwargs["collate_fn"] = data_api.SequenceSample.gather )
# NOTE: This is *NOT* the actual batch size for training. dataloader_kwargs = dict(
# It is just a proper size to load data to workers. shuffle=self.config.shuffle_dataset,
dataloader_kwargs["batch_size"] = 10240 generator=g,
else: )
dataloader_kwargs["batch_size"] = None if not isinstance(d, PullerStreamDataset):
self.__dataloader = torch.utils.data.DataLoader( dataloader_kwargs["collate_fn"] = data_api.SequenceSample.gather
self.__dataset, **dataloader_kwargs # NOTE: This is *NOT* the actual batch size for training.
) # It is just a proper size to load data to workers.
dataloader_kwargs["batch_size"] = 10240
else:
dataloader_kwargs["batch_size"] = None
self.__dataloaders.append(
torch.utils.data.DataLoader(d, **dataloader_kwargs)
)
self.dataset_size = len(self.__dataset) self.dataset_size = sum(len(d) for d in self.__datasets)
self.__data_generator = enumerate(self.__dataloader) self.__data_generators = [enumerate(d) for d in self.__dataloaders]
self.__models: Dict[ModelName, model_api.Model] = dict() self.__models: Dict[ModelName, model_api.Model] = dict()
self.__model_is_handle: Dict[ModelName, bool] = dict() self.__model_is_handle: Dict[ModelName, bool] = dict()
@ -377,25 +381,26 @@ class ModelWorker(worker_base.Worker):
) )
# Recover indices for dynamic dataset # Recover indices for dynamic dataset
if ( for i, d in enumerate(self.__datasets):
s.id.model_name == self.src_rpc.model_name if (
and self.__has_dataset s.id.model_name == self.src_rpc.model_name
and hasattr(self.__dataset, "filter") and self.__has_dataset
): and hasattr(d, "filter")
dataset_indices_path = os.path.join( ):
constants.MODEL_SAVE_ROOT, dataset_indices_path = os.path.join(
constants.experiment_name(), constants.MODEL_SAVE_ROOT,
constants.trial_name(), constants.experiment_name(),
"dataset_indices", constants.trial_name(),
f"{self._dp_rank}.npy", "dataset_indices",
) f"{self._dp_rank}_{i}.npy",
if os.path.exists(dataset_indices_path):
indices = np.load(dataset_indices_path).tolist()
logger.info(
f"DP rank {self._dp_rank} updating dataset indices upon recover, "
f"size {len(self.__dataset.active_indices)} -> {len(indices)}"
) )
self.__dataset.active_indices = indices if os.path.exists(dataset_indices_path):
indices = np.load(dataset_indices_path).tolist()
logger.info(
f"DP rank {self._dp_rank} updating dataset indices upon recover, "
f"size {len(d.active_indices)} -> {len(indices)}"
)
d.active_indices = indices
if constants.parallelism_rank() == 0: if constants.parallelism_rank() == 0:
self.logger.info( self.logger.info(
@ -537,9 +542,13 @@ class ModelWorker(worker_base.Worker):
cache = [] cache = []
while True: while True:
try: try:
request, data, handled, res, time_record = ( (
self.__request_queue.get_nowait() request,
) data,
handled,
res,
time_record,
) = self.__request_queue.get_nowait()
request: request_reply_stream.Payload request: request_reply_stream.Payload
if not handled: if not handled:
while len(request.pre_hooks) > 0: while len(request.pre_hooks) > 0:
@ -582,9 +591,13 @@ class ModelWorker(worker_base.Worker):
elif request.handle_name == "fetch": elif request.handle_name == "fetch":
dp_rank = int(re.search(r"__data(\d+)__", request.handler).group(1)) dp_rank = int(re.search(r"__data(\d+)__", request.handler).group(1))
assert self.__has_dataset assert self.__has_dataset
assert isinstance(request.data, int), request.data
dataset_id = request.data
# Fetch. # Fetch.
try: try:
self.__dataset_batch_counter, cur_sample = next(self.__data_generator) self.__dataset_batch_counter, cur_sample = next(
self.__data_generators[dataset_id]
)
except StopIteration: except StopIteration:
# Upon the first fetch request, filter dataset and create dataloader. # Upon the first fetch request, filter dataset and create dataloader.
eval_scores_path = os.path.join( eval_scores_path = os.path.join(
@ -598,39 +611,43 @@ class ModelWorker(worker_base.Worker):
constants.experiment_name(), constants.experiment_name(),
constants.trial_name(), constants.trial_name(),
"dataset_indices", "dataset_indices",
f"{dp_rank}.npy", f"{dp_rank}_{dataset_id}.npy",
) )
os.makedirs(os.path.dirname(dataset_indices_path), exist_ok=True) os.makedirs(os.path.dirname(dataset_indices_path), exist_ok=True)
if hasattr(self.__dataset, "filter") and os.path.exists( if hasattr(self.__datasets[dataset_id], "filter") and os.path.exists(
eval_scores_path eval_scores_path
): ):
# Don't filter dataset on the first poll after recover. # Don't filter dataset on the first poll after recover.
with open(eval_scores_path, "r", encoding="utf-8") as f: with open(eval_scores_path, "r", encoding="utf-8") as f:
dataset_eval_scores = json.load(f) dataset_eval_scores = json.load(f)
self.__dataset.filter(dataset_eval_scores) self.__datasets[dataset_id].filter(dataset_eval_scores)
# Save the dataset indices after filtering # Save the dataset indices after filtering
np.save( np.save(
dataset_indices_path, dataset_indices_path,
self.__dataset.active_indices, self.__datasets[dataset_id].active_indices,
) )
g = torch.Generator() g = torch.Generator()
g = g.set_state(self.__dataloader.generator.get_state()) g = g.set_state(self.__dataloaders[dataset_id].generator.get_state())
dataloader_kwargs = dict( dataloader_kwargs = dict(
shuffle=self.config.shuffle_dataset, shuffle=self.config.shuffle_dataset,
generator=g, generator=g,
) )
if not isinstance(self.__dataset, PullerStreamDataset): if not isinstance(self.__datasets[dataset_id], PullerStreamDataset):
dataloader_kwargs["collate_fn"] = data_api.SequenceSample.gather dataloader_kwargs["collate_fn"] = data_api.SequenceSample.gather
# NOTE: This is *NOT* the actual batch size for training. # NOTE: This is *NOT* the actual batch size for training.
# It is just a proper size to load data to workers. # It is just a proper size to load data to workers.
dataloader_kwargs["batch_size"] = 10240 dataloader_kwargs["batch_size"] = 10240
else: else:
dataloader_kwargs["batch_size"] = None dataloader_kwargs["batch_size"] = None
self.__dataloader = torch.utils.data.DataLoader( self.__dataloaders[dataset_id] = torch.utils.data.DataLoader(
self.__dataset, **dataloader_kwargs self.__datasets[dataset_id], **dataloader_kwargs
)
self.__data_generators[dataset_id] = enumerate(
self.__dataloaders[dataset_id]
)
self.__dataset_batch_counter, cur_sample = next(
self.__data_generators[dataset_id]
) )
self.__data_generator = enumerate(self.__dataloader)
self.__dataset_batch_counter, cur_sample = next(self.__data_generator)
if isinstance(cur_sample, data_api.SequenceSample): if isinstance(cur_sample, data_api.SequenceSample):
samples = cur_sample.unpack() samples = cur_sample.unpack()
@ -663,7 +680,10 @@ class ModelWorker(worker_base.Worker):
) )
elif request.handle_name == "spec": elif request.handle_name == "spec":
# Raw dataset without filtering. # Raw dataset without filtering.
res = self.dataset_size res = {
"n_datasets": len(self.__datasets),
"dataset_size": self.dataset_size,
}
elif request.handle_name == "clear_data_cache": elif request.handle_name == "clear_data_cache":
with cuda_tmarked("clear_data_cache", CUDATimeMarkType.misc): with cuda_tmarked("clear_data_cache", CUDATimeMarkType.misc):
ids = request.data ids = request.data
@ -772,8 +792,10 @@ class ModelWorker(worker_base.Worker):
if hook == "evaluate": if hook == "evaluate":
assert request.handle_name == "train_step", request.handle_name assert request.handle_name == "train_step", request.handle_name
assert isinstance(ret, dict), ret assert isinstance(ret, dict), ret
assert isinstance(res, dict), res if isinstance(res, dict):
res.update(ret) res.update(ret)
else:
res[0].update(ret)
time_record[ time_record[
f"timeperf/{request.handler.model_name.role}_{request.handle_name}/post-{hook}" f"timeperf/{request.handler.model_name.role}_{request.handle_name}/post-{hook}"
] += (time.perf_counter() - tik) ] += (time.perf_counter() - tik)
@ -803,13 +825,7 @@ class ModelWorker(worker_base.Worker):
with constants.model_scope(model_name): with constants.model_scope(model_name):
dist.barrier(group=constants.cpu_parallelism_group()) dist.barrier(group=constants.cpu_parallelism_group())
if constants.parallelism_rank() == 0: if constants.parallelism_rank() == 0:
name_resolve.add( name_resolve.add(name, str(global_step), replace=True)
name,
str(global_step),
delete_on_exit=False,
keepalive_ttl=30,
replace=True,
)
time_record[ time_record[
f"timeperf/{request.handler.model_name.role}_{request.handle_name}/param-sync-save" f"timeperf/{request.handler.model_name.role}_{request.handle_name}/param-sync-save"
] += (time.perf_counter() - tik) ] += (time.perf_counter() - tik)
@ -867,7 +883,7 @@ class ModelWorker(worker_base.Worker):
if len(self.__performance_recorder) == 0: if len(self.__performance_recorder) == 0:
self.__performance_recorder["info"] = { self.__performance_recorder["info"] = {
"pipeline_size": self._pp_size, "pipeline_size": self._pp_size,
"model_size": self._mp_size, "model_size": self._tp_size,
"data_size": self._dp_size, "data_size": self._dp_size,
"rank": constants.parallelism_rank(), "rank": constants.parallelism_rank(),
"sequence_parallel_enabled": constants.sequence_parallel(), "sequence_parallel_enabled": constants.sequence_parallel(),
@ -1374,9 +1390,13 @@ class ModelWorker(worker_base.Worker):
rescheduled_requests = [] rescheduled_requests = []
other_requests = [] other_requests = []
for _ in range(self.__request_queue.qsize()): for _ in range(self.__request_queue.qsize()):
request, data, handled, res, time_record = ( (
self.__request_queue.get_nowait() request,
) data,
handled,
res,
time_record,
) = self.__request_queue.get_nowait()
if request.handle_name not in ["inference", "generate", "train_step"]: if request.handle_name not in ["inference", "generate", "train_step"]:
other_requests.append((request, data, handled, res, time_record)) other_requests.append((request, data, handled, res, time_record))
else: else:
@ -1399,9 +1419,13 @@ class ModelWorker(worker_base.Worker):
# we can correctly log the time consumption in the master worker. # we can correctly log the time consumption in the master worker.
while True: while True:
try: try:
request, data, handled, res, time_record = ( (
self.__request_queue.get_nowait() request,
) data,
handled,
res,
time_record,
) = self.__request_queue.get_nowait()
self.handle_blocking_request( self.handle_blocking_request(
request, data, handled, res, time_record request, data, handled, res, time_record
) )

View File

@ -103,9 +103,14 @@ class PartialRolloutManager:
): ):
from realhf.impl.model.backend.sglang import SGLangAPIClient from realhf.impl.model.backend.sglang import SGLangAPIClient
max_new_tokens = min(raw_gconfig.max_new_tokens, self.new_tokens_per_chunk)
max_new_tokens = min(
max_new_tokens,
raw_gconfig.max_new_tokens - len(input_ids) + len(prompt_ids),
)
gconfig = raw_gconfig.new( gconfig = raw_gconfig.new(
n=1, n=1,
max_new_tokens=min(raw_gconfig.max_new_tokens, self.new_tokens_per_chunk), max_new_tokens=max_new_tokens,
) )
assert self.tokenizer.pad_token_id is not None assert self.tokenizer.pad_token_id is not None
assert self.tokenizer.eos_token_id is not None assert self.tokenizer.eos_token_id is not None
@ -130,6 +135,7 @@ class PartialRolloutManager:
group_idx=group_idx, group_idx=group_idx,
raw_gconfig=raw_gconfig, raw_gconfig=raw_gconfig,
server_url=url, server_url=url,
version=cur_server_version,
), ),
), ),
stream=False, stream=False,
@ -190,6 +196,7 @@ class PartialRolloutManager:
s: APIGenerateOutput = await task s: APIGenerateOutput = await task
group_idx = s.metadata["group_idx"] group_idx = s.metadata["group_idx"]
raw_gconfig = s.metadata["raw_gconfig"] raw_gconfig = s.metadata["raw_gconfig"]
previous_version = s.metadata["version"]
assert s.group_size == 1 assert s.group_size == 1
no_eos = s.no_eos[0] no_eos = s.no_eos[0]
@ -202,20 +209,27 @@ class PartialRolloutManager:
if no_eos and gen_len < raw_gconfig.max_new_tokens: if no_eos and gen_len < raw_gconfig.max_new_tokens:
# Unfinished request due to chunked generation. # Unfinished request due to chunked generation.
# Send it back to continue. # Send it back to continue.
async with aiohttp.ClientSession() as session: req_meta = GenReqMeta(
async with session.post( qid=s.qid,
f"http://{self.gserver_manager_addr}/get_model_version", prompt_len=s.prompt_len,
json=dict(server_url=s.metadata["server_url"]), group_size=raw_gconfig.n,
timeout=ClientTimeout(total=self.timeout, sock_connect=30), new_token_budget=raw_gconfig.max_new_tokens,
) as resp: predicted_new_tokens=None,
resp.raise_for_status() previous_server_url=s.metadata["server_url"],
cur_version = (await resp.json())["version"] previous_version=previous_version,
)
info = await self._schedule_request(req_meta)
cur_version = info["version"]
server_url = info["url"]
if len(s.output_logprobs) > 0: if len(s.output_logprobs) > 0:
prev_logprobs = s.prev_logprobs + s.output_logprobs[0] prev_logprobs = s.prev_logprobs + s.output_logprobs[0]
else: else:
prev_logprobs = [] prev_logprobs = s.prev_logprobs
if prev_logprobs is None:
prev_logprobs = []
await self._issue_generation( await self._issue_generation(
s.metadata["server_url"], server_url,
s.qid, s.qid,
group_idx, group_idx,
s.prompt_ids, s.prompt_ids,
@ -240,9 +254,10 @@ class PartialRolloutManager:
try: try:
qid, prompt_token_ids, gconfig = self.request_queue.get_nowait() qid, prompt_token_ids, gconfig = self.request_queue.get_nowait()
req_meta = GenReqMeta( req_meta = GenReqMeta(
qid=qid,
prompt_len=len(prompt_token_ids), prompt_len=len(prompt_token_ids),
group_size=gconfig.n, group_size=gconfig.n,
new_token_budget=self.new_tokens_per_chunk, new_token_budget=gconfig.max_new_tokens,
predicted_new_tokens=None, predicted_new_tokens=None,
) )
dst_server_info = await self._schedule_request(req_meta) dst_server_info = await self._schedule_request(req_meta)

View File

@ -171,7 +171,9 @@ class NameResolvingZmqPuller(ZMQJsonPuller):
name = names.push_pull_stream( name = names.push_pull_stream(
experiment_name, trial_name, stream_name=f"puller{puller_index}" experiment_name, trial_name, stream_name=f"puller{puller_index}"
) )
host, port = network.gethostip(), network.find_free_port() host, port = network.gethostip(), network.find_free_port(
experiment_name=experiment_name, trial_name=trial_name
)
addr = f"{host}:{port}" addr = f"{host}:{port}"
name_resolve.add(name, addr) name_resolve.add(name, addr)
super().__init__(host, port, **kwargs) super().__init__(host, port, **kwargs)

View File

@ -189,10 +189,11 @@ class RolloutWorker(AsyncWorker):
assert data_id not in self.rollout_tasks assert data_id not in self.rollout_tasks
return cur_sample return cur_sample
async def allocate_new_rollout(self) -> bool: async def allocate_new_rollout(self, qid) -> bool:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get( async with session.post(
f"http://{self.gserver_manager_addr}/allocate_rollout", f"http://{self.gserver_manager_addr}/allocate_rollout",
json=dict(qid=qid),
timeout=ClientTimeout( timeout=ClientTimeout(
total=self.config.rollout_request_timeout, sock_connect=30 total=self.config.rollout_request_timeout, sock_connect=30
), ),
@ -231,10 +232,10 @@ class RolloutWorker(AsyncWorker):
self._cur_data = self.load_next_data() self._cur_data = self.load_next_data()
if self._cur_data is not None: if self._cur_data is not None:
can_rollout = await self.allocate_new_rollout() data = self._cur_data
qid = data.ids[0]
can_rollout = await self.allocate_new_rollout(qid)
if can_rollout: if can_rollout:
data = self._cur_data
qid = data.ids[0]
self.act_queues[qid] = asyncio.Queue(1024) self.act_queues[qid] = asyncio.Queue(1024)
task = asyncio.create_task(self.rollout_task(qid, data)) task = asyncio.create_task(self.rollout_task(qid, data))
@ -265,7 +266,11 @@ class RolloutWorker(AsyncWorker):
accepted = True accepted = True
self.push_stream.push([traj.as_json_compatible() for traj in trajs]) self.push_stream.push([traj.as_json_compatible() for traj in trajs])
info = dict(qid=qid, accepted=accepted) n_tokens = 0
for traj in trajs:
seqlens = [sum(datapack.flat2d(ss)) for ss in traj.seqlens.values()]
n_tokens += max(seqlens)
info = dict(qid=qid, accepted=accepted, n_tokens=n_tokens)
async with aiohttp.ClientSession( async with aiohttp.ClientSession(
f"http://{self.gserver_manager_addr}" f"http://{self.gserver_manager_addr}"
) as session: ) as session:

View File

@ -36,7 +36,6 @@ ray
redis redis
scipy scipy
seaborn seaborn
setuptools>=61.0
tqdm tqdm
networkx==3.3 networkx==3.3
matplotlib matplotlib
@ -59,3 +58,5 @@ protobuf<3.21
rich rich
orjson>=3.10.16 orjson>=3.10.16
flask flask
setuptools>=62.3.0,<75.9
func_timeout

View File

@ -246,30 +246,6 @@ if not no_ext and _is_cuda():
ext_modules.append(interval_op_cuda) ext_modules.append(interval_op_cuda)
if not no_ext: if not no_ext:
search_extension = setuptools.Extension(
name="realhf._C.mdm_search",
sources=[
"csrc/search/search.cpp",
"csrc/search/rpc.cpp",
"csrc/search/device_mesh.cpp",
"csrc/search/simulate.cpp",
],
language="c++",
extra_compile_args=[
"-O3",
"-Wall",
"-shared",
"-std=c++11",
"-fPIC",
"-std=c++17",
],
include_dirs=[
os.path.join(os.path.abspath(os.path.dirname(__file__)), "csrc", "search"),
get_pybind11_include_path(),
],
)
ext_modules.append(search_extension)
interval_extension = setuptools.Extension( interval_extension = setuptools.Extension(
name="realhf._C.interval_op", name="realhf._C.interval_op",
sources=[ sources=[

Some files were not shown because too many files have changed in this diff Show More