Support asynchronous RL training, Qwen3, and the latest SGLang (#47)

* feat: one buffer for each task

* feat: support "one buffer for each task" for async

* make kv_cache_dtype configurable

Signed-off-by: Tiwei Bie <tiwei.btw@antgroup.com>

* style: use plural form

fix: use _seed_from_key to set different seeds for data loaders
fix: call load_data for one buffer each time

* PullRequest: 125 Support running async experiments in the 2407 image.

Merge branch fw/async2407 of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/125

Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .

* fix: handle multiple datasets in recover indices
fix: `isinstance(self.__datasets, PullerStreamDataset)`
feat: use the "spec" request to obtain the number of datasets
fix: revert rollout worker

* fix: revert async_rl_exp.py

* fix flag for list (cuda_graph_bs)

* format

* [FIX] fix async task reward [sglang bf16-> fp16]

* fix: define `self.__datasets` in advance

* PullRequest: 130 [Refactor] Remove deprecated search related code

Merge branch mzy/remove-search of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/130

Signed-off-by: 博惟 <bowei.fw@antgroup.com>


* remove search related

* PullRequest: 131 [Refactor] Change terminology "model parallel" into "tensor parallel" to align with megatron.

Merge branch mzy/mp-to-tp of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/131?tab=comment

Signed-off-by: 博惟 <bowei.fw@antgroup.com>


* change mp to tp
* .
* .

* PullRequest: 142 Fix an error for megatron backend destroy

Merge branch fw/fix-meagatron-destroy of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/142

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .

* PullRequest: 143 Fix the port conflict issue of generation servers

Merge branch fw/fix-gen-port of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/143?tab=comment

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* somehow fix the port issue
* add clearance period
* .
* .

* PullRequest: 145 Add code environment

Merge branch fw/code-env of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/145?tab=comment

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* add code env
* somehow fix the port issue
* fix

* PullRequest: 144 Add decoupled PPO loss

Merge branch fw/decoupled-ppo-loss of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/144?tab=comment

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* fix ppo step logging, nan in stats tracker, and add decoupled loss
* .
* somehow fix the port issue
* fix typo

* PullRequest: 146 Merge SLURM logs and save experiment configs in yaml format.

Merge branch fw/better-logging of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/146

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* merge all slurm logs into one
* write config to yaml

* PullRequest: 141 Merge changes during NeurIPS submission

Merge branch fw/async-dev of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/141

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* .
* .
* .
* .
* .
* .
* .
* .
* .
* update script
* .
* .
* .
* .
* [ADD] add least req scheduling
* fix test genreq
* .
* .
* fix stats tracker nan
* .
* .
* .
* .
* .
* .
* .
* uppper clip decoupled objective
* add throughput exp script
* .
* remove behav upper clip param
* .
* .
* .
* plot curve
* update thpt script
* .
* master worker raise error when exiting
* update script
* add gen throughput logging
* .
* .
* add decoupled wandb data
* .
* fix port issue and add no training option
* .
* enlarge ttl
* remove gserver manager await staled
* update weights in groups
* .
* .
* .
* add port clearance period
* .
* .
* .
* add plot script
* add sft throughput eval
* .
* log tokens in null interface
* 消融实验和interruptible generation
* 画图脚本/运行脚本/数据结果
* .
* remove scripts
* add port test
* remove force_sync_reward
* revert some changes
* .
* revert
* revert fix
* fix
* revert
* fix typo

* support qwen3 training

* PullRequest: 147 Support interruption in SGLang and fix a KeyError in gather-scatter communication

Merge branch fw/sglang046-with-abort-request of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/147?tab=diff

Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com>


* fix ppo step logging, nan in stats tracker, and add decoupled loss
* .
* somehow fix the port issue
* initial commit
* add interupt request
* fix data transfer issue
* max concurrent rollouts defaults to train batch size
* merge main
* add patch
* fix patch typp
* revert sglang
* fix typo
* fix minor typo
* .
* pip show editable sglang path

* PullRequest: 149 fix: code faas max_retries

Merge branch xss/fix_code_verifier of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/149

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* fix: code faas max_retries

* PullRequest: 150 [Bug Fix] Fix key errors in `_run_scatter` in data transfer

Merge branch mzy/fix-scatter-groups of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/150

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* fix scatter groups key error

* fix test

* .

* PullRequest: 151 Fix Qwen3 import error when using transformers with a lower version

Merge branch fw/fix-qwen3 of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/151

Reviewed-by: 温差 <xushusheng.xss@antgroup.com>


* merge all slurm logs into one
* write config to yaml
* .

* PullRequest: 152 Support sglang0.4.6 and fix master_worker import error

Merge branch adopt_sglang046 of git@code.alipay.com:inclusionAI/AReaL.git into main
https://code.alipay.com/inclusionAI/AReaL/pull_requests/152

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* Support sglang0.4.6 and fix master_worker import error
* remove disable_mla option

---------

Signed-off-by: Tiwei Bie <tiwei.btw@antgroup.com>
Co-authored-by: wanghuaijie.whj <wanghuaijie.whj@antgroup.com>
Co-authored-by: Tiwei Bie <tiwei.btw@antgroup.com>
Co-authored-by: kira.gw <kira.gw@antgroup.com>
Co-authored-by: shenxujie.sxj <shenxujie.sxj@antgroup.com>
Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com>
Co-authored-by: sam.gjx <sam.gjx@antgroup.com>
Co-authored-by: 温差 <xushusheng.xss@antgroup.com>
Co-authored-by: 履渊 <yuhong.gyh@antgroup.com>
This commit is contained in:
Wei Fu 2025-05-26 09:45:13 +08:00 committed by GitHub
parent f7ab31d050
commit c60d128b14
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
112 changed files with 2361 additions and 4559 deletions

View File

@ -1,87 +0,0 @@
#include <device_mesh.hpp>
#include <cassert>
#include <iostream>
// DeviceMesh::DeviceMesh()
// : device_mesh_name(""), n_nodes(0), n_gpus(0), node_names({}), gpu_ids({}) {
// };
DeviceMesh::DeviceMesh(int n_nodes, int n_gpus_per_node, std::vector<std::vector<int>> mapping,
std::string global_mesh_name, std::string name)
: n_nodes(n_nodes),
n_gpus_per_node(n_gpus_per_node),
mapping(mapping),
global_mesh_name(global_mesh_name),
name(name) {
assert(n_nodes == static_cast<int>(mapping.size()));
for (int i = 0; i < n_nodes; i++) {
assert(n_gpus_per_node == static_cast<int>(mapping[i].size()));
}
};
bool is_all_overlap(std::vector<DeviceMesh *> device_meshes, DeviceMesh device_mesh) {
for (DeviceMesh *other : device_meshes) {
if (!device_mesh.overlap(*other)) return false;
}
return true;
};
bool is_all_overlap(std::unordered_set<DeviceMesh *> device_meshes, DeviceMesh device_mesh) {
for (DeviceMesh *other : device_meshes) {
if (!device_mesh.overlap(*other)) return false;
}
return true;
};
bool DeviceMesh::contain(const DeviceMesh &other) {
// check whether one device mapping is contained by another by
// checking 1. whether global_mesh_name is identical
// 2. whether mapping of one device mesh is contained by the other one
if (global_mesh_name != other.global_mesh_name) return false;
for (int i = 0; i < n_nodes; i++) {
for (int j = 0; j < n_gpus_per_node; j++) {
if (mapping[i][j] == 0 && other.mapping[i][j] == 1) return false;
}
}
return true;
};
bool DeviceMesh::contained_by(const DeviceMesh &other) {
if (global_mesh_name != other.global_mesh_name) return false;
for (int i = 0; i < n_nodes; i++) {
for (int j = 0; j < n_gpus_per_node; j++) {
if (mapping[i][j] == 1 && other.mapping[i][j] == 0) return false;
}
}
return true;
};
bool DeviceMesh::overlap(const DeviceMesh &other) {
if (global_mesh_name != other.global_mesh_name) return false;
for (int i = 0; i < n_nodes; i++) {
for (int j = 0; j < n_gpus_per_node; j++) {
if (mapping[i][j] == 1 && other.mapping[i][j] == 1) return true;
}
}
return false;
};
ModelParallelStrategy::ModelParallelStrategy(int num_pp, int num_dp, int num_mp)
: num_pp(num_pp), num_dp(num_dp), num_mp(num_mp) {};
bool ModelParallelStrategy::operator==(const ModelParallelStrategy &other) const {
return num_pp == other.num_pp && num_dp == other.num_dp && num_mp == other.num_mp;
};
bool DeviceMesh::operator==(const DeviceMesh &other) const {
return name == other.name && global_mesh_name == other.global_mesh_name;
};
std::string ModelParallelStrategy::to_string() {
return "num_pp:" + std::to_string(num_pp) + ";" + "num_dp:" + std::to_string(num_dp) + ";"
+ "num_mp:" + std::to_string(num_mp);
};
std::string ModelParallelStrategy::to_key() {
return std::to_string(num_pp) + "," + std::to_string(num_mp) + "," + std::to_string(num_dp);
}

View File

@ -1,49 +0,0 @@
#ifndef DEVICE_MESH_HPP
#define DEVICE_MESH_HPP
#include <string>
#include <vector>
#include <unordered_map>
#include <unordered_set>
// #include <rpc.hpp>
class RPCInstance;
class DeviceMesh {
public:
int n_nodes;
int n_gpus_per_node;
std::vector<std::vector<int>> mapping;
std::string global_mesh_name;
std::string name;
RPCInstance *pre_task = nullptr;
// DeviceMesh();
DeviceMesh(int n_nodes, int n_gpus_per_node, std::vector<std::vector<int>> mapping,
std::string global_mesh_name, std::string name);
bool overlap(const DeviceMesh &other);
bool contain(const DeviceMesh &other);
bool contained_by(const DeviceMesh &other);
bool operator==(const DeviceMesh &other) const;
};
bool is_all_overlap(std::vector<DeviceMesh *> device_meshes, DeviceMesh device_mesh);
bool is_all_overlap(std::unordered_set<DeviceMesh *> device_meshes, DeviceMesh device_mesh);
class ModelParallelStrategy {
public:
int num_pp, num_dp, num_mp;
ModelParallelStrategy(int num_pp, int num_dp, int num_mp);
bool operator==(const ModelParallelStrategy &other) const;
std::string to_string();
std::string to_key();
};
class ModelDeviceMapping {};
#endif // DEVICE_MESH_HPP

View File

@ -1,233 +0,0 @@
#include <rpc.hpp>
#include <device_mesh.hpp>
#include <numeric>
#include <algorithm>
#include <iostream>
#include <iomanip>
RPC::RPC(std::string model_name, std::string rpc_name, std::string interface_type)
: model_name(model_name), rpc_name(rpc_name), interface_type(interface_type) {};
RPCExecution::RPCExecution(RPC *rpc_ptr, DeviceMesh &device_mesh,
ModelParallelStrategy &model_parallel_strategy, uint64_t time_cost,
uint64_t mem, uint64_t static_mem)
: rpc_ptr(rpc_ptr),
device_mesh(device_mesh),
model_parallel_strategy(model_parallel_strategy),
time_cost(time_cost),
mem(mem),
static_mem(static_mem) {};
bool OverlapGroup::maybe_add(RPCExecution *rpc_exe) {
if (rpc_executions.empty()) {
rpc_executions.insert(rpc_exe);
device_meshes.insert(&rpc_exe->device_mesh);
mem_static = rpc_exe->static_mem;
mem_active = rpc_exe->mem - rpc_exe->static_mem;
return true;
}
if (is_all_overlap(device_meshes, rpc_exe->device_mesh)) {
rpc_executions.insert(rpc_exe);
// bool dm_in_group = device_meshes.find(&rpc_exe -> device_mesh) != device_meshes.end();
device_meshes.insert(&rpc_exe->device_mesh);
mem_static += rpc_exe->static_mem;
mem_active = std::max(mem_active, rpc_exe->mem - rpc_exe->static_mem);
return true;
}
return false;
};
void DeviceMeshGroup::add_to_groups(RPCExecution *rpc_exe) {
if (overlap_groups.empty()) {
OverlapGroup *og = new OverlapGroup();
og->maybe_add(rpc_exe);
overlap_groups.push_back(og);
return;
}
std::vector<OverlapGroup *> tmp_new_ogs;
for (OverlapGroup *og : overlap_groups) {
// OverlapGroup og_copy = *og;
bool update = og->maybe_add(rpc_exe);
if (!update) {
tmp_new_ogs.push_back(new OverlapGroup());
tmp_new_ogs.back()->maybe_add(rpc_exe);
for (RPCExecution *rpc_exe : og->rpc_executions) { tmp_new_ogs.back()->maybe_add(rpc_exe); }
}
tmp_new_ogs.push_back(og);
}
overlap_groups.clear();
for (OverlapGroup *og : tmp_new_ogs) { overlap_groups.push_back(og); }
};
void GroupedRPCExecutions::resolve(RPCExecution *rpc_exe) {}
void GroupedRPCExecutions::add(RPCExecution *rpc_exe) { group.add_to_groups(rpc_exe); };
void GroupedRPCExecutions::offload(std::string model_name) {};
uint64_t GroupedRPCExecutions::total_mem_cost() {
uint64_t max_mem = 0;
for (auto &og : group.overlap_groups) {
// double og_ma = (og -> mem_active/(1024*1024))/1024.0;
// double og_ms = (og -> mem_static/(1024*1024))/1024.0;
// std::cout << "og size " << og -> rpc_executions.size()
// << " mem active " << og_ma << " GB"
// << " mem static " << og_ms << " GB" << std::endl;
if (og->mem_active + og->mem_static > max_mem) { max_mem = og->mem_active + og->mem_static; }
}
return max_mem;
};
RPCInstance::RPCInstance(RPC *rpc_ptr, int id, std::string name)
: rpc_ptr(rpc_ptr), id(id), name(name) {};
void RPCInstance::remove_parent(RPCInstance *parent) {
auto it = std::find(parents.begin(), parents.end(), parent);
if (it != parents.end()) { parents.erase(it); }
};
void RPCInstance::remove_child(RPCInstance *child) {
auto it = std::find(children.begin(), children.end(), child);
if (it != children.end()) { children.erase(it); }
};
void RPCInstance::add_parent(RPCInstance *parent) { parents.push_back(parent); };
void RPCInstance::add_child(RPCInstance *child) { children.push_back(child); };
void RPCInstance::remove_tmp_child(RPCInstance *child) {
auto it = std::find(tmp_children.begin(), tmp_children.end(), child);
if (it != tmp_children.end()) { tmp_children.erase(it); }
};
void RPCInstance::remove_tmp_parent(RPCInstance *parent) {
auto it = std::find(tmp_parents.begin(), tmp_parents.end(), parent);
if (it != tmp_parents.end()) { tmp_parents.erase(it); }
};
void RPCInstance::add_tmp_parent(RPCInstance *parent) { tmp_parents.push_back(parent); };
void RPCInstance::add_tmp_child(RPCInstance *child) { tmp_children.push_back(child); };
uint64_t parameter_sync_cost(uint64_t model_size, RPCExecution *src, RPCExecution *dst,
std::unordered_map<std::string, uint64_t> &cost_table) {
// 7b size 13738442752 Bytes
// double size_multiplier = double(model_size) / 13738442752.0;
std::string model_key = std::to_string(model_size);
std::string src_key = src->model_parallel_strategy.to_key();
std::string dst_key = dst->model_parallel_strategy.to_key();
if (src_key == dst_key) return 0;
std::string key = model_key + "," + src_key + "," + dst_key;
if (cost_table.find(key) == cost_table.end()) {
// std::cout << "key " << key << " not found" << std::endl;
return 0;
}
return cost_table[key];
}
void RPCInstance::resolve_parameter_sync(std::vector<RPCInstance *> tmp_graph,
std::unordered_map<std::string, uint64_t> &cost_table) {
// add parameter synchronization edges
if (!param_sync) return;
// dst to train
uint64_t from_cost =
parameter_sync_cost(param_sync_size, param_sync_rpc_exe_ptr, rpc_exe_ptr, cost_table);
uint64_t to_cost =
parameter_sync_cost(param_sync_size, rpc_exe_ptr, param_sync_rpc_exe_ptr, cost_table);
// if (param_sync_cost > 0)
// std::cout << "Param sync cost " << param_sync_cost << " from "
// << param_sync_rpc_exe_ptr -> rpc_ptr -> rpc_name << " to "
// << rpc_exe_ptr -> rpc_ptr -> rpc_name << std::endl;
// add param sync from src to dst
RPCExecution *from_src_exe =
new RPCExecution(rpc_ptr, param_sync_rpc_exe_ptr->device_mesh,
param_sync_rpc_exe_ptr->model_parallel_strategy, from_cost, 0, 0);
RPCInstance *from_src = new RPCInstance(rpc_ptr, id, name + ":from_src");
from_src->rpc_exe_ptr = from_src_exe;
RPCExecution *from_dst_exe = new RPCExecution(
rpc_ptr, rpc_exe_ptr->device_mesh, rpc_exe_ptr->model_parallel_strategy, from_cost, 0, 0);
RPCInstance *from_dst = new RPCInstance(rpc_ptr, id, name + ":from_dst");
from_dst->rpc_exe_ptr = from_dst_exe;
// bool overlap = src -> rpc_exe_ptr -> device_mesh.overlap(
// dst -> rpc_exe_ptr -> device_mesh);
for (RPCInstance *parent : parents) {
parent->remove_tmp_child(this);
from_src->add_tmp_parent(parent);
from_dst->add_tmp_parent(parent);
parent->add_tmp_child(from_src);
parent->add_tmp_child(from_dst);
}
this->tmp_parents.clear();
from_src->add_tmp_child(this);
from_dst->add_tmp_child(this);
this->add_tmp_parent(from_src);
this->add_tmp_parent(from_dst);
tmp_graph.push_back(from_src);
tmp_graph.push_back(from_dst);
tmp_ris.push_back(from_src);
tmp_ris.push_back(from_dst);
tmp_exes.push_back(from_src_exe);
tmp_exes.push_back(from_dst_exe);
// add param sync from dst to src
RPCExecution *to_src_exe = new RPCExecution(rpc_ptr, rpc_exe_ptr->device_mesh,
rpc_exe_ptr->model_parallel_strategy, to_cost, 0, 0);
RPCInstance *to_src = new RPCInstance(rpc_ptr, id, name + ":to_src");
to_src->rpc_exe_ptr = to_src_exe;
RPCExecution *to_dst_exe =
new RPCExecution(rpc_ptr, param_sync_rpc_exe_ptr->device_mesh,
param_sync_rpc_exe_ptr->model_parallel_strategy, to_cost, 0, 0);
RPCInstance *to_dst = new RPCInstance(rpc_ptr, id, name + ":to_dst");
to_dst->rpc_exe_ptr = to_dst_exe;
for (RPCInstance *child : children) {
child->remove_tmp_parent(this);
to_src->add_tmp_child(child);
to_dst->add_tmp_child(child);
child->add_tmp_parent(to_src);
child->add_tmp_parent(to_dst);
}
this->tmp_children.clear();
to_src->add_tmp_parent(this);
to_dst->add_tmp_parent(this);
this->add_tmp_child(to_src);
this->add_tmp_child(to_dst);
tmp_graph.push_back(to_src);
tmp_graph.push_back(to_dst);
tmp_ris.push_back(to_src);
tmp_ris.push_back(to_dst);
tmp_exes.push_back(to_src_exe);
tmp_exes.push_back(to_dst_exe);
}
CommStats::CommStats(uint64_t local_send, uint64_t local_recv, uint64_t remote_send,
uint64_t remote_recv, uint64_t offload_store, uint64_t offload_load)
: local_send(local_send),
local_recv(local_recv),
remote_send(remote_send),
remote_recv(remote_recv),
offload_store(offload_store),
offload_load(offload_load) {};
// ModelConfig::ModelConfig(std::string model_name,
// uint64_t param_size_bytes)
// : model_name(model_name), param_size_bytes(param_size_bytes) {
// };
std::string RPCExecution::to_string() {
return rpc_ptr->rpc_name + " on " + device_mesh.name
+ ", parallel strategy: " + model_parallel_strategy.to_string();
};

View File

@ -1,121 +0,0 @@
#ifndef RPC_HPP
#define RPC_HPP
#include <string>
#include <vector>
#include <unordered_map>
#include <device_mesh.hpp>
class CommStats {
public:
uint64_t local_send, local_recv, remote_send, remote_recv, offload_store, offload_load;
CommStats(uint64_t local_send, uint64_t local_recv, uint64_t remote_send, uint64_t remote_recv,
uint64_t offload_store, uint64_t offload_load);
};
class RPC {
public:
std::string model_name;
std::string rpc_name;
// interface_type: 0=generate, 1=train_step, 2=inference
std::string interface_type;
RPC(std::string model_name, std::string rpc_name, std::string interface_type);
};
class RPCExecution {
public:
RPC *rpc_ptr;
DeviceMesh &device_mesh;
ModelParallelStrategy &model_parallel_strategy;
uint64_t time_cost, mem, static_mem;
RPCExecution(RPC *rpc_ptr, DeviceMesh &device_mesh,
ModelParallelStrategy &model_parallel_strategy, uint64_t time_cost, uint64_t mem,
uint64_t static_mem);
std::string to_string();
};
class OverlapGroup {
public:
std::unordered_set<RPCExecution *> rpc_executions;
std::unordered_set<DeviceMesh *> device_meshes;
uint64_t mem_static;
uint64_t mem_active;
bool maybe_add(RPCExecution *rpc_exe);
};
class DeviceMeshGroup {
public:
// std::string device_mesh_name;
std::vector<OverlapGroup *> overlap_groups;
void add_to_groups(RPCExecution *rpc_exe);
};
class GroupedRPCExecutions {
public:
// std::unordered_map<std::string, DeviceMeshGroup> dn_to_group;
DeviceMeshGroup group;
void add(RPCExecution *rpc_exe);
void resolve(RPCExecution *rpc_exe);
void offload(std::string model_name);
uint64_t total_mem_cost();
};
class RPCInstance {
public:
RPC *rpc_ptr;
int id;
std::string name;
std::vector<RPCInstance *> children;
std::vector<RPCInstance *> parents;
std::vector<RPCInstance *> tmp_children;
std::vector<RPCInstance *> tmp_parents;
std::vector<RPCInstance *> tmp_ris; // pointers to tmp rpc instances
std::vector<RPCExecution *> tmp_exes; // pointers to tmp rpc executions
RPCExecution *rpc_exe_ptr = nullptr;
RPCExecution *param_sync_rpc_exe_ptr = nullptr;
bool param_sync = false;
uint64_t param_sync_size = 0;
bool offload = false;
uint64_t offload_size = 0;
RPCInstance(RPC *rpc_ptr, int id, std::string name);
uint64_t ready_time = 0, start_time = 0, end_time = 0;
void remove_parent(RPCInstance *parent);
void remove_child(RPCInstance *child);
void add_parent(RPCInstance *parent);
void add_child(RPCInstance *child);
void add_tmp_parent(RPCInstance *parent);
void add_tmp_child(RPCInstance *child);
void remove_tmp_parent(RPCInstance *parent);
void remove_tmp_child(RPCInstance *child);
void resolve_parameter_sync(std::vector<RPCInstance *> tmp_graph,
std::unordered_map<std::string, uint64_t> &cost_table);
// void resolve_offload(std::vector<RPCInstance*> tmp_graph,
// CommStats& comm_stats);
};
uint64_t parameter_sync_cost(uint64_t param_size_bytes, RPCExecution *src, RPCExecution *dst,
std::unordered_map<std::string, uint64_t> &cost_table);
uint64_t remote_param_sync_size(uint64_t size, RPCExecution *src, RPCExecution *dst);
// class ModelConfig {
// std::string model_name;
// uint64_t param_size_bytes;
// ModelConfig(std::string model_name, uint64_t param_size_bytes);
// };
#endif

View File

@ -1,827 +0,0 @@
#include <iostream>
#include <iomanip>
#include <algorithm>
#include <rpc.hpp>
#include <device_mesh.hpp>
#include <simulate.hpp>
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <pybind11/stl.h>
#include <chrono>
#include <limits>
#include <fstream>
#include <cmath>
#include <random>
#include <functional>
#include <thread>
namespace py = pybind11;
uint64_t VALID_COUNT_CAP = 25000000; // 25000000
size_t MAX_EXE_PER_RPC = 1000;
// std::unordered_map<std::string, DeviceMesh*> device_mesh_map;
void print_int_vector(std::vector<int> &vec) {
std::cout << "[";
for (int i = 0; i < static_cast<int>(vec.size()); i++) { std::cout << vec[i] << ", "; }
std::cout << "] ";
};
std::size_t vector_hash(const std::vector<int> &vec) {
std::size_t seed = vec.size();
for (const auto &i : vec) {
seed ^= std::hash<int>{}(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
return seed;
}
long check_memory_bytes() {
std::ifstream statm_file("/proc/self/statm");
long memory_usage_bytes = -1;
if (!statm_file) {
std::cerr << "Failed to open /proc/self/statm\n";
} else {
long size, resident, share, text, lib, data, dt;
statm_file >> size >> resident >> share >> text >> lib >> data >> dt;
// size is in pages, to convert to bytes, multiply by the page size
long page_size = sysconf(_SC_PAGESIZE);
memory_usage_bytes = size * page_size;
}
return memory_usage_bytes;
};
void make_rpc_exe_table(std::unordered_map<std::string, std::vector<RPCExecution *>> &rpc_exe_table,
std::vector<RPCExecution *> &rpc_exes) {
for (auto &rpc_exe : rpc_exes) {
if (rpc_exe_table[rpc_exe->rpc_ptr->rpc_name].size() < MAX_EXE_PER_RPC)
rpc_exe_table[rpc_exe->rpc_ptr->rpc_name].push_back(rpc_exe);
}
for (auto &x : rpc_exe_table) {
std::string rpc_name = x.first;
std::vector<RPCExecution *> &rpc_exe_list = x.second;
// sort first
std::sort(rpc_exe_list.begin(), rpc_exe_list.end(),
[](const RPCExecution *a, const RPCExecution *b) {
if (a->time_cost == b->time_cost)
return a->device_mesh.name < b->device_mesh.name;
else
return a->time_cost < b->time_cost;
});
}
}
void make_sorted_rpc_names(
std::vector<std::string> &sorted_rpc_names,
std::unordered_map<std::string, std::vector<RPCExecution *>> &rpc_exe_table) {
std::vector<std::pair<std::string, uint64_t>> average_time_cost;
for (auto &x : rpc_exe_table) {
std::string rpc_name = x.first;
std::vector<RPCExecution *> &rpc_exe_list = x.second;
uint64_t total_time_cost = 0;
int c = 0;
for (auto &rpc_exe : rpc_exe_list) {
total_time_cost += rpc_exe->time_cost;
c += 1;
if (c > 10) break;
}
average_time_cost.push_back(std::make_pair(rpc_name, total_time_cost / 10));
}
std::sort(average_time_cost.begin(), average_time_cost.end(),
[](const std::pair<std::string, float> &a, const std::pair<std::string, float> &b) {
return a.second > b.second;
});
for (auto &x : average_time_cost) { sorted_rpc_names.push_back(x.first); }
}
void prepare(std::unordered_map<std::string, std::vector<RPCExecution *>> &rpc_exe_table,
std::unordered_map<std::string, RPC *> &rpc_table,
std::vector<std::string> &sorted_rpc_names,
std::unordered_map<std::string, std::vector<RPCInstance *>> &ri_table,
std::unordered_map<std::string, std::vector<RPCInstance *>> &model_name_ri_table,
std::vector<RPC *> rpcs, std::vector<RPCExecution *> rpc_exes,
std::vector<RPCInstance *> graph) {
std::vector<std::pair<std::string, uint64_t>> average_time_cost;
for (auto &rpc : rpcs) { rpc_table[rpc->rpc_name] = rpc; }
make_rpc_exe_table(rpc_exe_table, rpc_exes);
make_sorted_rpc_names(sorted_rpc_names, rpc_exe_table);
for (auto &rpc_instance : graph) {
ri_table[rpc_instance->rpc_ptr->rpc_name].push_back(rpc_instance);
}
for (auto &rpc_instance : graph) {
model_name_ri_table[rpc_instance->rpc_ptr->model_name].push_back(rpc_instance);
}
};
std::vector<SimulateResult> mcmc_search(std::vector<RPC *> rpcs,
std::vector<RPCExecution *> rpc_exes,
std::vector<RPCInstance *> graph,
std::unordered_map<std::string, uint64_t> &cost_table,
std::unordered_map<std::string, uint64_t> model_sizes,
double beta, double time_limit,
MinEndTimeQueue &top_k_queue) {
std::unordered_map<std::string, std::vector<RPCExecution *>> rpc_exe_table;
std::unordered_map<std::string, RPC *> rpc_table;
std::vector<std::string> sorted_rpc_names;
std::unordered_map<std::string, std::vector<RPCInstance *>> ri_table;
std::unordered_map<std::string, std::vector<RPCInstance *>> model_name_ri_table;
std::chrono::duration<double> time_limit_duration(time_limit);
std::vector<SimulateResult> time_cost_cache;
prepare(rpc_exe_table, rpc_table, sorted_rpc_names, ri_table, model_name_ri_table, rpcs, rpc_exes,
graph);
std::vector<int> index;
std::vector<int> min_index;
std::vector<int> max_index;
uint64_t min_index_mem = 0;
uint64_t valid_count = 0;
uint64_t oom_count = 0;
int num_rpcs = static_cast<int>(sorted_rpc_names.size());
uint64_t min_time_cost = std::numeric_limits<uint64_t>::max();
uint64_t max_time_cost = 0;
double avg = 0;
// [6, 2, 3, 2, 12, 15, ]
// [0, 0, 0, 7, 8, 13, ]
// index = { 23, 7, 30, 154, 173, 265 };
// SimulateResult sr1 = simulate(graph, cost_table, model_sizes,
// rpc_table, rpc_exe_table, ri_table,
// model_name_ri_table, sorted_rpc_names,
// index);
// std::cout << "index 1 end time " << sr1.end_time << " mem cost " << sr1.mem_cost << std::endl;
// std::cout << "************************" << std::endl;
// index = { 1, 0, 0, 11, 20, 3 };
// SimulateResult sr2 = simulate(graph, cost_table, model_sizes,
// rpc_table, rpc_exe_table, ri_table,
// model_name_ri_table, sorted_rpc_names,
// index);
// std::cout << "index 2 end time " << sr2.end_time << " mem cost " << sr2.mem_cost << std::endl;
// exit(0);
// initial value
index.resize(sorted_rpc_names.size(), 0);
// index 1
SimulateResult first_sr = simulate(graph, cost_table, model_sizes, rpc_table, rpc_exe_table,
ri_table, model_name_ri_table, sorted_rpc_names, index);
SimulateResult final_sr = first_sr;
uint64_t current_cost = first_sr.end_time;
time_cost_cache.push_back(first_sr);
if (first_sr.oom)
oom_count += 1;
else
top_k_queue.insert(first_sr);
// std::cout << "initial cost " << current_cost << " oom "
// << first_sr.oom << std::endl;
// index = {0, 1, 3, 34, 4, 10};
// std::unordered_map<std::size_t, uint64_t> time_cost_cache;
auto start = std::chrono::high_resolution_clock::now();
// bool outer_loop_break_flag = false;
while (valid_count < VALID_COUNT_CAP) {
// only change one model execution in each iteration
// std::vector<SimulateResult> sr_vector;
std::unordered_map<int, std::pair<int, int>> flatten_to_pair;
std::vector<double> weight;
// double beta = 0.0075;
int max_step_range = 10000;
int current = 0;
std::vector<int> new_index(index);
for (int i = 0; i < num_rpcs; i++) {
std::string rpc_name = sorted_rpc_names[i];
int c_i = index[i];
int min_i = std::max(0, c_i - max_step_range);
int max_i =
std::min(static_cast<int>(rpc_exe_table[rpc_name].size()), c_i + max_step_range + 1);
for (int j = min_i; j < max_i; j++) {
if (j == c_i) continue;
// int tmp = new_index[i];
// new_index[i] = j;
// SimulateResult sr = simulate(graph, cost_table, model_sizes,
// rpc_table, rpc_exe_table, ri_table,
// model_name_ri_table, sorted_rpc_names,
// new_index);
// sr_vector.push_back(sr);
// new_index[i] = tmp;
flatten_to_pair[current] = std::make_pair(i, j);
current++;
}
}
// if (time_cost_cache.size() > 10000000) {
// time_cost_cache.clear();
// }
// std::cout << "sr vector size " << sr_vector.size() << std::endl;
// for (int i = 0; i < static_cast<int>(sr_vector.size()); i++) {
// weight.push_back(std::exp(-beta * (sr_vector[i].end_time/100000)));
// }
// exit(0);
std::random_device rd;
std::mt19937 gen(rd());
// std::discrete_distribution<int> d(weight.begin(), weight.end());
std::uniform_int_distribution<int> d(0, static_cast<int>(flatten_to_pair.size() - 1));
int selected = d(gen);
// assign new index
int selected_i = flatten_to_pair[selected].first;
int selected_j = flatten_to_pair[selected].second;
new_index[selected_i] = selected_j;
SimulateResult selected_sr =
simulate(graph, cost_table, model_sizes, rpc_table, rpc_exe_table, ri_table,
model_name_ri_table, sorted_rpc_names, new_index);
uint64_t selected_cost = selected_sr.end_time;
// if (selected_sr.oom) {
// std::cout << "oom max end time " << selected_cost << std::endl;
// } else {
// std::cout << "max end time " << selected_cost << std::endl;
// }
if (selected_cost < std::numeric_limits<uint64_t>::max()) {
bool accepted = true;
if (current_cost < selected_cost) {
double accept_prob = std::exp(-beta * ((selected_cost - current_cost) / 100000));
// double accept_prob = ap * ap;
std::bernoulli_distribution accept_dist(accept_prob);
accepted = accept_dist(gen);
}
if (accepted) {
// std::cout << "accepted" << std::endl;
index = new_index;
current_cost = selected_cost;
valid_count++;
if (!selected_sr.oom) {
// min_time_cost = std::min(min_time_cost, selected_cost);
// max_time_cost = std::max(max_time_cost, selected_cost);
avg = (selected_cost + avg * (valid_count - 1)) / valid_count;
if (min_time_cost > selected_cost) {
min_time_cost = selected_cost;
min_index = index;
min_index_mem = selected_sr.mem_cost;
// final_sr = selected_sr;
auto now = std::chrono::high_resolution_clock::now();
double diff =
std::chrono::duration_cast<std::chrono::milliseconds>(now - start).count();
selected_sr.used_time = diff;
time_cost_cache.push_back(selected_sr);
}
if (min_time_cost == selected_cost) {
if (min_index_mem > selected_sr.mem_cost) {
min_index = index;
min_index_mem = selected_sr.mem_cost;
// final_sr = selected_sr;
}
}
top_k_queue.insert(selected_sr);
if (max_time_cost < selected_cost) {
max_time_cost = selected_cost;
max_index = index;
}
} else {
oom_count += 1;
}
// if (min_time_cost <= 100000000) break; // DEBUG
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
if (valid_count % 1000 == 0) {
std::cout << " valid_count " << valid_count << " oom count " << oom_count << " time "
<< diff.count() << " min time cost " << min_time_cost << " min index mem cost "
<< min_index_mem << " max time cost " << max_time_cost << " current cost "
<< current_cost << " avg " << avg << " mem usage " << check_memory_bytes()
<< std::endl;
std::cout << "min index : ";
print_int_vector(min_index);
std::cout << std::endl;
std::cout << "max index : ";
print_int_vector(max_index);
std::cout << std::endl;
std::cout << "current index : ";
print_int_vector(index);
std::cout << std::endl;
}
if (diff > time_limit_duration) break;
}
}
}
// std::cout << "break" << std::endl;
// int rpc_index = 0;
// for (int index : final_sr.index) {
// // std::cout << "index " << index << " rpc_index " << rpc_index << std::endl;
// // std::cout << "rpc name " << sorted_rpc_names[rpc_index] << std::endl;
// RPCExecution* re_ptr = rpc_exe_table[sorted_rpc_names[rpc_index]][index];
// final_sr.rpc_exe_list.push_back(re_ptr);
// rpc_index ++;
// }
// std::cout << "final_sr.rpc_exe_list size " << final_sr.rpc_exe_list.size() << std::endl;
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = end - start;
std::cout << "MCMC Search finished, beta " << beta << " time limit " << time_limit << " seconds"
<< " valid_count " << valid_count << " oom_count " << oom_count << " time "
<< diff.count() << " min time cost " << min_time_cost << " max time cost "
<< max_time_cost << " avg " << avg << std::endl;
std::cout << "RPCExecutions: " << std::endl;
for (auto &re_ptr : final_sr.rpc_exe_list) { std::cout << re_ptr->to_string() << std::endl; }
return time_cost_cache;
// return final_sr;
}
void multi_mcmc_search(std::vector<RPC *> rpcs, std::vector<RPCExecution *> rpc_exes,
std::vector<RPCInstance *> graph,
std::unordered_map<std::string, uint64_t> &cost_table,
std::unordered_map<std::string, uint64_t> model_sizes, double beta_min,
double beta_max, double beta_step, MinEndTimeQueue &res_queue,
int top_k = 10, double time_limit = 60.0,
int repeat = 1 // Remove the trailing comma here
) {
SimulateResult sr;
std::vector<MinEndTimeQueue *> queues;
// std::vector<std::thread> ts;
for (int i = 0; i < repeat; i++) {
for (double beta = beta_min; beta < beta_max; beta += beta_step) {
MinEndTimeQueue *q = new MinEndTimeQueue(10);
queues.push_back(q);
// Create a new thread to run mcmc_search
std::vector<SimulateResult> r =
mcmc_search(rpcs, rpc_exes, graph, cost_table, model_sizes, beta, time_limit, *q);
}
}
for (auto &q : queues) {
mergeMinEndTimeQueues(res_queue, *q);
// delete q;
}
// std::cout << "Best result: " << sr.end_time << std::endl;
// for (auto& re_ptr : sr.rpc_exe_list) {
// std::cout << re_ptr -> to_string() << std::endl;
// }
// std::cout << "Index: ";
// for (int i : sr.index) {
// std::cout << i << " ";
// }
// std::cout << std::endl;
// std::cout << "Time cost: " << sr.end_time
// << ", mem cost: " << sr.mem_cost << std::endl;
// std::cout << "sr.rpc_exe_list size " << sr.rpc_exe_list.size() << std::endl;
// return sr;
}
void input_check(std::vector<RPC *> rpcs, std::vector<RPCExecution *> rpc_exes,
std::vector<RPCInstance *> graph, CommStats &comm_stats) {
std::cout << "rpcs: " << rpcs.size() << std::endl;
std::cout << "rpc_exes: " << rpc_exes.size() << std::endl;
std::cout << "graph: " << graph.size() << std::endl;
for (auto &rpc_instance : graph) {
std::cout << "==================" << std::endl;
std::cout << "rpc instance: " << rpc_instance->name << std::endl;
std::cout << "parents" << std::endl;
for (auto &parent : rpc_instance->parents) { std::cout << parent->name << " "; }
std::cout << std::endl;
std::cout << "children" << std::endl;
for (auto &child : rpc_instance->children) { std::cout << child->name << " "; }
std::cout << std::endl;
// std::cout << "parents: " << rpc_instance -> parents.size()
// << " children: " << rpc_instance -> children.size() << std::endl;
}
std::cout << "comm_stats: " << std::endl;
std::cout << "local_send: " << comm_stats.local_send << std::endl;
std::cout << "local_recv: " << comm_stats.local_recv << std::endl;
std::cout << "remote_send: " << comm_stats.remote_send << std::endl;
std::cout << "remote_recv: " << comm_stats.remote_recv << std::endl;
std::cout << "offload_store: " << comm_stats.offload_store << std::endl;
std::cout << "offload_load: " << comm_stats.offload_load << std::endl;
}
RPC *cast_rpc(py::handle rpc_py) {
return new RPC(py::str(rpc_py.attr("model_name")).cast<std::string>(),
rpc_py.attr("name").cast<std::string>(),
py::str(rpc_py.attr("interface_type")).cast<std::string>());
}
DeviceMesh *cast_device_mesh(py::handle device_mesh_py,
std::unordered_map<std::string, DeviceMesh *> &device_mesh_map) {
std::string name = device_mesh_py.attr("name").cast<std::string>();
if (device_mesh_map.find(name) == device_mesh_map.end()) {
py::array_t<int32_t> mapping_array =
device_mesh_py.attr("mapping").cast<pybind11::array_t<int32_t>>();
py::buffer_info buf_info = mapping_array.request();
auto rows = buf_info.shape[0];
auto cols = buf_info.shape[1];
std::vector<std::vector<int>> mapping(rows, std::vector<int>(cols));
// Get a pointer to the data
int32_t *data = static_cast<int32_t *>(buf_info.ptr);
// Fill the 2D vector with data from the numpy array
for (size_t i = 0; i < static_cast<size_t>(rows); ++i) {
for (size_t j = 0; j < static_cast<size_t>(cols); ++j) { mapping[i][j] = data[i * cols + j]; }
}
DeviceMesh *device_mesh =
new DeviceMesh(device_mesh_py.attr("n_nodes").cast<int>(),
device_mesh_py.attr("n_gpus_per_node").cast<int>(), mapping,
device_mesh_py.attr("global_mesh_name").cast<std::string>(),
device_mesh_py.attr("name").cast<std::string>());
device_mesh_map[name] = device_mesh;
return device_mesh;
} else {
return device_mesh_map[name];
}
}
ModelParallelStrategy *cast_model_parallel_strategy(py::handle model_parallel_strategy_py) {
return new ModelParallelStrategy(
model_parallel_strategy_py.attr("pipeline_parallel_size").cast<int>(),
model_parallel_strategy_py.attr("data_parallel_size").cast<int>(),
model_parallel_strategy_py.attr("model_parallel_size").cast<int>());
}
RPCExecution *cast_rpc_execution(py::handle rpc_exe_py, std::unordered_map<std::string, RPC *> &tmp,
std::unordered_map<std::string, DeviceMesh *> &device_mesh_map) {
DeviceMesh *device_mesh = cast_device_mesh(rpc_exe_py.attr("device_mesh"), device_mesh_map);
ModelParallelStrategy *model_parallel_strategy =
cast_model_parallel_strategy(rpc_exe_py.attr("parallel_strategy"));
// RPC* rpc = cast_rpc(rpc_exe_py.attr("rpc"));
return new RPCExecution(
tmp[rpc_exe_py.attr("rpc").attr("name").cast<std::string>()], *device_mesh,
*model_parallel_strategy, rpc_exe_py.attr("time_cost").cast<uint64_t>(),
rpc_exe_py.attr("mem").cast<uint64_t>(), rpc_exe_py.attr("static_mem").cast<uint64_t>());
}
RPCInstance *cast_rpc_instance_wo_dependency(py::handle rpc_instance_py,
std::unordered_map<std::string, RPC *> &tmp) {
return new RPCInstance(tmp[rpc_instance_py.attr("rpc").attr("name").cast<std::string>()],
rpc_instance_py.attr("iteration_id").cast<int>(),
rpc_instance_py.attr("name").cast<std::string>());
}
void cast_rpc_instance_dependency(py::handle rpc_instance_py, RPCInstance *ri_ptr,
std::unordered_map<std::string, RPCInstance *> &tmp_graph) {
for (py::handle parent_py : rpc_instance_py.attr("parents"))
ri_ptr->parents.push_back(tmp_graph[parent_py.attr("name").cast<std::string>()]);
for (py::handle child_py : rpc_instance_py.attr("children"))
ri_ptr->children.push_back(tmp_graph[child_py.attr("name").cast<std::string>()]);
}
py::list py_single_mcmc_search_time_profile(py::list rpcs_py, py::list rpc_exes_py,
py::list graph_py, py::dict cost_table_py,
py::dict model_sizes_py, py::object beta,
py::object time_limit) {
std::vector<RPC *> rpcs;
std::unordered_map<std::string, RPC *> tmp;
for (py::handle rpc_py : rpcs_py) {
RPC *rpc_ptr = cast_rpc(rpc_py);
rpcs.push_back(rpc_ptr);
tmp[rpc_ptr->rpc_name] = rpc_ptr;
}
std::vector<RPCExecution *> rpc_exes;
std::unordered_map<std::string, DeviceMesh *> tmp_device_mesh;
for (py::handle rpc_exe_py : rpc_exes_py) {
RPCExecution *rpc_exe_ptr = cast_rpc_execution(rpc_exe_py, tmp, tmp_device_mesh);
rpc_exes.push_back(rpc_exe_ptr);
}
std::vector<RPCInstance *> graph;
std::unordered_map<std::string, RPCInstance *> tmp_graph;
for (py::handle ri_py : graph_py) {
RPCInstance *ri_ptr = cast_rpc_instance_wo_dependency(ri_py, tmp);
// std::cout << "cast " << ri_ptr -> name << std::endl;
tmp_graph[ri_ptr->name] = ri_ptr;
}
// build dependecny
for (py::handle ri_py : graph_py) {
std::string ri_name = ri_py.attr("name").cast<std::string>();
cast_rpc_instance_dependency(ri_py, tmp_graph[ri_name], tmp_graph);
graph.push_back(tmp_graph[ri_name]);
}
std::unordered_map<std::string, uint64_t> cost_table =
cost_table_py.cast<std::unordered_map<std::string, uint64_t>>();
std::unordered_map<std::string, uint64_t> model_sizes =
model_sizes_py.cast<std::unordered_map<std::string, uint64_t>>();
MinEndTimeQueue res_queue(10);
std::vector<SimulateResult> rlist =
mcmc_search(rpcs, rpc_exes, graph, cost_table, model_sizes, beta.cast<double>(),
time_limit.cast<double>(), res_queue);
std::unordered_map<std::string, std::vector<RPCExecution *>> rpc_exe_table;
std::vector<std::string> sorted_rpc_names;
make_rpc_exe_table(rpc_exe_table, rpc_exes);
make_sorted_rpc_names(sorted_rpc_names, rpc_exe_table);
py::list result;
std::cout << "rlist.size " << rlist.size() << std::endl;
for (auto &r : rlist) {
// SimulateResult r = res_queue.getQueue().top();
// res_queue.getQueue().pop();
std::cout << "End time: " << r.end_time << std::endl;
for (auto &re_ptr : r.rpc_exe_list) { std::cout << re_ptr->to_string() << std::endl; }
std::cout << "Index: ";
for (int i : r.index) { std::cout << i << " "; }
std::cout << std::endl;
std::cout << "Time cost: " << r.end_time << ", mem cost: " << r.mem_cost << std::endl;
int rpc_index = 0;
for (int index : r.index) {
// std::cout << "index " << index << " rpc_index " << rpc_index << std::endl;
// std::cout << "rpc name " << sorted_rpc_names[rpc_index] << std::endl;
RPCExecution *re_ptr = rpc_exe_table[sorted_rpc_names[rpc_index]][index];
r.rpc_exe_list.push_back(re_ptr);
rpc_index++;
}
py::dict rdict;
for (auto &re_ptr : r.rpc_exe_list) {
py::dict rpc_exe_info;
std::string rpc_name = re_ptr->rpc_ptr->rpc_name;
py::object rpc_name_obj = py::str(rpc_name);
// rpc_exe_info.append(re_ptr -> device_mesh.device_mesh_name);
// rpc_exe_info.append(re_ptr -> model_parallel_strategy.num_dp);
// rpc_exe_info.append(re_ptr -> model_parallel_strategy.num_mp);
// rpc_exe_info.append(re_ptr -> model_parallel_strategy.num_pp);
rpc_exe_info["device_mesh"] = re_ptr->device_mesh.name;
rpc_exe_info["num_dp"] = re_ptr->model_parallel_strategy.num_dp;
rpc_exe_info["num_mp"] = re_ptr->model_parallel_strategy.num_mp;
rpc_exe_info["num_pp"] = re_ptr->model_parallel_strategy.num_pp;
rdict[rpc_name_obj] = rpc_exe_info;
// std::cout << "append key " << rpc_name_obj << std::endl;
}
rdict["end_time"] = r.end_time;
rdict["mem_cost"] = r.mem_cost;
rdict["used_time"] = r.used_time;
result.append(rdict);
}
return result;
};
py::list py_multi_mcmc_search(py::list rpcs_py, py::list rpc_exes_py, py::list graph_py,
py::dict cost_table_py, py::dict model_sizes_py,
py::object beta_min_py, py::object beta_max_py,
py::object beta_step_py, py::object time_limit_py,
py::object repeat) {
std::vector<RPC *> rpcs;
std::unordered_map<std::string, RPC *> tmp;
for (py::handle rpc_py : rpcs_py) {
RPC *rpc_ptr = cast_rpc(rpc_py);
rpcs.push_back(rpc_ptr);
tmp[rpc_ptr->rpc_name] = rpc_ptr;
}
std::vector<RPCExecution *> rpc_exes;
std::unordered_map<std::string, DeviceMesh *> tmp_device_mesh;
for (py::handle rpc_exe_py : rpc_exes_py) {
RPCExecution *rpc_exe_ptr = cast_rpc_execution(rpc_exe_py, tmp, tmp_device_mesh);
rpc_exes.push_back(rpc_exe_ptr);
}
std::vector<RPCInstance *> graph;
std::unordered_map<std::string, RPCInstance *> tmp_graph;
for (py::handle ri_py : graph_py) {
RPCInstance *ri_ptr = cast_rpc_instance_wo_dependency(ri_py, tmp);
// std::cout << "cast " << ri_ptr -> name << std::endl;
tmp_graph[ri_ptr->name] = ri_ptr;
}
// build dependecny
for (py::handle ri_py : graph_py) {
std::string ri_name = ri_py.attr("name").cast<std::string>();
cast_rpc_instance_dependency(ri_py, tmp_graph[ri_name], tmp_graph);
graph.push_back(tmp_graph[ri_name]);
}
std::unordered_map<std::string, uint64_t> cost_table =
cost_table_py.cast<std::unordered_map<std::string, uint64_t>>();
std::unordered_map<std::string, uint64_t> model_sizes =
model_sizes_py.cast<std::unordered_map<std::string, uint64_t>>();
double beta_min = beta_min_py.cast<double>();
double beta_max = beta_max_py.cast<double>();
double beta_step = beta_step_py.cast<double>();
double time_limit = time_limit_py.cast<double>();
int rp = repeat.cast<int>();
MinEndTimeQueue res_queue(10);
multi_mcmc_search(rpcs, rpc_exes, graph, cost_table, model_sizes, beta_min, beta_max, beta_step,
res_queue, 10, time_limit, rp);
std::unordered_map<std::string, std::vector<RPCExecution *>> rpc_exe_table;
std::vector<std::string> sorted_rpc_names;
make_rpc_exe_table(rpc_exe_table, rpc_exes);
make_sorted_rpc_names(sorted_rpc_names, rpc_exe_table);
// std::cout << "r.rpc_exe_list size " << r.rpc_exe_list.size() << std::endl;
// for (int rpc_index = 0; rpc_index < 6; rpc_index++) {
// std::cout << "rpc name " << sorted_rpc_names[rpc_index] << std::endl;
// int index = 0;
// for (auto& re_ptr : rpc_exe_table[sorted_rpc_names[rpc_index]]) {
// std::cout << "index " << index << " " << re_ptr -> to_string() << std::endl;
// index ++;
// }
// }
py::list result;
std::cout << "res_queue.getQueue().size " << res_queue.getQueue().size() << std::endl;
while (!res_queue.getQueue().empty()) {
SimulateResult r = res_queue.getQueue().top();
res_queue.getQueue().pop();
std::cout << "End time: " << r.end_time << std::endl;
for (auto &re_ptr : r.rpc_exe_list) { std::cout << re_ptr->to_string() << std::endl; }
std::cout << "Index: ";
for (int i : r.index) { std::cout << i << " "; }
std::cout << std::endl;
std::cout << "Time cost: " << r.end_time << ", mem cost: " << r.mem_cost << std::endl;
int rpc_index = 0;
for (int index : r.index) {
// std::cout << "index " << index << " rpc_index " << rpc_index << std::endl;
// std::cout << "rpc name " << sorted_rpc_names[rpc_index] << std::endl;
RPCExecution *re_ptr = rpc_exe_table[sorted_rpc_names[rpc_index]][index];
r.rpc_exe_list.push_back(re_ptr);
rpc_index++;
}
py::dict rdict;
for (auto &re_ptr : r.rpc_exe_list) {
py::dict rpc_exe_info;
std::string rpc_name = re_ptr->rpc_ptr->rpc_name;
py::object rpc_name_obj = py::str(rpc_name);
// convert device mesh mapping into py::array_t<int>
std::vector<std::vector<int>> mapping = re_ptr->device_mesh.mapping;
int rows = mapping.size();
int cols = mapping[0].size();
py::array_t<int> numpy_array({rows, cols});
for (int i = 0; i < rows; ++i) {
for (int j = 0; j < cols; ++j) {
*numpy_array.mutable_data(i, j) = mapping[i][j];
// std::cout << i << j << mapping[i][j] << std::endl;
}
}
// store in py::dict
rpc_exe_info["device_mesh_mapping"] = numpy_array;
rpc_exe_info["device_mesh_name"] = re_ptr->device_mesh.name;
rpc_exe_info["num_dp"] = re_ptr->model_parallel_strategy.num_dp;
rpc_exe_info["num_mp"] = re_ptr->model_parallel_strategy.num_mp;
rpc_exe_info["num_pp"] = re_ptr->model_parallel_strategy.num_pp;
rdict[rpc_name_obj] = rpc_exe_info;
// std::cout << "append key " << rpc_name_obj << std::endl;
}
rdict["end_time"] = r.end_time;
rdict["mem_cost"] = r.mem_cost;
result.append(rdict);
}
return result;
};
PYBIND11_MODULE(mdm_search, m) {
m.doc() = "model device mapping search module";
// for debug
// m.def("mcmc_search", [](py::list rpcs_py, py::list rpc_exes_py,
// py::list graph_py, py::object comm_stats_py,
// py::dict model_sizes_py, py::object beta_py,
// py::object time_limit_py) {
// std::vector<RPC*> rpcs;
// std::unordered_map<std::string, RPC*> tmp;
// for (py::handle rpc_py : rpcs_py) {
// RPC* rpc_ptr = cast_rpc(rpc_py);
// rpcs.push_back(rpc_ptr);
// tmp[rpc_ptr -> rpc_name] = rpc_ptr;
// }
// std::vector<RPCExecution*> rpc_exes;
// for (py::handle rpc_exe_py : rpc_exes_py) {
// RPCExecution* rpc_exe_ptr = cast_rpc_execution(rpc_exe_py, tmp);
// rpc_exes.push_back(rpc_exe_ptr);
// }
// std::vector<RPCInstance*> graph;
// std::unordered_map<std::string, RPCInstance*> tmp_graph;
// for (py::handle ri_py : graph_py) {
// RPCInstance* ri_ptr = cast_rpc_instance_wo_dependency(ri_py, tmp);
// std::cout << "cast " << ri_ptr -> name << std::endl;
// tmp_graph[ri_ptr -> name] = ri_ptr;
// }
// // build dependecny
// for (py::handle ri_py : graph_py) {
// std::string ri_name = ri_py.attr("name").cast<std::string>();
// cast_rpc_instance_dependency(ri_py, tmp_graph[ri_name], tmp_graph);
// graph.push_back(tmp_graph[ri_name]);
// }
// CommStats comm_stats(
// comm_stats_py.attr("local_send").cast<uint64_t>(),
// comm_stats_py.attr("local_recv").cast<uint64_t>(),
// comm_stats_py.attr("remote_send").cast<uint64_t>(),
// comm_stats_py.attr("remote_recv").cast<uint64_t>(),
// comm_stats_py.attr("offload_load").cast<uint64_t>(),
// comm_stats_py.attr("offload_store").cast<uint64_t>()
// );
// std::unordered_map<std::string, uint64_t> model_sizes
// = model_sizes_py.cast<std::unordered_map<std::string, uint64_t>>();
// double beta = beta_py.cast<double>();
// double time_limit = time_limit_py.cast<double>();
// mcmc_search(rpcs, rpc_exes, graph,
// comm_stats, model_sizes, beta,
// time_limit);
// });
m.def("input_check",
[](py::list rpcs_py, py::list rpc_exes_py, py::list graph_py, py::object comm_stats_py) {
std::vector<RPC *> rpcs;
std::unordered_map<std::string, RPC *> tmp;
for (py::handle rpc_py : rpcs_py) {
RPC *rpc_ptr = cast_rpc(rpc_py);
rpcs.push_back(rpc_ptr);
tmp[rpc_ptr->rpc_name] = rpc_ptr;
}
std::vector<RPCExecution *> rpc_exes;
std::unordered_map<std::string, DeviceMesh *> tmp_device_mesh;
for (py::handle rpc_exe_py : rpc_exes_py) {
RPCExecution *rpc_exe_ptr = cast_rpc_execution(rpc_exe_py, tmp, tmp_device_mesh);
rpc_exes.push_back(rpc_exe_ptr);
}
std::vector<RPCInstance *> graph;
std::unordered_map<std::string, RPCInstance *> tmp_graph;
for (py::handle ri_py : graph_py) {
RPCInstance *ri_ptr = cast_rpc_instance_wo_dependency(ri_py, tmp);
std::cout << "cast " << ri_ptr->name << std::endl;
tmp_graph[ri_ptr->name] = ri_ptr;
}
// build dependecny
for (py::handle ri_py : graph_py) {
std::string ri_name = ri_py.attr("name").cast<std::string>();
cast_rpc_instance_dependency(ri_py, tmp_graph[ri_name], tmp_graph);
graph.push_back(tmp_graph[ri_name]);
}
CommStats comm_stats(comm_stats_py.attr("local_send").cast<uint64_t>(),
comm_stats_py.attr("local_recv").cast<uint64_t>(),
comm_stats_py.attr("remote_send").cast<uint64_t>(),
comm_stats_py.attr("remote_recv").cast<uint64_t>(),
comm_stats_py.attr("offload_load").cast<uint64_t>(),
comm_stats_py.attr("offload_store").cast<uint64_t>());
input_check(rpcs, rpc_exes, graph, comm_stats);
});
// mcmc search to py result
m.def("multi_mcmc_search", &py_multi_mcmc_search);
m.def("parameter_sync_cost", [](py::object rpcs_py, py::object param_size_bytes_py,
py::dict cost_table_py, py::object src_py, py::object dst_py) {
uint64_t param_size_bytes = param_size_bytes_py.cast<uint64_t>();
std::unordered_map<std::string, uint64_t> cost_table =
cost_table_py.cast<std::unordered_map<std::string, uint64_t>>();
std::vector<RPC *> rpcs;
std::unordered_map<std::string, RPC *> tmp;
for (py::handle rpc_py : rpcs_py) {
RPC *rpc_ptr = cast_rpc(rpc_py);
rpcs.push_back(rpc_ptr);
tmp[rpc_ptr->rpc_name] = rpc_ptr;
}
std::unordered_map<std::string, DeviceMesh *> tmp_device_mesh;
RPCExecution *src = cast_rpc_execution(src_py, tmp, tmp_device_mesh);
RPCExecution *dst = cast_rpc_execution(dst_py, tmp, tmp_device_mesh);
return parameter_sync_cost(param_size_bytes, src, dst, cost_table);
});
m.def("mcmc_search_time_profile", &py_single_mcmc_search_time_profile);
};

View File

@ -1,244 +0,0 @@
#include <rpc.hpp>
#include <device_mesh.hpp>
#include <simulate.hpp>
#include <iostream>
#include <queue>
#include <algorithm>
#include <iomanip>
#include <limits>
#include <chrono>
#define MAX(a, b) ((a) > (b) ? (a) : (b))
uint64_t SOFT_GPU_MEM_CAP = 85899345920; // 80G
SimulateResult::SimulateResult()
: end_time(std::numeric_limits<uint64_t>::max()), oom(true), mem_cost(0) {}
SimulateResult::SimulateResult(uint64_t end_time, bool oom, uint64_t mem_cost,
std::vector<int> &index)
: end_time(end_time), oom(oom), mem_cost(mem_cost), index(index) {}
SimulateResult simulate(
std::vector<RPCInstance *> &graph, std::unordered_map<std::string, uint64_t> &cost_table,
std::unordered_map<std::string, uint64_t> &model_sizes,
std::unordered_map<std::string, RPC *> &rpc_table,
std::unordered_map<std::string, std::vector<RPCExecution *>> &rpc_exe_table,
std::unordered_map<std::string, std::vector<RPCInstance *>> &ri_table,
std::unordered_map<std::string, std::vector<RPCInstance *>> &model_name_ri_table,
std::vector<std::string> &sorted_rpc_names, std::vector<int> &index) {
auto start = std::chrono::high_resolution_clock::now();
GroupedRPCExecutions grouped_rpc_exe;
// std::unordered_map<std::string, RPCExecution*> param_dst; // model_name -> rpc_exe_ptr
std::unordered_set<std::string> offloaded;
uint64_t oom_penalty = 3;
int num_rpcs = static_cast<int>(sorted_rpc_names.size());
for (int i = 0; i < num_rpcs; i++) {
std::string rpc_name = sorted_rpc_names[i];
RPC *rpc = rpc_table[rpc_name];
RPCExecution *rpc_exe = rpc_exe_table[rpc_name][index[i]];
for (auto &ri : ri_table[rpc_name]) {
ri->rpc_exe_ptr = rpc_exe; // assign rpc_exe to rpc_instance
}
// dirty implementation, check whether instance is hooked with param syncs
for (auto &rpc_instance : model_name_ri_table[rpc->model_name]) {
if (rpc_instance->rpc_ptr->interface_type == "ModelInterfaceType.TRAIN_STEP"
&& rpc->interface_type != "ModelInterfaceType.TRAIN_STEP") {
// param_dst[rpc -> model_name] = rpc_exe;
rpc_instance->param_sync = true;
rpc_instance->param_sync_size = model_sizes[rpc_instance->rpc_ptr->model_name];
rpc_instance->param_sync_rpc_exe_ptr = rpc_exe;
}
}
grouped_rpc_exe.add(rpc_exe);
std::string model_name = rpc->model_name;
}
// rpc_instances: list of rpc instances, graph
std::priority_queue<RPCInstance *, std::vector<RPCInstance *>, CompareReadyTime> ready_queue;
// std::vector<RPCInstance*> executed; // for debug, remove later
std::unordered_map<std::string, size_t> parent_executed;
std::unordered_set<DeviceMesh *> device_meshes;
// for offload and parameter sync RPC instances
std::vector<RPCInstance *> tmp_graph;
// resolve parameter sync
for (RPCInstance *node : graph) {
tmp_graph.push_back(node);
node->tmp_children = node->children;
node->tmp_parents = node->parents;
// std::cout << "Resolve parameter sync: " << node -> name
// << " " << node -> param_sync << std::endl;
node->resolve_parameter_sync(tmp_graph, cost_table);
// node -> resolve_offload(tmp_graph, comm_stats);
}
uint64_t max_end_time = 0;
for (RPCInstance *node : tmp_graph) {
// std::cout << "Node: " << node -> name << " parents: "
// << node -> parents.size() << std::endl;
if (node->parents.size() == 0) ready_queue.push(node);
// init device meshes
RPCExecution *rpc_exe = node->rpc_exe_ptr;
device_meshes.insert(&rpc_exe->device_mesh);
}
std::vector<RPCInstance *> executed;
// simulate
while (!ready_queue.empty()) {
RPCInstance *t = ready_queue.top();
RPCExecution *rpc_exe = t->rpc_exe_ptr;
uint64_t exec_time = rpc_exe->time_cost;
DeviceMesh *device_mesh = &rpc_exe->device_mesh;
ready_queue.pop();
if (device_mesh->pre_task == nullptr) {
t->start_time = t->ready_time;
} else {
t->start_time = MAX(t->ready_time, device_mesh->pre_task->end_time);
}
t->end_time = t->start_time + exec_time;
max_end_time = MAX(t->end_time, max_end_time);
for (DeviceMesh *mesh : device_meshes) {
if (device_mesh->overlap(*mesh)) {
if (mesh->pre_task == nullptr || mesh->pre_task->end_time <= t->end_time) {
mesh->pre_task = t;
}
// mesh -> pre_task = t;
}
}
executed.push_back(t);
for (RPCInstance *child : t->tmp_children) {
child->ready_time = MAX(t->end_time, child->ready_time);
// std::cout << "parent: " << t -> name
// << " child: " << child -> name << std::endl;
parent_executed[child->name] += 1;
// child -> remove_parent(t);
if (child->tmp_parents.size() == parent_executed[child->name]) {
ready_queue.push(child);
// std::cout << "Ready: " << child -> name
// << " ready time " << child -> ready_time << std::endl;
}
}
// std::cout << "ready_queue size " << ready_queue.size()
// << " executed size " << executed.size() << std::endl;
}
// 110045999
// if (max_end_time < 100000000) { // DEBUG
// std::cout << "INDEX: [";
// for (int i : index) {
// std::cout << i << ", ";
// }
// std::cout << "]" << std::endl;
// for (auto& x : rpc_exe_table){
// std::string rpc_name = x.first;
// std::vector<RPCExecution*>& rpc_exe_list = x.second;
// int count = 0;
// for (RPCExecution* rpc_exe : rpc_exe_list) {
// std::cout << "RPC: " << rpc_name
// << " device mesh " << rpc_exe -> device_mesh.device_mesh_name
// << " time cost " << rpc_exe -> time_cost << std::endl;
// count ++;
// if (count > 10) break;
// }
// }
// for (RPCInstance* ri : executed) {
// for (RPCInstance* parent : ri -> tmp_parents) {
// std::cout << "Parent: " << parent -> name << " of " << ri -> name
// << " start time " << parent -> start_time
// << " end time " << parent -> end_time << std::endl;
// }
// std::cout << "Executed: " << ri -> name
// << " start time " << ri -> start_time
// << " end time " << ri -> end_time
// << " rpc name " << ri -> rpc_ptr -> rpc_name
// << " device mesh "
// << ri -> rpc_exe_ptr -> device_mesh.device_mesh_name
// << " rpc exe time cost " << ri -> rpc_exe_ptr -> time_cost
// << std::endl;
// }
// }
// clear device mesh pre tasks
for (DeviceMesh *mesh : device_meshes) { mesh->pre_task = nullptr; }
// clear rpc instance times
for (RPCInstance *node : graph) {
node->tmp_children.clear();
node->tmp_parents.clear();
node->ready_time = 0;
node->start_time = 0;
node->end_time = 0;
tmp_graph.clear();
for (RPCInstance *ptr : node->tmp_ris) { delete ptr; }
node->tmp_ris.clear();
for (RPCExecution *ptr : node->tmp_exes) { delete ptr; }
node->tmp_exes.clear();
}
uint64_t current_mem = grouped_rpc_exe.total_mem_cost();
if (current_mem > SOFT_GPU_MEM_CAP) { max_end_time *= oom_penalty; }
// std::cout << "Max end time: " << max_end_time
// << " executed size " << executed.size() << std::endl;
std::chrono::duration<double> elapsed = std::chrono::high_resolution_clock::now() - start;
// std::cout << "Elapsed time (micro seconds): "
// << std::chrono::duration_cast<std::chrono::microseconds>(elapsed).count()
// << std::endl;
for (OverlapGroup *ptr : grouped_rpc_exe.group.overlap_groups) { delete ptr; }
grouped_rpc_exe.group.overlap_groups.clear();
return SimulateResult(max_end_time, current_mem > SOFT_GPU_MEM_CAP, current_mem, index);
};
SimulateResult &SimulateResult::operator=(const SimulateResult &other) {
if (this != &other) {
end_time = other.end_time;
oom = other.oom;
mem_cost = other.mem_cost;
index = other.index;
rpc_exe_list = other.rpc_exe_list;
}
return *this;
};
bool isPresent(MinEndTimeQueue &q, SimulateResult element) {
std::priority_queue<SimulateResult, std::vector<SimulateResult>, CompareEndTime> pq =
q.getQueue();
std::queue<SimulateResult> tmp;
while (!pq.empty()) {
if (pq.top().end_time == element.end_time) { return true; }
tmp.push(pq.top());
pq.pop();
}
while (!tmp.empty()) {
pq.push(tmp.front());
tmp.pop();
}
return false;
}
void mergeMinEndTimeQueues(MinEndTimeQueue &target, MinEndTimeQueue &q1) {
// Get the underlying priority queues
std::priority_queue<SimulateResult, std::vector<SimulateResult>, CompareEndTime> pq1 =
q1.getQueue();
// Insert all elements from q1 into the merged queue
while (!pq1.empty()) {
if (!isPresent(target, pq1.top())) target.insert(pq1.top());
pq1.pop();
}
}

View File

@ -1,75 +0,0 @@
#ifndef SIMULATE_HPP
#define SIMULATE_HPP
#include <rpc.hpp>
#include <device_mesh.hpp>
#include <queue>
#include <iostream>
class SimulateResult {
public:
uint64_t end_time;
bool oom;
uint64_t mem_cost;
std::vector<int> index;
std::vector<RPCExecution *> rpc_exe_list;
double used_time = 0;
SimulateResult();
SimulateResult(uint64_t end_time, bool oom, uint64_t mem_cost, std::vector<int> &index);
SimulateResult &operator=(const SimulateResult &other);
};
SimulateResult simulate(
std::vector<RPCInstance *> &graph, std::unordered_map<std::string, uint64_t> &cost_table,
std::unordered_map<std::string, uint64_t> &model_sizes,
std::unordered_map<std::string, RPC *> &rpc_table,
std::unordered_map<std::string, std::vector<RPCExecution *>> &rpc_exe_table,
std::unordered_map<std::string, std::vector<RPCInstance *>> &ri_table,
std::unordered_map<std::string, std::vector<RPCInstance *>> &model_name_ri_table,
std::vector<std::string> &sorted_rpc_names, std::vector<int> &index);
// Comparator for priority queue
struct CompareEndTime {
bool operator()(SimulateResult const &r1, SimulateResult const &r2) {
// We want largest end_time at the top of the queue, so we reverse the comparison
return r1.end_time < r2.end_time;
}
};
class MinEndTimeQueue {
public:
MinEndTimeQueue(int capacity) : k(capacity) {}
void insert(SimulateResult r) {
if (queue.size() < k) {
// std::cout << "push " << "end_time: " << r.end_time << " qsize " << queue.size() <<
// std::endl;
queue.push(r);
} else if (r.end_time < queue.top().end_time) {
// std::cout << "push " << "end_time: " << r.end_time << " qsize " << queue.size() <<
// std::endl;
queue.pop();
queue.push(r);
}
}
std::priority_queue<SimulateResult, std::vector<SimulateResult>, CompareEndTime> &getQueue() {
return queue;
}
private:
std::priority_queue<SimulateResult, std::vector<SimulateResult>, CompareEndTime> queue;
int k;
};
void mergeMinEndTimeQueues(MinEndTimeQueue &target, MinEndTimeQueue &q1);
class CompareReadyTime {
public:
bool operator()(RPCInstance *r1, RPCInstance *r2) { return r1->ready_time > r2->ready_time; }
};
#endif // SIMULATE_HPP

View File

@ -83,7 +83,7 @@ async def async_invoke_function(
url: str,
timeout: aiohttp.ClientTimeout,
payload: Dict[str, Any] = None,
max_retries: int = 100,
max_retries: int = 2,
initial_retry_interval: float = 0.5,
max_retry_interval: float = 10.0,
):
@ -137,7 +137,7 @@ async def async_invoke_function(
)
retries += 1
if retries > max_retries:
if retries >= max_retries:
return {
"uid": payload.get("uid", ""),
"success": False,
@ -189,6 +189,7 @@ async def batch_function_call_async(payload_list, url, timeout, concurrency=1500
data_list.append(data)
elapsed_times.append(elapsed)
if len(elapsed_times) > 0:
p50 = median(elapsed_times)
p90 = calculate_percentile(elapsed_times, 90)
p99 = calculate_percentile(elapsed_times, 99)

View File

@ -1,6 +1,7 @@
import ast
import faulthandler
import json
import os
import platform
# to run the solution files we're using a timing based approach
@ -38,6 +39,14 @@ def truncatefn(s, length=300):
return s[: length // 2] + "...(truncated) ..." + s[-length // 2 :]
def load_from_path(l):
outputs = []
for x in l:
with open(x) as f:
outputs.append(f.read())
return outputs
class CODE_TYPE(Enum):
call_based = 0
standard_input = 1
@ -105,6 +114,13 @@ def run_test(sample, test=None, debug=False, timeout=6):
which_type = CODE_TYPE.call_based
if in_outs:
assert "inputs" in in_outs
assert "outputs" in in_outs
if os.path.isfile(in_outs["inputs"][0]):
assert os.path.isfile(in_outs["outputs"][0])
in_outs["inputs"] = load_from_path(in_outs["inputs"])
in_outs["outputs"] = load_from_path(in_outs["outputs"])
if in_outs.get("fn_name", "") == "":
which_type = CODE_TYPE.standard_input # Standard input
method_name = None

View File

@ -59,9 +59,10 @@ def call_verify(problem, generation, debug, timeout=SINGLE_CASE_EXEC_TIMEOUT):
shell=True,
preexec_fn=os.setsid,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
try:
pro.wait(600)
pro.wait(200)
except Exception as e:
pass
try:

View File

@ -9,7 +9,7 @@ from functioncall.base.utils import construct_uid, load_jsonl, logger
SINGLE_CASE_EXEC_TIMEOUT = 6
TEST_CASE_BATCH_SIZE = 1
FUNCTIONCALL_TIMEOUT = 1000
FUNCTIONCALL_TIMEOUT = 100
def round_up_memory(memory):

View File

@ -0,0 +1,144 @@
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
index 174656b2..33fe0a5f 100644
--- a/python/sglang/srt/managers/io_struct.py
+++ b/python/sglang/srt/managers/io_struct.py
@@ -687,10 +687,21 @@ class FlushCacheReqOutput:
success: bool
+@dataclass
+class InterruptAllReqInput:
+ pass
+
+
+@dataclass
+class InterruptAllReqOutput:
+ num_interrupted_requests: int
+
+
@dataclass
class UpdateWeightFromDiskReqInput:
# The model path with the new weights
model_path: str
+ allow_interrupt: bool = False
# The format to load the weights
load_format: Optional[str] = None
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 8891115c..843a8a82 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -70,6 +70,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ InterruptAllReqInput,
+ InterruptAllReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -419,6 +421,7 @@ class Scheduler(
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
+ (InterruptAllReqInput, self.interrupt_all_requests),
(TokenizedGenerateReqInput, self.handle_generate_request),
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
(FlushCacheReqInput, self.flush_cache_wrapped),
@@ -1938,6 +1941,15 @@ class Scheduler(
def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError()
+ def interrupt_all_requests(self, recv_req: InterruptAllReqInput):
+ num = len(self.waiting_queue) + len(self.running_batch.reqs)
+ for req in self.waiting_queue:
+ req.sampling_params.max_new_tokens = 0
+ for req in self.running_batch.reqs:
+ req.sampling_params.max_new_tokens = len(req.output_ids)
+ logger.info(f"Interrupt {num} requests.")
+ return InterruptAllReqOutput(num)
+
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req)
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index 82709b09..bfab3ce7 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -76,6 +76,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ InterruptAllReqInput,
+ InterruptAllReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -265,6 +267,9 @@ class TokenizerManager:
self.resume_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
+ self.interrupt_requests_communicator = _Communicator(
+ self.send_to_scheduler, server_args.dp_size
+ )
self.flush_cache_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
@@ -294,6 +299,10 @@ class TokenizerManager:
UpdateWeightFromDiskReqOutput,
self._handle_update_weights_from_disk_req_output,
),
+ (
+ InterruptAllReqOutput,
+ self.interrupt_requests_communicator.handle_recv,
+ ),
(
InitWeightsUpdateGroupReqOutput,
self.init_weights_update_group_communicator.handle_recv,
@@ -767,6 +776,13 @@ class TokenizerManager:
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
+ if obj.allow_interrupt:
+ num_interrupted_requests = await self.interrupt_all_requests(
+ InterruptAllReqInput()
+ )
+ # Set a break point to wait for the interrupt to finish
+ await asyncio.sleep(0.1)
+
# default the load format to the server_args
if obj.load_format is None:
obj.load_format = self.server_args.load_format
@@ -776,7 +792,12 @@ class TokenizerManager:
# Hold the lock if it is not async. This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
- return await self._wait_for_model_update_from_disk(obj)
+ success, message, n_paused = (
+ await self._wait_for_model_update_from_disk(obj)
+ )
+ if obj.allow_interrupt:
+ return success, message, num_interrupted_requests
+ return success, message, n_paused
async def _wait_for_model_update_from_disk(
self, obj: UpdateWeightFromDiskReqInput
@@ -849,6 +870,18 @@ class TokenizerManager:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message
+ async def interrupt_all_requests(
+ self,
+ obj: InterruptAllReqInput,
+ request: Optional[fastapi.Request] = None,
+ ) -> Tuple[bool, str]:
+ self.auto_create_handle_loop()
+ result = await self.interrupt_requests_communicator(obj)
+ if self.server_args.dp_size == 1:
+ return result[0].num_interrupted_requests
+ else:
+ return [r.num_interrupted_requests for r in result]
+
async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
):

View File

@ -0,0 +1,144 @@
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
index 5390668c..db370d19 100644
--- a/python/sglang/srt/managers/io_struct.py
+++ b/python/sglang/srt/managers/io_struct.py
@@ -687,10 +687,21 @@ class FlushCacheReqOutput:
success: bool
+@dataclass
+class InterruptAllReqInput:
+ pass
+
+
+@dataclass
+class InterruptAllReqOutput:
+ num_interrupted_requests: int
+
+
@dataclass
class UpdateWeightFromDiskReqInput:
# The model path with the new weights
model_path: str
+ allow_interrupt: bool = False
# The format to load the weights
load_format: Optional[str] = None
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
index 1178eec5..318dee33 100644
--- a/python/sglang/srt/managers/scheduler.py
+++ b/python/sglang/srt/managers/scheduler.py
@@ -73,6 +73,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ InterruptAllReqInput,
+ InterruptAllReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -427,6 +429,7 @@ class Scheduler(
# Init request dispatcher
self._request_dispatcher = TypeBasedDispatcher(
[
+ (InterruptAllReqInput, self.interrupt_all_requests),
(TokenizedGenerateReqInput, self.handle_generate_request),
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
(FlushCacheReqInput, self.flush_cache_wrapped),
@@ -1971,6 +1974,15 @@ class Scheduler(
def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError()
+ def interrupt_all_requests(self, recv_req: InterruptAllReqInput):
+ num = len(self.waiting_queue) + len(self.running_batch.reqs)
+ for req in self.waiting_queue:
+ req.sampling_params.max_new_tokens = 0
+ for req in self.running_batch.reqs:
+ req.sampling_params.max_new_tokens = len(req.output_ids)
+ logger.info(f"Interrupt {num} requests.")
+ return InterruptAllReqOutput(num)
+
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
"""In-place update of the weights from disk."""
success, message = self.tp_worker.update_weights_from_disk(recv_req)
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
index b646fae1..c668728b 100644
--- a/python/sglang/srt/managers/tokenizer_manager.py
+++ b/python/sglang/srt/managers/tokenizer_manager.py
@@ -80,6 +80,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
+ InterruptAllReqInput,
+ InterruptAllReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
@@ -279,6 +281,9 @@ class TokenizerManager:
self.slow_down_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
+ self.interrupt_requests_communicator = _Communicator(
+ self.send_to_scheduler, server_args.dp_size
+ )
self.flush_cache_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
@@ -309,6 +314,10 @@ class TokenizerManager:
UpdateWeightFromDiskReqOutput,
self._handle_update_weights_from_disk_req_output,
),
+ (
+ InterruptAllReqOutput,
+ self.interrupt_requests_communicator.handle_recv,
+ ),
(
InitWeightsUpdateGroupReqOutput,
self.init_weights_update_group_communicator.handle_recv,
@@ -799,6 +808,13 @@ class TokenizerManager:
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
+ if obj.allow_interrupt:
+ num_interrupted_requests = await self.interrupt_all_requests(
+ InterruptAllReqInput()
+ )
+ # Set a break point to wait for the interrupt to finish
+ await asyncio.sleep(0.1)
+
# default the load format to the server_args
if obj.load_format is None:
obj.load_format = self.server_args.load_format
@@ -808,7 +824,12 @@ class TokenizerManager:
# Hold the lock if it is not async. This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
- return await self._wait_for_model_update_from_disk(obj)
+ success, message, n_paused = (
+ await self._wait_for_model_update_from_disk(obj)
+ )
+ if obj.allow_interrupt:
+ return success, message, num_interrupted_requests
+ return success, message, n_paused
async def _wait_for_model_update_from_disk(
self, obj: UpdateWeightFromDiskReqInput
@@ -881,6 +902,18 @@ class TokenizerManager:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message
+ async def interrupt_all_requests(
+ self,
+ obj: InterruptAllReqInput,
+ request: Optional[fastapi.Request] = None,
+ ) -> Tuple[bool, str]:
+ self.auto_create_handle_loop()
+ result = await self.interrupt_requests_communicator(obj)
+ if self.server_args.dp_size == 1:
+ return result[0].num_interrupted_requests
+ else:
+ return [r.num_interrupted_requests for r in result]
+
async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
):

View File

@ -97,8 +97,7 @@ class PromptOnlyDatasetConfig:
class ModelFamily:
"""Identifier for HuggingFace model types (e.g., llama, gpt2).
Used for model registration and allocation. The size parameter is specifically
relevant for the 'search' allocation mode.
Used for model registration and allocation.
"""
_class: str = field(
@ -107,12 +106,6 @@ class ModelFamily:
"`realhf/api/from_hf` for supported models.",
}
)
size: int = field(
default=0,
metadata={
"help": "Model size parameter. Only used by 'search' allocation mode, ignored otherwise",
},
)
is_critic: bool = field(
default=False,
metadata={
@ -121,8 +114,8 @@ class ModelFamily:
)
def __repr__(self):
"""Returns formatted string representation: '{class}-{size}[-critic]'."""
s = f"{self._class}-{self.size}"
"""Returns formatted string representation: '{class}[-critic]'."""
s = f"{self._class}"
if self.is_critic:
s += "-critic"
return s
@ -136,7 +129,7 @@ class ParallelismConfig:
Sequence parallelism is only used in combination with tensor-model parallelism.
"""
model_parallel_size: int = field(
tensor_parallel_size: int = field(
default=1, metadata={"help": "Size of tensor-model parallelism"}
)
pipeline_parallel_size: int = field(
@ -155,7 +148,7 @@ class ParallelismConfig:
def __str__(self):
"""Returns compact string representation: 'Parallel(mp=X,pp=Y,dp=Z)'."""
return (
f"Parallel(mp={self.model_parallel_size},"
f"Parallel(mp={self.tensor_parallel_size},"
f"pp={self.pipeline_parallel_size},"
f"dp={self.data_parallel_size})"
)
@ -168,7 +161,7 @@ class ParallelismConfig:
Implemented as static method to avoid OmegaConf compatibility issues.
"""
return (
(this.model_parallel_size == other.model_parallel_size)
(this.tensor_parallel_size == other.tensor_parallel_size)
and (this.pipeline_parallel_size == other.pipeline_parallel_size)
and (this.data_parallel_size == other.data_parallel_size)
)
@ -186,7 +179,7 @@ class OptimizerConfig:
default="adam",
metadata={"help": "Optimizer type", "choices": ["adam", "empty"]},
)
lr: float = field(default=1e-5, metadata={"help": "Learning rate"})
lr: float = field(default=2e-5, metadata={"help": "Learning rate"})
weight_decay: float = field(default=0.05, metadata={"help": "Weight decay"})
beta1: float = field(default=0.9, metadata={"help": "Adam beta1 parameter"})
beta2: float = field(default=0.95, metadata={"help": "Adam beta2 parameter"})
@ -198,14 +191,14 @@ class OptimizerConfig:
},
)
lr_scheduler_type: str = field(
default="cosine",
default="constant",
metadata={
"help": "Learning rate scheduler type",
"choices": ["linear", "cosine", "constant"],
},
)
warmup_steps_proportion: float = field(
default=0.02,
default=0.001,
metadata={
"help": "Proportion of training steps for warmup",
},
@ -237,6 +230,7 @@ class vLLMConfig:
"""
max_num_seqs: int = 256
dtype: str = "float16"
kv_cache_type: str = "auto"
num_scheduler_steps: int = 1
multi_step_stream_outputs: bool = True
@ -278,7 +272,6 @@ class SGLangConfig:
enable_nccl_nvls: bool = False
disable_outlines_disk_cache: bool = False
disable_custom_all_reduce: bool = False
disable_mla: bool = False
disable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
@ -296,7 +289,7 @@ class SGLangConfig:
enable_memory_saver: bool = False
allow_auto_truncate: bool = False
# NOTE: to avoid the illegal memory access error
attention_backend: Optional[str] = "triton"
attention_backend: Optional[str] = "flashinfer"
sampling_backend: Optional[str] = None
context_length: Optional[int] = 32768
mem_fraction_static: Optional[float] = 0.9
@ -309,15 +302,19 @@ class SGLangConfig:
schedule_conservativeness: float = 1.0
cpu_offload_gb: int = 0
hybrid_train: bool = False
dtype: str = "float16"
kv_cache_dtype: str = "auto"
# logging
log_level: str = "info"
log_level: str = "warning"
log_level_http: Optional[str] = "warning"
log_requests: bool = False
log_requests_level: int = 0
show_time_cost: bool = False
enable_metrics: bool = True # Exports Prometheus-like metrics
decode_log_interval: int = 1000 # How often (in tokens) to log decode progress.
# The interval (in decoding iterations) to log throughput
# and update prometheus metrics
decode_log_interval: int = 1
# Use staticmethod to make OmegaConf happy.
@staticmethod
@ -327,6 +324,7 @@ class SGLangConfig:
tp_size,
server_index,
base_gpu_id,
dist_init_addr: Optional[str] = None,
):
from realhf.base import constants, network, pkg_version, seeding
from realhf.experiments.common.utils import asdict as conf_as_dict
@ -345,7 +343,6 @@ class SGLangConfig:
tokenizer_mode="auto",
load_format="auto",
trust_remote_code=True,
kv_cache_dtype="auto",
device="cuda",
served_model_name=f"{constants.experiment_name()}/{constants.trial_name()}/{model_path}",
is_embedding=False,
@ -365,6 +362,7 @@ class SGLangConfig:
ep_size=1, # TODO: check
nnodes=1,
node_rank=0,
dist_init_addr=dist_init_addr,
**args,
)
@ -385,6 +383,10 @@ class SGLangConfig:
if v is True:
flags.append(f"--{k.replace('_','-')} ")
continue
if isinstance(v, list):
values = " ".join(map(str, v))
flags.append(f"--{k.replace('_','-')} {values}")
continue
flags.append(f"--{k.replace('_','-')} {v}")
flags = " ".join(flags)
return f"python3 -m sglang.launch_server {flags}"
@ -444,7 +446,7 @@ class ModelTrainEvalConfig:
# Model Architecture Configuration
type: ModelFamily = field(
default=ModelFamily("llama", 7, False),
default=ModelFamily("llama", False),
metadata={"help": "Model family specification"},
)
path: str = field(default="", metadata={"help": "Path to HuggingFace checkpoint"})
@ -679,13 +681,13 @@ class PPOHyperparameters:
value_norm_eps: float = field(
default=1e-5, metadata={"help": "Epsilon term for numerical stability"}
)
# Experimental Features
recompute_logprob: bool = field(
default=False,
metadata={
"help": "Recompute log probabilities after generation. Used mainly for debugging purposes"
},
metadata={"help": "Recompute logp and replace the logp returned by inference."},
)
use_decoupled_loss: bool = field(
default=False,
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."},
)
@ -772,6 +774,13 @@ class ExperimentSaveEvalControl:
"For benchmarking purposes only. None indicates normal training."
},
)
benchmark_n_seqs: Optional[int] = field(
default=None,
metadata={
"help": "Terminate training after consuming this number of samples. "
"For benchmarking purposes only. None indicates normal training."
},
)
@dataclass
@ -847,7 +856,7 @@ class BaseExperimentConfig:
Note:
- Recovery modes: auto, fault, resume, disabled
- Allocation modes: manual, search, heuristic, or pattern-based
- Allocation modes: manual, heuristic, or pattern-based
"""
experiment_name: str = field(
@ -919,13 +928,9 @@ class BaseExperimentConfig:
default="",
metadata={
"help": "GPU parallel strategy allocation mode. "
"Options: manual/search/heuristic or pattern-based."
"Options: manual/heuristic or pattern-based."
},
)
allocation_use_cache: bool = field(
default=False,
metadata={"help": "Use allocation search cache (search mode only)."},
)
n_nodes: int = field(
default=1, metadata={"help": "Number of nodes for experiment."}
)
@ -998,9 +1003,17 @@ class BaseExperimentConfig:
@dataclass
class AsyncRLOptions:
schedule_policy: str = field(
default="round_robin",
metadata={
"help": "The request schedule policy during generation. Available options: [round_robin]."
},
)
new_tokens_per_chunk: int = field(
default=1024,
metadata={"help": "The lenght of chunked generation."},
default=int(1e10),
metadata={
"help": "The length of chunked generation. Only valid if inference can't be interrupted."
},
)
max_head_offpolicyness: int = field(
default=0,
@ -1013,9 +1026,11 @@ class AsyncRLOptions:
"help": "Number of rollout workers. None defaults to train world size."
},
)
max_concurrent_rollouts: int = field(
default=1024,
metadata={"help": "Max concurrent rollout jobs in each worker."},
max_concurrent_rollouts: Optional[int] = field(
default=None,
metadata={
"help": "Max concurrent rollouts globally. Defaults to train batch size."
},
)
flush_request_timeout: int = field(
default=120,
@ -1225,6 +1240,12 @@ class PPOMATHExperimentOptions:
},
)
# testing only
no_training: bool = field(
default=False,
metadata={"help": "Run without training. Test-only."},
)
@dataclass
class MathCodeEvalOptions:

View File

@ -100,8 +100,8 @@ class ModelShardID:
:type model_name: ModelName
:param dp_rank: The data parallel rank.
:type dp_rank: int
:param mp_rank: The tensor-model parallel rank.
:type mp_rank: int
:param tp_rank: The tensor-model parallel rank.
:type tp_rank: int
:param pp_rank: The pipeline-model parallel rank.
:type pp_rank: int
:param topo: The 3D parallelism topology of this model.
@ -110,22 +110,22 @@ class ModelShardID:
model_name: ModelName
dp_rank: int
mp_rank: int
tp_rank: int
pp_rank: int
topo: topology.ProcessTopology
def __post_init__(self):
assert self.dp_rank >= 0 and self.mp_rank >= 0 and self.pp_rank >= 0
assert self.dp_rank >= 0 and self.tp_rank >= 0 and self.pp_rank >= 0
if "@" in self.model_name.role:
raise ValueError("model_name cannot contain @")
assert self.dp_rank < self.topo.get_dim("data")
assert self.mp_rank < self.topo.get_dim("model")
assert self.tp_rank < self.topo.get_dim("tensor")
assert self.pp_rank < self.topo.get_dim("pipe")
@property
def parallelism_rank(self):
return self.topo.get_rank(
data=self.dp_rank, model=self.mp_rank, pipe=self.pp_rank
data=self.dp_rank, tensor=self.tp_rank, pipe=self.pp_rank
)
@classmethod
@ -134,14 +134,14 @@ class ModelShardID:
return cls(
model_name=model_name,
dp_rank=c.data,
mp_rank=c.model,
tp_rank=c.tensor,
pp_rank=c.pipe,
topo=topo,
)
def __repr__(self):
n = cluster.spec.suffix_n_digits
return f"{self.model_name}@pp{self.pp_rank:0{n}d}@mp{self.mp_rank:0{n}d}@dp{self.dp_rank:0{n}d}"
return f"{self.model_name}@pp{self.pp_rank:0{n}d}@tp{self.tp_rank:0{n}d}@dp{self.dp_rank:0{n}d}"
def __hash__(self):
return hash(str(self))
@ -152,7 +152,7 @@ class ModelShardID:
return (
self.model_name == other.model_name
and self.dp_rank == other.dp_rank
and self.mp_rank == other.mp_rank
and self.tp_rank == other.tp_rank
and self.pp_rank == other.pp_rank
)
return False

View File

@ -547,6 +547,7 @@ class SequenceSample:
return [[seqlen] for seqlen in seqlens]
elif key in [
"packed_logprobs",
"prox_logp",
"logprobs",
"packed_ref_logprobs",
"ref_logprobs",

View File

@ -10,7 +10,6 @@ import matplotlib.pyplot as plt
import networkx as nx
import realhf.base.logging as logging
from realhf.api.cli_args import ModelFamily
from realhf.api.core.config import (
ModelInterfaceAbstraction,
ModelInterfaceType,
@ -94,13 +93,6 @@ class MFCDef:
:type min_n_seqs_per_pass: int
:param log_return_value: Whether to log the return value of the interface implementation.
:type log_return_value: bool
:param model_type: The specification of the LLM, e.g., LLaMA-7B. Used by the profiler and
search engine to produce an optimal execution plan. Can be omitted if the search engine
is not used.
:type model_type: Optional[ModelFamily]
:param model_path: The path to the model file. Used to get the config for the search engine.
Can be omitted if the search engine is not used.
:type model_path: Optional[str]
"""
# The unique identifier of this model function call.
@ -126,10 +118,6 @@ class MFCDef:
min_n_seqs_per_pass: int | float = 1
log_return_value: bool = False
# Only used by search.
model_type: Optional[Any | ModelFamily] = None
model_path: Optional[str] = None
# Reserved dataclasses.fields. Should not be set by the user.
_G: nx.DiGraph = None
_pre_hooks: List[RPCHook] = dataclasses.field(default_factory=lambda: [])

View File

@ -11,6 +11,7 @@ from typing import Any, Callable, Dict, Hashable, List, Literal, Optional, Tuple
import aiohttp
import numpy as np
import torch
import torch.distributed as dist
import torch.utils.data
import transformers
@ -24,6 +25,7 @@ from realhf.api.core.config import (
ModelWrapperAbstraction,
)
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample, load_hf_tokenizer
from realhf.base.datapack import flat2d
from realhf.base.recover import StepInfo
logger = logging.getLogger("model_api")
@ -37,15 +39,19 @@ class ZeroTotalLossWeightException(Exception):
class GenRespMeta:
qid: str
accepted: bool
n_tokens: int
@dataclasses.dataclass
class GenReqMeta:
## Meta info used to schedule the request. ##
qid: Hashable
prompt_len: int
group_size: int
new_token_budget: int
predicted_new_tokens: int | None
previous_server_url: str = ""
previous_version: int = -1
@dataclasses.dataclass
@ -120,6 +126,7 @@ class APIGenerateOutput:
@staticmethod
def concat(outputs: List["APIGenerateOutput"]):
assert len(set([o.qid for o in outputs])) == 1
return APIGenerateOutput(
qid=outputs[0].qid,
prompt_ids=outputs[0].prompt_ids,
@ -436,6 +443,8 @@ class ReaLModelConfig:
rotary_special_impl: Optional[str] = None
# for gemma
normalize_embed: bool = False
# for qwen3
qk_layernorm: bool = False
# for opt, it's 2
abs_position_embedding_offset: int = 0
do_layernorm_before: bool = True
@ -798,7 +807,7 @@ class ModelInterface(abc.ABC):
model: Model,
data: SequenceSample,
mb_spec: MicroBatchSpec,
) -> Dict:
) -> Dict | List[Dict]:
raise NotImplementedError()
# Mock methods for creating data and profiling an individual MFC.
@ -860,7 +869,17 @@ class NullInterface(ModelInterface):
def train_step(
self, model: Model, data: SequenceSample, mb_spec: MicroBatchSpec
) -> Dict:
) -> Dict | List[Dict]:
from realhf.base import constants
n_tokens = sum(flat2d(data.seqlens[data._get_split_key()]))
n_tokens = torch.tensor(
n_tokens, dtype=torch.long, device=constants.current_device()
)
dist.all_reduce(n_tokens, group=constants.data_parallel_group())
if constants.parallelism_rank() == 0:
logger.info(f"Number of tokens in NullInterface training: {int(n_tokens)}")
model.inc_version()
return {}
def save(self, model: Model, save_dir: str):

View File

@ -462,8 +462,8 @@ class ExperimentConfig:
)
self_topo = model_topos[rpc.model_name]
if (
self_topo.get_dim("model") % other_topo.get_dim("model") != 0
and other_topo.get_dim("model") % self_topo.get_dim("model") != 0
self_topo.get_dim("tensor") % other_topo.get_dim("tensor") != 0
and other_topo.get_dim("tensor") % self_topo.get_dim("tensor") != 0
):
raise ValueError(
"To synchronize parameters between two models, "

252
realhf/api/from_hf/qwen3.py Normal file
View File

@ -0,0 +1,252 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
from typing import *
from transformers.configuration_utils import PretrainedConfig
from realhf.api.core.model_api import ReaLModelConfig, register_hf_family
from realhf.base.testing import (
TESTING_MODEL_HEAD_DIM,
TESTING_MODEL_HIDDEN_SIZE,
TESTING_MODEL_INTERMEDIATE_SIZE,
TESTING_MODEL_N_HEADS,
TESTING_MODEL_N_LAYERS,
TESTING_MODEL_N_POSITIONS,
TESTING_MODEL_VOCAB_SIZE,
)
from .llama import (
convert_state_dict_llama,
llama_embedding_layer_names,
llama_output_head_param_name,
to_llama_state_dict,
)
class Qwen3Config(PretrainedConfig):
model_type = "qwen3"
keys_to_ignore_at_inference = ["past_key_values"]
# Default tensor parallel plan for base model `Qwen3`
base_model_tp_plan = {
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
head_dim=128,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
attention_dropout=0.0,
**kwargs,
):
from transformers.modeling_rope_utils import rope_config_validation
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.use_sliding_window = use_sliding_window
self.sliding_window = (
sliding_window # we check `use_sliding_window` in the modeling code
)
self.max_window_layers = max_window_layers
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def convert_config_qwen3(
hf_config: Qwen3Config,
) -> ReaLModelConfig:
return ReaLModelConfig(
n_layers=hf_config.num_hidden_layers,
n_kv_heads=hf_config.num_key_value_heads,
hidden_dim=hf_config.hidden_size,
n_q_heads=hf_config.num_attention_heads,
head_dim=getattr(
hf_config,
"head_dim",
hf_config.hidden_size // hf_config.num_attention_heads,
),
intermediate_dim=hf_config.intermediate_size,
vocab_size=hf_config.vocab_size,
n_positions=hf_config.max_position_embeddings,
embd_pdrop=0.0,
attn_pdrop=(
hf_config.attention_dropout
if hasattr(hf_config, "attention_dropout")
else 0.1
),
layer_norm_epsilon=hf_config.rms_norm_eps,
activation_function=hf_config.hidden_act,
use_attention_bias=False,
use_attn_proj_bias=False,
scale_attn_by_inverse_layer_idx=False,
layer_norm_type="rms",
qk_layernorm=True,
mlp_type="llama",
apply_rotary=True,
rotary_base=hf_config.rope_theta,
rotary_interleaved=False,
tied_embedding=hf_config.tie_word_embeddings,
)
def convert_config_back_qwen3(
config: ReaLModelConfig,
) -> Qwen3Config:
return Qwen3Config(
vocab_size=config.vocab_size,
hidden_size=config.hidden_dim,
intermediate_size=config.intermediate_dim,
num_hidden_layers=config.n_layers,
num_key_value_heads=config.n_kv_heads,
num_attention_heads=config.n_q_heads,
head_dim=config.head_dim,
max_position_embeddings=config.n_positions,
rms_norm_eps=config.layer_norm_epsilon,
hidden_act=config.activation_function,
attention_dropout=config.attn_pdrop,
rope_theta=config.rotary_base,
architectures=["Qwen3ForCausalLM"], # ["Qwen3ForCausalLM"],
tie_word_embeddings=config.tied_embedding,
)
def qwen3_config_maker():
hf_config = Qwen3Config(
vocab_size=TESTING_MODEL_VOCAB_SIZE,
max_position_embeddings=TESTING_MODEL_N_POSITIONS,
hidden_size=TESTING_MODEL_HIDDEN_SIZE,
intermediate_size=TESTING_MODEL_INTERMEDIATE_SIZE,
num_hidden_layers=TESTING_MODEL_N_LAYERS,
num_attention_heads=TESTING_MODEL_N_HEADS,
head_dim=TESTING_MODEL_HEAD_DIM,
num_key_value_heads=8,
hidden_act="silu",
rms_norm_eps=1e-5,
)
return convert_config_qwen3(hf_config)
def convert_state_dict_qwen3(state_dict: Dict, config: ReaLModelConfig) -> Dict:
llama_state_dict = convert_state_dict_llama(state_dict, config)
# model.layers.0.self_attn.k_norm.weight -> 1.attn.k_ln.weight
new_state_dict = {}
for k, v in llama_state_dict.items():
if "k_norm" in k:
k = k.replace("k_norm", "k_ln")
if "q_norm" in k:
k = k.replace("q_norm", "q_ln")
new_state_dict[k] = v
return new_state_dict
def convert_state_dict_back_qwen3(state_dict: Dict, config: ReaLModelConfig) -> Dict:
new_sd = to_llama_state_dict(state_dict, config)
layer_indices = list(set([int(k.split(".")[0]) for k in state_dict.keys()]))
for i in layer_indices:
if i == 0 or i == config.n_layers + 1:
continue
new_sd[f"model.layers.{i - 1}.self_attn.k_norm.weight"] = state_dict[
f"{i}.attn.k_ln.weight"
]
new_sd[f"model.layers.{i - 1}.self_attn.q_norm.weight"] = state_dict[
f"{i}.attn.q_ln.weight"
]
return new_sd
def qwen3_transformer_block_param_name(config: ReaLModelConfig, idx: int) -> List[str]:
names = []
for k in ["weight", "bias"]:
names += [
f"model.layers.{idx}.input_layernorm.{k}",
f"model.layers.{idx}.mlp.down_proj.{k}",
f"model.layers.{idx}.mlp.gate_proj.{k}",
f"model.layers.{idx}.mlp.up_proj.{k}",
f"model.layers.{idx}.post_attention_layernorm.{k}",
f"model.layers.{idx}.self_attn.k_proj.{k}",
f"model.layers.{idx}.self_attn.o_proj.{k}",
f"model.layers.{idx}.self_attn.q_proj.{k}",
# f"model.layers.{idx}.self_attn.rotary_emb.inv_freq",
f"model.layers.{idx}.self_attn.v_proj.{k}",
]
if idx == config.n_layers - 1:
names += [f"model.norm.{k}"]
# Qwen3
if config.qk_layernorm:
names += [
f"model.layers.{idx}.self_attn.q_norm.weight",
f"model.layers.{idx}.self_attn.k_norm.weight",
]
return names
register_hf_family(
name="qwen3",
hf_cls_name="Qwen3ForCausalLM", # "Qwen3ForCausalLM"
config_from_hf_converter=convert_config_qwen3,
config_to_hf_converter=convert_config_back_qwen3,
sd_from_hf_converter=convert_state_dict_qwen3,
sd_to_hf_converter=convert_state_dict_back_qwen3,
embedding_param_names=llama_embedding_layer_names,
tblock_param_names=qwen3_transformer_block_param_name,
head_param_names=llama_output_head_param_name,
real_config_maker=qwen3_config_maker,
)

View File

@ -224,18 +224,18 @@ def find_parallel_strategies(
) -> List[ParallelismConfig]:
n_gpus = np.sum(device_mesh.mapping)
res = []
for num_mp in [1, 2, 4, 8]:
if n_gpus >= num_mp:
assert n_gpus % num_mp == 0
num_dp_pp = n_gpus // num_mp
for num_tp in [1, 2, 4, 8]:
if n_gpus >= num_tp:
assert n_gpus % num_tp == 0
num_dp_pp = n_gpus // num_tp
num_pp = 1
while num_pp <= num_dp_pp:
num_dp_mp = n_gpus // num_pp
num_dp_tp = n_gpus // num_pp
valid = (
num_dp_mp in [1, 2, 4, 8] or num_dp_mp % 8 == 0
num_dp_tp in [1, 2, 4, 8] or num_dp_tp % 8 == 0
) and num_dp_pp % num_pp == 0
if valid:
res.append(ParallelismConfig(num_pp, num_mp, num_dp_pp // num_pp))
res.append(ParallelismConfig(num_pp, num_tp, num_dp_pp // num_pp))
num_pp += 1
return res
@ -248,7 +248,7 @@ class RPCAllocation:
def __post_init__(self):
world_size = (
self.parallel.model_parallel_size
self.parallel.tensor_parallel_size
* self.parallel.pipeline_parallel_size
* self.parallel.data_parallel_size
)

View File

@ -8,12 +8,10 @@ import functools
import inspect
import json
import os
import pickle
import subprocess
from typing import Callable, Optional
from typing import Callable
import hydra
import omegaconf
import yaml
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, OmegaConf
@ -29,6 +27,9 @@ def kind_reminder(config_name, logger, args):
logger.info(
f"Logs will be dumped to {os.path.join(LOG_ROOT, args.experiment_name, args.trial_name)}"
)
logger.info(
f"Experiment configs will be dumped to {os.path.join(LOG_ROOT, args.experiment_name, args.trial_name, 'config.yaml')}"
)
logger.info(
f"Model checkpoints will be saved to {os.path.join(MODEL_SAVE_ROOT, args.experiment_name, args.trial_name)}"
)
@ -69,7 +70,7 @@ def register_quickstart_exp(config_name: str, exp_cls: Callable):
logger = logging.getLogger("quickstart", "colored")
print_runtime_helper(OmegaConf.to_object(args))
# print_runtime_helper(OmegaConf.to_object(args))
exp_name = args.experiment_name
if args.trial_name == MISSING:
@ -80,6 +81,17 @@ def register_quickstart_exp(config_name: str, exp_cls: Callable):
trial_name = args.trial_name
from realhf.apps.main import main_start, main_stop
config_save_path = os.path.join(
LOG_ROOT, args.experiment_name, args.trial_name, "config.yaml"
)
with open(config_save_path, "w") as f:
yaml.dump(
dataclasses.asdict(OmegaConf.to_object(args)),
f,
default_flow_style=False,
sort_keys=False,
)
kind_reminder(config_name, logger, args)
exp_fn = functools.partial(exp_cls, **args)

View File

@ -87,7 +87,7 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0):
job_group_id = str(uuid.uuid4())
logger.info(f"AReaL Version: {get_full_version_with_dirty_description()}")
logger.info(f"AReaL Job Group ID: {job_group_id}")
logger.info(f"AReaL Job Group Index: {recover_count}")
logger.info(f"AReaL Job Group Index (recover count): {recover_count}")
if recover_count == 0:
constants.set_experiment_trial_names(args.experiment_name, args.trial_name)
experiment = config_package.make_experiment(args.experiment_name)
@ -110,10 +110,6 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0):
assert (
args.recover_mode == "disabled"
), "Recover mode is not supported for local runs!"
# Use search cache for recover runs
force_allocation_use_cache = (
recover_count > 1 or args.recover_mode == "resume"
) and args.allocation_mode == "search"
# handle args
args.ignore_worker_error = (
args.ignore_worker_error and args.recover_mode == "disabled"
@ -174,12 +170,8 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0):
)
for k, v in BASE_ENVIRONS.items():
os.environ[k] = v
os.environ["REAL_IS_REMOTE"] = "0" if not force_allocation_use_cache else "1"
# setup experiments
if args.allocation_mode == "search":
experiment._search()
sched = sched_client.make(
mode=scheduler_mode(args.mode),
expr_name=expr_name,
@ -324,80 +316,6 @@ def main_find_config(args):
print(exp_name)
def main_profile_layers(args):
from realhf.api.cli_args import ModelFamily
_main_profile_layers(
ModelFamily(args.model_class, args.model_size, args.is_critic),
args.model_path,
)
def _main_profile_layers(model_family, model_path):
from realhf.api.cli_args import ModelFamily
from realhf.base.slurm_utils import check_slurm_availability
from realhf.base.testing import clear_name_resolve
expr_name = trial_name = "profile"
cmd = (
f"python3 -m realhf.apps.profile_layers --expr_name {expr_name} --trial_name {trial_name} "
f"--model_path {model_path} --model_name {model_family} "
)
if check_slurm_availability():
if not os.environ.get("CLUSTER_SPEC_PATH", ""):
raise ValueError(
"Environment variable CLUSTER_SPEC_PATH must be set for slurm mode! "
"See example/cluster_config.json for a template."
)
BASE_ENVIRONS = constants.get_env_vars(
REAL_MODE="slurm",
CLUSTER_SPEC_PATH=os.environ.get("CLUSTER_SPEC_PATH", ""),
)
clear_name_resolve(expr_name, trial_name)
sched = sched_client.make(
mode="slurm", expr_name=expr_name, trial_name=trial_name
)
print(
f"Profiling {model_family} layers, model path {model_path}, " f"cmd {cmd}"
)
sched.submit_array(
worker_type="profile_layer",
cmd=cmd,
count=1,
cpu=64,
gpu=8,
gpu_type="tesla",
mem=500000,
env_vars=BASE_ENVIRONS,
container_image=config_package._LLM_GPU_IMAGE,
)
try:
sched.wait(timeout=None)
except (
KeyboardInterrupt,
sched_client.JobException,
TimeoutError,
) as e:
sched.stop_all()
raise e
else:
try:
print(
f"Profiling {model_family} layers, model path {model_path}, "
f"cmd {cmd}"
)
clear_name_resolve(expr_name, trial_name)
os.system(cmd)
except (
KeyboardInterrupt,
sched_client.JobException,
TimeoutError,
) as e:
raise e
def main():
parser = argparse.ArgumentParser(prog="ReaLHF")
subparsers = parser.add_subparsers(dest="cmd", help="sub-command help")
@ -482,7 +400,7 @@ def main():
type=str,
required=False,
default="pipe_model",
choices=["manual", "search", "heuristic", "pipe_model", "pipe_data"],
choices=["manual", "heuristic", "pipe_model", "pipe_data"],
help="Mode of GPU resource/model parallel strategy allocation.",
)
subparser.set_defaults(ignore_worker_error=False)
@ -514,15 +432,6 @@ def main():
subparser.add_argument("--regex", "-r", type=str, required=True)
subparser.set_defaults(func=main_find_config)
subparser = subparsers.add_parser(
"profile_layers", help="profile layers of a model."
)
subparser.add_argument("--model_class", type=str, required=True)
subparser.add_argument("--model_size", type=int, required=True)
subparser.add_argument("--is_critic", action="store_true")
subparser.add_argument("--model_path", type=str, required=True)
subparser.set_defaults(func=main_profile_layers)
args = parser.parse_args()
args.func(args)

View File

@ -1,91 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import argparse
import itertools
import time
import realhf.base.testing as testing
BATCH_SIZE_RANGE = [1, 2, 4, 8, 16, 32, 64, 128]
SEQ_LEN_RANGE = [128, 256, 512]
def profile_layer_func(
world_size,
model_path,
model_name,
warm_up_rounds,
profile_rounds,
batch_size_range,
seq_len_range,
use_sequence_parallel=False,
use_gradient_checkpointing=False,
):
# FIXME: use_sequence_parallel=True and use_gradient_checkpointing=True will cause bugs
import torch
import realhf.base.constants as constants
testing.init_global_constants(
1, world_size, 1, sequence_parallel=False, gradient_checkpointing=False
)
device = torch.device("cuda")
with constants.model_scope(testing.MODEL_NAME):
from realhf.search_engine.layers import make_profile_layers
profile_layers = make_profile_layers(device, model_path, model_name)
st = time.monotonic_ns()
for i in range(warm_up_rounds + profile_rounds):
for bs, seq_len in itertools.product(batch_size_range, seq_len_range):
profile_layers.fwd_gen(bs, seq_len)
profile_layers.fwd_bwd_opt(bs, seq_len)
if i < warm_up_rounds:
profile_layers.reset_stats()
profile_layers.make_dataframe_and_print()
profile_layers.dump_stats(world_size)
t = (time.monotonic_ns() - st) / int(1e9)
print(f"profile world size {world_size} cost {t:4f} seconds")
if __name__ == "__main__":
st = time.monotonic_ns()
parser = argparse.ArgumentParser(prog="profile_layers")
parser.add_argument(
"--model_path",
type=str,
required=True,
)
parser.add_argument("--expr_name", type=str, default="profile")
parser.add_argument("--trial_name", type=str, default="profile")
parser.add_argument("--model_name", type=str, default="Llama-2-70b")
parser.add_argument("--warm_up_rounds", type=int, default=1)
parser.add_argument("--profile_rounds", type=int, default=3)
# parser.add_argument("--use_sequence_parallel", action="store_true")
# parser.add_argument("--use_gradient_checkpointing", action="store_true")
args = parser.parse_args()
world_sizes = [1, 2, 4, 8]
for world_size in world_sizes:
testing.clear_name_resolve(args.expr_name, args.trial_name)
mp = testing.LocalMultiProcessTest(
world_size,
profile_layer_func,
world_size,
args.model_path,
args.model_name,
args.warm_up_rounds,
args.profile_rounds,
BATCH_SIZE_RANGE,
SEQ_LEN_RANGE,
expr_name=args.expr_name,
trial_name=args.trial_name,
)
mp.launch()
t = (time.monotonic_ns() - st) / int(1e9)
print(f"profile model {args.model_name} time cost {t:4f} seconds")

View File

@ -72,6 +72,7 @@ MODEL_SAVE_ROOT = f"{cluster_spec.fileroot}/checkpoints/{getpass.getuser()}"
LOG_ROOT = f"{cluster_spec.fileroot}/logs/{getpass.getuser()}"
RECOVER_ROOT = f"{cluster_spec.fileroot}/recover/{getpass.getuser()}"
SLURM_LOCK_FILE_NAME = f"{cluster_spec.fileroot}/logs/slurm_scheduler.lock"
PORT_LOCK_FILE_ROOT = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/ports"
PYTORCH_KERNEL_CACHE_PATH = (
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels"
)
@ -120,6 +121,9 @@ BASE_ENVIRONS = {
"REAL_GPU_MEMORY_KILL_THRESHOLD": os.getenv(
"REAL_GPU_MEMORY_KILL_THRESHOLD", "0.95"
),
"LC_ALL": "C",
"LANG": "C",
"NCCL_DEBUG": "WARN",
}
# Set PPU-specific environment variables for stable training.
@ -146,7 +150,6 @@ elif cluster_spec.name == "na132":
"NCCL_IB_SL": "5",
"NCCL_IB_TC": "136",
"NCCL_IB_HCA": "mlx5_bond",
"NCCL_DEBUG": "WARN",
"NCCL_IB_QPS_PER_CONNECTION": "8",
"NCCL_SET_THREAD_NAME": "1",
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
@ -165,6 +168,7 @@ os.makedirs(DATASET_CACHE_PATH, exist_ok=True)
os.makedirs(PROFILER_CACHE_PATH, exist_ok=True)
os.makedirs(TORCH_EXTENSIONS_DIR, exist_ok=True)
os.makedirs(QUICKSTART_EXPR_CACHE_PATH, exist_ok=True)
os.makedirs(PORT_LOCK_FILE_ROOT, exist_ok=True)
os.makedirs(SGLANG_CACHE_PATH, exist_ok=True)
# _model_name will be changed in the model_scope context manager
@ -186,16 +190,12 @@ _self_group = None
_rank_mapping: Dict["ModelName", Dict["ModelShardID", int]] = {}
_global_memory_buffer: GlobalMemoryBuffer = GlobalMemoryBuffer()
# used only in scripts and tests
_fake_mp_world_size = None
_fake_mp_rank = None
# TODO: As in Megatron, we can set NCCL group options. Is it necessary?
def reset_run():
global _model_name, _grids, _pgroups, _pgroup_ranks, _self_group, _rank_mapping, _global_memory_buffer, _fake_mp_world_size, _fake_mp_rank
global _model_name, _grids, _pgroups, _pgroup_ranks, _self_group, _rank_mapping, _global_memory_buffer
_model_name = None
_grids = {}
_pgroups = {}
@ -203,8 +203,6 @@ def reset_run():
_self_group = None
_rank_mapping = {}
_global_memory_buffer = GlobalMemoryBuffer()
_fake_mp_world_size = None
_fake_mp_rank = None
@contextlib.contextmanager
@ -284,7 +282,7 @@ def set_rank_mapping(
else:
msid2mwid = {k: v for k, v in msid2mwid.items() if k.model_name == model_name}
_rank_mapping[model_name] = {
topo.get_rank(data=s.dp_rank, model=s.mp_rank, pipe=s.pp_rank): mw_id
topo.get_rank(data=s.dp_rank, tensor=s.tp_rank, pipe=s.pp_rank): mw_id
for s, mw_id in msid2mwid.items()
}
@ -408,7 +406,7 @@ def parallelism_group_ranks():
def parallelism_group_size() -> int:
"""The 3D parallelism group size of a specific model, normally dp_size *
pp_size * mp_size."""
pp_size * tp_size."""
import torch.distributed as dist
return dist.get_world_size(group=parallelism_group())
@ -470,37 +468,25 @@ def prev_pipe_stage():
def is_dp_head():
return is_last_pipe_stage() and model_parallel_rank() == 0
return is_last_pipe_stage() and tensor_parallel_rank() == 0
def model_parallel_rank() -> int:
def tensor_parallel_rank() -> int:
"""Return the rank inside the tensor parallelism group."""
try:
return grid().get_tensor_model_parallel_rank()
except RuntimeError as e: # used only in scripts and tests
if _fake_mp_rank is not None:
return _fake_mp_rank
else:
raise e
def model_parallel_world_size() -> int:
def tensor_parallel_world_size() -> int:
"""Return the world size of the tensor parallelism group."""
try:
return grid().get_tensor_model_parallel_world_size()
except RuntimeError as e: # used only in scripts and tests
if _fake_mp_world_size is not None:
return _fake_mp_world_size
else:
raise e
def model_parallel_group():
def tensor_parallel_group():
"""Return the NCCL tensor parallelism process group."""
return grid().get_tensor_model_parallel_group()
def model_parallel_cpu_group():
def tensor_parallel_cpu_group():
"""Return the GLOO tensor parallelism process group."""
return grid().get_tensor_model_parallel_cpu_group()
@ -536,26 +522,6 @@ def data_parallel_group():
return grid().get_data_parallel_group()
def set_fake_mp_world_size(world_size):
# used only in scripts and tests
global _fake_mp_world_size
_fake_mp_world_size = world_size
def set_fake_mp_rank(rank):
# used only in scripts and tests
global _fake_mp_rank
_fake_mp_rank = rank
def set_fake_grid(model_name, rank, topo):
# used only in scripts and tests
from realhf.base.topology import FakeGrid
global _grids
_grids[model_name] = FakeGrid(rank=rank, topo=topo)
def get_global_memory_buffer():
global _global_memory_buffer
assert _global_memory_buffer is not None, "global memory buffer is not set"

View File

@ -62,7 +62,7 @@ def reveal_pg_identity(expr_name, trial_name, worker_index):
master_group_name = names.distributed_peer(
expr_name, trial_name, GLOBAL_PROCESS_GROUP_NAME
)
name_resolve.add_subentry(master_group_name, str(worker_index), keepalive_ttl=300)
name_resolve.add_subentry(master_group_name, str(worker_index))
def isolate_cuda_device(
@ -100,12 +100,10 @@ def isolate_cuda_device(
name_resolve_identifier,
),
rank,
keepalive_ttl=60,
)
name_resolve.add_subentry(
names.distributed_peer(experiment_name, trial_name, name_resolve_identifier),
rank,
keepalive_ttl=30,
)
logger.debug(
f"Worker type {worker_type} rank {rank} waiting for peers, world size {world_size}..."

View File

@ -141,9 +141,18 @@ def getLogger(
return logging.getLogger(name)
_LATEST_WANDB_STEP = 0
def log_wandb_tensorboard(data, step=None, summary_writer=None):
import wandb
global _LATEST_WANDB_STEP
if step is None:
step = _LATEST_WANDB_STEP
else:
_LATEST_WANDB_STEP = max(_LATEST_WANDB_STEP, step)
wandb.log(data, step=step)
if summary_writer is not None:
for key, val in data.items():

View File

@ -618,7 +618,7 @@ class Etcd3NameRecordRepository(NameRecordRepository):
self._to_delete = set()
logger.info(f"Connected to etcd3 at {self._host}:{self._port}")
logger.debug(f"Connected to etcd3 at {self._host}:{self._port}")
def __del__(self):
"""Clean up resources when the object is deleted."""
@ -945,12 +945,13 @@ def make_repository(type_="nfs", **kwargs):
# DEFAULT_REPOSITORY_TYPE = "redis" if socket.gethostname().startswith("frl") else "nfs"
DEFAULT_REPOSITORY_TYPE = "nfs"
if (
etcd3 is not None
and cluster.spec.name in ["wa180", "na132", "su18"]
and os.getenv("REAL_ETCD_ADDR", "")
):
if etcd3 is not None and os.getenv("REAL_ETCD_ADDR", ""):
DEFAULT_REPOSITORY_TYPE = "etcd3"
if os.getenv("REAL_ETCD_ADDR", "") and etcd3 is None:
logger.warning(
f"Detected REAL_ETCD_ADDR but etcd3 client is not available. "
"Please run `pip install -r requirements.txt` if you want to use etcd name resolve."
)
DEFAULT_REPOSITORY = make_repository(DEFAULT_REPOSITORY_TYPE)
add = DEFAULT_REPOSITORY.add
add_subentry = DEFAULT_REPOSITORY.add_subentry

View File

@ -93,5 +93,9 @@ def gen_servers(experiment_name, trial_name):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_servers"
def used_ports(experiment_name, trial_name, host_name):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/{host_name}/"
def gen_server_manager(experiment_name, trial_name):
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_server_manager"

View File

@ -2,31 +2,16 @@
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import fcntl
import os
import socket
import time
from contextlib import closing
from functools import wraps
from realhf.base import constants, logging, name_resolve, names
def find_free_port(low=1, high=65536, exclude_ports=None):
"""Find a free port within the specified range, excluding certain ports."""
if exclude_ports is None:
exclude_ports = set()
while True:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
port = s.getsockname()[1]
if low <= port <= high and port not in exclude_ports:
return port
def find_multiple_free_ports(count, low=1, high=65536):
"""Find multiple mutually exclusive free ports."""
free_ports = set()
for _ in range(count):
port = find_free_port(low, high, exclude_ports=free_ports)
free_ports.add(port)
return list(free_ports)
logger = logging.getLogger(__name__)
def gethostname():
@ -35,3 +20,54 @@ def gethostname():
def gethostip():
return socket.gethostbyname(socket.gethostname())
def find_free_port(
low=1, high=65536, exclude_ports=None, experiment_name="port", trial_name="port"
):
"""Find a free port within the specified range, excluding certain ports."""
ports_name = names.used_ports(experiment_name, trial_name, gethostip())
used_ports = list(map(int, name_resolve.get_subtree(ports_name)))
if exclude_ports is None:
exclude_ports = set(used_ports)
else:
exclude_ports = exclude_ports.union(set(used_ports))
free_port = None
lockfile = os.path.join(constants.PORT_LOCK_FILE_ROOT, gethostip())
while True:
with open(lockfile, "w") as fd:
# This will block until lock is acquired
fcntl.flock(fd, fcntl.LOCK_EX)
try:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
port = s.getsockname()[1]
if low <= port <= high and port not in exclude_ports:
name_resolve.add_subentry(ports_name, str(port))
logger.info(f"Found free port {port}")
free_port = port
break
finally:
fcntl.flock(fd, fcntl.LOCK_UN)
time.sleep(0.05)
return free_port
def find_multiple_free_ports(
count, low=1, high=65536, experiment_name="port", trial_name="port"
):
"""Find multiple mutually exclusive free ports."""
free_ports = set()
for _ in range(count):
port = find_free_port(
low=low,
high=high,
exclude_ports=free_ports,
experiment_name=experiment_name,
trial_name=trial_name,
)
free_ports.add(port)
return list(free_ports)

View File

@ -171,6 +171,9 @@ class DistributedStatsTracker:
else:
raise ValueError(f"Unknown reduce type: {reduce_type}")
keys_to_pop = [k for k, v in result.items() if v is None]
for k in keys_to_pop:
result.pop(k)
return result
def _sum_of(self, key, reduce_group):
@ -209,7 +212,7 @@ class DistributedStatsTracker:
dist.all_reduce(x, group=reduce_group)
dist.all_reduce(d, group=reduce_group)
if d == 0:
return 0
return None
return x / d
def _min_of(self, key, reduce_group):
@ -224,7 +227,7 @@ class DistributedStatsTracker:
if reduce_group is not None:
dist.all_reduce(x, group=reduce_group, op=dist.ReduceOp.MIN)
if torch.isinf(x):
return float("nan")
return None
return float(x)
def _max_of(self, key, reduce_group):
@ -239,7 +242,7 @@ class DistributedStatsTracker:
if reduce_group is not None:
dist.all_reduce(x, group=reduce_group, op=dist.ReduceOp.MAX)
if torch.isinf(x):
return float("nan")
return None
return float(x)

View File

@ -22,9 +22,9 @@ import torch.utils.data
from realhf.api.core.data_api import SequenceSample
from realhf.base import constants, gpu_utils, logging, name_resolve, names, topology
from realhf.base.topology import (
DataPipeModelParallelTopology,
DataPipeTensorParallelTopology,
ParallelGrid,
PipeDataModelParallelTopology,
PipeDataTensorParallelTopology,
)
logger = logging.getLogger("testing")
@ -106,9 +106,6 @@ class StandaloneTestingProcess(mp.Process):
self.expr_name, self.trial_name, self.rank, backend=self.dist_backend
)
# setup some useful constants
constants.set_experiment_trial_names(self.expr_name, self.trial_name)
# misc setup
if constants.use_cuda():
pynvml.nvmlInit()
@ -207,7 +204,7 @@ class LocalMultiProcessTest:
def init_global_constants(
num_dp=1,
num_mp=1,
num_tp=1,
num_pp=1,
topo=None,
model_name=None,
@ -227,9 +224,9 @@ def init_global_constants(
if topo is None:
if is_train:
topo = PipeDataModelParallelTopology(
topo = PipeDataTensorParallelTopology(
num_dp=num_dp,
num_mp=num_mp,
num_tp=num_tp,
num_pp=num_pp,
sequence_parallel=sequence_parallel,
gradient_checkpointing=gradient_checkpointing,
@ -237,13 +234,13 @@ def init_global_constants(
max_prompt_len=max_prompt_len,
)
else:
topo = DataPipeModelParallelTopology(
topo = DataPipeTensorParallelTopology(
num_dp=num_dp,
num_mp=num_mp,
num_tp=num_tp,
num_pp=num_pp,
sequence_parallel=sequence_parallel,
)
ws = num_dp * num_mp * num_pp
ws = num_dp * num_tp * num_pp
else:
ws = topo.world_size()

View File

@ -65,22 +65,22 @@ def decompose_to_three_factors(n: int) -> List[Tuple[int, int, int]]:
return factors
class PipeDataModelrocessCoord(NamedTuple):
class PipeDataTensorProcessCoord(NamedTuple):
pipe: int
data: int
model: int
tensor: int
class DataPipeModelrocessCoord(NamedTuple):
class DataPipeTensorProcessCoord(NamedTuple):
data: int
pipe: int
model: int
tensor: int
# Explicitly define these class to allow pickling.
PROCESS_COORD_REGISTRY = {
"pipe#data#model": PipeDataModelrocessCoord,
"data#pipe#model": DataPipeModelrocessCoord,
"pipe#data#tensor": PipeDataTensorProcessCoord,
"data#pipe#tensor": DataPipeTensorProcessCoord,
}
@ -327,20 +327,20 @@ def _prime_factors(N):
return primes
class PipeDataModelParallelTopology(ProcessTopology):
class PipeDataTensorParallelTopology(ProcessTopology):
"""A topology for hybrid pipeline, model, and data parallelism."""
def __init__(
self,
num_pp: int,
num_mp: int,
num_tp: int,
num_dp: int,
sequence_parallel: bool,
gradient_checkpointing: bool,
gradient_accumulation_fusion: bool,
max_prompt_len: Optional[int] = None,
):
super().__init__(axes=["pipe", "data", "model"], dims=[num_pp, num_dp, num_mp])
super().__init__(axes=["pipe", "data", "tensor"], dims=[num_pp, num_dp, num_tp])
self.sequence_parallel = sequence_parallel
self.gradient_checkpointing = gradient_checkpointing
@ -348,7 +348,7 @@ class PipeDataModelParallelTopology(ProcessTopology):
self.gradient_accumulation_fusion = gradient_accumulation_fusion
class DataPipeModelParallelTopology(ProcessTopology):
class DataPipeTensorParallelTopology(ProcessTopology):
"""A topology for hybrid data, pipeline, and tensor parallelism.
Note that DP is the most outer dimension. Used for inference only.
@ -357,12 +357,12 @@ class DataPipeModelParallelTopology(ProcessTopology):
def __init__(
self,
num_pp: int,
num_mp: int,
num_tp: int,
num_dp: int,
sequence_parallel: bool,
max_prompt_len: Optional[int] = None,
):
super().__init__(axes=["data", "pipe", "model"], dims=[num_dp, num_pp, num_mp])
super().__init__(axes=["data", "pipe", "tensor"], dims=[num_dp, num_pp, num_tp])
self.sequence_parallel = sequence_parallel
self.max_prompt_len = max_prompt_len
@ -414,7 +414,7 @@ class ParallelGrid:
self.data_parallel_size = max(self._topo.get_dim("data"), 1)
self.pipe_parallel_size = max(self._topo.get_dim("pipe"), 1)
self.model_parallel_size = max(self._topo.get_dim("model"), 1)
self.model_parallel_size = max(self._topo.get_dim("tensor"), 1)
self.slice_parallel_size = self.model_parallel_size
assert self._is_grid_valid(), (
"Invalid Grid",
@ -520,7 +520,7 @@ class ParallelGrid:
self.slice_group = None
self.slice_proc_group = self.slice_proc_group_gloo = None
self.mp_group = []
self.model_groups = self._topo.get_axis_comm_lists("model")
self.model_groups = self._topo.get_axis_comm_lists("tensor")
for g in self.model_groups:
proc_group = new_or_get_group(ranks=[rank_mapping[x] for x in g])
# NOTE: We must create the GLOO group for vLLM's usage.
@ -634,8 +634,8 @@ class ParallelGrid:
def get_tensor_model_parallel_rank(self):
if self.global_rank == -1:
return -1
if "model" in self._topo.get_axis_names():
return self._topo.get_coord(rank=self.global_rank).model
if "tensor" in self._topo.get_axis_names():
return self._topo.get_coord(rank=self.global_rank).tensor
else:
return 0
@ -662,12 +662,12 @@ class FakeGrid:
self.data_parallel_size = max(self._topo.get_dim("data"), 1)
self.pipe_parallel_size = max(self._topo.get_dim("pipe"), 1)
self.model_parallel_size = max(self._topo.get_dim("model"), 1)
self.model_parallel_size = max(self._topo.get_dim("tensor"), 1)
self.coord: ProcessCoord = self._topo.get_coord(self.rank)
self.coord = self._topo.get_coord(self.rank)
self.dp_id = self.coord.data
self.pp_id = self.coord.pipe
self.mp_id = self.coord.model
self.mp_id = self.coord.tensor
self.world_size = (
self.data_parallel_size * self.pipe_parallel_size * self.model_parallel_size

View File

@ -6,7 +6,11 @@ from typing import Any, Dict, List, Tuple
import realhf.base.logging as logging
from realhf.api.cli_args import ModelTrainEvalConfig, PPOMATHExperimentOptions
from realhf.api.core.config import AgentAbstraction, EnvServiceAbstraction
from realhf.api.core.config import (
AgentAbstraction,
EnvServiceAbstraction,
ModelInterfaceAbstraction,
)
from realhf.api.core.model_api import GenerationHyperparameters
from realhf.api.quickstart.entrypoint import register_quickstart_exp
from realhf.experiments.async_exp.async_rl_exp import AsyncRLExperimentConfig
@ -36,7 +40,7 @@ class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig):
@property
def env(self) -> EnvServiceAbstraction:
return EnvServiceAbstraction(
"math-single-step", args=dict(dataset_path=self.dataset.path)
"math-code-single-step", args=dict(dataset_path=self.dataset.path)
)
@property
@ -71,6 +75,11 @@ class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig):
rpcs["ref_inf"].output_keys = ("packed_ref_logprobs",)
if "rew_inf" in rpcs:
rpcs.pop("rew_inf")
if self.no_training:
rpcs["actor_train"].interface_impl = ModelInterfaceAbstraction("null")
rpcs["actor_gen"].interface_impl = ModelInterfaceAbstraction("null")
if "actor_inf" in rpcs:
rpcs["actor_inf"].interface_impl = ModelInterfaceAbstraction("null")
return rpcs
@property

17
realhf/experiments/async_exp/async_rl_exp.py Normal file → Executable file
View File

@ -60,7 +60,6 @@ GEN_WORKER_DEFAULT_CAPACITY = 512
@dataclasses.dataclass
class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
@property
def generation_config(self) -> GenerationHyperparameters:
raise NotImplementedError()
@ -203,16 +202,17 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
"config_from_hf_converter"
](hf_config)
if (
model_config.n_kv_heads % rpc_alloc.parallel.model_parallel_size
model_config.n_kv_heads % rpc_alloc.parallel.tensor_parallel_size
!= 0
) or (
model_config.n_q_heads % rpc_alloc.parallel.model_parallel_size != 0
model_config.n_q_heads % rpc_alloc.parallel.tensor_parallel_size
!= 0
):
raise ValueError(
f"The number of KV heads {model_config.n_kv_heads} or "
f"Q heads {model_config.n_q_heads} is not"
f" divisible by the configured TP size "
f"({rpc_alloc.parallel.model_parallel_size}). "
f"({rpc_alloc.parallel.tensor_parallel_size}). "
f"Please decrease TP size."
)
mapping = rpc_alloc.device_mesh.mapping
@ -250,7 +250,7 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
topo=topo,
dp_rank=topo.get_coord(shard_idx).data,
pp_rank=topo.get_coord(shard_idx).pipe,
mp_rank=topo.get_coord(shard_idx).model,
tp_rank=topo.get_coord(shard_idx).tensor,
),
model=model,
backend=backend,
@ -308,15 +308,18 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
model_name = gen_rpc_alloc.rpc.model_name
train_rpcs = [alloc.rpc for alloc in rpc_allocs if alloc.rpc.is_train()]
assert all(rpc.n_seqs == train_rpcs[0].n_seqs for rpc in train_rpcs)
max_concurrent_rollouts = self.max_concurrent_rollouts
if max_concurrent_rollouts is None:
max_concurrent_rollouts = train_rpcs[0].n_seqs
return [
GserverManager(
model_name=model_name,
flush_request_timeout=self.flush_request_timeout,
n_servers=gen_world_size // gen_tp_size,
schedule_policy="round_robin",
schedule_policy=self.schedule_policy,
max_head_offpolicyness=self.max_head_offpolicyness,
train_batch_size=train_rpcs[0].n_seqs,
max_concurrent_rollouts=self.max_concurrent_rollouts,
max_concurrent_rollouts=max_concurrent_rollouts,
)
]

View File

@ -37,21 +37,21 @@ def default_parallel_config(n_gpus: int) -> List[Dict[str, Any]]:
x = [
{
"data_parallel_size": dp,
"model_parallel_size": mp,
"tensor_parallel_size": tp,
"pipeline_parallel_size": pp,
"use_sequence_parallel": mp > 1,
"use_sequence_parallel": tp > 1,
}
for dp, mp, pp in factors
for dp, tp, pp in factors
]
x += [
{
"data_parallel_size": dp,
"model_parallel_size": mp,
"tensor_parallel_size": tp,
"pipeline_parallel_size": pp,
"use_sequence_parallel": False,
}
for dp, mp, pp in factors
if mp > 1
for dp, tp, pp in factors
if tp > 1
]
return x
@ -122,7 +122,7 @@ class ProfileConfig(CommonExperimentConfig):
k
in [
"data_parallel_size",
"model_parallel_size",
"tensor_parallel_size",
"pipeline_parallel_size",
"use_sequence_parallel",
]
@ -130,7 +130,7 @@ class ProfileConfig(CommonExperimentConfig):
), pcfg.keys()
assert (self.n_nodes * self.n_gpus_per_node) == (
pcfg.get("data_parallel_size", 1)
* pcfg.get("model_parallel_size", 1)
* pcfg.get("tensor_parallel_size", 1)
* pcfg.get("pipeline_parallel_size", 1)
)
@ -246,8 +246,6 @@ class ProfileConfig(CommonExperimentConfig):
model_name="default",
input_keys=["packed_prompts"],
log_return_value=False,
model_type=self._tmp_model.type,
model_path=self._tmp_model.path,
balanced_dp=True,
)

View File

@ -70,7 +70,7 @@ def check_valid_parallel_batch_size(rpc_alloc: RPCAllocation):
mb_spec = rpc.mb_spec
dp_size = rpc_alloc.parallel.data_parallel_size
tp_size = rpc_alloc.parallel.model_parallel_size
tp_size = rpc_alloc.parallel.tensor_parallel_size
pp_size = rpc_alloc.parallel.pipeline_parallel_size
factor = 1

View File

@ -44,7 +44,6 @@ from realhf.api.quickstart.device_mesh import (
)
from realhf.base.cluster import spec as cluster_spec
from realhf.experiments.common.check import (
check_is_realhf_native_model_interface,
check_valid_model_and_path,
check_valid_optimizer,
check_valid_parallel_batch_size,
@ -61,7 +60,6 @@ from realhf.experiments.common.utils import (
resolve_replica_ids,
resolve_rpc_hooks,
)
from realhf.search_engine.search import search_rpc_allocations
# Register all HF models
import realhf.api.from_hf # isort:skip
@ -144,10 +142,6 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
"""
return None
@property
def search_kwargs(self) -> Dict[str, Any]:
return {}
@property
def global_device_mesh(self) -> DeviceMesh:
return DeviceMesh(
@ -161,20 +155,6 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
f"_heuristic_rpc_allocation is not implemented in {self.__class__}"
)
def _search(self):
# called in both api.main and controller
gradient_checkpointing = any(
model.gradient_checkpointing for model in self.models.values()
)
rpc_allocs: List[RPCAllocation] = search_rpc_allocations(
device_mesh=self.global_device_mesh,
rpcs=list(self.rpcs.values()),
gradient_checkpointing=gradient_checkpointing,
use_cache=self.allocation_use_cache,
**self.search_kwargs,
)
return rpc_allocs
def scheduling_setup(self) -> ExperimentScheduling:
"""The resourced occupied by each worker.
@ -221,24 +201,11 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
self._check_legal_allocation_options()
rpcs = self.rpcs
if self.allocation_mode == "search":
# assert self.mode == "slurm"
# assumes gradient checkpointing for all training RPCs if one is enabled
# for the simplicity of search configurations
rpc_allocs = self._search()
for rpc_alloc in rpc_allocs:
assert isinstance(rpc_alloc.rpc, str)
for rpc in rpcs.values():
if rpc.name == rpc_alloc.rpc:
rpc_alloc.rpc = rpc
break
else:
raise ValueError(f"RPC {rpc_alloc.rpc} not found in rpcs.")
elif self._allocation_mode.is_decoupled():
if self._allocation_mode.is_decoupled():
paras = self._allocation_mode.parallel_strat
gdp, gpp, gmp = paras["gen"]["d"], paras["gen"]["p"], paras["gen"]["m"]
gen_world_size = gdp * gpp * gmp
gdp, gpp, gtp = paras["gen"]["d"], paras["gen"]["p"], paras["gen"]["m"]
gen_world_size = gdp * gpp * gtp
assert (
gen_world_size < self.n_gpus_per_node
or gen_world_size % self.n_gpus_per_node == 0
@ -268,7 +235,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
parallel=ParallelismConfig(
data_parallel_size=gdp,
pipeline_parallel_size=gpp,
model_parallel_size=gmp,
tensor_parallel_size=gtp,
use_sequence_parallel=False,
),
)
@ -276,7 +243,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
else:
rpc_name = rpc.name
if rpc_name in paras:
dp, pp, mp = (
dp, pp, tp = (
paras[rpc_name]["d"],
paras[rpc_name]["p"],
paras[rpc_name]["m"],
@ -287,9 +254,9 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
f"RPC {rpc_name} parallel strategy not given, "
"expect a `*` to specify the default parallel strategy."
)
dp, pp, mp = paras["*"]["d"], paras["*"]["p"], paras["*"]["m"]
dp, pp, tp = paras["*"]["d"], paras["*"]["p"], paras["*"]["m"]
if (
dp * pp * mp + gdp * gpp * gmp
dp * pp * tp + gdp * gpp * gtp
!= self.n_nodes * self.n_gpus_per_node
):
raise ValueError(
@ -297,7 +264,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
"does not equal to the number of gpus. "
"Note that the device mesh of vLLM/SGLang should be disjoint from the device mesh of other MFCs, "
"so their summation should be equal to the total number of gpus. "
f"dp={dp}, pp={pp}, mp={mp}, gen.dp={gdp}, gen.pp={gpp}, gen.mp={gmp}, "
f"dp={dp}, pp={pp}, mp={tp}, gen.dp={gdp}, gen.pp={gpp}, gen.mp={gtp}, "
f"n_nodes={self.n_nodes}, n_gpus_per_node={self.n_gpus_per_node}"
)
alloc = RPCAllocation(
@ -306,10 +273,10 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
parallel=ParallelismConfig(
data_parallel_size=dp,
pipeline_parallel_size=pp,
model_parallel_size=mp,
tensor_parallel_size=tp,
use_sequence_parallel=(
rpc.interface_type == ModelInterfaceType.TRAIN_STEP
and mp > 1
and tp > 1
),
),
)
@ -323,7 +290,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
rpc_allocs = []
for rpc_name, rpc in self.rpcs.items():
if rpc_name in paras:
dp, pp, mp = (
dp, pp, tp = (
paras[rpc_name]["d"],
paras[rpc_name]["p"],
paras[rpc_name]["m"],
@ -334,18 +301,18 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
f"RPC {rpc_name} parallel strategy not given, "
"expect a `*` to specify the default parallel strategy."
)
dp, pp, mp = paras["*"]["d"], paras["*"]["p"], paras["*"]["m"]
assert dp * pp * mp == self.n_nodes * self.n_gpus_per_node
dp, pp, tp = paras["*"]["d"], paras["*"]["p"], paras["*"]["m"]
assert dp * pp * tp == self.n_nodes * self.n_gpus_per_node
alloc = RPCAllocation(
rpc=rpc,
device_mesh=self.global_device_mesh,
parallel=ParallelismConfig(
data_parallel_size=dp,
pipeline_parallel_size=pp,
model_parallel_size=mp,
tensor_parallel_size=tp,
use_sequence_parallel=(
rpc.interface_type == ModelInterfaceType.TRAIN_STEP
and mp > 1
and tp > 1
),
),
)
@ -455,7 +422,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
topo=topo,
dp_rank=topo.get_coord(shard_idx).data,
pp_rank=topo.get_coord(shard_idx).pipe,
mp_rank=topo.get_coord(shard_idx).model,
tp_rank=topo.get_coord(shard_idx).tensor,
),
model=ModelAbstraction(
"tokenizer", args=dict(tokenizer_path=model_cfg.path)
@ -464,7 +431,6 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
gen_backend_name,
args=dict(
model_path=model_cfg.path,
dtype="bfloat16" if model_cfg.bf16 else "float16",
**dict_args,
),
),
@ -503,16 +469,17 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
"config_from_hf_converter"
](hf_config)
if (
model_config.n_kv_heads % rpc_alloc.parallel.model_parallel_size
model_config.n_kv_heads % rpc_alloc.parallel.tensor_parallel_size
!= 0
) or (
model_config.n_q_heads % rpc_alloc.parallel.model_parallel_size != 0
model_config.n_q_heads % rpc_alloc.parallel.tensor_parallel_size
!= 0
):
raise ValueError(
f"The number of KV heads {model_config.n_kv_heads} or "
f"Q heads {model_config.n_q_heads} is not"
f" divisible by the configured TP size "
f"({rpc_alloc.parallel.model_parallel_size}). "
f"({rpc_alloc.parallel.tensor_parallel_size}). "
f"Please decrease TP size."
)
mapping = rpc_alloc.device_mesh.mapping
@ -572,7 +539,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
topo=topo,
dp_rank=topo.get_coord(shard_idx).data,
pp_rank=topo.get_coord(shard_idx).pipe,
mp_rank=topo.get_coord(shard_idx).model,
tp_rank=topo.get_coord(shard_idx).tensor,
),
model=model,
backend=backend,
@ -612,12 +579,9 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
"please setup slurm for distributed runs."
)
if self.n_gpus_per_node != 8 and self.allocation_mode in [
"search",
"heuristic",
]:
if self.n_gpus_per_node != 8 and self.allocation_mode == "heuristic":
raise ValueError(
f"Cannot run search or heuristic allocation with "
f"Cannot run heuristic allocation with "
f"n_gpus_per_node {self.n_gpus_per_node}, "
"please set n_gpus_per_node to 8."
)
@ -627,13 +591,6 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
raise KeyError(
f"RPC name {rpc_name} does not match the name in the MFCDef object {rpc.name}."
)
if not check_is_realhf_native_model_interface(
rpc.interface_impl.type_
) and self.allocation_mode in ["search"]:
raise ValueError(
f"RPC {rpc.name} interface is not a realhf native implementation. "
f"The search allocation mode are not available."
)
if self.allocation_mode == "manual" and rpc_name not in self.allocations:
if rpc_name not in self.allocations:
raise ValueError(

View File

@ -65,8 +65,6 @@ class MathCodeEvalConfig(MathCodeEvalOptions, CommonExperimentConfig):
model_name="actor",
mb_spec=self.actor_gen.mb_spec,
interface_type=ModelInterfaceType.GENERATE,
model_type=self.actor.type,
model_path=self.actor.path,
interface_impl=actor_interface,
input_keys=("packed_prompts", "task_ids"),
output_keys=("packed_input_ids",),
@ -79,8 +77,6 @@ class MathCodeEvalConfig(MathCodeEvalOptions, CommonExperimentConfig):
mb_spec=self.rew_inf.mb_spec,
interface_type=ModelInterfaceType.INFERENCE,
interface_impl=rw_interface,
model_type=self.rew.type,
model_path=self.rew.path,
min_n_seqs_per_pass=1 / self.group_size,
input_keys=("packed_input_ids", "packed_prompts", "task_ids"),
output_keys=("rewards",),

View File

@ -39,8 +39,6 @@ class NullSFTConfig(CommonExperimentConfig, SFTExperimentOptions):
model_name="default",
input_keys=("packed_input_ids", "prompt_mask"),
log_return_value=True,
model_type=self.model.type,
model_path=self.model.path,
)
return {"trainDefault": rpc}
@ -88,8 +86,6 @@ class NullPPOConfig(CommonExperimentConfig, NullPPOExperimentOptions):
model_name="default",
input_keys=("packed_prompts",),
output_keys=("rewards",),
model_type=self.model.type,
model_path=self.model.path,
)
rpc = MFCDef(
n_seqs=self.dataset.train_bs_n_seqs,
@ -100,8 +96,6 @@ class NullPPOConfig(CommonExperimentConfig, NullPPOExperimentOptions):
model_name="default",
input_keys=("packed_prompts", "rewards"),
log_return_value=True,
model_type=self.model.type,
model_path=self.model.path,
)
return {"trainDefault": rpc, "reward": rw}

View File

@ -148,32 +148,31 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
"packed_logprobs",
"prompt_mask",
]
if self.ppo.recompute_logprob:
if self.ppo.recompute_logprob and not self.ppo.use_decoupled_loss:
rollout_output_keys.remove("packed_logprobs")
rollout = MFCDef(
name="actor_gen",
model_name="actor",
mb_spec=self.actor_gen.mb_spec,
interface_type=ModelInterfaceType.GENERATE,
model_type=self.actor.type,
model_path=self.actor.path,
interface_impl=actor_interface,
input_keys=("packed_prompts", "task_ids"),
output_keys=tuple(rollout_output_keys),
n_seqs=self.dataset.train_bs_n_seqs,
)
actor_inf_outputs = ("packed_logprobs",)
if self.ppo.use_decoupled_loss:
actor_inf_outputs = ("proximal_logprobs",)
actor_inf = MFCDef(
name="actor_inf",
model_name="actor",
mb_spec=self.actor_inf.mb_spec,
interface_type=ModelInterfaceType.INFERENCE,
model_type=self.actor.type,
model_path=self.actor.path,
interface_impl=actor_interface,
input_keys=("packed_input_ids",),
output_keys=("packed_logprobs",),
output_key_remap=dict(logprobs="packed_logprobs"),
output_keys=actor_inf_outputs,
output_key_remap=dict(logprobs=actor_inf_outputs[0]),
n_seqs=self.dataset.train_bs_n_seqs,
)
@ -200,8 +199,6 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
model_name="ref",
mb_spec=self.ref_inf.mb_spec,
interface_type=ModelInterfaceType.INFERENCE,
model_type=self.ref.type,
model_path=self.ref.path,
interface_impl=ref_interface,
min_n_seqs_per_pass=1 / self.group_size,
input_keys=tuple(inf_ref_inputs),
@ -216,8 +213,6 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
mb_spec=self.critic_inf.mb_spec,
interface_type=ModelInterfaceType.INFERENCE,
interface_impl=critic_interface,
model_type=self.critic.type,
model_path=self.critic.path,
min_n_seqs_per_pass=1 / self.group_size,
input_keys=("packed_input_ids", "seq_no_eos_mask"),
output_keys=("values",),
@ -238,13 +233,13 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
train_actor_inputs.remove("values")
if self.ppo.kl_ctl == 0:
train_actor_inputs.remove("packed_ref_logprobs")
if self.ppo.use_decoupled_loss:
train_actor_inputs.append("proximal_logprobs")
train_actor = MFCDef(
name="actor_train",
model_name="actor",
mb_spec=self.actor_train.mb_spec,
interface_type=ModelInterfaceType.TRAIN_STEP,
model_type=self.actor.type,
model_path=self.actor.path,
interface_impl=actor_interface,
input_keys=tuple(train_actor_inputs),
log_return_value=True,
@ -269,8 +264,6 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
mb_spec=self.critic_train.mb_spec,
interface_type=ModelInterfaceType.TRAIN_STEP,
interface_impl=critic_interface,
model_type=self.critic.type,
model_path=self.critic.path,
input_keys=tuple(train_critic_inputs),
log_return_value=True,
min_n_seqs_per_pass=self.ppo.ppo_n_minibatches / self.group_size,
@ -289,7 +282,7 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
if self.ppo.disable_value:
rpcs.pop("critic_inf")
rpcs.pop("critic_train")
if not self.ppo.recompute_logprob:
if not self.ppo.recompute_logprob and not self.ppo.use_decoupled_loss:
rpcs.pop("actor_inf")
if self.ppo.kl_ctl == 0:
rpcs.pop("ref_inf")
@ -311,7 +304,7 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
if self.ppo.disable_value:
allocs.pop("critic_inf")
allocs.pop("critic_train")
if not self.ppo.recompute_logprob:
if not self.ppo.recompute_logprob and not self.ppo.use_decoupled_loss:
allocs.pop("actor_inf")
if self.ppo.kl_ctl == 0:
allocs.pop("ref_inf")
@ -337,14 +330,6 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
def tokenizer_name_or_path(self) -> str:
return self.actor.path
@property
def search_kwargs(self):
return {
"num_gen_tokens": self.ppo.gen.max_new_tokens,
"n_ppo_minibatches": self.ppo.ppo_n_minibatches,
"seq_len": self.dataset.max_prompt_len,
}
@property
def max_prompt_len(self):
return self.dataset.max_prompt_len

View File

@ -36,8 +36,6 @@ class SFTConfig(CommonExperimentConfig, SFTExperimentOptions):
model_name="default",
input_keys=("packed_input_ids", "prompt_mask"),
log_return_value=True,
model_type=self.model.type,
model_path=self.model.path,
)
return {"trainDefault": rpc}

View File

@ -34,8 +34,8 @@ from realhf.api.core.dfg import OffloadHook, ParamReallocHook
from realhf.api.quickstart.device_mesh import RPCAllocation
from realhf.base import logging
from realhf.base.topology import (
DataPipeModelParallelTopology,
PipeDataModelParallelTopology,
DataPipeTensorParallelTopology,
PipeDataTensorParallelTopology,
ProcessTopology,
)
@ -73,8 +73,8 @@ def get_topo(
max_prompt_len: Optional[int] = None,
) -> ProcessTopology:
if is_train:
return PipeDataModelParallelTopology(
num_mp=parallel.model_parallel_size,
return PipeDataTensorParallelTopology(
num_tp=parallel.tensor_parallel_size,
num_pp=parallel.pipeline_parallel_size,
num_dp=parallel.data_parallel_size,
sequence_parallel=parallel.use_sequence_parallel,
@ -82,8 +82,8 @@ def get_topo(
max_prompt_len=max_prompt_len,
gradient_accumulation_fusion=gradient_accumulation_fusion,
)
return DataPipeModelParallelTopology(
num_mp=parallel.model_parallel_size,
return DataPipeTensorParallelTopology(
num_tp=parallel.tensor_parallel_size,
num_pp=parallel.pipeline_parallel_size,
num_dp=parallel.data_parallel_size,
sequence_parallel=parallel.use_sequence_parallel,
@ -93,7 +93,7 @@ def get_topo(
def get_world_size(parallel: ParallelismConfig) -> int:
return (
parallel.model_parallel_size
parallel.tensor_parallel_size
* parallel.pipeline_parallel_size
* parallel.data_parallel_size
)
@ -247,9 +247,8 @@ class AllocationType(enum.Enum):
GLOBAL_HYBRID = 2
MANUAL = 3
HEURISTIC = 4
SEARCH = 5
DECOUPLED_SGLANG = 6
DECOUPLED_MOCK = 7
DECOUPLED_SGLANG = 5
DECOUPLED_MOCK = 6
@dataclasses.dataclass
@ -293,8 +292,6 @@ class AllocationMode:
return cls(AllocationType.MANUAL, None)
if allocation_mode == "heuristic":
return cls(AllocationType.HEURISTIC, None)
if allocation_mode == "search":
return cls(AllocationType.SEARCH, None)
alloc_3d = AllocationMode.extract_3d_alloc(allocation_mode)
alloc_hybrid = AllocationMode.extract_key_value_alloc(allocation_mode)

View File

@ -1 +1 @@
import realhf.impl.environment.math_single_step_env
import realhf.impl.environment.math_code_single_step_env

View File

@ -0,0 +1,75 @@
# Copyright 2025 Ant Group Inc.
import asyncio
import os
import re
from typing import List, Tuple
from functioncall.code.local_verify import code_verify as local_code_verify
from functioncall.code.verify import code_verify
from functioncall.math.verify import math_verify
from realhf.api.core.env_api import EnvironmentService, register_environment
from realhf.base import logging
from realhf.impl.dataset.math_code_dataset import load_metadata
from realhf.impl.dataset.math_parser import parse_lines_in_parallel
ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel
code_verify_call = code_verify if ENABLE_FUNCTION_CALL else local_code_verify
logger = logging.getLogger("Math Single Step Environment")
def extract_code(text, min_length=20):
code_pattern = r"(?i)```(?:python|py|cpp|CPP)?\s*\n?(.*?)\n?```"
code_blocks = re.findall(code_pattern, text, re.DOTALL)
valid_blocks = []
for block in code_blocks:
clean_block = block.strip()
if len(clean_block) < min_length:
continue
valid_blocks.append(clean_block)
if not valid_blocks:
# logger.warning(f"failed to extract python code from {text}")
return None
# return the last code block
return valid_blocks[-1]
class MathCodeSingleStepEnv(EnvironmentService):
def __init__(self, dataset_path: str):
self.id2info, _ = load_metadata(dataset_path)
async def reset(self, seed=None, options=None):
return None, {}
async def step(self, action: Tuple[str, List[str]]):
qid, answers = action
group_size = len(answers)
qid = qid.split("@")[0]
cur_task = self.id2info[qid]["task"]
if cur_task == "math":
format_rewards = await asyncio.to_thread(
math_verify_call,
self.id2info,
answers,
[qid for _ in range(group_size)],
)
elif cur_task == "code":
answers = [extract_code(x) for x in answers]
format_rewards = await asyncio.to_thread(
code_verify_call,
self.id2info,
answers,
[qid for _ in range(group_size)],
)
else:
raise NotImplementedError()
return None, format_rewards, True, False, {}
register_environment("math-code-single-step", MathCodeSingleStepEnv)

View File

@ -1,38 +0,0 @@
# Copyright 2025 Ant Group Inc.
import asyncio
import os
from typing import List, Tuple
from functioncall.math.verify import math_verify
from realhf.api.core.env_api import EnvironmentService, register_environment
from realhf.base import logging
from realhf.impl.dataset.math_code_dataset import load_metadata
from realhf.impl.dataset.math_parser import parse_lines_in_parallel
ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel
logger = logging.getLogger("Math Single Step Environment")
class MathSingleStepEnv(EnvironmentService):
def __init__(self, dataset_path: str):
self.id2info, _ = load_metadata(dataset_path)
async def reset(self, seed=None, options=None):
return None, {}
async def step(self, action: Tuple[str, List[str]]):
qid, answers = action
group_size = len(answers)
format_rewards = await asyncio.to_thread(
math_verify_call,
self.id2info,
answers,
[qid for _ in range(group_size)],
)
return None, format_rewards, True, False, {}
register_environment("math-single-step", MathSingleStepEnv)

View File

@ -13,7 +13,7 @@ import realhf.api.from_hf
import realhf.base.logging as logging
from realhf.api.core.model_api import HF_MODEL_FAMILY_REGISTRY
from realhf.base.importing import import_module
from realhf.base.pkg_version import is_version_less
from realhf.base.pkg_version import is_available, is_version_less
from realhf.impl.model.conversion.hf_registry import HFModelRegistry
from realhf.impl.model.nn.real_llm_api import ReaLModel
@ -27,8 +27,9 @@ import_module(os.path.join(_filepath, "nn"), _p)
# NOTE: skip importing vLLM for now to avoid an
# "invalid device context" issue for the 25.01 image
if is_version_less("vllm", "0.6.4"):
if is_available("vllm") and is_version_less("vllm", "0.6.4"):
import realhf.impl.model.backend.vllm
import realhf.impl.model.backend.inference
import realhf.impl.model.backend.megatron
import realhf.impl.model.backend.mock_train

View File

@ -62,7 +62,7 @@ class PipelinableInferenceEngine(model_api.PipelinableEngine):
f"num_layers(this stage)={self.module.num_layers} "
f"pp_size={constants.pipe_parallel_world_size()} "
f"dp_size={constants.data_parallel_world_size()} "
f"mp_size={constants.model_parallel_world_size()} "
f"tp_size={constants.tensor_parallel_world_size()} "
)
if constants.data_parallel_rank() == 0:
logger.info(

View File

@ -126,7 +126,7 @@ def megatron_ctx():
# Build the tensor model-parallel groups.
parallel_state._TENSOR_MODEL_PARALLEL_GROUP = g
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
g = constants.model_parallel_group()
g = constants.tensor_parallel_group()
parallel_state._TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = (
dist.get_process_group_ranks(g)
)
@ -155,7 +155,7 @@ def megatron_ctx():
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
# Build the tensor + context parallel groups
parallel_state._TENSOR_AND_CONTEXT_PARALLEL_GROUP = (
constants.model_parallel_group()
constants.tensor_parallel_group()
)
# Build expert parallel groups.
@ -173,7 +173,7 @@ def megatron_ctx():
)
else:
parallel_state._TENSOR_AND_EXPERT_PARALLEL_GROUP = (
constants.model_parallel_group()
constants.tensor_parallel_group()
)
parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP = (
constants.data_parallel_group()
@ -227,7 +227,7 @@ class MegatronEngine:
def _all_reduce_layernorm_grads(self):
if not (
constants.sequence_parallel() and constants.model_parallel_world_size() > 1
constants.sequence_parallel() and constants.tensor_parallel_world_size() > 1
):
return
real_model: ReaLModel = self.ddp.module
@ -255,7 +255,7 @@ class MegatronEngine:
assert all(x is not None for x in grads)
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, group=constants.model_parallel_group())
dist.all_reduce(coalesced, group=constants.tensor_parallel_group())
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
@ -362,7 +362,10 @@ class PipeTrainInstrSetForMegatron(PipeTrainInstrSet):
)
dist.all_reduce(grad_norm, group=constants.tp_and_pp_group())
grad_norm /= constants.tp_and_pp_world_size()
if constants.data_parallel_rank() == 0 and constants.model_parallel_rank() == 0:
if (
constants.data_parallel_rank() == 0
and constants.tensor_parallel_rank() == 0
):
logger.info(
f"Model name {constants.model_name()}, "
f"Pipeline rank {constants.pipe_parallel_rank()}. "
@ -539,7 +542,10 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
)
dist.all_reduce(grad_norm, group=constants.tp_and_pp_group())
grad_norm /= constants.tp_and_pp_world_size()
if constants.data_parallel_rank() == 0 and constants.model_parallel_rank() == 0:
if (
constants.data_parallel_rank() == 0
and constants.tensor_parallel_rank() == 0
):
logger.info(
f"Megatron backend update success? {update_successful}. "
f"Grad Norm: {grad_norm}. "
@ -700,6 +706,7 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
# Deleting models directly will not release the memory.
# We must disable hooks at first.
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
if self.ddp.use_distributed_optimizer and self.ddp.overlap_param_gather:
model.module.engine.ddp.disable_forward_pre_hook()
else:
optimizer = model.module.engine.optim
@ -726,7 +733,7 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
sd = optimizer.state_dict()
dp = constants.data_parallel_rank()
pp = constants.pipe_parallel_rank()
tp = constants.model_parallel_rank()
tp = constants.tensor_parallel_rank()
# HACK: (bowei) I'm not sure whether there's duplicated information.
torch.save(
sd, pathlib.Path(save_dir) / f"megatron_optim_sd_d{dp}p{pp}t{tp}.mckpt"
@ -742,7 +749,7 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
dp = constants.data_parallel_rank()
pp = constants.pipe_parallel_rank()
tp = constants.model_parallel_rank()
tp = constants.tensor_parallel_rank()
sd = torch.load(
pathlib.Path(load_dir) / f"megatron_optim_sd_d{dp}p{pp}t{tp}.mckpt"

View File

@ -82,7 +82,7 @@ def _split_and_prefill_pipe_input(
raise PipelineError(
"Partitioned seqlens are not equal across pipeline parallel ranks. "
f"Current rank (dp={constants.data_parallel_rank()},"
f"tp={constants.model_parallel_rank()},pp={constants.pipe_parallel_rank()}), "
f"tp={constants.tensor_parallel_rank()},pp={constants.pipe_parallel_rank()}), "
f"gathered batch seqlens={_batch_seqlen_all_gathered}, "
f"Have you ensured that the order of dataset across ranks is the same?",
)
@ -118,7 +118,7 @@ def _split_and_prefill_pipe_input(
total_len = (
packed_input_ids.shape[0]
if not constants.sequence_parallel()
else packed_input_ids.shape[0] // constants.model_parallel_world_size()
else packed_input_ids.shape[0] // constants.tensor_parallel_world_size()
)
mb_seq_lens.append(total_len)
return (x, ys)
@ -569,7 +569,7 @@ class PipeGenInstrSet:
"batch_lengths", micro_batch_id, remove=False
)
batch_length = (
batch_length // constants.model_parallel_world_size()
batch_length // constants.tensor_parallel_world_size()
if constants.sequence_parallel()
else batch_length
)

View File

@ -198,8 +198,8 @@ class SGLangGenerationEngine(PipelinableEngine):
hybrid_train: bool,
request_timeout: int = 1800,
):
if constants.model_parallel_rank() != 0:
dist.barrier(group=constants.model_parallel_cpu_group())
if constants.tensor_parallel_rank() != 0:
dist.barrier(group=constants.tensor_parallel_cpu_group())
return
# Start the serving process
self.server_proc = mp.Process(
@ -224,8 +224,8 @@ class SGLangGenerationEngine(PipelinableEngine):
if server_args_dict["enable_metrics"]:
dp_rank = constants.data_parallel_rank()
pp_rank = constants.pipe_parallel_rank()
mp_rank = constants.model_parallel_rank()
metric_server_name = f"d{dp_rank}p{pp_rank}m{mp_rank}"
tp_rank = constants.tensor_parallel_rank()
metric_server_name = f"d{dp_rank}p{pp_rank}t{tp_rank}"
key = names.metric_server(
constants.experiment_name(),
constants.trial_name(),
@ -243,7 +243,7 @@ class SGLangGenerationEngine(PipelinableEngine):
# offload weights/cache
self.hybrid_train = hybrid_train
dist.barrier(group=constants.model_parallel_cpu_group())
dist.barrier(group=constants.tensor_parallel_cpu_group())
def __del__(self):
if hasattr(self, "server_proc"):
@ -381,8 +381,8 @@ class SGLangGenerationEngine(PipelinableEngine):
"NOTE: passing in an arbitrary `min_new_tokens` will lead to a bug for SGLang v0.4.3 "
"because we force to skip_tokenizer_init."
)
if constants.model_parallel_rank() != 0:
dist.barrier(group=constants.model_parallel_cpu_group())
if constants.tensor_parallel_rank() != 0:
dist.barrier(group=constants.tensor_parallel_cpu_group())
return None, None, None
results = asyncio.run(
@ -393,12 +393,12 @@ class SGLangGenerationEngine(PipelinableEngine):
gconfig=gconfig,
)
)
dist.barrier(group=constants.model_parallel_cpu_group())
dist.barrier(group=constants.tensor_parallel_cpu_group())
return results
def update_weights_from_disk(self, path):
if constants.model_parallel_rank() != 0:
dist.barrier(group=constants.model_parallel_cpu_group())
if constants.tensor_parallel_rank() != 0:
dist.barrier(group=constants.tensor_parallel_cpu_group())
return
async def _fn():
@ -409,18 +409,17 @@ class SGLangGenerationEngine(PipelinableEngine):
await client.async_update_weights_from_disk(path)
asyncio.run(_fn())
dist.barrier(group=constants.model_parallel_cpu_group())
dist.barrier(group=constants.tensor_parallel_cpu_group())
@dataclasses.dataclass
class SGLangGenerationBackend(ModelBackend, SGLangConfig):
model_path: str = ""
dtype: str = "float16"
def _initialize(self, model: Model, spec: FinetuneSpec) -> Model:
if constants.pipe_parallel_world_size() != 1:
raise RuntimeError("SGLang does not support pipe parallel size > 1.")
if constants.model_parallel_world_size() > cluster.spec.n_gpus_per_node:
if constants.tensor_parallel_world_size() > cluster.spec.n_gpus_per_node:
raise RuntimeError(
"AReaL's SGLang integration does not support model parallel size > n_gpus_per_node."
)
@ -436,7 +435,13 @@ class SGLangGenerationBackend(ModelBackend, SGLangConfig):
) != len(datapack.flat2d(ports)):
dist.all_gather_object(
ports,
network.find_multiple_free_ports(2, low=20000, high=40000),
network.find_multiple_free_ports(
2,
low=10000,
high=60000,
experiment_name=constants.experiment_name(),
trial_name=constants.trial_name(),
),
group=constants.data_parallel_group(),
)
api_server_port, dist_port = ports[constants.data_parallel_rank()]
@ -450,13 +455,12 @@ class SGLangGenerationBackend(ModelBackend, SGLangConfig):
tokenizer_mode="auto",
load_format="auto",
trust_remote_code=True,
kv_cache_dtype="auto",
device="cuda",
served_model_name=f"{constants.experiment_name()}/{constants.trial_name()}/{constants.model_name().role}",
is_embedding=False,
skip_tokenizer_init=True,
# Other runtime options
tp_size=constants.model_parallel_world_size(),
tp_size=constants.tensor_parallel_world_size(),
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
base_gpu_id=int(os.environ["CUDA_VISIBLE_DEVICES"]),
file_storage_path=os.path.join(

View File

@ -36,7 +36,7 @@ def _vllm_group_rank(group_type: _vLLMGroupType):
if group_type == _vLLMGroupType.WORLD:
return constants.tp_and_pp_rank()
elif group_type == _vLLMGroupType.TP:
return constants.model_parallel_rank()
return constants.tensor_parallel_rank()
elif group_type == _vLLMGroupType.PP:
return constants.pipe_parallel_rank()
@ -45,7 +45,7 @@ def _vllm_group_size(group_type: _vLLMGroupType):
if group_type == _vLLMGroupType.WORLD:
return constants.tp_and_pp_world_size()
elif group_type == _vLLMGroupType.TP:
return constants.model_parallel_world_size()
return constants.tensor_parallel_world_size()
elif group_type == _vLLMGroupType.PP:
return constants.pipe_parallel_world_size()
@ -54,7 +54,7 @@ def _vllm_parallel_group(group_type: _vLLMGroupType):
if group_type == _vLLMGroupType.WORLD:
return constants.tp_and_pp_group()
elif group_type == _vLLMGroupType.TP:
return constants.model_parallel_group()
return constants.tensor_parallel_group()
elif group_type == _vLLMGroupType.PP:
return constants.pipe_parallel_group()

View File

@ -213,7 +213,7 @@ class GPUExecutor_(GPUExecutor):
tok = time.perf_counter()
after_mem = float(pynvml.nvmlDeviceGetMemoryInfo(handle).used)
is_dp_head = (
constants.is_last_pipe_stage() and constants.model_parallel_rank() == 0
constants.is_last_pipe_stage() and constants.tensor_parallel_rank() == 0
)
if is_dp_head:
logger.info(
@ -241,7 +241,7 @@ class GPUExecutor_(GPUExecutor):
tok = time.perf_counter()
after_mem = float(pynvml.nvmlDeviceGetMemoryInfo(handle).used)
is_dp_head = (
constants.is_last_pipe_stage() and constants.model_parallel_rank() == 0
constants.is_last_pipe_stage() and constants.tensor_parallel_rank() == 0
)
if is_dp_head:
logger.info(

View File

@ -166,7 +166,6 @@ class vLLMGenerationEngine(model_api.PipelinableEngine, LLM):
@dataclasses.dataclass
class vLLMGenerationBackend(vLLMConfig, model_api.ModelBackend):
model_path: str = ""
dtype: str = "bfloat16"
def _initialize(
self, model: model_api.Model, spec: model_api.FinetuneSpec
@ -192,7 +191,7 @@ class vLLMGenerationBackend(vLLMConfig, model_api.ModelBackend):
kv_cache_dtype=self.kv_cache_type,
device=constants.current_device(),
# Parallelism.
tensor_parallel_size=constants.model_parallel_world_size(),
tensor_parallel_size=constants.tensor_parallel_world_size(),
pipeline_parallel_size=constants.pipe_parallel_world_size(),
# KV cahce and scheduling.
num_scheduler_steps=self.num_scheduler_steps,

View File

@ -100,7 +100,7 @@ def setup_global_comm(
if worker_index == 0:
host_ip = socket.gethostbyname(socket.gethostname())
port = network.find_free_port()
port = network.find_free_port(experiment_name=expr_name, trial_name=trial_name)
pg_init_addr = f"tcp://{host_ip}:{port}"
name_resolve.add(pg_master_name, pg_init_addr, keepalive_ttl=300)
else:

View File

@ -43,10 +43,10 @@ def is_trainable(model_name: ModelName) -> bool:
class ParamReallocPair:
src: ModelName
src_dp_rank: int
src_mp_rank: int
src_tp_rank: int
src_pp_rank: int
dst: ModelName
dst_mp_rank: int
dst_tp_rank: int
dst_pp_rank: int
@ -171,32 +171,34 @@ def _create_param_realloc_groups(
range(from_topo.get_dim("pipe")), range(to_topo.get_dim("pipe"))
):
# create tensor reshard groups
src_mp_size = from_topo.get_dim("model")
dst_mp_size = to_topo.get_dim("model")
src_tp_size = from_topo.get_dim("tensor")
dst_tp_size = to_topo.get_dim("tensor")
for mp_j in range(dst_mp_size):
for tp_j in range(dst_tp_size):
_all_dst_ranks = filter_match_mwids(
dst, to_topo, msid2mwid, pipe=pp_j, model=mp_j
dst, to_topo, msid2mwid, pipe=pp_j, tensor=tp_j
)
if src_mp_size > dst_mp_size:
factor = src_mp_size // dst_mp_size
mp_is = list(range(factor * mp_j, factor * (mp_j + 1)))
if src_tp_size > dst_tp_size:
factor = src_tp_size // dst_tp_size
tp_is = list(range(factor * tp_j, factor * (tp_j + 1)))
_all_src_ranks = [
filter_match_mwids(src, from_topo, msid2mwid, model=mp_i, pipe=pp_i)
for mp_i in mp_is
filter_match_mwids(
src, from_topo, msid2mwid, tensor=tp_i, pipe=pp_i
)
for tp_i in tp_is
]
else:
factor = dst_mp_size // src_mp_size
factor = dst_tp_size // src_tp_size
_all_src_ranks = [
filter_match_mwids(
src,
from_topo,
msid2mwid,
model=mp_j // factor,
tensor=tp_j // factor,
pipe=pp_i,
)
]
# All GPUs in _src_ranks have the data required by (pp_j, mp_j)
# All GPUs in _src_ranks have the data required by (pp_j, tp_j)
for _src_ranks in _all_src_ranks:
# NOTE: inter-node communication cost is significantly larger than intra-node communication cost.
# We only select one sender per host/node to prevent multiple senders occupying the same network bandwidth.
@ -209,42 +211,42 @@ def _create_param_realloc_groups(
)
_idle_src_ranks = [r for r in _src_ranks if r not in assignment]
for _src_rank in _idle_src_ranks:
dp_i, mp_i = (
dp_i, tp_i = (
from_topo.get_coord(
mwid2msid[_src_rank][src].parallelism_rank
).data,
from_topo.get_coord(
mwid2msid[_src_rank][src].parallelism_rank
).model,
).tensor,
)
key = ParamReallocPair(
src=src,
src_dp_rank=dp_i,
src_mp_rank=mp_i,
src_tp_rank=tp_i,
src_pp_rank=pp_i,
dst=dst,
dst_mp_rank=mp_j,
dst_tp_rank=tp_j,
dst_pp_rank=pp_j,
)
param_realloc_dst_ranks[key] = []
param_realloc_groups[key] = None
param_realloc_src_ranks[key] = _src_rank
for _src_rank, _dst_ranks in assignment.items():
dp_i, mp_i = (
dp_i, tp_i = (
from_topo.get_coord(
mwid2msid[_src_rank][src].parallelism_rank
).data,
from_topo.get_coord(
mwid2msid[_src_rank][src].parallelism_rank
).model,
).tensor,
)
key = ParamReallocPair(
src=src,
src_dp_rank=dp_i,
src_mp_rank=mp_i,
src_tp_rank=tp_i,
src_pp_rank=pp_i,
dst=dst,
dst_mp_rank=mp_j,
dst_tp_rank=tp_j,
dst_pp_rank=pp_j,
)
param_realloc_dst_ranks[key] = _dst_ranks
@ -315,8 +317,8 @@ def setup_param_realloc(
@dataclasses.dataclass
class ReparallelizeSenderStep:
rank: int
sender_mp_portion_id: int
receiver_mp_portion_id: int
sender_tp_portion_id: int
receiver_tp_portion_id: int
param_keys: List[str]
param_intervals_cpu: List[Tuple[int, int]]
param_intervals_cuda: torch.Tensor
@ -330,8 +332,8 @@ class ReparallelizeSenderStep:
@dataclasses.dataclass
class ReparallelizeReceiverStep:
rank: int
sender_mp_portion_id: int
receiver_mp_portion_id: int
sender_tp_portion_id: int
receiver_tp_portion_id: int
sender_param_intervals_cpu: List[Tuple[int, int]]
sender_param_intervals_cuda: torch.Tensor
sender_max_interval_size: int
@ -356,9 +358,9 @@ def _derive_reparallelize_comm_plan(
pg_info: ParamReallocInfo,
dtype: Optional[torch.dtype] = torch.float16,
) -> List[ReparallelizeReceiverStep | ReparallelizeSenderStep]:
src_mp_size = from_topo.get_dim("model")
dst_mp_size = to_topo.get_dim("model")
assert src_mp_size % dst_mp_size == 0 or dst_mp_size % src_mp_size == 0
src_tp_size = from_topo.get_dim("tensor")
dst_tp_size = to_topo.get_dim("tensor")
assert src_tp_size % dst_tp_size == 0 or dst_tp_size % src_tp_size == 0
for k, v in dataclasses.asdict(to_model_config).items():
if k not in ["is_critic"] and v != getattr(from_model_config, k):
raise ValueError(
@ -366,8 +368,8 @@ def _derive_reparallelize_comm_plan(
f"value in checkpoint is `{v}`, current value is `{getattr(from_model_config, k)}`)."
)
if from_model_config.n_kv_heads > 1 and (
from_model_config.n_kv_heads % src_mp_size == 0
) != (from_model_config.n_kv_heads % dst_mp_size == 0):
from_model_config.n_kv_heads % src_tp_size == 0
) != (from_model_config.n_kv_heads % dst_tp_size == 0):
raise ValueError("Whether to partition kv heads should remain the same.")
from_layer_mapping = partition_pipeline_layers(
@ -400,7 +402,7 @@ def _derive_reparallelize_comm_plan(
from_model_param_specs, _ = build_param_spec(
from_layer_indices,
from_model_config,
mp_size=from_topo.get_dim("model"),
tp_size=from_topo.get_dim("tensor"),
dp_size=from_topo.get_dim("data"),
pp_size=from_topo.get_dim("pipe"),
head_param_point_to_embedding=from_model_head_param_point_to_embedding,
@ -411,7 +413,7 @@ def _derive_reparallelize_comm_plan(
to_model_param_specs, _ = build_param_spec(
to_layer_indices,
to_model_config,
mp_size=to_topo.get_dim("model"),
tp_size=to_topo.get_dim("tensor"),
pp_size=to_topo.get_dim("pipe"),
dp_size=to_topo.get_dim("data"),
head_param_point_to_embedding=to_model_head_param_point_to_embedding,
@ -428,25 +430,25 @@ def _derive_reparallelize_comm_plan(
if len(layer_indices) == 0:
continue
for mp_i in range(src_mp_size):
if dst_mp_size > src_mp_size:
factor = dst_mp_size // src_mp_size
mp_js = [i + factor * mp_i for i in range(factor)]
receiver_mp_portion_id = 0
for tp_i in range(src_tp_size):
if dst_tp_size > src_tp_size:
factor = dst_tp_size // src_tp_size
tp_js = [i + factor * tp_i for i in range(factor)]
receiver_tp_portion_id = 0
else:
factor = src_mp_size // dst_mp_size
mp_js = [mp_i // factor]
receiver_mp_portion_id = mp_i % factor
for sender_mp_portion_id, mp_j in enumerate(mp_js):
factor = src_tp_size // dst_tp_size
tp_js = [tp_i // factor]
receiver_tp_portion_id = tp_i % factor
for sender_tp_portion_id, tp_j in enumerate(tp_js):
for dp_i in range(src_dp_size):
key = ParamReallocPair(
src=from_model_name,
src_dp_rank=dp_i,
src_mp_rank=mp_i,
src_tp_rank=tp_i,
src_pp_rank=pp_i,
dst=to_model_name,
dst_mp_rank=mp_j,
dst_tp_rank=tp_j,
dst_pp_rank=pp_j,
)
src = pg_info.param_realloc_src_ranks[key]
@ -462,10 +464,10 @@ def _derive_reparallelize_comm_plan(
)
param_size = param_size_from_keys(
config=from_model_config,
src_mp_size=src_mp_size,
src_tp_size=src_tp_size,
sd_keys=param_keys,
src2dst_tp_size=max(dst_mp_size // src_mp_size, 1),
src2dst_tp_rank=sender_mp_portion_id,
src2dst_tp_size=max(dst_tp_size // src_tp_size, 1),
src2dst_tp_rank=sender_tp_portion_id,
head_param_point_to_embedding=from_model_head_param_point_to_embedding,
)
if torch.distributed.is_initialized():
@ -474,11 +476,11 @@ def _derive_reparallelize_comm_plan(
param_intervals_cpu = param_intervals_from_keys(
model_name=from_model_name,
config=from_model_config,
mp_size=src_mp_size,
tp_size=src_tp_size,
param_spec=from_model_param_specs,
sd_keys=param_keys,
portion_size=max(dst_mp_size // src_mp_size, 1),
portion_rank=sender_mp_portion_id,
portion_size=max(dst_tp_size // src_tp_size, 1),
portion_rank=sender_tp_portion_id,
head_param_point_to_embedding=from_model_head_param_point_to_embedding,
)
param_intervals_cuda = torch.tensor(
@ -493,11 +495,11 @@ def _derive_reparallelize_comm_plan(
receiver_param_intervals_cpu = param_intervals_from_keys(
model_name=to_model_name,
config=to_model_config,
mp_size=dst_mp_size,
tp_size=dst_tp_size,
param_spec=to_model_param_specs,
sd_keys=param_keys,
portion_size=max(src_mp_size // dst_mp_size, 1),
portion_rank=receiver_mp_portion_id,
portion_size=max(src_tp_size // dst_tp_size, 1),
portion_rank=receiver_tp_portion_id,
head_param_point_to_embedding=to_model_head_param_point_to_embedding,
)
receiver_param_intervals_cuda = torch.tensor(
@ -513,8 +515,8 @@ def _derive_reparallelize_comm_plan(
comm_plan.append(
ReparallelizeReceiverStep(
rank=dst_rank,
sender_mp_portion_id=sender_mp_portion_id,
receiver_mp_portion_id=receiver_mp_portion_id,
sender_tp_portion_id=sender_tp_portion_id,
receiver_tp_portion_id=receiver_tp_portion_id,
param_keys=param_keys,
sender_param_intervals_cpu=param_intervals_cpu,
sender_param_intervals_cuda=param_intervals_cuda,
@ -532,8 +534,8 @@ def _derive_reparallelize_comm_plan(
comm_plan.append(
ReparallelizeSenderStep(
rank=src,
sender_mp_portion_id=sender_mp_portion_id,
receiver_mp_portion_id=receiver_mp_portion_id,
sender_tp_portion_id=sender_tp_portion_id,
receiver_tp_portion_id=receiver_tp_portion_id,
param_keys=param_keys,
param_intervals_cpu=param_intervals_cpu,
param_intervals_cuda=param_intervals_cuda,

View File

@ -22,8 +22,8 @@ from realhf.base.saveload_utils import (
)
from realhf.impl.model.nn.real_llm_api import ReaLModel
from realhf.impl.model.nn.real_llm_parallel import (
mp_merge_key,
mp_partition_real_model_state_dict,
tp_merge_key,
tp_partition_real_model_state_dict,
)
logger = logging.getLogger("HF Registry")
@ -141,11 +141,11 @@ class HFModelRegistry:
partition_tik = time.perf_counter()
sd = {k: v for k, v in sd.items() if k in required_hf_sd_names}
sd = self.sd_from_hf_converter(sd, model.config)
psd = mp_partition_real_model_state_dict(
psd = tp_partition_real_model_state_dict(
sd,
model.config,
constants.model_parallel_world_size(),
constants.model_parallel_rank(),
constants.tensor_parallel_world_size(),
constants.tensor_parallel_rank(),
)
return psd, partition_tik - load_tik, time.perf_counter() - partition_tik
@ -222,8 +222,8 @@ class HFModelRegistry:
dp_rank = constants.data_parallel_rank()
pp_rank = constants.pipe_parallel_rank()
mp_rank = constants.model_parallel_rank()
mp_size = constants.model_parallel_world_size()
tp_rank = constants.tensor_parallel_rank()
tp_size = constants.tensor_parallel_world_size()
pp_size = constants.pipe_parallel_world_size()
dp_size = constants.data_parallel_world_size()
@ -234,7 +234,7 @@ class HFModelRegistry:
# of each pipeline stage into smaller shards.
approx_param_size = (
sum(v.numel() * v.element_size() for v in model.state_dict().values())
* mp_size
* tp_size
)
# By default a shard is at most 1GB. A small size enables parallel saving during training.
@ -274,9 +274,9 @@ class HFModelRegistry:
and k == f"{model.config.n_layers + 1}.weight"
):
continue
gather_list = [torch.zeros_like(v) for _ in range(mp_size)]
dist.all_gather(gather_list, v, group=constants.model_parallel_group())
gathered = mp_merge_key(k, gather_list, model.config)
gather_list = [torch.zeros_like(v) for _ in range(tp_size)]
dist.all_gather(gather_list, v, group=constants.tensor_parallel_group())
gathered = tp_merge_key(k, gather_list, model.config)
cpu_sd[k] = gathered.cpu()
t2 = time.perf_counter()
@ -299,7 +299,7 @@ class HFModelRegistry:
param_size = param_size.item()
# Save tokenizer and huggingface model config.
if pp_rank == 0 and dp_rank == 0 and mp_rank == 0:
if pp_rank == 0 and dp_rank == 0 and tp_rank == 0:
hf_config.save_pretrained(save_dir)
if tokenizer is not None:
tokenizer.save_pretrained(save_dir)
@ -307,7 +307,7 @@ class HFModelRegistry:
# Dump parameters to disk.
if len(pp_stage_n_shards) == 1 and pp_stage_n_shards[0] == 1:
fn = "pytorch_model.bin"
if pp_rank == 0 and dp_rank == 0 and mp_rank == 0:
if pp_rank == 0 and dp_rank == 0 and tp_rank == 0:
torch.save(hf_sd, os.path.join(save_dir, fn))
else:
output_fn = (
@ -326,8 +326,8 @@ class HFModelRegistry:
bin_index["weight_map"] = {}
weight_map = {}
mesh_size = dp_size * mp_size
mesh_idx = dp_rank * mp_size + mp_rank
mesh_size = dp_size * tp_size
mesh_idx = dp_rank * tp_size + tp_rank
n_shards_per_gpu = (n_shards + mesh_size - 1) // mesh_size
if mesh_idx < len(range(0, n_shards, n_shards_per_gpu)):
s = list(range(0, n_shards, n_shards_per_gpu))[mesh_idx]
@ -357,7 +357,7 @@ class HFModelRegistry:
for wm in weight_map_list:
bin_index["weight_map"].update(wm)
if pp_rank == 0 and dp_rank == 0 and mp_rank == 0:
if pp_rank == 0 and dp_rank == 0 and tp_rank == 0:
with open(
os.path.join(save_dir, "pytorch_model.bin.index.json"), "w"
) as f:

View File

@ -240,7 +240,7 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
return data
local_rank = constants.grid().topo.get_rank(
data=constants.data_parallel_rank(),
model=0,
tensor=0,
pipe=constants.pipe_parallel_world_size() - 1,
)
dst = constants.to_global_pg_rank(local_rank)

View File

@ -3,7 +3,7 @@
# Licensed under the Apache License, Version 2.0 (the "License").
import dataclasses
from typing import Dict, Literal, Optional
from typing import Dict, List, Literal, Optional
import torch
import torch.distributed as dist
@ -86,6 +86,7 @@ def _ppo_actor_loss_from_model_outputs(
eps_clip=eps_clip,
loss_mask=ppo_loss_mask,
c_clip=c_clip,
proximal_logprobs=input_.data.get("prox_logp", None),
)
# Log training statistics
@ -106,13 +107,20 @@ def _ppo_actor_loss_from_model_outputs(
dual_clip_ratio=ppo_stat["dual_clip_mask"].float(),
denominator="n_valid_tokens",
)
if "behave_imp_weight" in ppo_stat:
stats_tracker.denominator(unclipped_behave_tokens=ppo_stat["behave_mask"])
stats_tracker.stat(
behave_imp_weight=ppo_stat["behave_imp_weight"],
behave_approx_kl=ppo_stat["behave_approx_kl"],
denominator="unclipped_behave_tokens",
)
vocab_min_logits = logits.detach().min(-1).values.float()
vocab_max_logits = logits.detach().max(-1).values.float()
dist.all_reduce(
vocab_min_logits, group=constants.model_parallel_group(), op=dist.ReduceOp.MIN
vocab_min_logits, group=constants.tensor_parallel_group(), op=dist.ReduceOp.MIN
)
dist.all_reduce(
vocab_max_logits, group=constants.model_parallel_group(), op=dist.ReduceOp.MAX
vocab_max_logits, group=constants.tensor_parallel_group(), op=dist.ReduceOp.MAX
)
stats_tracker.stat(
vocab_min_logits=vocab_min_logits,
@ -505,7 +513,7 @@ class PPOActorInterface(model_api.ModelInterface):
model: model_api.Model,
input_: SequenceSample,
mb_spec: MicroBatchSpec,
) -> Dict:
) -> Dict | List[Dict]:
module = model.module
# We call module.eval() because dropout causes the computation of incorrect of log probs.
module.eval()
@ -656,15 +664,20 @@ class PPOActorInterface(model_api.ModelInterface):
advantages = torch.cat(adv_list, 0)
# Prepare data to be splitted into mini-batches.
flat_input = SequenceSample.from_default(
ids=list(range(input_.bs * self.group_size)),
data=dict(
flat_data = dict(
advantages=advantages,
old_logp=old_logp,
ppo_loss_mask=loss_mask,
packed_input_ids=input_.data["packed_input_ids"],
kl_rewards=kl_rewards,
),
)
use_prox_logp = "proximal_logprobs" in input_.data
if use_prox_logp:
flat_data["prox_logp"] = input_.data["proximal_logprobs"].float()
flat_input = SequenceSample.from_default(
ids=list(range(input_.bs * self.group_size)),
data=flat_data,
seqlens=[int(x) for x in input_lens.cpu().numpy().tolist()],
)
@ -672,6 +685,7 @@ class PPOActorInterface(model_api.ModelInterface):
dense_reward_score = dense_reward_score[shift_one_indices]
### Logging code starts. ###
all_stats = []
with stats_tracker.scope("ppo_actor"):
assert (
task_ids.shape == reward_score.shape
@ -682,12 +696,13 @@ class PPOActorInterface(model_api.ModelInterface):
for idx, task in enumerate(RL_TASKS)
}
stats_tracker.denominator(
global_denominators = dict(
n_seqs=torch.ones_like(reward_score, dtype=torch.bool),
n_tokens=torch.ones_like(prompt_mask, dtype=torch.bool),
n_valid_tokens=loss_mask.bool(),
**task_denominators,
)
stats_tracker.denominator(**global_denominators)
for task in RL_TASKS:
stats_tracker.stat(
@ -721,6 +736,22 @@ class PPOActorInterface(model_api.ModelInterface):
**seq_stats,
denominator="n_seqs",
)
scalars = dict(
disable_value=self.disable_value,
mask_no_eos_with_zero=self.mask_no_eos_with_zero,
eps_clip=self.eps_clip,
use_prox_logp=use_prox_logp,
)
if self.c_clip is not None:
scalars["c_clip"] = self.c_clip
scalars["use_dual_clip"] = 1
else:
scalars["use_dual_clip"] = 0
stats_tracker.scalar(**scalars)
global_stats = stats_tracker.export()
for k in global_denominators:
global_stats.pop(f"ppo_actor/{k}")
# Run mini-batched PPO training!
def _loss_fn(logits, input_):
@ -736,7 +767,6 @@ class PPOActorInterface(model_api.ModelInterface):
)
for reuse in range(self.sample_reuse):
with stats_tracker.scope(f"reuse{reuse}"):
# NOTE: We split PPO minibatches in terms of #seqs instead of #tokens.
flat_input = SequenceSample.shuffled(flat_input)
bs = flat_input.bs
@ -751,7 +781,6 @@ class PPOActorInterface(model_api.ModelInterface):
f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in datas]}"
)
for mb_i, data in enumerate(datas):
with stats_tracker.scope(f"mb{mb_i}"):
train_stat = module.train_batch(
input_=data,
mb_spec=mb_spec,
@ -763,16 +792,12 @@ class PPOActorInterface(model_api.ModelInterface):
token_normalize_scope=self.token_normalize_scope,
)
stats_tracker.scalar(**train_stat)
all_stats.append(stats_tracker.export())
stats_tracker.scalar(
disable_value=self.disable_value,
mask_no_eos_with_zero=self.mask_no_eos_with_zero,
c_clip=self.c_clip if self.c_clip is not None else float("nan"),
eps_clip=self.eps_clip,
)
model.inc_version()
all_stats[0].update(global_stats)
return stats_tracker.export()
return all_stats
# Mock methods for profiling only.
def _mock_inference(
@ -1033,7 +1058,7 @@ class PPOCriticInterface(model_api.ModelInterface):
model: model_api.Model,
input_: SequenceSample,
mb_spec: MicroBatchSpec,
) -> Dict:
) -> Dict | List[Dict]:
assert model.module.module.config.is_critic
if self.disable_value:

View File

@ -3,7 +3,7 @@
# Licensed under the Apache License, Version 2.0 (the "License").
import dataclasses
from typing import Dict, Literal
from typing import Dict, List, Literal
import torch
import torch.distributed as dist
@ -68,10 +68,10 @@ def compute_packed_sft_loss(
vocab_min_logits = logits.detach().min(-1).values.float()
vocab_max_logits = logits.detach().max(-1).values.float()
dist.all_reduce(
vocab_min_logits, group=constants.model_parallel_group(), op=dist.ReduceOp.MIN
vocab_min_logits, group=constants.tensor_parallel_group(), op=dist.ReduceOp.MIN
)
dist.all_reduce(
vocab_max_logits, group=constants.model_parallel_group(), op=dist.ReduceOp.MAX
vocab_max_logits, group=constants.tensor_parallel_group(), op=dist.ReduceOp.MAX
)
stats_tracker.stat(
vocab_min_logits=vocab_min_logits,
@ -88,7 +88,7 @@ class SFTInterface(model_api.ModelInterface):
def train_step(
self, model: model_api.Model, data: SequenceSample, mb_spec: MicroBatchSpec
) -> Dict:
) -> Dict | List[Dict]:
module = model.module
module.train()

View File

@ -10,14 +10,14 @@ import torch.utils.checkpoint
import realhf.base.constants as constants
import realhf.base.logging as logging
from realhf.impl.model.parallelism.model_parallel.modules import RowParallelLinear
from realhf.impl.model.parallelism.tensor_parallel.modules import RowParallelLinear
from realhf.impl.model.utils.functional import (
apply_rotary_varlen,
compute_varlen_position_indices,
torch_attn_func,
)
from .mlp import LayerNormQKVLinear
from .mlp import GemmaRMSNorm, LayerNormQKVLinear, LlamaRMSNorm
from .rotary import RotaryEmbedding
try:
@ -53,6 +53,8 @@ class CausalSelfAttentionLayer(nn.Module):
layer_norm_type: Optional[str] = None,
# opt applies layer norm after attn
do_layernorm_before: bool = True,
# qk layer norm (Qwen3)
qk_layernorm: bool = False,
# rotary embedding
apply_rotary: bool = False,
rotary_base: float = 10000.0,
@ -67,7 +69,7 @@ class CausalSelfAttentionLayer(nn.Module):
super().__init__()
if dtype is None:
dtype = torch.float16
assert hidden_dim % head_dim == 0
assert hidden_dim % head_dim == 0, (hidden_dim, head_dim)
self.c_attn = LayerNormQKVLinear(
input_dim=hidden_dim,
head_dim=head_dim,
@ -82,7 +84,7 @@ class CausalSelfAttentionLayer(nn.Module):
layer_index=layer_index,
)
if constants.model_parallel_world_size() > 1:
if constants.tensor_parallel_world_size() > 1:
self.c_proj = RowParallelLinear(
n_q_heads * head_dim,
hidden_dim,
@ -100,6 +102,21 @@ class CausalSelfAttentionLayer(nn.Module):
device=device,
)
self.qk_layernorm = qk_layernorm
if qk_layernorm:
if layer_norm_type is None:
layer_norm_fn = nn.LayerNorm
elif layer_norm_type == "rms":
layer_norm_fn = LlamaRMSNorm
elif layer_norm_type == "gemma":
layer_norm_fn = GemmaRMSNorm
self.q_ln = layer_norm_fn(
head_dim, eps=layer_norm_epsilon, dtype=dtype, device=device
)
self.k_ln = layer_norm_fn(
head_dim, eps=layer_norm_epsilon, dtype=dtype, device=device
)
self.resid_dropout = nn.Dropout(resid_pdrop)
self.attn_pdrop = attn_pdrop
@ -173,6 +190,10 @@ class CausalSelfAttentionLayer(nn.Module):
q, k, v = self.c_attn(hidden_states)
if self.qk_layernorm:
q = self.q_ln(q)
k = self.k_ln(k)
if self.apply_rotary and (k_cache is None or str(q.device) == "cpu"):
# otherwise, we input rotary cos/sin directly into flash_attn_with_kvcache
rotary_cache_len = max_seqlen

View File

@ -8,7 +8,7 @@ import torch
import torch.nn as nn
from torch.nn import init
from realhf.impl.model.parallelism.model_parallel.modules import ParallelEmbedding
from realhf.impl.model.parallelism.tensor_parallel.modules import ParallelEmbedding
class OffsetPositionalEmbedding(nn.Embedding):

View File

@ -15,7 +15,7 @@ from transformers.activations import ACT2FN
import realhf.base.constants as constants
import realhf.base.logging as logging
from realhf.impl.model.parallelism.model_parallel.modules import (
from realhf.impl.model.parallelism.tensor_parallel.modules import (
ColumnParallelLinear,
RowParallelLinear,
merged_linear_with_grad_accumulation_and_async_allreduce,
@ -49,10 +49,10 @@ class LayerNormQKVLinear(nn.Module):
layer_index=None,
):
super().__init__()
model_parallel = constants.model_parallel_world_size() > 1
tensor_parallel = constants.tensor_parallel_world_size() > 1
sequence_parallel = constants.sequence_parallel()
gradient_accumulation_fusion = constants.gradient_accumulation_fusion()
if not model_parallel and (sequence_parallel or gradient_accumulation_fusion):
if not tensor_parallel and (sequence_parallel or gradient_accumulation_fusion):
global SEQUENCE_PARALLEL_WARNED
if not SEQUENCE_PARALLEL_WARNED:
logger.warning(
@ -73,16 +73,16 @@ class LayerNormQKVLinear(nn.Module):
input_dim, eps=layer_norm_epsilon, dtype=dtype, device=device
)
self.model_parallel = model_parallel
self.tensor_parallel = tensor_parallel
self.layer_index = layer_index
self.mp_worldsize = constants.model_parallel_world_size()
assert n_q_heads % self.mp_worldsize == 0, (
self.tp_worldsize = constants.tensor_parallel_world_size()
assert n_q_heads % self.tp_worldsize == 0, (
f"n_q_heads {n_q_heads} must be divisible by "
f"mp_worldsize {self.mp_worldsize}"
f"tp_worldsize {self.tp_worldsize}"
)
assert n_kv_heads % self.mp_worldsize == 0, (
assert n_kv_heads % self.tp_worldsize == 0, (
f"n_kv_heads {n_kv_heads} must be divisible by "
f"mp_worldsize {self.mp_worldsize}"
f"tp_worldsize {self.tp_worldsize}"
)
hidden_dim = input_dim
# TODO: we can fuse the forward of qkv attention
@ -141,9 +141,9 @@ class LayerNormQKVLinear(nn.Module):
self.v_attn.weight,
self.v_attn.bias,
)
q = q.view(*q.shape[:-1], self.nq // self.mp_worldsize, self.d)
k = k.view(*k.shape[:-1], self.nkv // self.mp_worldsize, self.d)
v = v.view(*v.shape[:-1], self.nkv // self.mp_worldsize, self.d)
q = q.view(*q.shape[:-1], self.nq // self.tp_worldsize, self.d)
k = k.view(*k.shape[:-1], self.nkv // self.tp_worldsize, self.d)
v = v.view(*v.shape[:-1], self.nkv // self.tp_worldsize, self.d)
return q, k, v
@ -163,10 +163,10 @@ class LayerNormMLP(nn.Module):
device: Optional[Union[str, torch.device]] = None,
):
super().__init__()
model_parallel = constants.model_parallel_world_size() > 1
tensor_parallel = constants.tensor_parallel_world_size() > 1
sequence_parallel = constants.sequence_parallel()
gradient_accumulation_fusion = constants.gradient_accumulation_fusion()
if not model_parallel and (sequence_parallel or gradient_accumulation_fusion):
if not tensor_parallel and (sequence_parallel or gradient_accumulation_fusion):
global SEQUENCE_PARALLEL_WARNED
if not SEQUENCE_PARALLEL_WARNED:
logger.warning(
@ -228,12 +228,12 @@ class LlamaLayerNormMLP(nn.Module):
device: Optional[Union[str, torch.device]] = None,
):
super().__init__()
self.model_parallel = constants.model_parallel_world_size() > 1
self.tensor_parallel = constants.tensor_parallel_world_size() > 1
gradient_accumulation_fusion = constants.gradient_accumulation_fusion()
self.is_expert = is_expert
# when used as experts the MLP always compute without sequence parallel
sequence_parallel = constants.sequence_parallel() and not is_expert
if not self.model_parallel and (
if not self.tensor_parallel and (
sequence_parallel or gradient_accumulation_fusion
):
global SEQUENCE_PARALLEL_WARNED
@ -418,13 +418,13 @@ if constants.use_te_impl():
eps=layer_norm_epsilon,
sequence_parallel=constants.sequence_parallel(),
return_bias=False,
tp_group=constants.model_parallel_group(),
tp_size=constants.model_parallel_world_size(),
tp_group=constants.tensor_parallel_group(),
tp_size=constants.tensor_parallel_world_size(),
bias=False,
normalization="RMSNorm",
activation="swiglu",
fuse_wgrad_accumulation=constants.gradient_accumulation_fusion(),
params_dtype=dtype,
set_parallel_mode=constants.model_parallel_world_size() > 1,
set_parallel_mode=constants.tensor_parallel_world_size() > 1,
device=device,
)

View File

@ -10,11 +10,11 @@ from torch.nn.parameter import Parameter
import realhf.base.constants as constants
from realhf.api.core.model_api import ReaLModelConfig
from realhf.impl.model.modules.mlp import LlamaLayerNormMLP, get_activation_fn
from realhf.impl.model.parallelism.model_parallel.mappings import (
from realhf.impl.model.parallelism.tensor_parallel.mappings import (
copy_to_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
)
from realhf.impl.model.parallelism.model_parallel.utils import divide
from realhf.impl.model.parallelism.tensor_parallel.utils import divide
from realhf.impl.model.utils.random import _initialize_affine_weight_gpu
try:
@ -125,7 +125,7 @@ class GroupedMLP(torch.nn.Module):
self.activation_func = get_activation_fn(self.config.activation_function)
# How many feature each rank holds for fc1 and fc2, respectively.
tp_size = constants.model_parallel_world_size()
tp_size = constants.tensor_parallel_world_size()
intermediate_dim_per_partition = divide(self.config.intermediate_dim, tp_size)
# Note: The current kernel implementations of grouped_gemm
@ -186,7 +186,7 @@ class GroupedMLP(torch.nn.Module):
):
tokens_per_expert = tokens_per_expert.cpu()
if permuted_local_hidden_states.nelement() != 0:
if constants.model_parallel_world_size() > 1:
if constants.tensor_parallel_world_size() > 1:
permuted_local_hidden_states = copy_to_tensor_model_parallel_region(
permuted_local_hidden_states
)
@ -208,7 +208,7 @@ class GroupedMLP(torch.nn.Module):
output = grouped_gemm.ops.gmm(
inter, self.grouped_down_proj, tokens_per_expert, trans_b=False
)
if constants.model_parallel_world_size() > 1:
if constants.tensor_parallel_world_size() > 1:
output = reduce_from_tensor_model_parallel_region(output)
else:
# No token is allocated for local experts.

View File

@ -8,7 +8,7 @@ import torch.nn.init as init
import realhf.base.constants as constants
from realhf.api.core.model_api import ReaLModelConfig
from realhf.impl.model.parallelism.model_parallel.mappings import (
from realhf.impl.model.parallelism.tensor_parallel.mappings import (
gather_from_sequence_parallel_region,
)
from realhf.impl.model.utils.moe import (
@ -117,10 +117,10 @@ class TopKRouter(torch.nn.Module):
torch.Tensor: The activation tensor with the attached gradient function.
"""
moe_aux_loss_coeff = self.config.moe.aux_loss_coeff
moe_aux_loss_coeff /= constants.model_parallel_world_size()
moe_aux_loss_coeff /= constants.tensor_parallel_world_size()
scale_for_logging = 1.0
if constants.sequence_parallel():
scale_for_logging *= constants.model_parallel_world_size()
scale_for_logging *= constants.tensor_parallel_world_size()
aux_loss = switch_load_balancing_loss_func(
probs,
@ -128,7 +128,7 @@ class TopKRouter(torch.nn.Module):
self.config.moe.top_k,
moe_aux_loss_coeff,
sequence_partition_group=(
constants.model_parallel_group()
constants.tensor_parallel_group()
if constants.sequence_parallel()
else None
),
@ -155,7 +155,7 @@ class TopKRouter(torch.nn.Module):
"""
if self.config.moe.z_loss_coeff > 0:
moe_z_loss_coeff = (
self.config.moe.z_loss_coeff / constants.model_parallel_world_size()
self.config.moe.z_loss_coeff / constants.tensor_parallel_world_size()
)
z_loss = z_loss_func(logits, moe_z_loss_coeff)
logits = MoEAuxLossAutoScaler.apply(logits, z_loss)

View File

@ -7,7 +7,7 @@ import torch
import realhf.base.constants as constants
from realhf.api.core.model_api import ReaLModelConfig
from realhf.impl.model.parallelism.model_parallel.mappings import (
from realhf.impl.model.parallelism.tensor_parallel.mappings import (
gather_from_sequence_parallel_region,
scatter_to_sequence_parallel_region,
)

View File

@ -23,8 +23,8 @@ from .real_llm_base import ReaLModelParamKeys
from .real_llm_parallel import (
get_real_model_param_shape,
intervals_partition_fn,
mp_partition_key,
shape_partition_fn,
tp_partition_key,
)
try:
@ -188,7 +188,7 @@ def set_intervals(
def param_size_from_keys(
config: model_api.ReaLModelConfig,
src_mp_size: int,
src_tp_size: int,
sd_keys: List[str],
src2dst_tp_size: int,
src2dst_tp_rank: int,
@ -202,9 +202,9 @@ def param_size_from_keys(
and "0.wte.weight" in sd_keys
):
continue
new_shape = mp_partition_key(
new_shape = tp_partition_key(
k,
get_real_model_param_shape(k, config, src_mp_size),
get_real_model_param_shape(k, config, src_tp_size),
src2dst_tp_rank,
src2dst_tp_size,
config,
@ -218,7 +218,7 @@ def build_param_spec(
layer_indices: List[int],
config: model_api.ReaLModelConfig,
dp_size: int,
mp_size: int,
tp_size: int,
pp_size: int,
head_param_point_to_embedding: bool,
bucket_size: int = 40000000,
@ -273,7 +273,7 @@ def build_param_spec(
if head_param_point_to_embedding and k == f"{config.n_layers + 1}.weight":
continue
shape = get_real_model_param_shape(k, config, mp_size)
shape = get_real_model_param_shape(k, config, tp_size)
numel = int(np.prod(shape))
data_end_index = data_start_index + numel
@ -307,14 +307,14 @@ def param_intervals_from_keys(
config: model_api.ReaLModelConfig,
head_param_point_to_embedding: bool,
param_spec: Dict[str, ContiguousParamSpec],
mp_size: int,
tp_size: int,
sd_keys: List[str],
portion_size: int,
portion_rank: int,
) -> List[int]:
param_size = param_size_from_keys(
config=config,
src_mp_size=mp_size,
src_tp_size=tp_size,
sd_keys=sd_keys,
src2dst_tp_size=portion_size,
src2dst_tp_rank=portion_rank,
@ -333,13 +333,13 @@ def param_intervals_from_keys(
if (
model_name,
k.split(".", 1)[1],
mp_size,
tp_size,
portion_rank,
portion_size,
) not in _FLAT_PARAM_INDICES_CACHE:
zero_start_intervals = mp_partition_key(
zero_start_intervals = tp_partition_key(
k,
get_real_model_param_shape(k, config, mp_size),
get_real_model_param_shape(k, config, tp_size),
portion_rank,
portion_size,
config,
@ -349,7 +349,7 @@ def param_intervals_from_keys(
(
model_name,
k.split(".", 1)[1],
mp_size,
tp_size,
portion_rank,
portion_size,
)
@ -359,7 +359,7 @@ def param_intervals_from_keys(
(
model_name,
k.split(".", 1)[1],
mp_size,
tp_size,
portion_rank,
portion_size,
)

View File

@ -167,7 +167,7 @@ class ReaLModel(nn.Module):
self._param_spec, self._param_size = build_param_spec(
list(range(self.layer_idx_start, self.layer_idx_end)),
self.config,
mp_size=constants.model_parallel_world_size(),
tp_size=constants.tensor_parallel_world_size(),
pp_size=constants.pipe_parallel_world_size(),
dp_size=constants.data_parallel_world_size(),
head_param_point_to_embedding=self.head_param_point_to_embedding,
@ -282,7 +282,7 @@ class ReaLModel(nn.Module):
device=device,
dtype=dtype,
)
elif not config.is_critic and constants.model_parallel_world_size() > 1:
elif not config.is_critic and constants.tensor_parallel_world_size() > 1:
l = ParallelActorHead(
config.hidden_dim,
config.vocab_size,
@ -428,14 +428,14 @@ class ReaLModel(nn.Module):
x.cu_seqlens = x.cu_seqlens.int()
# Copy input tensor to a pinned buffer.
mp_size = constants.model_parallel_world_size()
tp_size = constants.tensor_parallel_world_size()
batch_length = None
if ys[0].packed_input_ids is not None:
batch_length = ys[0].packed_input_ids.shape[0]
if x.pp_input is not None:
batch_length = x.pp_input.shape[0]
assert batch_length is not None
padded_batch_length = (batch_length + mp_size - 1) // mp_size * mp_size
padded_batch_length = (batch_length + tp_size - 1) // tp_size * tp_size
pad_size = padded_batch_length - batch_length
if (
@ -609,7 +609,7 @@ class ReaLModel(nn.Module):
to_param_spec, to_param_size = build_param_spec(
to_layer_indices,
to_model_config,
mp_size=to_topo.get_dim("model"),
tp_size=to_topo.get_dim("tensor"),
dp_size=to_topo.get_dim("data"),
pp_size=to_topo.get_dim("pipe"),
head_param_point_to_embedding=to_model_head_param_point_to_embedding,

View File

@ -17,7 +17,7 @@ import transformers
import realhf.base.constants as constants
import realhf.base.logging as logging
import realhf.impl.model.parallelism.model_parallel.mappings as tensor_parallel
import realhf.impl.model.parallelism.tensor_parallel.mappings as tensor_parallel
from realhf.api.core import model_api
from realhf.impl.model.modules import (
CausalSelfAttentionLayer,
@ -28,9 +28,8 @@ from realhf.impl.model.modules import (
LlamaRMSNorm,
OffsetParallelPositionalEmbedding,
OffsetPositionalEmbedding,
scatter_to_sequence_parallel_region,
)
from realhf.impl.model.parallelism.model_parallel.modules import (
from realhf.impl.model.parallelism.tensor_parallel.modules import (
ColumnParallelLinear,
ParallelEmbedding,
gather_from_sequence_parallel_region,
@ -139,6 +138,7 @@ class ReaLModelBlock(nn.Module):
use_attention_bias=config.use_attention_bias,
use_attn_proj_bias=config.use_attn_proj_bias,
do_layernorm_before=config.do_layernorm_before,
qk_layernorm=config.qk_layernorm,
apply_rotary=config.apply_rotary,
rotary_base=config.rotary_base,
rotary_interleaved=config.rotary_interleaved,
@ -281,8 +281,8 @@ class VocabPositionEmbedding(nn.Module):
self.n_positions = config.n_positions
self.hidden_dim = config.hidden_dim
model_parallel = constants.model_parallel_world_size() > 1
if model_parallel:
tensor_parallel = constants.tensor_parallel_world_size() > 1
if tensor_parallel:
embed_cls = ParallelEmbedding
else:
embed_cls = nn.Embedding
@ -295,7 +295,7 @@ class VocabPositionEmbedding(nn.Module):
if self.apply_abs_pos_embed:
p_embed_cls = (
OffsetParallelPositionalEmbedding
if model_parallel
if tensor_parallel
else OffsetPositionalEmbedding
)
self.wpe = p_embed_cls(
@ -416,7 +416,7 @@ class ParallelActorHead(ColumnParallelLinear):
def _forward(self, x: torch.Tensor):
weight = self.weight
if self._norm_head:
from realhf.impl.model.parallelism.model_parallel.mappings import (
from realhf.impl.model.parallelism.tensor_parallel.mappings import (
gather_from_sequence_parallel_region,
)
@ -431,7 +431,7 @@ class ParallelActorHead(ColumnParallelLinear):
).transpose(1, 0)
head_norm = unnormed_head.norm(dim=0, keepdim=True, p=2)
normed_head = unnormed_head / (head_norm + 1e-7)
weight = scatter_to_sequence_parallel_region(normed_head)
weight = tensor_parallel.scatter_to_sequence_parallel_region(normed_head)
output = parallel_lm_logits(
x,
@ -486,6 +486,12 @@ class ReaLModelParamKeys:
keys += [f"{idx + 1}.attn.c_proj.weight"]
if config.use_attn_proj_bias:
keys += [f"{idx + 1}.attn.c_proj.bias"]
if config.qk_layernorm:
keys += [f"{idx + 1}.attn.q_ln.weight"]
keys += [f"{idx + 1}.attn.k_ln.weight"]
if config.layer_norm_type is None:
keys += [f"{idx + 1}.attn.q_ln.bias"]
keys += [f"{idx + 1}.attn.k_ln.bias"]
keys += [f"{idx + 1}.mlp.ln.weight"]
if config.layer_norm_type is None:
keys += [f"{idx + 1}.mlp.ln.bias"]

View File

@ -55,8 +55,8 @@ def genstep(
unfinished_sequences: Bool tensor indicator of whether a sequence is finished.
Shape [bs].
"""
if constants.model_parallel_world_size() > 1:
from realhf.impl.model.parallelism.model_parallel.mappings import (
if constants.tensor_parallel_world_size() > 1:
from realhf.impl.model.parallelism.tensor_parallel.mappings import (
gather_from_tensor_model_parallel_region,
)
@ -95,20 +95,20 @@ def genstep(
next_tokens = distrb.mode if gconfig.greedy else distrb.sample()
logprob = distrb.log_prob(next_tokens)
if constants.model_parallel_world_size() > 1:
if constants.model_parallel_rank() > 0:
if constants.tensor_parallel_world_size() > 1:
if constants.tensor_parallel_rank() > 0:
logprob[:] = 0
next_tokens[:] = 0
handle = torch.distributed.all_reduce(
logprob,
torch.distributed.ReduceOp.SUM,
async_op=True,
group=constants.model_parallel_group(),
group=constants.tensor_parallel_group(),
)
torch.distributed.all_reduce(
next_tokens,
torch.distributed.ReduceOp.SUM,
group=constants.model_parallel_group(),
group=constants.tensor_parallel_group(),
)
if tokenizer.eos_token_id is not None:
@ -139,7 +139,7 @@ def genstep(
if not logits_mask.any():
logits_mask = None
if constants.model_parallel_world_size() > 1:
if constants.tensor_parallel_world_size() > 1:
handle.wait()
return next_tokens, logprob, logits_mask, terminate, unfinished_sequences

View File

@ -42,27 +42,27 @@ if constants.use_te_impl():
def tensor_slice_partition_fn(
tensor: torch.Tensor,
mp_rank: Optional[int],
mp_world_size: int,
tp_rank: Optional[int],
tp_world_size: int,
dim: Optional[int],
) -> Union[List[torch.Tensor], torch.Tensor]:
"""Partition a tensor by slicing along a dimension for tensor-model
parallelism."""
if dim is None:
splits = [tensor for _ in range(mp_world_size)]
splits = [tensor for _ in range(tp_world_size)]
else:
assert tensor.shape[dim] % mp_world_size == 0
splits = torch.split(tensor, tensor.shape[dim] // mp_world_size, dim=dim)
if mp_rank is None:
assert tensor.shape[dim] % tp_world_size == 0
splits = torch.split(tensor, tensor.shape[dim] // tp_world_size, dim=dim)
if tp_rank is None:
return [s.contiguous() for s in splits]
else:
return splits[mp_rank].contiguous()
return splits[tp_rank].contiguous()
def intervals_partition_fn(
shape: torch.Size,
mp_rank: Optional[int],
mp_world_size: int,
tp_rank: Optional[int],
tp_world_size: int,
dim: Optional[int],
) -> Union[List[torch.Tensor], torch.Tensor]:
"""Get the intervals of a MP-partitioned tensor in the flatten view.
@ -74,34 +74,34 @@ def intervals_partition_fn(
Used by parameter reallocation. Return a numpy array of shape [N,
2], where N is the number of intervals.
"""
assert mp_rank is not None
assert tp_rank is not None
param_size = int(np.prod(shape))
if dim is None:
return np.array([(0, param_size)], dtype=np.int64)
if dim < 0:
dim = len(shape) + dim
assert shape[dim] % mp_world_size == 0
assert shape[dim] % tp_world_size == 0
if len(shape) == 1:
assert dim == 0
partition_size = shape[0] // mp_world_size
partition_size = shape[0] // tp_world_size
return np.array(
[(partition_size * mp_rank, partition_size * (mp_rank + 1))],
[(partition_size * tp_rank, partition_size * (tp_rank + 1))],
dtype=np.int64,
)
else:
assert len(shape) == 2, shape
if dim == 0:
row_start = mp_rank * shape[0] // mp_world_size
row_end = (mp_rank + 1) * shape[0] // mp_world_size
row_start = tp_rank * shape[0] // tp_world_size
row_end = (tp_rank + 1) * shape[0] // tp_world_size
return np.array(
[(row_start * shape[1], row_end * shape[1])], dtype=np.int64
)
else:
assert dim == 1
col_start = mp_rank * shape[1] // mp_world_size
col_end = (mp_rank + 1) * shape[1] // mp_world_size
col_start = tp_rank * shape[1] // tp_world_size
col_end = (tp_rank + 1) * shape[1] // tp_world_size
return np.arange(shape[0], dtype=np.int64)[:, None] * shape[1] + np.array(
[(col_start, col_end)], dtype=np.int64
)
@ -109,32 +109,32 @@ def intervals_partition_fn(
def shape_partition_fn(
shape: torch.Size,
mp_rank: Optional[int],
mp_world_size: int,
tp_rank: Optional[int],
tp_world_size: int,
dim: Optional[int],
):
"""Get the partitioned shape of a tensor for tensor-model parallelism."""
if dim is None:
splits = [shape for _ in range(mp_world_size)]
splits = [shape for _ in range(tp_world_size)]
else:
if dim < 0:
dim = len(shape) + dim
assert shape[dim] % mp_world_size == 0
assert shape[dim] % tp_world_size == 0
splits = [
(*shape[:dim], shape[dim] // mp_world_size, *shape[dim + 1 :])
for _ in range(mp_world_size)
(*shape[:dim], shape[dim] // tp_world_size, *shape[dim + 1 :])
for _ in range(tp_world_size)
]
if mp_rank is None:
if tp_rank is None:
return [s for s in splits]
else:
return splits[mp_rank]
return splits[tp_rank]
def mp_partition_key(
def tp_partition_key(
key: str,
tensor_or_shape: torch.Tensor | torch.Size,
mp_rank: Optional[int],
mp_size: Optional[int],
tp_rank: Optional[int],
tp_size: Optional[int],
config: model_api.ReaLModelConfig,
partition_fn: Callable[
[torch.Tensor, Optional[int], int, Optional[int]],
@ -149,7 +149,7 @@ def mp_partition_key(
if any([ek in key for ek in EMBEDDING_KEYS]):
assert "weight" in key
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=0)
return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=0)
elif key == f"{config.n_layers + 1}.weight": # output head
if (
isinstance(tensor_or_shape, torch.Tensor) and tensor_or_shape.shape[0] == 1
@ -157,88 +157,90 @@ def mp_partition_key(
not isinstance(tensor_or_shape, torch.Tensor) and tensor_or_shape[0] == 1
):
assert config.is_critic
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=None)
return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=None)
else:
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=0)
return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=0)
elif any([ck in key for ck in COLUMN_LINEAR_KEYS]):
if (
("k_attn" in key) or ("v_attn" in key)
) and config.n_kv_heads % mp_size != 0:
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=None)
) and config.n_kv_heads % tp_size != 0:
return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=None)
# partition both weight and bias
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=0)
return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=0)
elif any([rk in key for rk in ROW_LINEAR_KEYS]):
# only partition weight
if "weight" in key:
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=1)
return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=1)
else:
assert "bias" in key, key
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=None)
return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=None)
else:
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=None)
return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=None)
def mp_partition_real_model_state_dict(
def tp_partition_real_model_state_dict(
state_dict: Dict[str, torch.Tensor],
config: model_api.ReaLModelConfig,
mp_size: int,
mp_rank: Optional[int] = None,
tp_size: int,
tp_rank: Optional[int] = None,
) -> Union[Dict, List[Dict]]:
"""A helper function to partition a state dict using `mp_partition_key`."""
if mp_size == 1:
if mp_rank is None:
"""A helper function to partition a state dict using `tp_partition_key`."""
if tp_size == 1:
if tp_rank is None:
return [state_dict]
else:
return state_dict
new_state_dict = {}
for k, v in state_dict.items():
new_state_dict[k] = mp_partition_key(k, v, mp_rank, mp_size, config)
new_state_dict[k] = tp_partition_key(k, v, tp_rank, tp_size, config)
if mp_rank is None:
if tp_rank is None:
return [
{k: v[mp_rank] for k, v in new_state_dict.items()}
for mp_rank in range(mp_size)
{k: v[tp_rank] for k, v in new_state_dict.items()}
for tp_rank in range(tp_size)
]
else:
return new_state_dict
def get_real_model_param_shape(
k: str, config: model_api.ReaLModelConfig, mp_size: int
k: str, config: model_api.ReaLModelConfig, tp_size: int
) -> Tuple:
if "wte.weight" in k:
assert config.vocab_size % mp_size == 0
return (config.vocab_size // mp_size, config.hidden_dim)
assert config.vocab_size % tp_size == 0
return (config.vocab_size // tp_size, config.hidden_dim)
elif "wpe.weight" in k:
assert config.n_positions % mp_size == 0
if (config.n_positions + config.abs_position_embedding_offset) % mp_size != 0:
assert config.n_positions % tp_size == 0
if (config.n_positions + config.abs_position_embedding_offset) % tp_size != 0:
raise ValueError(
f"The dimenstion of position embedding "
f"({config.n_positions} + offset {config.abs_position_embedding_offset}) "
f"is not divisible by mp_size ({mp_size}). "
f"is not divisible by tp_size ({tp_size}). "
"Models like this (e.g. OPT-350m) inherently do not support tensor parallelism."
)
return (
(config.n_positions + config.abs_position_embedding_offset) // mp_size,
(config.n_positions + config.abs_position_embedding_offset) // tp_size,
config.hidden_dim,
)
elif ".ln." in k or ".ln_f." in k:
return (config.hidden_dim,)
elif ".q_ln." in k or ".k_ln." in k:
return (config.head_dim,)
elif k == f"{config.n_layers + 1}.weight": # output head
if config.is_critic:
return (1, config.hidden_dim)
elif mp_size > 1:
assert config.vocab_size % mp_size == 0
return (config.vocab_size // mp_size, config.hidden_dim)
elif tp_size > 1:
assert config.vocab_size % tp_size == 0
return (config.vocab_size // tp_size, config.hidden_dim)
else:
return (config.vocab_size, config.hidden_dim)
elif any([ck in k for ck in COLUMN_LINEAR_KEYS]):
if "k_attn" in k or "v_attn" in k:
if "weight" in k:
if config.n_kv_heads % mp_size == 0:
if config.n_kv_heads % tp_size == 0:
return (
config.head_dim * config.n_kv_heads // mp_size,
config.head_dim * config.n_kv_heads // tp_size,
config.hidden_dim,
)
else:
@ -248,27 +250,27 @@ def get_real_model_param_shape(
)
else:
assert "bias" in k
if config.n_kv_heads % mp_size == 0:
return (config.head_dim * config.n_kv_heads // mp_size,)
if config.n_kv_heads % tp_size == 0:
return (config.head_dim * config.n_kv_heads // tp_size,)
else:
return (config.head_dim * config.n_kv_heads,)
if "mlp" in k:
if "weight" in k:
return (config.intermediate_dim // mp_size, config.hidden_dim)
return (config.intermediate_dim // tp_size, config.hidden_dim)
else:
assert "bias" in k
return (config.intermediate_dim // mp_size,)
return (config.intermediate_dim // tp_size,)
if "weight" in k:
assert config.n_q_heads % mp_size == 0
return (config.n_q_heads * config.head_dim // mp_size, config.hidden_dim)
assert config.n_q_heads % tp_size == 0
return (config.n_q_heads * config.head_dim // tp_size, config.hidden_dim)
else:
assert "bias" in k
return (config.n_q_heads * config.head_dim // mp_size,)
return (config.n_q_heads * config.head_dim // tp_size,)
elif any([rk in k for rk in ROW_LINEAR_KEYS]):
if "mlp" in k and "weight" in k:
return (config.hidden_dim, config.intermediate_dim // mp_size)
return (config.hidden_dim, config.intermediate_dim // tp_size)
elif "attn" in k and "weight" in k:
return (config.hidden_dim, config.n_q_heads * config.head_dim // mp_size)
return (config.hidden_dim, config.n_q_heads * config.head_dim // tp_size)
elif "bias" in k:
return (config.hidden_dim,)
else:
@ -280,7 +282,7 @@ def get_real_model_param_shape(
raise NotImplementedError(f"unkown shape of key {k}.")
def mp_merge_key(
def tp_merge_key(
k: str,
tensors: List[torch.Tensor],
config: model_api.ReaLModelConfig,
@ -297,17 +299,17 @@ def mp_merge_key(
return tensors[0]
def mp_merge_real_model_state_dict(
def tp_merge_real_model_state_dict(
state_dicts: List[Dict[str, torch.Tensor]],
config: model_api.ReaLModelConfig,
) -> Dict:
mp_size = len(state_dicts)
if mp_size == 1:
tp_size = len(state_dicts)
if tp_size == 1:
return state_dicts[0]
new_state_dict = {}
for k in state_dicts[0].keys():
new_state_dict[k] = mp_merge_key(k, [sd[k] for sd in state_dicts], config)
new_state_dict[k] = tp_merge_key(k, [sd[k] for sd in state_dicts], config)
return new_state_dict
@ -317,37 +319,37 @@ class ReaLModelParamCount:
@staticmethod
def _derive_count_from_keys(
keys: List[str], config: model_api.ReaLModelConfig, mp_size: int
keys: List[str], config: model_api.ReaLModelConfig, tp_size: int
) -> int:
count = 0
for k in keys:
count += np.prod(get_real_model_param_shape(k, config, mp_size))
count += np.prod(get_real_model_param_shape(k, config, tp_size))
return int(count)
@staticmethod
def embed(config: model_api.ReaLModelConfig, mp_size: int) -> int:
def embed(config: model_api.ReaLModelConfig, tp_size: int) -> int:
return ReaLModelParamCount._derive_count_from_keys(
ReaLModelParamKeys.embed(config), config, mp_size
ReaLModelParamKeys.embed(config), config, tp_size
)
@staticmethod
def tblock(config: model_api.ReaLModelConfig, idx: int, mp_size: int) -> int:
def tblock(config: model_api.ReaLModelConfig, idx: int, tp_size: int) -> int:
return ReaLModelParamCount._derive_count_from_keys(
ReaLModelParamKeys.tblock(config, idx), config, mp_size
ReaLModelParamKeys.tblock(config, idx), config, tp_size
)
@staticmethod
def head(config: model_api.ReaLModelConfig, mp_size: int) -> int:
def head(config: model_api.ReaLModelConfig, tp_size: int) -> int:
return ReaLModelParamCount._derive_count_from_keys(
ReaLModelParamKeys.head(config), config, mp_size
ReaLModelParamKeys.head(config), config, tp_size
)
@staticmethod
def total(config: model_api.ReaLModelConfig, idx: int, mp_size: int) -> int:
def total(config: model_api.ReaLModelConfig, idx: int, tp_size: int) -> int:
return (
config.n_layers * ReaLModelParamCount.tblock(config, idx, mp_size)
+ ReaLModelParamCount.head(config, mp_size)
+ ReaLModelParamCount.embed(config, mp_size)
config.n_layers * ReaLModelParamCount.tblock(config, idx, tp_size)
+ ReaLModelParamCount.head(config, tp_size)
+ ReaLModelParamCount.embed(config, tp_size)
)
@ -356,7 +358,7 @@ def partition_pipeline_layers(
num_stages: int,
method: str = "parameters",
) -> Dict[int, Tuple[int, int]]:
# Ignoring mp_size in param count because tensor parallel equally partitions parameters.
# Ignoring tp_size in param count because tensor parallel equally partitions parameters.
# It is irrelevant to how we partition pipeline stages.
param_counts = (
[ReaLModelParamCount.embed(config, 1)]

View File

@ -13,11 +13,11 @@ def _reduce(input_):
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if constants.model_parallel_world_size() == 1:
if constants.tensor_parallel_world_size() == 1:
return input_
# All-reduce.
torch.distributed.all_reduce(input_, group=constants.model_parallel_group())
torch.distributed.all_reduce(input_, group=constants.tensor_parallel_group())
return input_
@ -25,7 +25,7 @@ def _split_along_last_dim(input_):
"""Split the tensor along its last dimension and keep the corresponding
slice."""
world_size = constants.model_parallel_world_size()
world_size = constants.tensor_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
@ -34,7 +34,7 @@ def _split_along_last_dim(input_):
input_list = split_tensor_along_last_dim(input_, world_size)
# Note: torch.split does not create contiguous tensors by default.
rank = constants.model_parallel_rank()
rank = constants.tensor_parallel_rank()
output = input_list[rank].contiguous()
return output
@ -44,7 +44,7 @@ def _split_along_first_dim(input_):
"""Split the tensor along its first dimension and keep the corresponding
slice."""
world_size = constants.model_parallel_world_size()
world_size = constants.tensor_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
@ -55,7 +55,7 @@ def _split_along_first_dim(input_):
dim_size % world_size == 0
), "First dimension of the tensor should be divisible by tensor parallel size"
local_dim_size = dim_size // world_size
rank = constants.model_parallel_rank()
rank = constants.tensor_parallel_rank()
dim_offset = rank * local_dim_size
output = input_[dim_offset : dim_offset + local_dim_size].contiguous()
@ -66,19 +66,19 @@ def _split_along_first_dim(input_):
def _gather_along_last_dim(input_):
"""Gather tensors and concatinate along the last dimension."""
world_size = constants.model_parallel_world_size()
world_size = constants.tensor_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Size and dimension.
last_dim = input_.dim() - 1
rank = constants.model_parallel_rank()
rank = constants.tensor_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(
tensor_list, input_, group=constants.model_parallel_group()
tensor_list, input_, group=constants.tensor_parallel_group()
)
# Note: torch.cat already creates a contiguous tensor.
@ -90,7 +90,7 @@ def _gather_along_last_dim(input_):
def _gather_along_first_dim(input_):
"""Gather tensors and concatinate along the first dimension."""
world_size = constants.model_parallel_world_size()
world_size = constants.tensor_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
@ -102,7 +102,7 @@ def _gather_along_first_dim(input_):
dim_size, dtype=input_.dtype, device=constants.current_device()
)
torch.distributed._all_gather_base(
output, input_.contiguous(), group=constants.model_parallel_group()
output, input_.contiguous(), group=constants.tensor_parallel_group()
)
return output
@ -110,7 +110,7 @@ def _gather_along_first_dim(input_):
def _reduce_scatter_along_first_dim(input_):
"""Reduce-scatter the input tensor across model parallel group."""
world_size = constants.model_parallel_world_size()
world_size = constants.tensor_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
@ -128,7 +128,7 @@ def _reduce_scatter_along_first_dim(input_):
dim_size, dtype=input_.dtype, device=constants.current_device()
)
torch.distributed._reduce_scatter_base(
output, input_.contiguous(), group=constants.model_parallel_group()
output, input_.contiguous(), group=constants.tensor_parallel_group()
)
return output

View File

@ -44,7 +44,7 @@ except ImportError:
import realhf.base.logging as logging
logger = logging.getLogger("model_parallel.modules")
logger = logging.getLogger("tensor_parallel.modules")
def get_activation_fn(activation_function: str) -> Callable:
@ -95,12 +95,12 @@ class ParallelEmbedding(torch.nn.Module):
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
self.tensor_model_parallel_size = constants.model_parallel_world_size()
self.tensor_model_parallel_size = constants.tensor_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = (
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings,
constants.model_parallel_rank(),
constants.tensor_parallel_rank(),
self.tensor_model_parallel_size,
)
)
@ -110,7 +110,7 @@ class ParallelEmbedding(torch.nn.Module):
logger.debug(
f"ParallelEmbedding: num_embeddings={num_embeddings}, per_partition={self.num_embeddings_per_partition}, embedding_dim={embedding_dim},"
f"tp_rank={constants.model_parallel_rank()},tp_world_size={constants.model_parallel_world_size()}"
f"tp_rank={constants.tensor_parallel_rank()},tp_world_size={constants.tensor_parallel_world_size()}"
)
# Allocate weights and initialize.
self.weight = Parameter(
@ -264,7 +264,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
assert (
not ctx.async_grad_allreduce
), "async_grad_allreduce and sequence_parallel can not be both True"
world_size = constants.model_parallel_world_size()
world_size = constants.tensor_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
@ -272,7 +272,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size, input.dtype, "mpu"
)
torch.distributed._all_gather_base(
all_gather_buffer, input, group=constants.model_parallel_group()
all_gather_buffer, input, group=constants.tensor_parallel_group()
)
total_input = all_gather_buffer
else:
@ -290,7 +290,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
use_bias = ctx.use_bias
if ctx.sequence_parallel:
world_size = constants.model_parallel_world_size()
world_size = constants.tensor_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
@ -300,7 +300,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
handle = torch.distributed._all_gather_base(
all_gather_buffer,
input,
group=constants.model_parallel_group(),
group=constants.tensor_parallel_group(),
async_op=True,
)
@ -327,7 +327,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
grad_input,
group=constants.model_parallel_group(),
group=constants.tensor_parallel_group(),
async_op=True,
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
@ -346,7 +346,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
handle = torch.distributed._reduce_scatter_base(
sub_grad_input,
grad_input,
group=constants.model_parallel_group(),
group=constants.tensor_parallel_group(),
async_op=True,
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
@ -525,7 +525,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct
assert (
not ctx.async_grad_allreduce
), "async_grad_allreduce and sequence_parallel can not be both True"
world_size = constants.model_parallel_world_size()
world_size = constants.tensor_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
@ -533,7 +533,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct
dim_size, input.dtype, "mpu"
)
torch.distributed._all_gather_base(
all_gather_buffer, input, group=constants.model_parallel_group()
all_gather_buffer, input, group=constants.tensor_parallel_group()
)
total_input = all_gather_buffer
else:
@ -557,7 +557,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct
is_w_parallel = ctx.is_w_parallel
if ctx.sequence_parallel:
world_size = constants.model_parallel_world_size()
world_size = constants.tensor_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
@ -567,7 +567,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct
handle = torch.distributed._all_gather_base(
all_gather_buffer,
input,
group=constants.model_parallel_group(),
group=constants.tensor_parallel_group(),
async_op=True,
)
@ -578,7 +578,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct
total_input = input
grad_input = 0
for w, is_parallel, grad in zip(weights, is_w_parallel, grads):
if is_parallel or constants.model_parallel_rank() == 0:
if is_parallel or constants.tensor_parallel_rank() == 0:
grad_input = grad_input + grad.matmul(w)
if ctx.sequence_parallel:
@ -597,7 +597,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
grad_input,
group=constants.model_parallel_group(),
group=constants.tensor_parallel_group(),
async_op=True,
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
@ -616,7 +616,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct
handle = torch.distributed._reduce_scatter_base(
sub_grad_input,
grad_input,
group=constants.model_parallel_group(),
group=constants.tensor_parallel_group(),
async_op=True,
)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
@ -785,7 +785,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size = output_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
world_size = constants.model_parallel_world_size()
world_size = constants.tensor_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size)
self.skip_bias_add = skip_bias_add
self.is_expert = is_expert
@ -852,7 +852,7 @@ class ColumnParallelLinear(torch.nn.Module):
# in expert MLPs always behave as sequence parallel is not enabled.
sequence_parallel = constants.sequence_parallel() and not self.is_expert
async_tensor_model_parallel_allreduce = (
constants.model_parallel_world_size() > 1 and not sequence_parallel
constants.tensor_parallel_world_size() > 1 and not sequence_parallel
)
if sequence_parallel:
@ -942,7 +942,7 @@ class RowParallelLinear(torch.nn.Module):
self.output_size = output_size
self.input_is_parallel = input_is_parallel
# Divide the weight matrix along the last dimension.
world_size = constants.model_parallel_world_size()
world_size = constants.tensor_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add
self.gradient_accumulation_fusion = gradient_accumulation_fusion
@ -1030,9 +1030,9 @@ def parallel_lm_logits(
bias=None,
):
"""LM logits using word embedding weights."""
model_parallel = constants.model_parallel_world_size() > 1
tensor_parallel = constants.tensor_parallel_world_size() > 1
sequence_parallel = constants.sequence_parallel()
async_grad_allreduce = not sequence_parallel and model_parallel
async_grad_allreduce = not sequence_parallel and tensor_parallel
# Parallel logits.
if sequence_parallel:
input_parallel = input_
@ -1066,7 +1066,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
torch.distributed.all_reduce(
logits_max,
op=torch.distributed.ReduceOp.MAX,
group=constants.model_parallel_group(),
group=constants.tensor_parallel_group(),
)
# Subtract the maximum value.
vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1)
@ -1074,8 +1074,8 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# Get the partition's vocab indecies
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = constants.model_parallel_rank()
world_size = constants.model_parallel_world_size()
rank = constants.tensor_parallel_rank()
world_size = constants.tensor_parallel_world_size()
vocab_start_index, vocab_end_index = get_vocab_range(
partition_vocab_size, rank, world_size
)
@ -1101,7 +1101,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
torch.distributed.all_reduce(
predicted_logits,
op=torch.distributed.ReduceOp.SUM,
group=constants.model_parallel_group(),
group=constants.tensor_parallel_group(),
)
# Sum of exponential of logits along vocab dimension across all GPUs.
@ -1111,7 +1111,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
torch.distributed.all_reduce(
sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=constants.model_parallel_group(),
group=constants.tensor_parallel_group(),
)
# Loss = log(sum(exp(logits))) - predicted-logit.

View File

@ -6,11 +6,7 @@ from typing import List, Sequence
import numpy as np
import torch
from realhf.base.constants import (
model_parallel_group,
model_parallel_rank,
model_parallel_world_size,
)
import realhf.base.constants as constants
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
"tensor_model_parallel": False,
@ -22,7 +18,7 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
def param_is_not_model_parallel_duplicate(param):
return (
hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel
) or (model_parallel_rank() == 0)
) or (constants.tensor_parallel_rank() == 0)
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
@ -110,8 +106,8 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
If False, returns a view into the existing Tensor.
Default is False
"""
partition_size = torch.numel(tensor) // model_parallel_world_size()
start_index = partition_size * model_parallel_rank()
partition_size = torch.numel(tensor) // constants.tensor_parallel_world_size()
start_index = partition_size * constants.tensor_parallel_rank()
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(
@ -135,7 +131,7 @@ def gather_split_1d_tensor(tensor):
Arguments:
tensor: A Tensor or view of this rank's portion of the data.
"""
numel_gathered = torch.numel(tensor) * model_parallel_world_size()
numel_gathered = torch.numel(tensor) * constants.tensor_parallel_world_size()
gathered = torch.empty(
numel_gathered,
dtype=tensor.dtype,
@ -147,7 +143,9 @@ def gather_split_1d_tensor(tensor):
# as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch.distributed._all_gather_base(gathered, tensor, group=model_parallel_group())
torch.distributed._all_gather_base(
gathered, tensor, group=constants.tensor_parallel_group()
)
return gathered

View File

@ -210,11 +210,11 @@ def gather_packed_shifted_log_probs(
"""
labels = torch.nn.functional.pad(labels[1:], (0, 1), value=0)
leave_one_indices = build_leave_one_indices(logits, cu_seqlens)
if constants.model_parallel_world_size() > 1:
if constants.tensor_parallel_world_size() > 1:
# NOTE: logprobs is freaking sensitive to input_ids. If the input sequence is a natural sequence, everything will be fine.
# However, if we input random token IDs, parallel cross entropy can produce VERY different results than the normal
# torch.gather based version (e.g., the maximum absolute different can reach ~50).
from realhf.impl.model.parallelism.model_parallel.modules import (
from realhf.impl.model.parallelism.tensor_parallel.modules import (
vocab_parallel_cross_entropy,
)
@ -239,14 +239,16 @@ def gather_packed_shifted_log_probs(
def apply_logits_mask(logits: torch.HalfTensor, mask: torch.BoolTensor):
assert mask.shape[-1] == logits.shape[-1] * constants.model_parallel_world_size(), (
constants.model_parallel_world_size(),
assert (
mask.shape[-1] == logits.shape[-1] * constants.tensor_parallel_world_size()
), (
constants.tensor_parallel_world_size(),
logits.shape,
mask.shape,
)
parallel_vocab_size = logits.shape[-1]
mp_rank = constants.model_parallel_rank()
mask = mask[:, mp_rank * parallel_vocab_size : (mp_rank + 1) * parallel_vocab_size]
tp_rank = constants.tensor_parallel_rank()
mask = mask[:, tp_rank * parallel_vocab_size : (tp_rank + 1) * parallel_vocab_size]
logits.masked_fill_(mask, torch.finfo(logits.dtype).min)

View File

@ -250,7 +250,7 @@ def pad_sequence_parallel_input(
):
"""Sequence parallel requires packed_input_ids has a shape of 1 dimension
[total_seq_len], and total_seq_len should be divisible by
model_parallel_world_size. This function is used to pad packed_input_ids to
tensor_parallel_world_size. This function is used to pad packed_input_ids to
suitable length with an empty sequence, and return new packed_input_ids,
cu_seqlens and max_seqlen.
@ -262,10 +262,10 @@ def pad_sequence_parallel_input(
Returns:
(torch.Tensor, torch.Tensor, int, int): padded (packed_input_ids, cu_seqlens, max_seqlen, pad_size)
"""
mp_world_size = constants.model_parallel_world_size()
tp_world_size = constants.tensor_parallel_world_size()
pad_size = 0
if len(packed_input_ids) % mp_world_size != 0:
pad_size = mp_world_size - len(packed_input_ids) % mp_world_size
if len(packed_input_ids) % tp_world_size != 0:
pad_size = tp_world_size - len(packed_input_ids) % tp_world_size
packed_input_ids = torch.nn.functional.pad(
packed_input_ids, (0, pad_size), value=1
)
@ -281,9 +281,9 @@ def pad_sequence_parallel_generate_input(
):
"""Only for pipeline generate input when model+seq parallel is enabled. To
make sure inputs for seq parallel model have a shape with first dimension
divisible by model_parallel_world_size, the packed_input_ids should have
length divisible by model_parallel_world_size, and contains number of
sequences divisible by model_parallel_world_size.
divisible by tensor_parallel_world_size, the packed_input_ids should have
length divisible by tensor_parallel_world_size, and contains number of
sequences divisible by tensor_parallel_world_size.
Args:
packed_input_ids (torch.Tensor): unpadded packed_input_ids
@ -293,16 +293,16 @@ def pad_sequence_parallel_generate_input(
Returns:
(torch.Tensor, torch.Tensor, int, int, int): padded (packed_input_ids, cu_seqlens, max_seqlen, pad_size, pad_seq_size)
"""
mp_world_size = constants.model_parallel_world_size()
tp_world_size = constants.tensor_parallel_world_size()
pad_size, pad_seq_size = 0, 0
if (
len(packed_input_ids) % mp_world_size != 0
or (len(cu_seqlens) - 1) % mp_world_size != 0
len(packed_input_ids) % tp_world_size != 0
or (len(cu_seqlens) - 1) % tp_world_size != 0
):
pad_size = mp_world_size - len(packed_input_ids) % mp_world_size
pad_seq_size = mp_world_size - (len(cu_seqlens) - 1) % mp_world_size
pad_size = tp_world_size - len(packed_input_ids) % tp_world_size
pad_seq_size = tp_world_size - (len(cu_seqlens) - 1) % tp_world_size
if pad_size < pad_seq_size:
pad_size += mp_world_size
pad_size += tp_world_size
pad_cu_seqlens = torch.tensor(list(range(1, pad_seq_size)) + [pad_size]) + len(
packed_input_ids
)

View File

@ -55,6 +55,7 @@ def actor_loss_fn(
eps_clip: float,
loss_mask: Optional[torch.BoolTensor] = None,
c_clip: Optional[float] = None,
proximal_logprobs: Optional[torch.FloatTensor] = None,
) -> Tuple[torch.Tensor, Dict]:
"""Compute PPO actor loss function.
@ -83,13 +84,22 @@ def actor_loss_fn(
old_logprobs = old_logprobs.clone()
if advantages.is_inference():
advantages = advantages.clone()
if proximal_logprobs is not None:
assert proximal_logprobs.dtype == torch.float32
if proximal_logprobs.is_inference():
proximal_logprobs = proximal_logprobs.clone()
denorm_logprobs = proximal_logprobs
else:
denorm_logprobs = old_logprobs
# create mask
if loss_mask is None:
loss_mask = torch.ones_like(logprobs, dtype=torch.bool)
loss_mask: torch.BoolTensor
if loss_mask is not None:
loss_mask_count = loss_mask.count_nonzero() or 1
# For numerical stability.
ratio = torch.where(loss_mask, torch.exp(logprobs - old_logprobs), 0)
else:
ratio = torch.exp(logprobs - old_logprobs)
ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0)
clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip)
pg_loss1 = -advantages * ratio
@ -104,24 +114,34 @@ def actor_loss_fn(
pg_loss = torch.min(pg_loss, pg_loss3)
else:
dual_clip_mask = torch.zeros_like(clip_mask)
if proximal_logprobs is not None:
behav_kl = proximal_logprobs - old_logprobs
behav_imp_weight = behav_kl.exp()
if c_clip is not None:
behav_mask = (behav_imp_weight <= c_clip).logical_and(loss_mask)
else:
behav_mask = loss_mask
behav_kl = torch.where(behav_mask, behav_kl, 0.0)
behav_imp_weight = torch.where(behav_mask, behav_imp_weight, 0.0)
pg_loss = pg_loss * behav_imp_weight
logging_loss = pg_loss.detach()
if loss_mask is not None:
pg_loss = torch.where(loss_mask, pg_loss, 0).sum() / loss_mask_count
else:
pg_loss = pg_loss.mean()
if loss_mask is not None:
clip_mask.logical_and_(loss_mask)
dual_clip_mask.logical_and_(loss_mask)
# Remain torch.CudaTensor here for all-reduce after train step.
stat = dict(
loss=logging_loss,
importance_weight=ratio.detach(),
approx_kl=(logprobs - old_logprobs).detach(),
approx_kl=(logprobs - denorm_logprobs).detach(),
clip_mask=clip_mask,
dual_clip_mask=dual_clip_mask,
)
if proximal_logprobs is not None:
stat["behave_imp_weight"] = behav_imp_weight
stat["behave_approx_kl"] = behav_kl
stat["behave_mask"] = behav_mask
return pg_loss, stat

View File

@ -13,7 +13,7 @@ from torch.cuda import device as device_ctx_manager
from torch.utils.checkpoint import detach_variable
import realhf.base.constants as constants
from realhf.impl.model.parallelism.model_parallel.utils import (
from realhf.impl.model.parallelism.tensor_parallel.utils import (
divide,
gather_split_1d_tensor,
safely_set_viewless_tensor_data,
@ -169,10 +169,10 @@ def model_parallel_cuda_manual_seed(seed):
tensor-model-parallel state: This state is different among a set of model parallel GPUs, but the same across data parallel groups. This is used for example for dropout in model parallel regions.
"""
# 2718 is just for fun and any POSITIVE value will work.
model_parallel_rank = constants.model_parallel_rank()
tensor_parallel_rank = constants.tensor_parallel_rank()
expert_parallel_rank = 0
offset = seed + 2718
tensor_model_parallel_seed = offset + model_parallel_rank
tensor_model_parallel_seed = offset + tensor_parallel_rank
# Data parallel gets the original seed.
data_parallel_seed = seed
@ -187,7 +187,7 @@ def model_parallel_cuda_manual_seed(seed):
)
expert_parallel_seed = (
seed + 1024 + 100 * expert_parallel_rank + model_parallel_rank
seed + 1024 + 100 * expert_parallel_rank + tensor_parallel_rank
)
_CUDA_RNG_STATE_TRACKER.add(_EXPERT_PARALLEL_RNG_TRACKER_NAME, expert_parallel_seed)
@ -331,8 +331,8 @@ def _initialize_affine_weight_cpu(
weight_list = torch.split(
master_weight, per_partition_per_stride_size, dim=partition_dim
)
rank = constants.model_parallel_rank()
world_size = constants.model_parallel_world_size()
rank = constants.tensor_parallel_rank()
world_size = constants.tensor_parallel_world_size()
my_weight_list = weight_list[rank::world_size]
with torch.no_grad():

View File

@ -5,13 +5,18 @@
import fcntl
import os
import re
import select
import subprocess
import threading
import time
from collections import defaultdict
from typing import Dict, List, Literal, Optional, Tuple
import colorama
import realhf.base.logging as logging
from realhf.base.cluster import spec as cluster_spec
from realhf.base.constants import LOG_ROOT
from realhf.base.constants import SLURM_LOCK_FILE_NAME as LOCK_FILE_NAME
from realhf.scheduler.client import JobException, JobInfo, JobState, SchedulerClient
from realhf.scheduler.evaluator import AutomaticEvaluator
@ -29,6 +34,49 @@ SCHEDULING_TIMEOUT_MAX_SECONDS = 3600 * 24
SCHEDULER_WAIT_CHECK_TIME_INTERVAL = 5
def monitor_log(
job_name: str, log_path: str, output_file: str, stop_event: threading.Event
):
"""Monitor a log file and write its contents to the output file with job name prefix."""
# Wait for log file to be created
while not os.path.exists(log_path) and not stop_event.is_set():
time.sleep(0.1)
if stop_event.is_set():
return
# Open the log file and follow it
with open(log_path, "r") as log_file, open(output_file, "a") as out_file:
# Store last position
position = 0
line_pos = 0
while not stop_event.is_set():
log_file.seek(position)
try:
new_lines = log_file.readlines()
except UnicodeDecodeError:
time.sleep(0.5)
continue
if new_lines:
# Update position
position = log_file.tell()
worker_type = job_name.split(":")[1]
# Write new lines to output file with job name prefix
for line in new_lines:
if line.strip(): # Skip empty lines
out_file.write(
f"{colorama.Fore.YELLOW + colorama.Style.DIM}({worker_type} Line {line_pos}){colorama.Style.RESET_ALL} {line}"
)
line_pos += 1
out_file.flush()
# Sleep briefly to avoid CPU spinning
time.sleep(0.1)
class SlurmSchedulerClient(SchedulerClient):
"""Uses Slurm (https://slurm.schedmd.com/overview.html)."""
@ -248,6 +296,26 @@ class SlurmSchedulerClient(SchedulerClient):
# before wait, commit all remaining pending jobs
# TODO: grab global file lock to avoid multi-experiment deadlocks
self.__allocate_and_commit_pending_jobs()
# Start monitoring threads
threads = []
stop_events = []
merged_log_path = os.path.join(
LOG_ROOT, self.expr_name, self.trial_name, "main.log"
)
for job_name, launch_info in self.__committed_jobs.items():
stop_event = threading.Event()
stop_events.append(stop_event)
# Thread for monitoring the log file
log_thread = threading.Thread(
target=monitor_log,
args=(job_name, launch_info.log_path, merged_log_path, stop_event),
)
threads.append(log_thread)
log_thread.start()
# begin wait
deadline = None if timeout is None else time.time() + timeout
left = set(self.__committed_jobs.keys())
@ -256,6 +324,11 @@ class SlurmSchedulerClient(SchedulerClient):
f"Waiting for {num_jobs_left} jobs. Jobs IDs: "
f"{','.join(sorted([x.job_info.slurm_id for x in self.__committed_jobs.values()]))}."
)
logger.info(
f"All slurm logs will be merged. To check the real-time output, "
f"run\n\t`tail -f {merged_log_path}`."
)
try:
while len(left) > 0:
if len(left) < num_jobs_left:
num_jobs_left = len(left)
@ -294,6 +367,9 @@ class SlurmSchedulerClient(SchedulerClient):
if update:
self.__committed_jobs.pop(job_slurm_name)
time.sleep(SCHEDULER_WAIT_CHECK_TIME_INTERVAL)
finally:
[s.set() for s in stop_events]
[t.join() for t in threads]
def __update_all(self):
states = []

View File

@ -294,21 +294,29 @@ class SlurmLaunchInfo:
@property
def multiprog_path(self) -> str:
return os.path.join(
path = os.path.join(
LOG_ROOT,
self.exper_name,
self.trial_name,
"slurm",
"multiprog",
f"{self.worker_type}-{self.worker_submission_idx}.multiprog",
)
os.makedirs(os.path.dirname(path), exist_ok=True)
return path
@property
def hostfile_path(self) -> str:
return os.path.join(
path = os.path.join(
LOG_ROOT,
self.exper_name,
self.trial_name,
"slurm",
"hostfile",
f"{self.worker_type}-{self.worker_submission_idx}.hostfile",
)
os.makedirs(os.path.dirname(path), exist_ok=True)
return path
def show_log(self):
try:

View File

@ -1,12 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
def import_profiler_registers():
import realhf.search_engine.enumerate
import realhf.search_engine.estimate
import realhf.search_engine.layers
import realhf.search_engine.param_realloc
import realhf.search_engine.search
import realhf.search_engine.utils

View File

@ -1,158 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
from typing import List
from realhf.api.core.dfg import MFCDef, ModelInterfaceType
from realhf.api.core.dfg import build_graph as build_dfg
from realhf.api.quickstart.device_mesh import DeviceMesh, find_parallel_strategies
from realhf.api.quickstart.search import MFCDef, RPCExecution, RPCInstance
from realhf.search_engine.estimate import (
estimate_rpc_memory_cost,
estimate_rpc_time_cost,
)
MEM_INDEX = 1.0 # heuristic value to scale estimated memory
def enumerate_rpc_executions(
rpc: MFCDef,
device_mesh: DeviceMesh,
seq_len: int,
num_gen_tokens: int,
n_ppo_minibatches: int,
gradient_checkpointing: bool,
) -> List[RPCExecution]:
sub_device_meshes = device_mesh.sub_device_meshes()
import pprint
feasible = []
for sub_device_mesh in sub_device_meshes:
ps = find_parallel_strategies(sub_device_mesh)
for parallel in ps:
num_dp = parallel.data_parallel_size
num_pp = parallel.pipeline_parallel_size
num_mp = parallel.model_parallel_size
bs = rpc.n_seqs
# seq_len = seq_len
min_bs = (
2 * num_dp * num_pp * n_ppo_minibatches
if rpc.interface_type == ModelInterfaceType.TRAIN_STEP
else num_dp * num_pp
)
if min_bs > bs:
# batch size too small
continue
# heuristic to filter out inherent slow configurations
if (
num_mp * num_dp > device_mesh.n_gpus_per_node
and rpc.interface_type == ModelInterfaceType.TRAIN_STEP
):
continue
if num_mp > 8:
continue
if num_pp > max(device_mesh.n_nodes, 8):
continue
# memory and time estimation
mem_cost, static_mem = estimate_rpc_memory_cost(
rpc,
parallel,
bs,
seq_len,
gradient_checkpointing=gradient_checkpointing,
n_ppo_minibatches=n_ppo_minibatches,
num_gen_tokens=num_gen_tokens,
offload=rpc.model_name.role in ["ref", "reward"],
)
mem_cost = int(mem_cost * MEM_INDEX)
static_mem = int(static_mem * MEM_INDEX)
time_cost = estimate_rpc_time_cost(
rpc,
parallel,
bs=bs,
seq_len=seq_len,
num_gen_tokens=num_gen_tokens,
gradient_checkpointing=gradient_checkpointing,
n_ppo_minibatches=n_ppo_minibatches,
)
time_cost = int(time_cost)
if mem_cost < device_mesh.gpu_memory_capacity:
feasible.append(
RPCExecution(
rpc,
sub_device_mesh,
parallel,
time_cost,
mem_cost,
static_mem,
)
)
return feasible
def build_graph(
rpcs: List[MFCDef],
num_epoch: int = 5,
epoch_dependency_interval: int = 1,
if_print=False,
) -> List[RPCInstance]:
"""Build model function call graph of multiple training epochs,
args:
exp: ProfileExperiment, the experiment object
num_epoch: int, number of training epochs
epoch_dependency_interval: int, the interval of epoch dependency,
e.g. if epoch_dependency_interval = 2, then the graph will have
edges between epoch i and epoch i+2, i+4, ...
returns:
rpc_instances: List[RPCInstance], the list of RPCInstance objects
"""
# one epoch dependency graph
rpcs, edges = build_dfg(rpcs)
rpc_names_mapping = {rpc.name: rpc for rpc in rpcs}
rpc_instances = []
# multi epoch graph
for epoch_id in range(num_epoch):
for rpc in rpcs:
children = []
parents = []
if rpc.is_src and epoch_id >= epoch_dependency_interval:
for other in rpcs:
if other.is_dst and other.model_name.role == rpc.model_name.role:
parents.append(
RPCInstance(
rpc,
epoch_id - epoch_dependency_interval,
[],
[],
)
)
if rpc.is_dst and rpc.model_name.role == rpc.model_name.role:
for other in rpcs:
if (
other.is_src
and epoch_id + epoch_dependency_interval < num_epoch
):
children.append(
RPCInstance(
rpc,
epoch_id + epoch_dependency_interval,
[],
[],
)
)
for parent in rpc.parents:
p = rpc_names_mapping[parent]
parents.append(RPCInstance(p, epoch_id, [], []))
for child in rpc.children:
c = rpc_names_mapping[child]
children.append(RPCInstance(c, epoch_id, [], []))
rpc_instance = RPCInstance(rpc, epoch_id, parents, children)
rpc_instances.append(rpc_instance)
if if_print:
for ri in rpc_instances:
print(ri)
return rpc_instances

View File

@ -1,499 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
# Estimate a fucntion-call level execution time for device mesh enumerate pruning
# assume one batch of data passes through all rpcs once
import argparse
import getpass
import itertools
import os
import pickle
from collections import defaultdict
from typing import Optional
import numpy as np
import pandas as pd
import realhf.base.cluster
import realhf.base.constants as constants
import realhf.base.logging as logging
from realhf.api.cli_args import ParallelismConfig
from realhf.api.core.dfg import MFCDef, ModelFamily, ModelInterfaceType
from realhf.api.core.model_api import ReaLModelConfig
from realhf.search_engine.param_realloc import estimate_param_realloc_time_cost
from realhf.search_engine.utils import load_model_config
logger = logging.getLogger("estimate", "benchmark")
PROFILE_RESULT_PATH = os.path.join(
realhf.base.cluster.spec.fileroot,
"logs",
getpass.getuser(),
"profile",
"profile",
"layer_stats",
)
def get_param_realloc_stats(
model_family: ModelFamily,
model_path: str,
n_nodes: int,
use_cache: bool = True,
):
non_critic = ModelFamily(model_family._class, model_family.size, False)
table_path = os.path.join(
constants.PROFILER_CACHE_PATH,
"param_realloc",
f"prtc_{non_critic}_n{n_nodes}.pkl",
)
if not os.path.exists(table_path):
print(
f"Calculating estimation of param realloc time cost for {model_family} at {model_path}"
)
estimate_param_realloc_time_cost(n_nodes, {non_critic: model_path})
print(f"Loading param realloc stats from {table_path}")
return pickle.load(open(table_path, "rb"))
def get_organized_op_stats(
model_family: ModelFamily, model_path: str, use_cache: bool = True
):
non_critic = ModelFamily(model_family._class, model_family.size, False)
# parse raw stats into list of OpInfo used for estimation
cache_path = os.path.join(
constants.PROFILER_CACHE_PATH, "organized_stats", f"{non_critic}.pkl"
)
if use_cache and os.path.exists(cache_path):
with open(cache_path, "rb") as f:
return pickle.load(f)
raw_result_path = os.path.join(
constants.PROFILER_CACHE_PATH, "layer_stats", str(non_critic)
)
if not os.path.exists(raw_result_path):
from realhf.apps.main import _main_profile_layers
_main_profile_layers(non_critic, model_path)
raw_stats_list = []
for fn in os.listdir(raw_result_path):
if not fn.startswith("layer-stats"):
continue
num_mp = int(fn.replace(".pkl", "").split("_")[1])
rank = int(fn.replace(".pkl", "").split("_")[2])
with open(os.path.join(raw_result_path, fn), "rb") as f:
stats = pickle.load(f)
if isinstance(stats, dict):
stats = pd.DataFrame(stats)
elif isinstance(stats, pd.DataFrame):
pass
else:
raise ValueError(f"Unsupported stats type {type(stats)}")
stats["num_mp"] = num_mp
stats["rank"] = rank
raw_stats_list.append(stats)
raw_stats = pd.concat(raw_stats_list)
bs_list = raw_stats["bs"].unique()
seq_len_list = raw_stats["seq_len"].unique()
op_name_list = raw_stats["op_name"].unique()
layer_name_list = raw_stats["layer_name"].unique()
num_mp_list = raw_stats["num_mp"].unique()
organized_stats = defaultdict(list)
for op_name, bs, seq_len, layer_name, num_mp in itertools.product(
op_name_list, bs_list, seq_len_list, layer_name_list, num_mp_list
):
filter_cond = (
(raw_stats["op_name"] == op_name)
& (raw_stats["bs"] == bs)
& (raw_stats["seq_len"] == seq_len)
& (raw_stats["layer_name"] == layer_name)
& (raw_stats["num_mp"] == num_mp)
)
avg_time_ns = raw_stats[filter_cond]["time_ns"].mean()
x = int(bs) if op_name == "fwd_gen_1" else int(bs * seq_len)
organized_stats["op_name"].append(op_name)
organized_stats["layer_name"].append(layer_name)
organized_stats["bs"].append(bs)
organized_stats["seq_len"].append(seq_len)
organized_stats["num_mp"].append(num_mp)
organized_stats["avg_time_ns"].append(avg_time_ns)
organized_stats["x"].append(x)
df = pd.DataFrame(organized_stats)
if use_cache:
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
with open(cache_path, "wb") as f:
pickle.dump(df, f)
return df
def computation_instruction_time_cost(
op_stats: pd.DataFrame,
op_name: str,
num_layers: int,
parallel_strategy: ParallelismConfig,
bs: int,
seqlen: int,
):
# inst cost unit: ns
layer_names = ["embedding_layer", "block_0", "head"]
num_pp = parallel_strategy.pipeline_parallel_size
num_mp = parallel_strategy.model_parallel_size
op_stats = op_stats[
(op_stats["op_name"] == op_name) & (op_stats["num_mp"] == num_mp)
]
op_cost = {}
embed_stats = op_stats[op_stats["layer_name"] == "embedding_layer"]
if embed_stats[
(embed_stats["bs"] == bs) & (embed_stats["seq_len"] == seqlen)
].empty:
# do linear interpolation for data points that does not exist
for layer_name in layer_names:
layer_stats = op_stats[op_stats["layer_name"] == layer_name]
assert layer_stats[
(layer_stats["bs"] == bs) & (layer_stats["seq_len"] == seqlen)
].empty
assert not layer_stats.empty, (
layer_name,
op_name,
num_mp,
op_stats,
)
xs = layer_stats["x"]
ys = layer_stats["avg_time_ns"]
x = int(bs) if op_name == "fwd_gen_1" else int(bs * seqlen)
y = np.interp(x, xs, ys)
if max(xs) < x or min(xs) > x:
logger.warning(
f"Interpolated value outside profiling range, "
f"parallel strategy {parallel_strategy}: "
f"{x} in {sorted(list(set(xs)))}"
)
# estimate using largest or smallest value
if max(xs) < x:
y = ys.max() * (x / xs[ys.idxmax()])
else:
y = ys.min() * (x / xs[ys.idxmin()])
op_cost[layer_name] = y
else:
for layer_name in layer_names:
assert not op_stats[op_stats["layer_name"] == layer_name].empty
required_stats = op_stats[
(op_stats["layer_name"] == layer_name)
& (op_stats["bs"] == bs)
& (op_stats["seq_len"] == seqlen)
]
assert required_stats.shape[0] == 1
op_cost[layer_name] = required_stats["avg_time_ns"].values[0]
embedding_layer_cost = op_cost["embedding_layer"]
block_0_cost = op_cost["block_0"]
head_cost = op_cost["head"]
cost = (embedding_layer_cost + num_layers * block_0_cost + head_cost) / num_pp
return cost
def communication_instruction_time_cost(comm_stats, size, comm_type):
return size / comm_stats[comm_type] # unit: ns
def estimate_instruction_time_costs(
model_family: ModelFamily,
model_path: str,
num_layers: int, # model configuration, num transformers layer
parallel_strategy: ParallelismConfig,
hidden_dim: int,
batch_size: int,
seq_len: int,
n_ppo_minibatches: int = 1,
):
comm_stats = default_communication_stats()
op_stats = get_organized_op_stats(model_family, model_path, use_cache=True)
num_mp = parallel_strategy.model_parallel_size
num_pp = parallel_strategy.pipeline_parallel_size
num_dp = parallel_strategy.data_parallel_size
num_gpus = num_dp * num_mp * num_pp
train_mbs = (
batch_size / (2 * num_pp * num_dp * n_ppo_minibatches)
if num_pp > 1
else batch_size / (num_dp * n_ppo_minibatches)
)
gen_mbs = batch_size / (num_pp * num_dp)
# pprint.pprint(op_cost, indent=4)
inst_keys = [
"gen_fwd_0",
"gen_fwd_1",
"inf_fwd",
"train_fwd",
"train_bwd",
"train_opt",
]
op_names = ["fwd_gen_0", "fwd_gen_1", "fwd", "fwd", "bwd", "opt"]
inst_stats = {}
for inst_key, op_name in zip(inst_keys, op_names):
mbs = train_mbs if "train" in inst_key else gen_mbs
inst_stats[inst_key] = computation_instruction_time_cost(
op_stats, op_name, num_layers, parallel_strategy, mbs, seq_len
)
comm_type = "remote_send" if num_gpus // num_pp >= 8 else "local_send"
inst_stats["act_p2p"] = communication_instruction_time_cost(
comm_stats, 2 * hidden_dim * train_mbs * seq_len, comm_type
)
inst_stats["grad_p2p"] = communication_instruction_time_cost(
comm_stats, 2 * hidden_dim * train_mbs * seq_len, comm_type
)
inst_stats["gen_act_p2p"] = communication_instruction_time_cost(
comm_stats, 2 * hidden_dim * gen_mbs, comm_type
)
return inst_stats
def _estimate_rpc_time_cost(
inst_stats,
parallel_strategy: ParallelismConfig,
model_interface_type: ModelInterfaceType,
# model function call args
num_gen_tokens: int,
gradient_checkpointing: bool,
n_ppo_minibatches: int = 1,
):
# TODO: improve/remove heuristic
num_pp = parallel_strategy.pipeline_parallel_size
num_dp = parallel_strategy.data_parallel_size
num_mp = parallel_strategy.model_parallel_size
if model_interface_type == ModelInterfaceType.INFERENCE:
num_micro_batches = num_pp
compute_cost = inst_stats["inf_fwd"] * (num_pp + num_micro_batches - 1)
comm_cost = inst_stats["act_p2p"] * (num_pp + num_micro_batches - 2) * 2
if num_mp > 1: # and num_pp * num_dp > 1:
compute_cost = compute_cost * (1 - num_mp * 0.03)
elif model_interface_type == ModelInterfaceType.TRAIN_STEP:
# TODO: add reduce grads, add ppo micro batches
num_micro_batches = num_pp * 2 if num_pp > 1 else 1
compute_cost = (inst_stats["train_fwd"] + inst_stats["train_bwd"]) * (
num_pp + num_micro_batches - 1
) + inst_stats["train_opt"]
if gradient_checkpointing:
compute_cost += inst_stats["train_fwd"] * (num_pp + num_micro_batches - 1)
comm_cost = (
(inst_stats["grad_p2p"] + inst_stats["act_p2p"])
* (num_pp + num_micro_batches - 2)
* 2
)
compute_cost = compute_cost * n_ppo_minibatches
comm_cost = comm_cost * n_ppo_minibatches
if num_pp * num_dp <= 1:
compute_cost = compute_cost * (1 - num_mp * 0.04)
if num_mp > 1: # and num_pp * num_dp > 1:
compute_cost = compute_cost * (1 - num_mp * 0.03)
elif model_interface_type == ModelInterfaceType.GENERATE:
num_micro_batches = num_pp
num_gen_tokens = num_gen_tokens
compute_cost = (
inst_stats["gen_fwd_0"] * (num_pp + num_micro_batches - 1)
+ inst_stats["gen_fwd_1"] * (num_gen_tokens - 1) * num_micro_batches
)
if num_dp * num_mp > 1:
compute_cost = compute_cost * (1 - min(num_dp * num_mp, 8) * 0.03)
comm_cost = 0
# dirty heuristic
if num_pp > 8:
compute_cost = compute_cost * (1 + num_pp * 0.01)
# FIXME: disable comm cost for its not accurate
comm_cost = 0
return compute_cost + comm_cost
def estimate_rpc_time_cost(
rpc: MFCDef,
parallel_strategy: ParallelismConfig,
bs: int,
seq_len: int,
num_gen_tokens: int = 256,
gradient_checkpointing: bool = False,
n_ppo_minibatches: int = 1,
):
# time unit: miliseconds
# FIXME: n_ppo_minibatches > 1 will result in bad estimation
# when batch size is large enough, n_ppo_minibatches > 1 will not affect the estimation
n_ppo_minibatches = 1
model_type = rpc.model_type
model_path = rpc.model_path
model_config = load_model_config(rpc.model_type._class, rpc.model_path)
inst_cost = estimate_instruction_time_costs(
model_type,
model_path,
model_config.n_layers,
parallel_strategy,
model_config.hidden_dim,
bs,
seq_len,
n_ppo_minibatches=n_ppo_minibatches,
)
return (
_estimate_rpc_time_cost(
inst_cost,
parallel_strategy,
rpc.interface_type,
num_gen_tokens=num_gen_tokens,
gradient_checkpointing=gradient_checkpointing,
n_ppo_minibatches=n_ppo_minibatches,
)
) / 1e6
def default_communication_stats(if_print=False):
# use default comm stats of cluster
r = dict(
# between GPUs on the same node
local_send=170, # unit: GB/s
local_recv=170,
# IB between GPUs on different nodes
remote_send=20, # unit: GB/s
remote_recv=20,
)
if if_print:
print(r)
return r
def estimate_model_size(model_config: ReaLModelConfig):
h = model_config.hidden_dim
i = model_config.intermediate_dim
v = model_config.vocab_size
L = model_config.n_layers
# for llama actor only
n_params = 3 * v * h + (3 * h * i + 4 * h * h) * L
return 2 * n_params
def estimate_rpc_memory_cost(
rpc: MFCDef,
parallel_strategy: ParallelismConfig,
batch_size: int,
seq_len: int,
offload: bool = False,
gradient_checkpointing: bool = False,
offload_optimizer: bool = False,
n_ppo_minibatches: int = 1,
num_gen_tokens: int = 128,
):
# TODO: improve heuristic
interface_type = rpc.interface_type
model_config = load_model_config(rpc.model_type._class, rpc.model_path)
h = model_config.hidden_dim
i = model_config.intermediate_dim
v = model_config.vocab_size
s = seq_len
gs = num_gen_tokens
b = batch_size
L = model_config.n_layers
# for llama actor only
n_params = 3 * v * h + (3 * h * i + 4 * h * h) * L
param_mem = 2 * n_params
grad_mem = 2 * n_params
optimizer_mem = 20 * n_params if not offload_optimizer else 0
num_pp = parallel_strategy.pipeline_parallel_size
num_mp = parallel_strategy.model_parallel_size
num_dp = parallel_strategy.data_parallel_size
# zero1, pp and mp divide evenly
# enable sequence parallel
if interface_type == ModelInterfaceType.TRAIN_STEP:
# gradient checkpointing is always enabled for flash attn
static_mem = (param_mem + grad_mem) // (num_pp * num_mp) + optimizer_mem // (
num_pp * num_dp * num_mp
)
micro_bs = b // (2 * num_pp * num_dp) if num_pp > 0 else b // (num_dp)
if gradient_checkpointing:
active_mem = (micro_bs * s * h * num_pp * 2) // (num_pp * num_mp)
else:
# FIXME: calculate other memory entries
active_mem = (micro_bs * s * h * num_pp * 2) * 2 * L // (num_pp * num_mp)
return static_mem + active_mem, static_mem
elif interface_type == ModelInterfaceType.INFERENCE:
static_mem = int(2 * param_mem // (num_pp * num_mp))
# if num_dp > 4:
# static_mem = static_mem * 1.25
if offload:
return static_mem, 0 # assume offload
else:
return static_mem, static_mem
elif interface_type == ModelInterfaceType.GENERATE:
static_mem = int(2 * param_mem // (num_pp * num_mp))
if num_dp > 4 and num_dp * num_mp * num_pp <= 16:
static_mem = static_mem * 1.25
if num_mp == 0 and num_pp == 0:
static_mem = static_mem * 1.25
active_mem = (
2 * (2 * b * (gs + s) * h) * L // (num_pp * num_mp * num_dp)
) # kv cache
return static_mem + active_mem, static_mem
def example(rpcs):
if args.model_size == 7:
n_nodes = 1
elif args.model_size == 13:
n_nodes = 2
elif args.model_size == 34:
n_nodes = 4
elif args.model_size == 70:
n_nodes = 8
expr_name = f"profile-s{args.model_size}p{n_nodes}m1d8"
rollout, inf, train = rpcs
bs = 128
seq_len = 1024
p1 = ParallelismConfig(
pipeline_parallel_size=1, model_parallel_size=4, data_parallel_size=8
)
rpc_cost = estimate_rpc_time_cost(
train,
p1,
gradient_checkpointing=True,
num_gen_tokens=896,
bs=bs,
seq_len=seq_len,
n_ppo_minibatches=4,
)
mem_cost, static_mem = estimate_rpc_memory_cost(rollout, p1, bs, seq_len)
print(f"{p1} rpc cost {rpc_cost:.2f} seconds mem cost {mem_cost/(1024**3):.2f} GB")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run a profiling experiment.")
parser.add_argument(
"-s",
"--model_size",
type=int,
default=7,
)
args = parser.parse_args()
example(args)

View File

@ -1,310 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import os
import pickle
import time
from collections import defaultdict
from typing import List, Optional, Union
import pandas as pd
import torch
import torch.distributed as dist
import transformers
import realhf.api.core.model_api as model_api
import realhf.api.core.system_api as config_package
import realhf.base.constants as constants
import realhf.base.logging as logging
from realhf.api.core.model_api import ReaLModelConfig
from realhf.impl.model.utils.padding import unpad_input
logger = logging.getLogger("profile layers", "system")
def make_layers(config: ReaLModelConfig, dtype, device):
from realhf.impl.model.nn.real_llm_base import (
OutputHead,
ReaLModelBlock,
VocabPositionEmbedding,
)
embedding_layer = VocabPositionEmbedding(
config,
dtype=dtype,
device=device,
)
real_model_blocks = [
ReaLModelBlock(
config,
layer_index=i,
output_layernorm=(i == 1),
dtype=dtype,
device=device,
)
for i in range(1)
]
head = OutputHead(
config.hidden_dim,
1 if config.is_critic else config.vocab_size,
bias=False,
device=device,
dtype=dtype,
)
layer_names = ["embedding_layer", "block_0", "head"]
return [embedding_layer] + real_model_blocks + [head], layer_names
class ProfileLayers:
def __init__(
self,
model_name: str,
config: ReaLModelConfig,
tokenizer: transformers.PreTrainedTokenizerFast = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[str, torch.device]] = None,
):
self.model_name = model_name
self.config = config
self.backend_config = config_package.ModelBackend(
type_="deepspeed",
args=dict(
optimizer_name="adam",
optimizer_config=dict(lr=1e-5, weight_decay=0.0, betas=(0.9, 0.95)),
warmup_steps_proportion=0.0,
min_lr_ratio=0.0,
zero_stage=1,
bf16=False,
),
)
self.dtype = dtype
self.device = device
self.layers, self.layer_names = make_layers(config, dtype, device)
self.hidden_dim = config.hidden_dim
self.head_dim = config.head_dim
self.max_new_tokens = 128 # only useful in kv cache memory alloc
self.min_new_tokens = 128
self.stats = defaultdict(list)
self.num_layers = len(self.layers)
self.layers = [
model_api.Model(name, layer, tokenizer, device=device, dtype=dtype)
for layer, name in zip(self.layers, self.layer_names)
]
self.backend = model_api.make_backend(self.backend_config)
ft_spec = model_api.FinetuneSpec(10, 100, 10)
self.layers = [self.backend.initialize(layer, ft_spec) for layer in self.layers]
self.stats = defaultdict(list)
def reset_stats(self):
self.stats = defaultdict(list)
def insert_data_point(self, layer_name, name, bs, seq_len, time_ns):
self.stats["layer_name"].append(layer_name)
self.stats["op_name"].append(name)
self.stats["bs"].append(bs)
self.stats["seq_len"].append(seq_len)
self.stats["time_ns"].append(time_ns)
@torch.no_grad()
def fwd_gen(self, bs, seq_len):
from realhf.impl.model.nn.real_llm_base import PipeCacheData, PipeTransferData
input_ids = torch.randint(
0,
self.config.vocab_size,
(bs, seq_len),
dtype=torch.long,
device=self.device,
)
attention_mask = torch.ones_like(input_ids, device=self.device)
# fwd_gen_0
packed_input_ids, _, cu_seqlens, max_seqlen = unpad_input(
input_ids, attention_mask
)
cu_seqlens = cu_seqlens.to(device=self.device)
packed_input_ids = packed_input_ids.to(device=self.device)
x = PipeTransferData(
cu_seqlens=cu_seqlens,
max_seqlen=int(max_seqlen),
store_kv_cache=True,
)
ys = [PipeCacheData() for _ in range(self.num_layers)]
ys[0].packed_input_ids = packed_input_ids
for layer_name, layer, y in zip(self.layer_names, self.layers, ys):
st = time.monotonic_ns()
x: PipeTransferData = layer.module(x, y)
x.pp_input = x.pp_output
torch.cuda.synchronize()
self.insert_data_point(
layer_name, "fwd_gen_0", bs, seq_len, time.monotonic_ns() - st
)
prompt_logits = x.pp_output
logits = prompt_logits[cu_seqlens[1:] - 1]
input_lens = cu_seqlens[1:] - cu_seqlens[:-1]
cache_seqlens = input_lens.clone().to(dtype=torch.int32)
layer_indices = range(len(ys))
for y, layer_idx in zip(ys[1:-1], layer_indices[1:-1]):
assert (
y.k_cache is not None
and y.v_cache is not None
and y.cache_seqlens is not None
)
kvcache_seqlen = max(
max_seqlen + self.max_new_tokens,
self.hidden_dim // self.head_dim + 10,
)
# fix of a flash attention bug
k_cache = torch.zeros(
(bs, kvcache_seqlen, *y.k_cache.shape[1:]),
dtype=y.k_cache.dtype,
device=self.device,
)
v_cache = torch.zeros_like(k_cache)
indices = (
torch.arange(
kvcache_seqlen,
device=constants.current_device(),
dtype=torch.long,
)[None, :]
< input_lens[:, None]
)
k_cache[indices] = y.k_cache
v_cache[indices] = y.v_cache
y.k_cache = k_cache
y.v_cache = v_cache
y.cache_seqlens = cache_seqlens
x = PipeTransferData(store_kv_cache=True)
ys[0].cache_seqlens = cache_seqlens
# fwd_gen_1
new_tokens = torch.randint(
0,
self.config.vocab_size,
(bs,),
dtype=torch.long,
device=self.device,
)
ys[0].packed_input_ids = new_tokens
ys[0].packed_position_ids = None
x.cu_seqlens = torch.arange(bs + 1, dtype=torch.int32, device=self.device)
x.max_seqlen = 1
for layer_name, layer, y in zip(self.layer_names, self.layers, ys):
st = time.monotonic_ns()
x = layer.module(x, y)
x.pp_input = x.pp_output
torch.cuda.synchronize()
self.insert_data_point(
layer_name, "fwd_gen_1", bs, seq_len, time.monotonic_ns() - st
)
def fwd_bwd_opt(self, bs, seq_len):
from realhf.impl.model.nn.real_llm_base import PipeCacheData, PipeTransferData
input_ids = torch.randint(
0,
self.config.vocab_size,
(bs, seq_len),
dtype=torch.long,
device=self.device,
)
attention_mask = torch.ones_like(input_ids, device=self.device)
packed_input_ids, _, cu_seqlens, max_seqlen = unpad_input(
input_ids, attention_mask
)
cu_seqlens = cu_seqlens.to(device=self.device)
packed_input_ids = packed_input_ids.to(device=self.device)
x = PipeTransferData(
cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, store_kv_cache=False
)
ys = [PipeCacheData() for _ in range(self.num_layers)]
ys[0].packed_input_ids = packed_input_ids
for layer_name, layer, y in zip(self.layer_names, self.layers, ys):
# fwd
st = time.monotonic_ns()
x: PipeTransferData = layer.module(x, y)
torch.cuda.synchronize()
self.insert_data_point(
layer_name, "fwd", bs, seq_len, time.monotonic_ns() - st
)
# bwd
r = torch.rand(
*x.pp_output.shape,
device=x.pp_output.device,
dtype=x.pp_output.dtype,
)
loss = torch.max(x.pp_output * r)
st = time.monotonic_ns()
layer.module.backward(loss)
torch.cuda.synchronize()
self.insert_data_point(
layer_name, "bwd", bs, seq_len, time.monotonic_ns() - st
)
# opt
st = time.monotonic_ns()
layer.module.step()
torch.cuda.synchronize()
self.insert_data_point(
layer_name, "opt", bs, seq_len, time.monotonic_ns() - st
)
x.pp_input = x.pp_output.clone().detach()
def make_dataframe_and_print(self):
df = pd.DataFrame(self.stats)
logger.info(f"Current Stats: \nstr{df}")
def dump_stats(self, world_size):
rank = dist.get_rank()
# dump full stats
dump_dir = os.path.join(
constants.PROFILER_CACHE_PATH,
"layer_stats",
)
dump_path = os.path.join(
dump_dir, self.model_name, f"layer-stats_{world_size}_{rank}.pkl"
)
os.makedirs(os.path.dirname(dump_path), exist_ok=True)
with open(dump_path, "wb") as f:
df = pd.DataFrame(self.stats)
pickle.dump(df, f)
def make_profile_layers(
device: torch.device,
model_path: str,
model_name: str,
dtype: Optional[str] = None,
hf_model_type: str = "llama",
):
from realhf.impl.model.nn.real_llm_api import ReaLModel
if dtype == "fp16" or dtype == None:
dtype = torch.float16
elif dtype == "bf16":
dtype = torch.bfloat16
elif dtype == "fp32":
dtype == torch.float32
else:
raise NotImplementedError(f"Unsupported dtype {dtype}")
tokenizer = None
config: ReaLModelConfig = getattr(ReaLModel, f"config_from_{hf_model_type}")(
model_path=model_path,
)
if tokenizer is None:
tokenizer = model_api.load_hf_tokenizer(model_path)
profile_layers = ProfileLayers(
model_name, config, tokenizer=tokenizer, dtype=dtype, device=device
)
return profile_layers

View File

@ -1,272 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import dataclasses
import itertools
import json
import time
from collections import defaultdict
from typing import *
import torch
import torch.distributed
import realhf.api.core.system_api as system_api
import realhf.base.constants as constants
import realhf.base.topology as topology
from realhf.api.core.config import ModelFamily, ModelName
from realhf.api.core.model_api import ReaLModelConfig
from realhf.base.topology import decompose_to_three_factors
def bcast_cost(
param_size: float, bw: float, src: int, dsts: List[int], n_nodes_per_gpu=8
):
src_node = src // n_nodes_per_gpu
dst_nodes = [dst // n_nodes_per_gpu for dst in dsts]
if src_node == dst_nodes[0] and all(
dst_node == dst_nodes[0] for dst_node in dst_nodes
):
return param_size * 2 * 8 / (1800 * 1024**3)
# return 0.0
else:
# param size is in float16, bw is in Gbps
return param_size * 2 * 8 / (bw * 1024**3) * len(set(dst_nodes))
def compute_cost(
world_size: int,
from_model_name: ModelName,
to_model_name: ModelName,
from_topo: topology.ProcessTopology,
to_topo: topology.ProcessTopology,
model_config: ReaLModelConfig,
bw: float, # Gbps
set_interval_cost: float,
) -> int:
from realhf.impl.model.comm.param_realloc import (
ParamReallocInfo,
ReparallelizeReceiverStep,
ReparallelizeSenderStep,
_create_param_realloc_groups,
_derive_reparallelize_comm_plan,
)
param_sync_groups = {}
param_sync_src_ranks = {}
param_sync_dst_ranks = {}
msid2mwid = {}
for i in range(from_topo.world_size()):
msid2mwid[
system_api.ModelShardID.from_parallelism_rank(from_model_name, from_topo, i)
] = i
for i in range(to_topo.world_size()):
msid2mwid[
system_api.ModelShardID.from_parallelism_rank(to_model_name, to_topo, i)
] = (i + world_size - to_topo.world_size())
_create_param_realloc_groups(
from_topo,
to_topo,
from_model_name,
to_model_name,
msid2mwid,
param_sync_groups,
param_sync_src_ranks,
param_sync_dst_ranks,
)
pg_info = ParamReallocInfo(
param_sync_groups,
param_sync_src_ranks,
param_sync_dst_ranks,
)
comm_plan = _derive_reparallelize_comm_plan(
from_model_name,
to_model_name,
from_topo,
to_topo,
model_config,
model_config,
pg_info,
)
# Run boradcast!
max_cost = max_comm_volume = max_bcast_cnt = 0
for _rank in range(world_size):
cost = comm_volume = bcast_cnt = 0
for step in comm_plan:
if isinstance(step, ReparallelizeReceiverStep) and step.rank == _rank:
if step.rank != step.src:
cost += bcast_cost(step.param_size, bw, step.src, step.dst_ranks)
comm_volume += step.param_size
bcast_cnt += 1
cost += set_interval_cost
if isinstance(step, ReparallelizeSenderStep) and step.rank == _rank:
if step.group is not None:
cost += bcast_cost(step.param_size, bw, step.rank, step.dst_ranks)
bcast_cnt += 1
max_cost = max(max_cost, cost)
max_comm_volume = max(max_comm_volume, comm_volume)
max_bcast_cnt = max(max_bcast_cnt, bcast_cnt)
return max_cost
def dump_table(
n_nodes: int,
model_family: ModelFamily,
model_path: str,
rank: int = 0,
parallel: int = 1,
):
from_model_name = ModelName("actor", 0)
to_model_name = ModelName("actor", 1)
def hash_tuple_into_str(t) -> str:
return ",".join([str(i) for i in t])
import tqdm
res = {}
device_mesh_sizes = [4] + [8 * i for i in range(1, n_nodes + 1)]
space = list(itertools.product(device_mesh_sizes, device_mesh_sizes))
sub_space = space[rank::parallel]
# for a, b in set(small_spaces + large_spaces):
for a, b in sub_space:
mtik = time.perf_counter()
all_configs = list(
itertools.product(
decompose_to_three_factors(a), decompose_to_three_factors(b)
)
)
all_configs = list(filter(lambda x: x[0][1] <= 8 and x[1][1] <= 8, all_configs))
all_configs = list(filter(lambda x: x[0][2] <= 8 and x[1][2] <= 8, all_configs))
all_configs = list(
filter(
lambda x: x[0][1] in [1, 2, 4, 8] and x[1][1] in [1, 2, 4, 8],
all_configs,
)
)
all_configs = list(
filter(lambda x: x[0][0] <= 16 and x[1][0] <= 16, all_configs)
)
all_configs = list(
filter(
lambda x: x[0][1] % x[1][1] == 0 or x[1][1] % x[0][1] == 0,
all_configs,
)
)
for config_id, (from_pp_mp_dp, to_pp_mp_dp) in tqdm.tqdm(
enumerate(all_configs)
):
world_size = max(a, b)
from_topo = topology.PipeDataModelParallelTopology(
*from_pp_mp_dp, False, False
)
to_topo = topology.PipeDataModelParallelTopology(*to_pp_mp_dp, False, False)
assert world_size >= from_topo.world_size()
assert world_size >= to_topo.world_size()
from realhf.search_engine.utils import load_model_config
mconfig = load_model_config(model_family._class, model_path)
cost = compute_cost(
world_size,
from_model_name,
to_model_name,
from_topo,
to_topo,
mconfig,
bw=200.0,
set_interval_cost=0.03,
)
res[
hash_tuple_into_str((model_family.size, *from_pp_mp_dp, *to_pp_mp_dp))
] = int(cost * 1000 * 1000)
print(
f"Time for model size {model_family.size} {a} -> {b} {rank}/{parallel}: "
f"{time.perf_counter() - mtik:.4f}, num res entries {len(res)}"
)
print(f"Rank {rank} of model {model_family} finished, res size {len(res)}.")
import os
import pickle
dump_path = os.path.join(constants.PROFILER_CACHE_PATH, "param_realloc")
fn = f"prtc_{model_family}_n{n_nodes}_{rank}_{parallel}.pkl"
if not os.path.exists(dump_path):
os.makedirs(dump_path, exist_ok=True)
with open(os.path.join(dump_path, fn), "wb") as f:
pickle.dump(res, f)
print(f"dumped table with {len(res)} entries to {model_family}-{rank}-{parallel}.")
def dump_table_parallel(
n_nodes: int,
model_family_to_path: Dict[ModelFamily, str],
parallel: int = 4,
):
import torch.multiprocessing as mp
mp.set_start_method("spawn", force=True)
rq = mp.Queue()
ps = []
for model_family, model_path in model_family_to_path.items():
for rank in range(parallel):
ps.append(
mp.Process(
target=dump_table,
args=(n_nodes, model_family, model_path, rank, parallel),
)
)
for p in ps:
p.start()
for p in ps:
p.join()
def merge_tables(
n_nodes: int,
model_family_to_path: Dict[ModelFamily, str],
parallel: int = 4,
):
import os
import pickle
res_path = os.path.join(constants.PROFILER_CACHE_PATH, "param_realloc")
for model_family in model_family_to_path.keys():
prefix = f"prtc_{model_family}_n{n_nodes}"
r = {}
counter = 0
for fn in os.listdir(res_path):
if fn.endswith(".pkl") and fn.startswith(prefix):
counter += 1
path = os.path.join(res_path, fn)
with open(path, "rb") as f:
r.update(pickle.load(f))
os.remove(path)
if counter < parallel:
raise RuntimeError(
"missing sub-tables, probably some sub-processes failed "
"during param realloc time cost estimation."
)
with open(os.path.join(res_path, f"{prefix}.pkl"), "wb") as f:
pickle.dump(r, f)
print(f"merged parallel tables into {prefix}.pkl, total entries {len(r)}")
def estimate_param_realloc_time_cost(
n_nodes: int,
model_family_to_path: Dict[ModelFamily, str],
parallel: int = 4,
):
dump_table_parallel(n_nodes, model_family_to_path, parallel)
merge_tables(n_nodes, model_family_to_path, parallel)

View File

@ -1,232 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
import argparse
import functools
import json
import os
import pickle
import pprint
import re
from typing import Any, Dict, List, Literal, Optional
import numpy as np
try:
import realhf._C.mdm_search as mdm_search
except ModuleNotFoundError:
mdm_search = None
import realhf.base.constants as constants
from realhf.api.cli_args import ModelTrainEvalConfig, ParallelismConfig
from realhf.api.core.config import ModelInterfaceType
from realhf.api.core.dfg import MFCDef
from realhf.api.quickstart.device_mesh import DeviceMesh, RPCAllocation
from realhf.api.quickstart.search import RPCExecution
def search_rpc_allocations(
device_mesh: DeviceMesh,
rpcs: List[MFCDef],
num_gen_tokens: int = 256,
n_ppo_minibatches: int = 1,
seq_len: int = 256,
gradient_checkpointing: bool = True,
use_cache: bool = False,
) -> List[RPCAllocation]:
from realhf.search_engine.enumerate import build_graph
from realhf.search_engine.estimate import get_param_realloc_stats
from_file = os.environ.get("REAL_IS_REMOTE", "0") == "1"
dump_dir = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"device_mapping.pkl",
)
log_dir = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"device_mapping",
)
rs_dir = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"raw_search_result",
)
rpc_exe_dir = os.path.join(
constants.LOG_ROOT,
constants.experiment_name(),
constants.trial_name(),
"rpc_exe_info",
)
if from_file or (use_cache and os.path.exists(dump_dir)):
with open(dump_dir, "r") as f:
s = json.load(f)
rpc_allocs = [RPCAllocation.from_dict(d) for d in s]
return rpc_allocs
else:
os.makedirs(os.path.dirname(dump_dir), exist_ok=True)
n_nodes = device_mesh.n_nodes
table = {}
for rpc in rpcs:
print(f"Getting param realloc stats for {rpc.model_type} at {rpc.model_path}")
t = get_param_realloc_stats(rpc.model_type, rpc.model_path, n_nodes, True)
table.update(t)
rpc_exe_list = make_rpc_exe_list(
rpcs,
device_mesh,
num_gen_tokens=num_gen_tokens,
n_ppo_minibatches=n_ppo_minibatches,
seq_len=seq_len,
gradient_checkpointing=gradient_checkpointing,
log_dir=rpc_exe_dir,
if_print=False,
)
graph = build_graph(rpcs, 5, 1, if_print=False)
model_size_dict = make_model_size_dict(rpcs, if_print=False)
n_nodes = device_mesh.n_nodes
search_time = 120
rs: List[Dict[str, List]] = mdm_search.multi_mcmc_search(
rpcs,
rpc_exe_list,
graph,
table,
model_size_dict,
0.001, # beta min
0.002, # beta max
0.001, # beta step
search_time, # time limit for each search
1, # repeat
)
if not from_file:
with open(rs_dir, "w") as f:
pprint.pprint(rs, stream=f)
r: Dict[str, Dict[str, Any]] = rs[-1]
pprint.pprint(r)
rpc_name_to_rpcs = {rpc.name: rpc for rpc in rpcs}
rpc_allocs = []
for rpc_name, alloc_info in r.items():
if rpc_name in ["end_time", "mem_cost"]:
continue
# rpc = rpc_dict[rpc_name]
rpc = rpc_name_to_rpcs[rpc_name]
parallel = ParallelismConfig(
pipeline_parallel_size=alloc_info["num_pp"],
data_parallel_size=alloc_info["num_dp"],
model_parallel_size=alloc_info["num_mp"],
use_sequence_parallel=(
alloc_info["num_mp"] > 1
and rpc.interface_type == ModelInterfaceType.TRAIN_STEP
),
)
sub_device_mesh = DeviceMesh(
n_nodes=device_mesh.n_nodes,
n_gpus_per_node=device_mesh.n_gpus_per_node,
mapping=alloc_info["device_mesh_mapping"],
name=alloc_info["device_mesh_name"],
global_mesh_name=device_mesh.global_mesh_name,
)
rpc_alloc = RPCAllocation(
rpc=rpc,
device_mesh=sub_device_mesh,
parallel=parallel,
)
rpc_allocs.append(rpc_alloc)
if not from_file:
with open(dump_dir, "w") as f:
json.dump([rpc_alloc.to_dict() for rpc_alloc in rpc_allocs], f, indent=4)
with open(log_dir, "w") as f:
pprint.pprint(rpc_allocs, stream=f)
return rpc_allocs
def make_rpc_exe_list(
rpcs: List[MFCDef],
device_mesh: DeviceMesh,
num_gen_tokens: int,
n_ppo_minibatches: int,
seq_len: int,
gradient_checkpointing: bool,
if_print: bool = False,
log_dir: Optional[str] = None,
) -> List[RPCExecution]:
from realhf.search_engine.enumerate import enumerate_rpc_executions
rpc_exe_list = []
log_flag = False
for rpc in rpcs:
# real_model_config = load_model_config(rpc)
feasible = enumerate_rpc_executions(
rpc,
device_mesh,
seq_len=seq_len,
num_gen_tokens=num_gen_tokens,
n_ppo_minibatches=n_ppo_minibatches,
gradient_checkpointing=gradient_checkpointing,
)
rpc_exe_list.extend(feasible)
if log_dir is not None:
mode = "w" if not log_flag else "a"
with open(log_dir, mode) as f:
f.write(f"{rpc.name} feasible: {len(feasible)}\n")
feasible.sort(key=lambda x: x.time_cost)
# feasible = feasible[:30]
for i, rpc_exe in enumerate(feasible):
f.write(
f"{i}: time_cost: {rpc_exe.time_cost} ms, {rpc_exe.time_cost} "
f"sub_device_mesh: {rpc_exe.device_mesh}, "
f"parallel_strategy: {rpc_exe.parallel_strategy}, "
f"mem_cost: {rpc_exe.mem/(1024*1024*1024):02f} GB, "
f"static_mem_cost: {rpc_exe.static_mem/(1024*1024*1024):02f} GB\n"
)
f.write("\n")
log_flag = True
if if_print:
print(f"{rpc.name} feasible: {len(feasible)}")
feasible.sort(key=lambda x: x.time_cost)
# feasible = feasible[:10]
for i, rpc_exe in enumerate(feasible):
print(
f"{i}: time_cost: {rpc_exe.time_cost} ms, "
f"sub_device_mesh: {rpc_exe.device_mesh}, "
f"parallel_strategy: {rpc_exe.parallel_strategy}, "
f"mem_cost: {rpc_exe.mem/(1024*1024*1024):02f} GB, "
f"static_mem_cost: {rpc_exe.static_mem/(1024*1024*1024):02f} GB"
)
return rpc_exe_list
def make_model_size_dict(rpcs: List[MFCDef], if_print: bool = False) -> Dict[str, int]:
model_size_dict = {}
for rpc in rpcs:
if rpc.model_name.role in model_size_dict:
continue
# model_configs = load_model_config(rpc)
# model_size_dict[rpc.model_name.role] = estimate_model_size(real_model_config)
model_size_dict[rpc.model_name.role] = rpc.model_type.size
if if_print:
print(
f"model_name: {rpc.model_name.role}, "
f"model_size: {rpc.model_type.size}"
)
return model_size_dict

View File

@ -1,28 +0,0 @@
# Copyright 2025 Ant Group Inc.
# Copyright 2024 Wei Fu & Zhiyu Mei
# Licensed under the Apache License, Version 2.0 (the "License").
from realhf.api.core.model_api import ReaLModelConfig
def find_factors(n):
factors = []
for i in range(1, n + 1):
if n % i == 0:
factors.append(i)
return factors
def make_stats_key(rpc_name, bs, seq_len):
return f"{rpc_name}|{bs}|{seq_len}"
def parse_stats_key(key):
rpc_name, bs, seq_len = key.split("|")
return rpc_name, int(bs), int(seq_len)
def load_model_config(model_class: str, model_path: str) -> ReaLModelConfig:
from realhf.impl.model.nn.real_llm_api import ReaLModel
return getattr(ReaLModel, f"config_from_{model_class}")(model_path=model_path)

View File

@ -478,6 +478,8 @@ def run_ray_worker(
# NOTE: Importing these will initialize DeepSpeed/CUDA devices.
# profiler.import_profiler_registers()
if worker_type != "master_worker":
# For master_worker, there could be errors while importing and it is not necessary.
import realhf.impl.dataset
import realhf.impl.model
import realhf.system

View File

@ -24,6 +24,17 @@ SCATTER_GROUPS = {}
logger = logging.getLogger("data_manager", "system")
def find_minimal_superset(A: List[Set[int]], B: Set[int]) -> Set[int] | None:
min_size = float("inf")
result = None
for S in A:
if B.issubset(S):
if len(S) < min_size:
min_size = len(S)
result = S
return result
class DataManager:
def __init__(
@ -52,7 +63,7 @@ class DataManager:
mw_ranks: Dict[ModelName, List[int]] = {}
# Stores the dp_head (i.e., mp_rank=0, pp_rank=-1) ranks given a model_name.
# Stores the dp_head (i.e., tp_rank=0, pp_rank=-1) ranks given a model_name.
mw_dp_head_ranks: Dict[ModelName, List[int]] = defaultdict(list)
assert msid2mwid is not None
@ -67,7 +78,7 @@ class DataManager:
topo,
msid2mwid,
pipe=topo.get_dim("pipe") - 1,
model=0,
tensor=0,
)
dp_size = topo.get_dim("data")
for dp_i in range(dp_size):
@ -87,7 +98,8 @@ class DataManager:
list(ranks), backend="nccl" if constants.use_cuda() else "gloo"
)
scatter_ranks = tuple(sorted(set([ranks[0]] + mw_ranks[dst])))
for rank in ranks:
scatter_ranks = tuple(sorted(set([rank] + mw_ranks[dst])))
SCATTER_GROUPS[scatter_ranks] = new_or_get_group(
list(scatter_ranks),
backend="nccl" if constants.use_cuda() else "gloo",
@ -228,7 +240,20 @@ class DataManager:
def _run_gather(
self, step: RedistribStep, data_infos: Dict[Hashable, SequenceSample]
):
if dist.get_rank() not in step.srcs:
# It's possible that some DP rank is not involved.
# Create dummpy data to make the gather happy.
gather_ranks = find_minimal_superset(
[set(k) for k in GATHER_GROUPS.keys()], set(step.srcs)
)
assert gather_ranks is not None, (
set(step.srcs),
[set(k) for k in GATHER_GROUPS.keys()],
)
gather_ranks = sorted(list(gather_ranks))
pgroup = GATHER_GROUPS[tuple(gather_ranks)]
if dist.get_rank() not in gather_ranks:
return
maxlen = 0
@ -249,11 +274,13 @@ class DataManager:
torch.empty(
maxlen, device=constants.current_device(), dtype=torch.float32
)
for _ in range(len(step.srcs))
for _ in gather_ranks
]
is_valid_gather = [i in step.srcs for i in gather_ranks]
else:
gather_list = None
if dist.get_rank() in step.srcs:
local_gather_idx = step.srcs.index(dist.get_rank())
ids = step.ids[local_gather_idx]
for i in ids:
@ -267,18 +294,27 @@ class DataManager:
]
)
data = self._pad_data(data, maxlen)
else:
data = torch.empty(
maxlen, device=constants.current_device(), dtype=torch.float32
)
dist.gather(
data,
gather_list,
dst=step.root,
group=GATHER_GROUPS[tuple(sorted(step.srcs))],
group=pgroup,
)
if dist.get_rank() != step.root:
del data
return
for ids, buf in zip(step.ids, gather_list):
cnt = 0
for is_valid, buf in zip(is_valid_gather, gather_list):
if not is_valid:
continue
ids = step.ids[cnt]
offset = 0
for i in ids:
for key in step.keys:
@ -302,6 +338,9 @@ class DataManager:
self.storage[i].update_(s)
else:
self.storage[i] = s
cnt += 1
assert cnt == len(step.srcs) == len(step.ids)
del data
def _run_scatter(
self, step: RedistribStep, data_infos: Dict[Hashable, SequenceSample]

View File

@ -27,7 +27,7 @@ class FunctionExecutor:
rpcs: List[MFCDef],
msid2mwid: Dict[ModelShardID, int],
stream: NameResolvingRequestClient,
buffer: AsyncIOSequenceBuffer,
buffers: List[AsyncIOSequenceBuffer],
model_topos: Dict[str, ProcessTopology],
model_configs: Dict[str, None | ReaLModelConfig],
ctrl: RPCCorountineControl,
@ -58,14 +58,15 @@ class FunctionExecutor:
model_topos=model_topos,
model_configs=model_configs,
ctrl=ctrl,
buffer=buffer,
buffers=buffers,
redistrib_planner=self.redistrib_planner,
summary_writer=summary_writer,
)
self.func_calls[rpc.name] = func_call
self.stream = stream
self.buffer = buffer
self.buffers = buffers
self.buffer_id = 0
self.data_loading_dp_idx = -1
self.shuffle_dataset = shuffle_dataset
@ -111,18 +112,17 @@ class FunctionExecutor:
self.ctrl.ids_to_clear.clear()
async def load_data(self):
buffer = self.buffer
async def load_data(self, buffer_id: int):
buffer = self.buffers[buffer_id]
ctrl = self.ctrl
received_ids = set()
while self.buffer.size < max(rpc.n_seqs for rpc in self.rpcs):
while buffer.size < max(rpc.n_seqs for rpc in self.rpcs):
resps = await self.stream.call_async(
handlers=[f"__data{dp_idx}__" for dp_idx in range(self.src_dp_size)],
handle_type="fetch",
datas=[None for _ in range(self.src_dp_size)],
datas=[buffer_id for _ in range(self.src_dp_size)],
verbose=False,
)
@ -182,10 +182,13 @@ class FunctionExecutor:
logger.info("Waiting for the finish of the execution graph.")
loop = asyncio.get_event_loop()
tasks = [loop.create_task(fc.run()) for fc in self.func_calls.values()] + [
tasks = [
loop.create_task(fc.run(self.buffer_id)) for fc in self.func_calls.values()
] + [
loop.create_task(self.flush_calls()),
loop.create_task(self.load_data()),
loop.create_task(self.load_data(self.buffer_id)),
loop.create_task(self.finish_traverse()),
]
loop.run_until_complete(asyncio.gather(*tasks))
self.buffer_id = (self.buffer_id + 1) % len(self.buffers)

View File

@ -1,14 +1,111 @@
import os
import subprocess
import sys
import time
from pathlib import Path
import requests
from realhf.api.cli_args import SGLangConfig
from realhf.api.core.system_api import GenerationServer as GenerationServerConfig
from realhf.base import gpu_utils, logging, name_resolve, names, network, seeding
from realhf.base import (
gpu_utils,
logging,
name_resolve,
names,
network,
pkg_version,
seeding,
)
from realhf.base.cluster import spec as cluster_spec
from realhf.system.worker_base import PollResult, Worker
logger = logging.getLogger(__name__)
def execute_shell_command(command: str) -> subprocess.Popen:
"""
Execute a shell command and return its process handle.
"""
# Replace newline continuations and split the command string.
command = command.replace("\\\n", " ").replace("\\", " ")
parts = command.split()
return subprocess.Popen(parts, text=True, stderr=subprocess.STDOUT)
def launch_server_cmd(command: str, port: int = 30000):
"""
Launch the server using the given command.
If no port is specified, a free port is reserved.
"""
p = Path(os.path.dirname(__file__))
patch_path = str(
p.parent.parent
/ "patch"
/ "sglang"
/ f"v{pkg_version.get_version('sglang')}.patch"
)
target_path = ""
sglang_meta = subprocess.check_output(
"python3 -m pip show sglang", shell=True
).decode("ascii")
for line in sglang_meta.split("\n"):
line = line.strip()
if line.startswith("Editable project location: "):
target_path = str(Path(line.split(": ")[1]).parent)
if target_path:
proc = subprocess.Popen(
["git", "apply", patch_path],
cwd=target_path,
stderr=sys.stdout,
stdout=sys.stdout,
)
proc.wait()
logger.info(f"Applied SGLang patch at {target_path}")
assert port is not None
full_command = f"{command} --port {port}"
process = execute_shell_command(full_command)
return process, port
def terminate_process(process, port=None):
"""
Terminate the process and, if a port was reserved, release it.
"""
from sglang.srt.utils import kill_process_tree
kill_process_tree(process.pid)
def wait_for_server(base_url: str, timeout: int = None) -> None:
"""Wait for the server to be ready by polling the /v1/models endpoint.
Args:
base_url: The base URL of the server
timeout: Maximum time to wait in seconds. None means wait forever.
"""
start_time = time.time()
while True:
try:
response = requests.get(
f"{base_url}/v1/models",
headers={"Authorization": "Bearer None"},
)
if response.status_code == 200:
time.sleep(5)
break
if timeout and time.time() - start_time > timeout:
raise TimeoutError("Server did not become ready within timeout period")
except requests.exceptions.RequestException:
time.sleep(1)
PORT_CLEARANCE_PERIOD = 90
class GenerationServer(Worker):
def _configure(self, config: GenerationServerConfig):
self.config = config
@ -36,20 +133,37 @@ class GenerationServer(Worker):
config = self.config
assert config.backend_type == "sglang"
host_ip = network.gethostip()
host = "localhost" if not config.backend_args.enable_metrics else host_ip
# NOTE: Ports returned by `find_multiple_free_ports` are unique,
# but SGLang servers still encounter conflicts.
# Use a clearance period to hack over this issue.
servers_per_node = cluster_spec.n_gpus_per_node // self.config.tp_size
idx_on_this_node = self.worker_index % servers_per_node
time.sleep(idx_on_this_node * PORT_CLEARANCE_PERIOD / servers_per_node)
ports = network.find_multiple_free_ports(
2,
low=10000,
high=60000,
experiment_name=self.experiment_name,
trial_name=self.trial_name,
)
server_port = ports[0]
nccl_port = ports[1]
cmd = SGLangConfig.build_cmd(
config.backend_args,
config.model_path,
tp_size=config.tp_size,
server_index=self.worker_index,
base_gpu_id=self.base_gpu_id,
dist_init_addr=f"{host}:{nccl_port}",
)
from sglang.utils import launch_server_cmd, wait_for_server
host_ip = network.gethostip()
host = "localhost" if not config.backend_args.enable_metrics else host_ip
# TODO: handle launching error and retry
self.server_process, self.server_port = launch_server_cmd(cmd)
self.server_process, self.server_port = launch_server_cmd(cmd, port=server_port)
self.server_addr = f"http://{host}:{self.server_port}"
wait_for_server(self.server_addr)
@ -80,6 +194,5 @@ class GenerationServer(Worker):
def _exit_hook(self, exit_status):
if self.server_process is not None and self.config.backend_type == "sglang":
from sglang.utils import terminate_process
terminate_process(self.server_process)

View File

@ -6,20 +6,46 @@ import shutil
import threading
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import List
import aiohttp
import numpy as np
from realhf.api.core.model_api import GenReqMeta, GenRespMeta, ModelVersionReq
from realhf.api.core.system_api import GserverManager as GserverManagerConfig
from realhf.base import constants, logging, name_resolve, names, network, recover
from realhf.system.worker_base import AsyncWorker, PollResult, Worker
from realhf.system.worker_base import PollResult, Worker
logger = logging.getLogger("Generation Manager", "colored")
logger = logging.getLogger("Generation Manager", "system")
STALENESS_WARNED = defaultdict(lambda: False)
@dataclass
class RolloutStat:
submit: int = 0
accepted: int = 0
running: int = 0
def inc(self):
self.submit += 1
self.accepted += 1
self.running += 1
def accept(self):
self.running -= 1
def reject(self):
self.running -= 1
self.accepted -= 1
@dataclass
class AllocateRolloutInput:
qid: str
class GserverManager(Worker):
"""This worker has the following functionalities:
1. As a router, it schedules generation requests and returns the
@ -37,23 +63,30 @@ class GserverManager(Worker):
assert self.config.worker_info.worker_count == 1
self.async_lock = asyncio.Lock()
self.threading_lock = threading.Lock()
self.n_total_rollouts = 0
self.n_running_rollouts = 0
self.accepted_rollouts = 0
self.rollout_stat = RolloutStat()
self.schedule_policy = config.schedule_policy
self._last_param_realloc_step = 0
self._qid_to_server_url = {}
self._server_token_usage = defaultdict(float)
self._server_request_counts = defaultdict(int)
self._last_thpt_output_time = time.time()
self._gen_tokens = 0
self.experiment_name = config.worker_info.experiment_name
self.trial_name = config.worker_info.trial_name
# manager server
self.server = None
self.manager_http_server = None
self.thread = None
self.server_urls = []
# recover info
self.__recover_run, self.__recover_info = recover.load_recover_info()
if self.__recover_run:
@ -67,10 +100,12 @@ class GserverManager(Worker):
name_resolve.add(name, self.__recover_info.last_step_info.global_step)
self._loaded_recover_weights = False
self.n_total_rollouts = self.accepted_rollouts = (
hist_rollouts = (
self.config.train_batch_size
* self.__recover_info.last_step_info.global_step
)
self.rollout_stat.submit = hist_rollouts
self.rollout_stat.accepted = hist_rollouts
return config.worker_info
@ -84,7 +119,7 @@ class GserverManager(Worker):
if cnt >= timeout:
raise TimeoutError("Waiting generation servers timeout.")
urls = name_resolve.get_subtree(name)
assert len(set(urls)) == len(urls), urls
assert len(set(urls)) == len(urls), (len(urls), len(set(urls)), urls)
return urls
def _get_recover_ckpt_path(self, role: str):
@ -146,42 +181,27 @@ class GserverManager(Worker):
async def flush_requests_and_update_weights(
self, server_url, new_param_path, update_weights_retries=5
):
# HACK: urls are designed for SGLang
server_index = self.server_urls.index(server_url)
async with aiohttp.ClientSession(server_url) as session:
running_requests = None
tik = time.perf_counter()
while running_requests is None or running_requests > 0:
if time.perf_counter() - tik > self.config.flush_request_timeout:
raise RuntimeError(
f"Waiting for flush requests failed. {running_requests} requests "
f"remain after {self.config.flush_request_timeout} secs waiting. "
f"Please try to reduce `new_tokens_per_chunk`."
)
if running_requests is not None and running_requests > 0:
logger.info(
f"Waiting for {running_requests} requests on gen server {server_index}... "
f"Time taken so far: {time.perf_counter() - tik:.4f}s"
)
await asyncio.sleep(0.5)
async with session.get(f"/metrics") as resp:
resp.raise_for_status()
text = await resp.text()
for line in text.split("\n"):
if line.startswith("sglang:num_running_reqs"):
running_requests = float(line.split(" ")[1])
break
success = False
for _ in range(update_weights_retries):
async with aiohttp.ClientSession(
server_url,
timeout=aiohttp.ClientTimeout(
total=self.config.flush_request_timeout, sock_connect=30
),
) as session:
async with session.post(
f"/update_weights_from_disk",
json=dict(model_path=new_param_path),
json=dict(model_path=new_param_path, allow_interrupt=True),
) as resp:
if resp.status == 200:
res = await resp.json()
success = res["success"]
if success:
logger.info(
f"{res['num_paused_requests']} requests are interrupted "
f"during updateing weights for server {server_index}: {server_url}"
)
return
logger.warning(
f"Update weights failed: {res['message']}. Retrying."
@ -198,6 +218,16 @@ class GserverManager(Worker):
self.round_robin_idx %= self.config.n_servers
return r
def _least_requests_schedule(self, req_meta: GenReqMeta) -> int:
counts = [
self._server_request_counts[server_url] for server_url in self.server_urls
]
return int(np.argmin(counts))
def _least_token_usage_schedule(self, req_meta: GenReqMeta) -> int:
url = min(self.server_urls, key=lambda k: self._server_token_usage[k])
return self.server_urls.index(url)
def _poll(self):
if not self.thread:
# Find addresses of generation servers
@ -228,6 +258,23 @@ class GserverManager(Worker):
loop.run_until_complete(asyncio.gather(*tasks))
logger.info(f"Generaion server updated weights from: {new_param_path}")
tasks = [
self._get_server_token_usage(server_url) for server_url in self.server_urls
]
loop = asyncio.get_event_loop()
token_usages = loop.run_until_complete(asyncio.gather(*tasks))
with self.threading_lock:
for server_url, token_usage in zip(self.server_urls, token_usages):
self._server_token_usage[server_url] = token_usage
if time.time() - self._last_thpt_output_time > 30:
interval = time.time() - self._last_thpt_output_time
logger.info(
f"Generation throughput: {self._gen_tokens / interval:.2f} tokens/s"
)
self._last_thpt_output_time = time.time()
self._gen_tokens = 0
# clear old weights
realloc_root = os.path.join(
constants.PARAM_REALLOC_PATH,
@ -237,9 +284,11 @@ class GserverManager(Worker):
)
if os.path.exists(realloc_root):
for realloc_version in os.listdir(realloc_root):
# Lock-free is safe here.
# Remain one checkpoint for recover.
if (
os.path.isdir(os.path.join(realloc_root, realloc_version))
and int(realloc_version) < self._last_param_realloc_step
and int(realloc_version) < self._last_param_realloc_step - 1
):
shutil.rmtree(os.path.join(realloc_root, realloc_version))
logger.info(
@ -247,29 +296,56 @@ class GserverManager(Worker):
f"checkpoint: {os.path.join(realloc_root, realloc_version)}"
)
# TODO: we may want to update server status
# in the main thread.
time.sleep(1)
time.sleep(5)
return PollResult(0, 0)
async def is_staled(self):
global_sample_cnt = self.n_total_rollouts
expected_version = global_sample_cnt // self.config.train_batch_size
staled = (
expected_version
> self.config.max_head_offpolicyness + self._last_param_realloc_step
async def _get_server_token_usage(self, server_url):
async with aiohttp.ClientSession(
server_url,
timeout=aiohttp.ClientTimeout(
total=self.config.flush_request_timeout, sock_connect=30
),
) as session:
async with session.get("/metrics") as resp:
resp.raise_for_status()
text = await resp.text()
for l in text.split("\n"):
if l.startswith("sglang:num_used_tokens"):
return float(l.split(" ")[1])
raise RuntimeError(f"Failed to get token usage metrics from {server_url}")
async def _get_server_num_running_requests(self, server_url):
async with aiohttp.ClientSession(
server_url,
timeout=aiohttp.ClientTimeout(
total=self.config.flush_request_timeout, sock_connect=30
),
) as session:
async with session.get(f"/metrics") as resp:
resp.raise_for_status()
text = await resp.text()
for line in text.split("\n"):
if line.startswith("sglang:num_running_reqs"):
return float(line.split(" ")[1])
raise RuntimeError(
f"Failed to get num running requests metrics from {server_url}"
)
def is_staled(self):
global_sample_cnt = self.rollout_stat.accepted
expected_version = global_sample_cnt // self.config.train_batch_size
version = self._last_param_realloc_step
staled = expected_version > self.config.max_head_offpolicyness + version
global STALENESS_WARNED
if staled and not STALENESS_WARNED[self._last_param_realloc_step]:
if staled and not STALENESS_WARNED[version]:
logger.warning(
f"expected version ({expected_version}) = "
f"global sample cnt ({global_sample_cnt}) // batch size ({self.config.train_batch_size}), "
f"current version {self._last_param_realloc_step}, "
f"current latest version {version}, "
f"offpolicyness {self.config.max_head_offpolicyness}. Staled? {staled}"
)
STALENESS_WARNED[self._last_param_realloc_step] = True
STALENESS_WARNED[version] = True
return staled
def _run_routing_service(self):
@ -282,46 +358,69 @@ class GserverManager(Worker):
@self.app.post("/schedule_request")
async def schedule_request(req_meta: GenReqMeta):
with self.threading_lock:
async with self.async_lock:
version = self._last_param_realloc_step
# FIXME: We only implement a round-robin scheduler that
# ignores server status and request metadata
if (
req_meta.previous_server_url
and req_meta.previous_version == self._last_param_realloc_step
):
return dict(
url=req_meta.previous_server_url,
version=req_meta.previous_version,
)
if self.schedule_policy == "round_robin":
server_idx = self._round_robin_schedule(req_meta)
return dict(url=self.server_urls[server_idx], version=max(0, version))
elif self.schedule_policy == "least_token_usage":
server_idx = self._least_token_usage_schedule(req_meta)
elif self.schedule_policy == "least_requests":
server_idx = self._least_requests_schedule(req_meta)
else:
raise NotImplementedError(
f"Unknown schedule policy {self.schedule_policy}"
)
server_url = self.server_urls[server_idx]
# qid prompt (n samples) use the same dst server
self._qid_to_server_url[req_meta.qid] = server_url
self._server_request_counts[server_url] += 1
self._server_token_usage[server_url] += (
req_meta.prompt_len
+ req_meta.new_token_budget * req_meta.group_size * 0.4
)
version = self._last_param_realloc_step
return dict(url=server_url, version=version)
@self.app.post("/get_model_version")
async def get_model_version(req: ModelVersionReq):
with self.threading_lock:
async with self.async_lock:
# FIXME: we may have different versions for different servers
version = self._last_param_realloc_step
return dict(version=version)
@self.app.get("/allocate_rollout")
async def allocate_rollout():
@self.app.post("/allocate_rollout")
async def allocate_rollout(req: AllocateRolloutInput):
with self.threading_lock:
async with self.async_lock:
has_capacity = (
self.n_running_rollouts < self.config.max_concurrent_rollouts
self.rollout_stat.running < self.config.max_concurrent_rollouts
)
is_staled = await self.is_staled()
is_staled = self.is_staled()
reason = ""
if has_capacity and not is_staled:
self.n_running_rollouts += 1
self.n_total_rollouts += 1
self.rollout_stat.inc()
return dict(success=True, reason=reason)
else:
if not has_capacity:
reason += f"capacity: {self.n_running_rollouts} >= {self.config.max_concurrent_rollouts}"
reason += f"capacity: {self.rollout_stat.running} >= {self.config.max_concurrent_rollouts}"
if is_staled:
global_sample_cnt = self.n_total_rollouts
global_sample_cnt = self.rollout_stat.accepted
expected_version = (
global_sample_cnt // self.config.train_batch_size
)
version = self._last_param_realloc_step
reason += (
f" and staled: expected version ({expected_version}) = "
f"global sample cnt ({global_sample_cnt}) // batch size ({self.config.train_batch_size}), "
f"current version {self._last_param_realloc_step}, "
f"current latest version {version}, "
f"offpolicyness {self.config.max_head_offpolicyness}."
)
return dict(success=False, reason=reason)
@ -329,13 +428,24 @@ class GserverManager(Worker):
@self.app.post("/finish_rollout")
async def finish_rollout(resp_meta: GenRespMeta):
with self.threading_lock:
async with self.async_lock:
self.n_running_rollouts -= 1
server_url = self._qid_to_server_url[resp_meta.qid]
self._server_request_counts[server_url] -= 1
assert (
self._server_request_counts[server_url] >= 0
), "server request count < 0"
self._qid_to_server_url.pop(resp_meta.qid)
self._gen_tokens += resp_meta.n_tokens
if resp_meta.accepted:
self.accepted_rollouts += 1
self.rollout_stat.accept()
else:
self.rollout_stat.reject()
return dict(success=True)
self.manager_addr = f"{network.gethostip()}:{network.find_free_port()}"
port = network.find_free_port(
experiment_name=self.experiment_name,
trial_name=self.trial_name,
)
self.manager_addr = f"{network.gethostip()}:{port}"
config = uvicorn.Config(
self.app,
@ -343,12 +453,12 @@ class GserverManager(Worker):
port=int(self.manager_addr.split(":")[1]),
log_level="warning",
)
self.server = uvicorn.Server(config)
self.server.run()
self.manager_http_server = uvicorn.Server(config)
self.manager_http_server.run()
def _exit_hook(self, exit_status):
if self.server:
self.server.should_exit = True
if self.manager_http_server:
self.manager_http_server.should_exit = True
if self.thread:
self.thread.join(timeout=3)
logger.info("Server stopped")

View File

@ -145,6 +145,7 @@ class MasterWorker(worker_base.Worker):
# for benchmark
self.e2e_time_history = []
self.__benchmark_steps = config.exp_ctrl.benchmark_steps
self.__benchmark_n_seqs = config.exp_ctrl.benchmark_n_seqs
return config.worker_info
@ -210,13 +211,14 @@ class MasterWorker(worker_base.Worker):
src_rpc_dp_size = src_rpc_topo.get_dim("data")
# Request training specification from data workers.
self._dataset_size = sum(
self.__stream.call(
specs = self.__stream.call(
handlers=[f"__data{i}__" for i in range(src_rpc_dp_size)],
datas=[None for i in range(src_rpc_dp_size)],
handle_type="spec",
),
)
assert all(x["n_datasets"] == specs[0]["n_datasets"] for x in specs), specs
self._dataset_size = sum(x["dataset_size"] for x in specs)
self._n_datasets = specs[0]["n_datasets"]
self._steps_per_epoch = self._dataset_size // src_rpc.n_seqs
@ -239,7 +241,7 @@ class MasterWorker(worker_base.Worker):
src_rpc_dp_size = src_rpc_topo.get_dim("data")
src_rpc_pp_size = src_rpc_topo.get_dim("pipe")
for i in range(src_rpc_dp_size):
rank = src_rpc_topo.get_rank(data=i, pipe=src_rpc_pp_size - 1, model=0)
rank = src_rpc_topo.get_rank(data=i, pipe=src_rpc_pp_size - 1, tensor=0)
handler_routing[f"__data{i}__"] = self.config.msid2mwid[
config_pkg.ModelShardID.from_parallelism_rank(
model_name=src_rpc.model_name,
@ -263,10 +265,13 @@ class MasterWorker(worker_base.Worker):
self.initialize_models()
self.__seqbuffer = AsyncIOSequenceBuffer(
self.__seqbuffers = [
AsyncIOSequenceBuffer(
self.__model_rpcs,
max_size=int(os.getenv("REAL_MASTER_BUFFER_SIZE", str(int(1e7)))),
)
for _ in range(self._n_datasets)
]
# wandb init, connect to remote wandb host
wandb.login()
@ -300,7 +305,7 @@ class MasterWorker(worker_base.Worker):
rpcs=self.__model_rpcs,
msid2mwid=self.config.msid2mwid,
stream=self.__stream,
buffer=self.__seqbuffer,
buffers=self.__seqbuffers,
model_topos=self.__model_topos,
model_configs=self.__model_configs,
ctrl=self.__rpc_ctrl,
@ -395,20 +400,33 @@ class MasterWorker(worker_base.Worker):
# Pause the worker if experiment or system-wise benchmark completes.
if (
(
self.__benchmark_steps is not None
and self.__rpc_ctrl.step_info.global_step >= self.__benchmark_steps
) or (
)
or (
self.__rpc_ctrl.step_info.global_step * self.__src_rpc.n_seqs
>= self.__total_train_epochs * self._dataset_size
)
or (
self.__benchmark_n_seqs is not None
and self.__rpc_ctrl.step_info.global_step
* self._ft_spec.train_batch_size
>= self.__benchmark_n_seqs
)
):
# We don't know whether it is the last step of the current epoch,
# so we exit at the first step of the next epoch.
if self.__benchmark_steps is not None:
if (
self.__benchmark_steps is not None
or self.__benchmark_n_seqs is not None
):
logger.info(
f"Finished benchmark {self.__benchmark_steps}. "
f"Time consumption of this setup: {time_since_configure:.3f}"
)
logger.info(f"avg #e2e# time *{np.mean(self.e2e_time_history):.3f}*")
# TODO: inform generation workers to exit
return self.experiment_complete_exit()
return worker_base.PollResult(sample_count=1, batch_count=1)
@ -439,9 +457,7 @@ class MasterWorker(worker_base.Worker):
s += f"(global step {global_step}) finishes. "
s += f"#End to end# execution time: *{e2e_time:.3f}*s. "
s += f"Total time consumption: {time_since_configure:.3f}s. "
logging.log_wandb_tensorboard(
{"timeperf/e2e": e2e_time}, step=self.__rpc_ctrl.step_info.global_step
)
logging.log_wandb_tensorboard({"timeperf/e2e": e2e_time})
if len(self.e2e_time_history) > 2:
remaining_steps = self._steps_per_epoch - epoch_step
remaining_epochs = self.__total_train_epochs - epoch

View File

@ -63,7 +63,7 @@ class ModelFunctionCall:
model_topos: Dict[str, topology.ProcessTopology],
model_configs: Dict[str, None | ReaLModelConfig],
ctrl: RPCCorountineControl,
buffer: AsyncIOSequenceBuffer,
buffers: List[AsyncIOSequenceBuffer],
redistrib_planner: RedistribPlanner,
summary_writer: SummaryWriter | None,
):
@ -89,7 +89,7 @@ class ModelFunctionCall:
)
self.rpc_ctrl = ctrl
self.buffer = buffer
self.buffers = buffers
self.redistrib_planner = redistrib_planner
self.summary_writer = summary_writer
@ -306,7 +306,7 @@ class ModelFunctionCall:
).partitions
return buf_indices, sample, partitions
async def run_step(self, buf_indices, sample):
async def run_step(self, buf_indices, sample, buffer_id: int):
rpc = self.rpc
topo = self.model_topos[rpc.model_name]
ctrl = self.rpc_ctrl
@ -317,7 +317,7 @@ class ModelFunctionCall:
]
dp_head_indices = [
topo.get_rank(data=i, pipe=topo.get_dim("pipe") - 1, model=0)
topo.get_rank(data=i, pipe=topo.get_dim("pipe") - 1, tensor=0)
for i in range(self.dp_size)
]
@ -348,11 +348,6 @@ class ModelFunctionCall:
if i not in dests:
dests[i] = []
# NOTE: The data loaded from the dataset may be unevenly distributed across DP ranks.
# Only bcast works in this case.
if rpc.is_src:
pattern = "bcast"
else:
pattern = "gather-scatter"
data_transfer_plan = self.redistrib_planner.derive_plan(
dests,
@ -362,14 +357,14 @@ class ModelFunctionCall:
blogger.info(f"Data tranfer plan for `{rpc.name}`: {data_transfer_plan}.")
# Update storage tracker for transferred data.
if rpc.is_src:
if pattern == "bcast":
# NOTE: since the data we loaded may be unevenly distributed across DP ranks,
# we should change the owner of the data to the src RPC.
for i in range(topo.world_size()):
h = ModelShardID.from_parallelism_rank(
model_name=rpc.model_name, topo=topo, parallelism_rank=i
)
is_dp_head = h.mp_rank == 0 and h.pp_rank == topo.get_dim("pipe") - 1
is_dp_head = h.tp_rank == 0 and h.pp_rank == topo.get_dim("pipe") - 1
gpu_id = self.msid2mwid[h]
for key in rpc.input_keys:
await self.redistrib_planner.storage_tracker.add_data(
@ -414,13 +409,13 @@ class ModelFunctionCall:
responses, time_records = list(zip(*[responses[i] for i in dp_head_indices]))
# If the returned data is a SequenceSample, it is the data returned by
# model function calls. The data shoulbe be amended into buffer.
# model function calls. The data should be amended into buffer.
# Otherwise, it's the train statistics and should be reduced and logged.
if isinstance(responses[-1], data_api.SequenceSample):
# Update storage tracker for generated data.
for dp_rank, x in enumerate(responses):
pp_size = topo.get_dim("pipe")
ranks = topo.filter_match(data=dp_rank, pipe=pp_size - 1, model=0)
ranks = topo.filter_match(data=dp_rank, pipe=pp_size - 1, tensor=0)
for rank in ranks:
h = config_pkg.ModelShardID.from_parallelism_rank(
model_name=rpc.model_name, topo=topo, parallelism_rank=rank
@ -434,8 +429,14 @@ class ModelFunctionCall:
is_owner=True,
)
res = data_api.SequenceSample.gather(responses)
else:
elif isinstance(responses[0], dict):
res = data_api.gather_stat(responses)
else:
assert isinstance(responses[0], list)
res = [
data_api.gather_stat([r[i] for r in responses])
for i in range(len(responses[0]))
]
if rpc.log_return_value:
if isinstance(res, dict):
@ -447,6 +448,17 @@ class ModelFunctionCall:
step=ctrl.step_info.global_step,
summary_writer=self.summary_writer,
)
elif isinstance(res, list):
for j, r in enumerate(res):
logger.info(
f"RPC name {rpc.name} returns ({j}/{len(res)})\n{data_api.tabulate_stats(r)}"
)
offset = len(res) * ctrl.step_info.global_step
logging.log_wandb_tensorboard(
r,
step=offset + j,
summary_writer=self.summary_writer,
)
else:
logger.info(f"RPC name {rpc.name} returns\n{res}")
@ -456,7 +468,6 @@ class ModelFunctionCall:
time_stats = stats_tracker.export()
logging.log_wandb_tensorboard(
time_stats,
step=ctrl.step_info.global_step,
summary_writer=self.summary_writer,
)
@ -475,7 +486,7 @@ class ModelFunctionCall:
await ctrl.train_count.put(1)
else:
logger.info(f"Amending RPC {rpc.name} output keys: {res.keys}")
await self.buffer.amend_batch(buf_indices, res.unpack())
await self.buffers[buffer_id].amend_batch(buf_indices, res.unpack())
# Wait for all side-effect requests to finish.
# Side-effect or empty requests are required for data transfer
@ -483,20 +494,20 @@ class ModelFunctionCall:
# Wait them after the main request to log the oorrect MFC time.
await self.stream.gather_async(other_req_ids)
async def run(self):
async def run(self, buffer_id: int):
rpc = self.rpc
topo = self.model_topos[rpc.model_name]
logger.info(
f"Running Model RPC, interface_type=#{rpc.interface_type}# "
f"(dp,mp,pp) = *({topo.get_dim('data')},{topo.get_dim('model')},{topo.get_dim('pipe')})*"
f"(dp,tp,pp) = *({topo.get_dim('data')},{topo.get_dim('tensor')},{topo.get_dim('pipe')})*"
)
consumed = 0
while True:
buf_indices, sample = await self.buffer.get_batch_for_rpc(rpc)
buf_indices, sample = await self.buffers[buffer_id].get_batch_for_rpc(rpc)
await self.run_step(buf_indices, sample)
await self.run_step(buf_indices, sample, buffer_id)
consumed += sample.bs
# Ensure that parent RPCs will not be over-consumed.

View File

@ -153,7 +153,7 @@ class ModelWorker(worker_base.Worker):
]
for s in self.config.shards:
_pp_size = s.id.topo.get_dim("pipe")
if not (s.id.mp_rank == 0 and s.id.pp_rank == _pp_size - 1):
if not (s.id.tp_rank == 0 and s.id.pp_rank == _pp_size - 1):
continue
if src_rpc.model_name == s.id.model_name:
self.__has_dataset = True
@ -195,8 +195,8 @@ class ModelWorker(worker_base.Worker):
return None
@property
def _mp_rank(self) -> int:
return constants.model_parallel_rank()
def _tp_rank(self) -> int:
return constants.tensor_parallel_rank()
@property
def _pp_rank(self) -> int:
@ -211,8 +211,8 @@ class ModelWorker(worker_base.Worker):
return constants.pipe_parallel_world_size()
@property
def _mp_size(self) -> int:
return constants.model_parallel_world_size()
def _tp_size(self) -> int:
return constants.tensor_parallel_world_size()
@property
def _dp_size(self) -> int:
@ -220,7 +220,7 @@ class ModelWorker(worker_base.Worker):
@property
def _is_dp_head(self) -> bool:
return self._mp_rank == 0 and self._pp_rank == self._pp_size - 1
return self._tp_rank == 0 and self._pp_rank == self._pp_size - 1
@property
def _model(self) -> model_api.Model:
@ -302,6 +302,7 @@ class ModelWorker(worker_base.Worker):
constants.set_grid(model_name_, grid)
# Set up training dataset for source RPCs.
self.__datasets = []
if self.__has_dataset:
datasets = [
data_api.make_dataset(
@ -321,31 +322,34 @@ class ModelWorker(worker_base.Worker):
)
for d in self.config.datasets
]
if len(self.config.datasets) == 1:
self.__dataset = datasets[0]
else:
self.__dataset = torch.utils.data.ConcatDataset(datasets)
self.__datasets = datasets
self.__dataloaders: List[
torch.utils.data.DataLoader[data_api.SequenceSample]
] = []
for i, d in enumerate(self.__datasets):
g = torch.Generator()
g.manual_seed(seeding.get_seed())
g.manual_seed(
self.config.base_seed + seeding._seed_from_key(f"__dataloader{i}__")
)
dataloader_kwargs = dict(
shuffle=self.config.shuffle_dataset,
generator=g,
)
if not isinstance(self.__dataset, PullerStreamDataset):
if not isinstance(d, PullerStreamDataset):
dataloader_kwargs["collate_fn"] = data_api.SequenceSample.gather
# NOTE: This is *NOT* the actual batch size for training.
# It is just a proper size to load data to workers.
dataloader_kwargs["batch_size"] = 10240
else:
dataloader_kwargs["batch_size"] = None
self.__dataloader = torch.utils.data.DataLoader(
self.__dataset, **dataloader_kwargs
self.__dataloaders.append(
torch.utils.data.DataLoader(d, **dataloader_kwargs)
)
self.dataset_size = len(self.__dataset)
self.dataset_size = sum(len(d) for d in self.__datasets)
self.__data_generator = enumerate(self.__dataloader)
self.__data_generators = [enumerate(d) for d in self.__dataloaders]
self.__models: Dict[ModelName, model_api.Model] = dict()
self.__model_is_handle: Dict[ModelName, bool] = dict()
@ -377,25 +381,26 @@ class ModelWorker(worker_base.Worker):
)
# Recover indices for dynamic dataset
for i, d in enumerate(self.__datasets):
if (
s.id.model_name == self.src_rpc.model_name
and self.__has_dataset
and hasattr(self.__dataset, "filter")
and hasattr(d, "filter")
):
dataset_indices_path = os.path.join(
constants.MODEL_SAVE_ROOT,
constants.experiment_name(),
constants.trial_name(),
"dataset_indices",
f"{self._dp_rank}.npy",
f"{self._dp_rank}_{i}.npy",
)
if os.path.exists(dataset_indices_path):
indices = np.load(dataset_indices_path).tolist()
logger.info(
f"DP rank {self._dp_rank} updating dataset indices upon recover, "
f"size {len(self.__dataset.active_indices)} -> {len(indices)}"
f"size {len(d.active_indices)} -> {len(indices)}"
)
self.__dataset.active_indices = indices
d.active_indices = indices
if constants.parallelism_rank() == 0:
self.logger.info(
@ -537,9 +542,13 @@ class ModelWorker(worker_base.Worker):
cache = []
while True:
try:
request, data, handled, res, time_record = (
self.__request_queue.get_nowait()
)
(
request,
data,
handled,
res,
time_record,
) = self.__request_queue.get_nowait()
request: request_reply_stream.Payload
if not handled:
while len(request.pre_hooks) > 0:
@ -582,9 +591,13 @@ class ModelWorker(worker_base.Worker):
elif request.handle_name == "fetch":
dp_rank = int(re.search(r"__data(\d+)__", request.handler).group(1))
assert self.__has_dataset
assert isinstance(request.data, int), request.data
dataset_id = request.data
# Fetch.
try:
self.__dataset_batch_counter, cur_sample = next(self.__data_generator)
self.__dataset_batch_counter, cur_sample = next(
self.__data_generators[dataset_id]
)
except StopIteration:
# Upon the first fetch request, filter dataset and create dataloader.
eval_scores_path = os.path.join(
@ -598,39 +611,43 @@ class ModelWorker(worker_base.Worker):
constants.experiment_name(),
constants.trial_name(),
"dataset_indices",
f"{dp_rank}.npy",
f"{dp_rank}_{dataset_id}.npy",
)
os.makedirs(os.path.dirname(dataset_indices_path), exist_ok=True)
if hasattr(self.__dataset, "filter") and os.path.exists(
if hasattr(self.__datasets[dataset_id], "filter") and os.path.exists(
eval_scores_path
):
# Don't filter dataset on the first poll after recover.
with open(eval_scores_path, "r", encoding="utf-8") as f:
dataset_eval_scores = json.load(f)
self.__dataset.filter(dataset_eval_scores)
self.__datasets[dataset_id].filter(dataset_eval_scores)
# Save the dataset indices after filtering
np.save(
dataset_indices_path,
self.__dataset.active_indices,
self.__datasets[dataset_id].active_indices,
)
g = torch.Generator()
g = g.set_state(self.__dataloader.generator.get_state())
g = g.set_state(self.__dataloaders[dataset_id].generator.get_state())
dataloader_kwargs = dict(
shuffle=self.config.shuffle_dataset,
generator=g,
)
if not isinstance(self.__dataset, PullerStreamDataset):
if not isinstance(self.__datasets[dataset_id], PullerStreamDataset):
dataloader_kwargs["collate_fn"] = data_api.SequenceSample.gather
# NOTE: This is *NOT* the actual batch size for training.
# It is just a proper size to load data to workers.
dataloader_kwargs["batch_size"] = 10240
else:
dataloader_kwargs["batch_size"] = None
self.__dataloader = torch.utils.data.DataLoader(
self.__dataset, **dataloader_kwargs
self.__dataloaders[dataset_id] = torch.utils.data.DataLoader(
self.__datasets[dataset_id], **dataloader_kwargs
)
self.__data_generators[dataset_id] = enumerate(
self.__dataloaders[dataset_id]
)
self.__dataset_batch_counter, cur_sample = next(
self.__data_generators[dataset_id]
)
self.__data_generator = enumerate(self.__dataloader)
self.__dataset_batch_counter, cur_sample = next(self.__data_generator)
if isinstance(cur_sample, data_api.SequenceSample):
samples = cur_sample.unpack()
@ -663,7 +680,10 @@ class ModelWorker(worker_base.Worker):
)
elif request.handle_name == "spec":
# Raw dataset without filtering.
res = self.dataset_size
res = {
"n_datasets": len(self.__datasets),
"dataset_size": self.dataset_size,
}
elif request.handle_name == "clear_data_cache":
with cuda_tmarked("clear_data_cache", CUDATimeMarkType.misc):
ids = request.data
@ -772,8 +792,10 @@ class ModelWorker(worker_base.Worker):
if hook == "evaluate":
assert request.handle_name == "train_step", request.handle_name
assert isinstance(ret, dict), ret
assert isinstance(res, dict), res
if isinstance(res, dict):
res.update(ret)
else:
res[0].update(ret)
time_record[
f"timeperf/{request.handler.model_name.role}_{request.handle_name}/post-{hook}"
] += (time.perf_counter() - tik)
@ -803,13 +825,7 @@ class ModelWorker(worker_base.Worker):
with constants.model_scope(model_name):
dist.barrier(group=constants.cpu_parallelism_group())
if constants.parallelism_rank() == 0:
name_resolve.add(
name,
str(global_step),
delete_on_exit=False,
keepalive_ttl=30,
replace=True,
)
name_resolve.add(name, str(global_step), replace=True)
time_record[
f"timeperf/{request.handler.model_name.role}_{request.handle_name}/param-sync-save"
] += (time.perf_counter() - tik)
@ -867,7 +883,7 @@ class ModelWorker(worker_base.Worker):
if len(self.__performance_recorder) == 0:
self.__performance_recorder["info"] = {
"pipeline_size": self._pp_size,
"model_size": self._mp_size,
"model_size": self._tp_size,
"data_size": self._dp_size,
"rank": constants.parallelism_rank(),
"sequence_parallel_enabled": constants.sequence_parallel(),
@ -1374,9 +1390,13 @@ class ModelWorker(worker_base.Worker):
rescheduled_requests = []
other_requests = []
for _ in range(self.__request_queue.qsize()):
request, data, handled, res, time_record = (
self.__request_queue.get_nowait()
)
(
request,
data,
handled,
res,
time_record,
) = self.__request_queue.get_nowait()
if request.handle_name not in ["inference", "generate", "train_step"]:
other_requests.append((request, data, handled, res, time_record))
else:
@ -1399,9 +1419,13 @@ class ModelWorker(worker_base.Worker):
# we can correctly log the time consumption in the master worker.
while True:
try:
request, data, handled, res, time_record = (
self.__request_queue.get_nowait()
)
(
request,
data,
handled,
res,
time_record,
) = self.__request_queue.get_nowait()
self.handle_blocking_request(
request, data, handled, res, time_record
)

View File

@ -103,9 +103,14 @@ class PartialRolloutManager:
):
from realhf.impl.model.backend.sglang import SGLangAPIClient
max_new_tokens = min(raw_gconfig.max_new_tokens, self.new_tokens_per_chunk)
max_new_tokens = min(
max_new_tokens,
raw_gconfig.max_new_tokens - len(input_ids) + len(prompt_ids),
)
gconfig = raw_gconfig.new(
n=1,
max_new_tokens=min(raw_gconfig.max_new_tokens, self.new_tokens_per_chunk),
max_new_tokens=max_new_tokens,
)
assert self.tokenizer.pad_token_id is not None
assert self.tokenizer.eos_token_id is not None
@ -130,6 +135,7 @@ class PartialRolloutManager:
group_idx=group_idx,
raw_gconfig=raw_gconfig,
server_url=url,
version=cur_server_version,
),
),
stream=False,
@ -190,6 +196,7 @@ class PartialRolloutManager:
s: APIGenerateOutput = await task
group_idx = s.metadata["group_idx"]
raw_gconfig = s.metadata["raw_gconfig"]
previous_version = s.metadata["version"]
assert s.group_size == 1
no_eos = s.no_eos[0]
@ -202,20 +209,27 @@ class PartialRolloutManager:
if no_eos and gen_len < raw_gconfig.max_new_tokens:
# Unfinished request due to chunked generation.
# Send it back to continue.
async with aiohttp.ClientSession() as session:
async with session.post(
f"http://{self.gserver_manager_addr}/get_model_version",
json=dict(server_url=s.metadata["server_url"]),
timeout=ClientTimeout(total=self.timeout, sock_connect=30),
) as resp:
resp.raise_for_status()
cur_version = (await resp.json())["version"]
req_meta = GenReqMeta(
qid=s.qid,
prompt_len=s.prompt_len,
group_size=raw_gconfig.n,
new_token_budget=raw_gconfig.max_new_tokens,
predicted_new_tokens=None,
previous_server_url=s.metadata["server_url"],
previous_version=previous_version,
)
info = await self._schedule_request(req_meta)
cur_version = info["version"]
server_url = info["url"]
if len(s.output_logprobs) > 0:
prev_logprobs = s.prev_logprobs + s.output_logprobs[0]
else:
prev_logprobs = s.prev_logprobs
if prev_logprobs is None:
prev_logprobs = []
await self._issue_generation(
s.metadata["server_url"],
server_url,
s.qid,
group_idx,
s.prompt_ids,
@ -240,9 +254,10 @@ class PartialRolloutManager:
try:
qid, prompt_token_ids, gconfig = self.request_queue.get_nowait()
req_meta = GenReqMeta(
qid=qid,
prompt_len=len(prompt_token_ids),
group_size=gconfig.n,
new_token_budget=self.new_tokens_per_chunk,
new_token_budget=gconfig.max_new_tokens,
predicted_new_tokens=None,
)
dst_server_info = await self._schedule_request(req_meta)

View File

@ -171,7 +171,9 @@ class NameResolvingZmqPuller(ZMQJsonPuller):
name = names.push_pull_stream(
experiment_name, trial_name, stream_name=f"puller{puller_index}"
)
host, port = network.gethostip(), network.find_free_port()
host, port = network.gethostip(), network.find_free_port(
experiment_name=experiment_name, trial_name=trial_name
)
addr = f"{host}:{port}"
name_resolve.add(name, addr)
super().__init__(host, port, **kwargs)

View File

@ -189,10 +189,11 @@ class RolloutWorker(AsyncWorker):
assert data_id not in self.rollout_tasks
return cur_sample
async def allocate_new_rollout(self) -> bool:
async def allocate_new_rollout(self, qid) -> bool:
async with aiohttp.ClientSession() as session:
async with session.get(
async with session.post(
f"http://{self.gserver_manager_addr}/allocate_rollout",
json=dict(qid=qid),
timeout=ClientTimeout(
total=self.config.rollout_request_timeout, sock_connect=30
),
@ -231,10 +232,10 @@ class RolloutWorker(AsyncWorker):
self._cur_data = self.load_next_data()
if self._cur_data is not None:
can_rollout = await self.allocate_new_rollout()
if can_rollout:
data = self._cur_data
qid = data.ids[0]
can_rollout = await self.allocate_new_rollout(qid)
if can_rollout:
self.act_queues[qid] = asyncio.Queue(1024)
task = asyncio.create_task(self.rollout_task(qid, data))
@ -265,7 +266,11 @@ class RolloutWorker(AsyncWorker):
accepted = True
self.push_stream.push([traj.as_json_compatible() for traj in trajs])
info = dict(qid=qid, accepted=accepted)
n_tokens = 0
for traj in trajs:
seqlens = [sum(datapack.flat2d(ss)) for ss in traj.seqlens.values()]
n_tokens += max(seqlens)
info = dict(qid=qid, accepted=accepted, n_tokens=n_tokens)
async with aiohttp.ClientSession(
f"http://{self.gserver_manager_addr}"
) as session:

View File

@ -36,7 +36,6 @@ ray
redis
scipy
seaborn
setuptools>=61.0
tqdm
networkx==3.3
matplotlib
@ -59,3 +58,5 @@ protobuf<3.21
rich
orjson>=3.10.16
flask
setuptools>=62.3.0,<75.9
func_timeout

View File

@ -246,30 +246,6 @@ if not no_ext and _is_cuda():
ext_modules.append(interval_op_cuda)
if not no_ext:
search_extension = setuptools.Extension(
name="realhf._C.mdm_search",
sources=[
"csrc/search/search.cpp",
"csrc/search/rpc.cpp",
"csrc/search/device_mesh.cpp",
"csrc/search/simulate.cpp",
],
language="c++",
extra_compile_args=[
"-O3",
"-Wall",
"-shared",
"-std=c++11",
"-fPIC",
"-std=c++17",
],
include_dirs=[
os.path.join(os.path.abspath(os.path.dirname(__file__)), "csrc", "search"),
get_pybind11_include_path(),
],
)
ext_modules.append(search_extension)
interval_extension = setuptools.Extension(
name="realhf._C.interval_op",
sources=[

Some files were not shown because too many files have changed in this diff Show More