mirror of https://github.com/inclusionAI/AReaL
Support asynchronous RL training, Qwen3, and the latest SGLang (#47)
* feat: one buffer for each task * feat: support "one buffer for each task" for async * make kv_cache_dtype configurable Signed-off-by: Tiwei Bie <tiwei.btw@antgroup.com> * style: use plural form fix: use _seed_from_key to set different seeds for data loaders fix: call load_data for one buffer each time * PullRequest: 125 Support running async experiments in the 2407 image. Merge branch fw/async2407 of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/125 Signed-off-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * fix: handle multiple datasets in recover indices fix: `isinstance(self.__datasets, PullerStreamDataset)` feat: use the "spec" request to obtain the number of datasets fix: revert rollout worker * fix: revert async_rl_exp.py * fix flag for list (cuda_graph_bs) * format * [FIX] fix async task reward [sglang bf16-> fp16] * fix: define `self.__datasets` in advance * PullRequest: 130 [Refactor] Remove deprecated search related code Merge branch mzy/remove-search of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/130 Signed-off-by: 博惟 <bowei.fw@antgroup.com> * remove search related * PullRequest: 131 [Refactor] Change terminology "model parallel" into "tensor parallel" to align with megatron. Merge branch mzy/mp-to-tp of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/131?tab=comment Signed-off-by: 博惟 <bowei.fw@antgroup.com> * change mp to tp * . * . * PullRequest: 142 Fix an error for megatron backend destroy Merge branch fw/fix-meagatron-destroy of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/142 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * PullRequest: 143 Fix the port conflict issue of generation servers Merge branch fw/fix-gen-port of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/143?tab=comment Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * somehow fix the port issue * add clearance period * . * . * PullRequest: 145 Add code environment Merge branch fw/code-env of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/145?tab=comment Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * add code env * somehow fix the port issue * fix * PullRequest: 144 Add decoupled PPO loss Merge branch fw/decoupled-ppo-loss of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/144?tab=comment Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * fix ppo step logging, nan in stats tracker, and add decoupled loss * . * somehow fix the port issue * fix typo * PullRequest: 146 Merge SLURM logs and save experiment configs in yaml format. Merge branch fw/better-logging of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/146 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * merge all slurm logs into one * write config to yaml * PullRequest: 141 Merge changes during NeurIPS submission Merge branch fw/async-dev of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/141 Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * . * . * . * . * . * . * . * . * . * update script * . * . * . * . * [ADD] add least req scheduling * fix test genreq * . * . * fix stats tracker nan * . * . * . * . * . * . * . * uppper clip decoupled objective * add throughput exp script * . * remove behav upper clip param * . * . * . * plot curve * update thpt script * . * master worker raise error when exiting * update script * add gen throughput logging * . * . * add decoupled wandb data * . * fix port issue and add no training option * . * enlarge ttl * remove gserver manager await staled * update weights in groups * . * . * . * add port clearance period * . * . * . * add plot script * add sft throughput eval * . * log tokens in null interface * 消融实验和interruptible generation * 画图脚本/运行脚本/数据结果 * . * remove scripts * add port test * remove force_sync_reward * revert some changes * . * revert * revert fix * fix * revert * fix typo * support qwen3 training * PullRequest: 147 Support interruption in SGLang and fix a KeyError in gather-scatter communication Merge branch fw/sglang046-with-abort-request of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/147?tab=diff Reviewed-by: 晓雷 <meizhiyu.mzy@antgroup.com> * fix ppo step logging, nan in stats tracker, and add decoupled loss * . * somehow fix the port issue * initial commit * add interupt request * fix data transfer issue * max concurrent rollouts defaults to train batch size * merge main * add patch * fix patch typp * revert sglang * fix typo * fix minor typo * . * pip show editable sglang path * PullRequest: 149 fix: code faas max_retries Merge branch xss/fix_code_verifier of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/149 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * fix: code faas max_retries * PullRequest: 150 [Bug Fix] Fix key errors in `_run_scatter` in data transfer Merge branch mzy/fix-scatter-groups of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/150 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * fix scatter groups key error * fix test * . * PullRequest: 151 Fix Qwen3 import error when using transformers with a lower version Merge branch fw/fix-qwen3 of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/151 Reviewed-by: 温差 <xushusheng.xss@antgroup.com> * merge all slurm logs into one * write config to yaml * . * PullRequest: 152 Support sglang0.4.6 and fix master_worker import error Merge branch adopt_sglang046 of git@code.alipay.com:inclusionAI/AReaL.git into main https://code.alipay.com/inclusionAI/AReaL/pull_requests/152 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * Support sglang0.4.6 and fix master_worker import error * remove disable_mla option --------- Signed-off-by: Tiwei Bie <tiwei.btw@antgroup.com> Co-authored-by: wanghuaijie.whj <wanghuaijie.whj@antgroup.com> Co-authored-by: Tiwei Bie <tiwei.btw@antgroup.com> Co-authored-by: kira.gw <kira.gw@antgroup.com> Co-authored-by: shenxujie.sxj <shenxujie.sxj@antgroup.com> Co-authored-by: 晓雷 <meizhiyu.mzy@antgroup.com> Co-authored-by: sam.gjx <sam.gjx@antgroup.com> Co-authored-by: 温差 <xushusheng.xss@antgroup.com> Co-authored-by: 履渊 <yuhong.gyh@antgroup.com>
This commit is contained in:
parent
f7ab31d050
commit
c60d128b14
|
@ -1,87 +0,0 @@
|
|||
#include <device_mesh.hpp>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
|
||||
// DeviceMesh::DeviceMesh()
|
||||
// : device_mesh_name(""), n_nodes(0), n_gpus(0), node_names({}), gpu_ids({}) {
|
||||
// };
|
||||
|
||||
DeviceMesh::DeviceMesh(int n_nodes, int n_gpus_per_node, std::vector<std::vector<int>> mapping,
|
||||
std::string global_mesh_name, std::string name)
|
||||
: n_nodes(n_nodes),
|
||||
n_gpus_per_node(n_gpus_per_node),
|
||||
mapping(mapping),
|
||||
global_mesh_name(global_mesh_name),
|
||||
name(name) {
|
||||
assert(n_nodes == static_cast<int>(mapping.size()));
|
||||
for (int i = 0; i < n_nodes; i++) {
|
||||
assert(n_gpus_per_node == static_cast<int>(mapping[i].size()));
|
||||
}
|
||||
};
|
||||
|
||||
bool is_all_overlap(std::vector<DeviceMesh *> device_meshes, DeviceMesh device_mesh) {
|
||||
for (DeviceMesh *other : device_meshes) {
|
||||
if (!device_mesh.overlap(*other)) return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
bool is_all_overlap(std::unordered_set<DeviceMesh *> device_meshes, DeviceMesh device_mesh) {
|
||||
for (DeviceMesh *other : device_meshes) {
|
||||
if (!device_mesh.overlap(*other)) return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
bool DeviceMesh::contain(const DeviceMesh &other) {
|
||||
// check whether one device mapping is contained by another by
|
||||
// checking 1. whether global_mesh_name is identical
|
||||
// 2. whether mapping of one device mesh is contained by the other one
|
||||
if (global_mesh_name != other.global_mesh_name) return false;
|
||||
for (int i = 0; i < n_nodes; i++) {
|
||||
for (int j = 0; j < n_gpus_per_node; j++) {
|
||||
if (mapping[i][j] == 0 && other.mapping[i][j] == 1) return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
bool DeviceMesh::contained_by(const DeviceMesh &other) {
|
||||
if (global_mesh_name != other.global_mesh_name) return false;
|
||||
for (int i = 0; i < n_nodes; i++) {
|
||||
for (int j = 0; j < n_gpus_per_node; j++) {
|
||||
if (mapping[i][j] == 1 && other.mapping[i][j] == 0) return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
bool DeviceMesh::overlap(const DeviceMesh &other) {
|
||||
if (global_mesh_name != other.global_mesh_name) return false;
|
||||
for (int i = 0; i < n_nodes; i++) {
|
||||
for (int j = 0; j < n_gpus_per_node; j++) {
|
||||
if (mapping[i][j] == 1 && other.mapping[i][j] == 1) return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
ModelParallelStrategy::ModelParallelStrategy(int num_pp, int num_dp, int num_mp)
|
||||
: num_pp(num_pp), num_dp(num_dp), num_mp(num_mp) {};
|
||||
|
||||
bool ModelParallelStrategy::operator==(const ModelParallelStrategy &other) const {
|
||||
return num_pp == other.num_pp && num_dp == other.num_dp && num_mp == other.num_mp;
|
||||
};
|
||||
|
||||
bool DeviceMesh::operator==(const DeviceMesh &other) const {
|
||||
return name == other.name && global_mesh_name == other.global_mesh_name;
|
||||
};
|
||||
|
||||
std::string ModelParallelStrategy::to_string() {
|
||||
return "num_pp:" + std::to_string(num_pp) + ";" + "num_dp:" + std::to_string(num_dp) + ";"
|
||||
+ "num_mp:" + std::to_string(num_mp);
|
||||
};
|
||||
|
||||
std::string ModelParallelStrategy::to_key() {
|
||||
return std::to_string(num_pp) + "," + std::to_string(num_mp) + "," + std::to_string(num_dp);
|
||||
}
|
|
@ -1,49 +0,0 @@
|
|||
#ifndef DEVICE_MESH_HPP
|
||||
#define DEVICE_MESH_HPP
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
// #include <rpc.hpp>
|
||||
|
||||
class RPCInstance;
|
||||
|
||||
class DeviceMesh {
|
||||
public:
|
||||
int n_nodes;
|
||||
int n_gpus_per_node;
|
||||
std::vector<std::vector<int>> mapping;
|
||||
std::string global_mesh_name;
|
||||
std::string name;
|
||||
RPCInstance *pre_task = nullptr;
|
||||
|
||||
// DeviceMesh();
|
||||
DeviceMesh(int n_nodes, int n_gpus_per_node, std::vector<std::vector<int>> mapping,
|
||||
std::string global_mesh_name, std::string name);
|
||||
|
||||
bool overlap(const DeviceMesh &other);
|
||||
bool contain(const DeviceMesh &other);
|
||||
bool contained_by(const DeviceMesh &other);
|
||||
|
||||
bool operator==(const DeviceMesh &other) const;
|
||||
};
|
||||
|
||||
bool is_all_overlap(std::vector<DeviceMesh *> device_meshes, DeviceMesh device_mesh);
|
||||
bool is_all_overlap(std::unordered_set<DeviceMesh *> device_meshes, DeviceMesh device_mesh);
|
||||
|
||||
class ModelParallelStrategy {
|
||||
public:
|
||||
int num_pp, num_dp, num_mp;
|
||||
|
||||
ModelParallelStrategy(int num_pp, int num_dp, int num_mp);
|
||||
|
||||
bool operator==(const ModelParallelStrategy &other) const;
|
||||
|
||||
std::string to_string();
|
||||
std::string to_key();
|
||||
};
|
||||
|
||||
class ModelDeviceMapping {};
|
||||
|
||||
#endif // DEVICE_MESH_HPP
|
|
@ -1,233 +0,0 @@
|
|||
#include <rpc.hpp>
|
||||
#include <device_mesh.hpp>
|
||||
#include <numeric>
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
|
||||
RPC::RPC(std::string model_name, std::string rpc_name, std::string interface_type)
|
||||
: model_name(model_name), rpc_name(rpc_name), interface_type(interface_type) {};
|
||||
|
||||
RPCExecution::RPCExecution(RPC *rpc_ptr, DeviceMesh &device_mesh,
|
||||
ModelParallelStrategy &model_parallel_strategy, uint64_t time_cost,
|
||||
uint64_t mem, uint64_t static_mem)
|
||||
: rpc_ptr(rpc_ptr),
|
||||
device_mesh(device_mesh),
|
||||
model_parallel_strategy(model_parallel_strategy),
|
||||
time_cost(time_cost),
|
||||
mem(mem),
|
||||
static_mem(static_mem) {};
|
||||
|
||||
bool OverlapGroup::maybe_add(RPCExecution *rpc_exe) {
|
||||
if (rpc_executions.empty()) {
|
||||
rpc_executions.insert(rpc_exe);
|
||||
device_meshes.insert(&rpc_exe->device_mesh);
|
||||
mem_static = rpc_exe->static_mem;
|
||||
mem_active = rpc_exe->mem - rpc_exe->static_mem;
|
||||
return true;
|
||||
}
|
||||
if (is_all_overlap(device_meshes, rpc_exe->device_mesh)) {
|
||||
rpc_executions.insert(rpc_exe);
|
||||
// bool dm_in_group = device_meshes.find(&rpc_exe -> device_mesh) != device_meshes.end();
|
||||
device_meshes.insert(&rpc_exe->device_mesh);
|
||||
mem_static += rpc_exe->static_mem;
|
||||
mem_active = std::max(mem_active, rpc_exe->mem - rpc_exe->static_mem);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
void DeviceMeshGroup::add_to_groups(RPCExecution *rpc_exe) {
|
||||
if (overlap_groups.empty()) {
|
||||
OverlapGroup *og = new OverlapGroup();
|
||||
og->maybe_add(rpc_exe);
|
||||
overlap_groups.push_back(og);
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<OverlapGroup *> tmp_new_ogs;
|
||||
for (OverlapGroup *og : overlap_groups) {
|
||||
// OverlapGroup og_copy = *og;
|
||||
bool update = og->maybe_add(rpc_exe);
|
||||
if (!update) {
|
||||
tmp_new_ogs.push_back(new OverlapGroup());
|
||||
tmp_new_ogs.back()->maybe_add(rpc_exe);
|
||||
for (RPCExecution *rpc_exe : og->rpc_executions) { tmp_new_ogs.back()->maybe_add(rpc_exe); }
|
||||
}
|
||||
tmp_new_ogs.push_back(og);
|
||||
}
|
||||
overlap_groups.clear();
|
||||
for (OverlapGroup *og : tmp_new_ogs) { overlap_groups.push_back(og); }
|
||||
};
|
||||
|
||||
void GroupedRPCExecutions::resolve(RPCExecution *rpc_exe) {}
|
||||
|
||||
void GroupedRPCExecutions::add(RPCExecution *rpc_exe) { group.add_to_groups(rpc_exe); };
|
||||
|
||||
void GroupedRPCExecutions::offload(std::string model_name) {};
|
||||
|
||||
uint64_t GroupedRPCExecutions::total_mem_cost() {
|
||||
uint64_t max_mem = 0;
|
||||
for (auto &og : group.overlap_groups) {
|
||||
// double og_ma = (og -> mem_active/(1024*1024))/1024.0;
|
||||
// double og_ms = (og -> mem_static/(1024*1024))/1024.0;
|
||||
// std::cout << "og size " << og -> rpc_executions.size()
|
||||
// << " mem active " << og_ma << " GB"
|
||||
// << " mem static " << og_ms << " GB" << std::endl;
|
||||
if (og->mem_active + og->mem_static > max_mem) { max_mem = og->mem_active + og->mem_static; }
|
||||
}
|
||||
return max_mem;
|
||||
};
|
||||
|
||||
RPCInstance::RPCInstance(RPC *rpc_ptr, int id, std::string name)
|
||||
: rpc_ptr(rpc_ptr), id(id), name(name) {};
|
||||
|
||||
void RPCInstance::remove_parent(RPCInstance *parent) {
|
||||
auto it = std::find(parents.begin(), parents.end(), parent);
|
||||
if (it != parents.end()) { parents.erase(it); }
|
||||
};
|
||||
|
||||
void RPCInstance::remove_child(RPCInstance *child) {
|
||||
auto it = std::find(children.begin(), children.end(), child);
|
||||
if (it != children.end()) { children.erase(it); }
|
||||
};
|
||||
|
||||
void RPCInstance::add_parent(RPCInstance *parent) { parents.push_back(parent); };
|
||||
|
||||
void RPCInstance::add_child(RPCInstance *child) { children.push_back(child); };
|
||||
|
||||
void RPCInstance::remove_tmp_child(RPCInstance *child) {
|
||||
auto it = std::find(tmp_children.begin(), tmp_children.end(), child);
|
||||
if (it != tmp_children.end()) { tmp_children.erase(it); }
|
||||
};
|
||||
|
||||
void RPCInstance::remove_tmp_parent(RPCInstance *parent) {
|
||||
auto it = std::find(tmp_parents.begin(), tmp_parents.end(), parent);
|
||||
if (it != tmp_parents.end()) { tmp_parents.erase(it); }
|
||||
};
|
||||
|
||||
void RPCInstance::add_tmp_parent(RPCInstance *parent) { tmp_parents.push_back(parent); };
|
||||
|
||||
void RPCInstance::add_tmp_child(RPCInstance *child) { tmp_children.push_back(child); };
|
||||
|
||||
uint64_t parameter_sync_cost(uint64_t model_size, RPCExecution *src, RPCExecution *dst,
|
||||
std::unordered_map<std::string, uint64_t> &cost_table) {
|
||||
// 7b size 13738442752 Bytes
|
||||
// double size_multiplier = double(model_size) / 13738442752.0;
|
||||
std::string model_key = std::to_string(model_size);
|
||||
std::string src_key = src->model_parallel_strategy.to_key();
|
||||
std::string dst_key = dst->model_parallel_strategy.to_key();
|
||||
if (src_key == dst_key) return 0;
|
||||
std::string key = model_key + "," + src_key + "," + dst_key;
|
||||
if (cost_table.find(key) == cost_table.end()) {
|
||||
// std::cout << "key " << key << " not found" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
return cost_table[key];
|
||||
}
|
||||
|
||||
void RPCInstance::resolve_parameter_sync(std::vector<RPCInstance *> tmp_graph,
|
||||
std::unordered_map<std::string, uint64_t> &cost_table) {
|
||||
// add parameter synchronization edges
|
||||
if (!param_sync) return;
|
||||
|
||||
// dst to train
|
||||
uint64_t from_cost =
|
||||
parameter_sync_cost(param_sync_size, param_sync_rpc_exe_ptr, rpc_exe_ptr, cost_table);
|
||||
uint64_t to_cost =
|
||||
parameter_sync_cost(param_sync_size, rpc_exe_ptr, param_sync_rpc_exe_ptr, cost_table);
|
||||
// if (param_sync_cost > 0)
|
||||
// std::cout << "Param sync cost " << param_sync_cost << " from "
|
||||
// << param_sync_rpc_exe_ptr -> rpc_ptr -> rpc_name << " to "
|
||||
// << rpc_exe_ptr -> rpc_ptr -> rpc_name << std::endl;
|
||||
|
||||
// add param sync from src to dst
|
||||
RPCExecution *from_src_exe =
|
||||
new RPCExecution(rpc_ptr, param_sync_rpc_exe_ptr->device_mesh,
|
||||
param_sync_rpc_exe_ptr->model_parallel_strategy, from_cost, 0, 0);
|
||||
RPCInstance *from_src = new RPCInstance(rpc_ptr, id, name + ":from_src");
|
||||
from_src->rpc_exe_ptr = from_src_exe;
|
||||
|
||||
RPCExecution *from_dst_exe = new RPCExecution(
|
||||
rpc_ptr, rpc_exe_ptr->device_mesh, rpc_exe_ptr->model_parallel_strategy, from_cost, 0, 0);
|
||||
RPCInstance *from_dst = new RPCInstance(rpc_ptr, id, name + ":from_dst");
|
||||
from_dst->rpc_exe_ptr = from_dst_exe;
|
||||
|
||||
// bool overlap = src -> rpc_exe_ptr -> device_mesh.overlap(
|
||||
// dst -> rpc_exe_ptr -> device_mesh);
|
||||
|
||||
for (RPCInstance *parent : parents) {
|
||||
parent->remove_tmp_child(this);
|
||||
from_src->add_tmp_parent(parent);
|
||||
from_dst->add_tmp_parent(parent);
|
||||
parent->add_tmp_child(from_src);
|
||||
parent->add_tmp_child(from_dst);
|
||||
}
|
||||
this->tmp_parents.clear();
|
||||
|
||||
from_src->add_tmp_child(this);
|
||||
from_dst->add_tmp_child(this);
|
||||
this->add_tmp_parent(from_src);
|
||||
this->add_tmp_parent(from_dst);
|
||||
|
||||
tmp_graph.push_back(from_src);
|
||||
tmp_graph.push_back(from_dst);
|
||||
|
||||
tmp_ris.push_back(from_src);
|
||||
tmp_ris.push_back(from_dst);
|
||||
tmp_exes.push_back(from_src_exe);
|
||||
tmp_exes.push_back(from_dst_exe);
|
||||
|
||||
// add param sync from dst to src
|
||||
RPCExecution *to_src_exe = new RPCExecution(rpc_ptr, rpc_exe_ptr->device_mesh,
|
||||
rpc_exe_ptr->model_parallel_strategy, to_cost, 0, 0);
|
||||
RPCInstance *to_src = new RPCInstance(rpc_ptr, id, name + ":to_src");
|
||||
to_src->rpc_exe_ptr = to_src_exe;
|
||||
|
||||
RPCExecution *to_dst_exe =
|
||||
new RPCExecution(rpc_ptr, param_sync_rpc_exe_ptr->device_mesh,
|
||||
param_sync_rpc_exe_ptr->model_parallel_strategy, to_cost, 0, 0);
|
||||
RPCInstance *to_dst = new RPCInstance(rpc_ptr, id, name + ":to_dst");
|
||||
to_dst->rpc_exe_ptr = to_dst_exe;
|
||||
|
||||
for (RPCInstance *child : children) {
|
||||
child->remove_tmp_parent(this);
|
||||
to_src->add_tmp_child(child);
|
||||
to_dst->add_tmp_child(child);
|
||||
child->add_tmp_parent(to_src);
|
||||
child->add_tmp_parent(to_dst);
|
||||
}
|
||||
this->tmp_children.clear();
|
||||
|
||||
to_src->add_tmp_parent(this);
|
||||
to_dst->add_tmp_parent(this);
|
||||
this->add_tmp_child(to_src);
|
||||
this->add_tmp_child(to_dst);
|
||||
|
||||
tmp_graph.push_back(to_src);
|
||||
tmp_graph.push_back(to_dst);
|
||||
|
||||
tmp_ris.push_back(to_src);
|
||||
tmp_ris.push_back(to_dst);
|
||||
tmp_exes.push_back(to_src_exe);
|
||||
tmp_exes.push_back(to_dst_exe);
|
||||
}
|
||||
|
||||
CommStats::CommStats(uint64_t local_send, uint64_t local_recv, uint64_t remote_send,
|
||||
uint64_t remote_recv, uint64_t offload_store, uint64_t offload_load)
|
||||
: local_send(local_send),
|
||||
local_recv(local_recv),
|
||||
remote_send(remote_send),
|
||||
remote_recv(remote_recv),
|
||||
offload_store(offload_store),
|
||||
offload_load(offload_load) {};
|
||||
|
||||
// ModelConfig::ModelConfig(std::string model_name,
|
||||
// uint64_t param_size_bytes)
|
||||
// : model_name(model_name), param_size_bytes(param_size_bytes) {
|
||||
// };
|
||||
|
||||
std::string RPCExecution::to_string() {
|
||||
return rpc_ptr->rpc_name + " on " + device_mesh.name
|
||||
+ ", parallel strategy: " + model_parallel_strategy.to_string();
|
||||
};
|
|
@ -1,121 +0,0 @@
|
|||
#ifndef RPC_HPP
|
||||
#define RPC_HPP
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <device_mesh.hpp>
|
||||
|
||||
class CommStats {
|
||||
public:
|
||||
uint64_t local_send, local_recv, remote_send, remote_recv, offload_store, offload_load;
|
||||
|
||||
CommStats(uint64_t local_send, uint64_t local_recv, uint64_t remote_send, uint64_t remote_recv,
|
||||
uint64_t offload_store, uint64_t offload_load);
|
||||
};
|
||||
|
||||
class RPC {
|
||||
public:
|
||||
std::string model_name;
|
||||
std::string rpc_name;
|
||||
// interface_type: 0=generate, 1=train_step, 2=inference
|
||||
std::string interface_type;
|
||||
|
||||
RPC(std::string model_name, std::string rpc_name, std::string interface_type);
|
||||
};
|
||||
|
||||
class RPCExecution {
|
||||
public:
|
||||
RPC *rpc_ptr;
|
||||
DeviceMesh &device_mesh;
|
||||
ModelParallelStrategy &model_parallel_strategy;
|
||||
uint64_t time_cost, mem, static_mem;
|
||||
|
||||
RPCExecution(RPC *rpc_ptr, DeviceMesh &device_mesh,
|
||||
ModelParallelStrategy &model_parallel_strategy, uint64_t time_cost, uint64_t mem,
|
||||
uint64_t static_mem);
|
||||
|
||||
std::string to_string();
|
||||
};
|
||||
|
||||
class OverlapGroup {
|
||||
public:
|
||||
std::unordered_set<RPCExecution *> rpc_executions;
|
||||
std::unordered_set<DeviceMesh *> device_meshes;
|
||||
uint64_t mem_static;
|
||||
uint64_t mem_active;
|
||||
|
||||
bool maybe_add(RPCExecution *rpc_exe);
|
||||
};
|
||||
|
||||
class DeviceMeshGroup {
|
||||
public:
|
||||
// std::string device_mesh_name;
|
||||
std::vector<OverlapGroup *> overlap_groups;
|
||||
|
||||
void add_to_groups(RPCExecution *rpc_exe);
|
||||
};
|
||||
|
||||
class GroupedRPCExecutions {
|
||||
public:
|
||||
// std::unordered_map<std::string, DeviceMeshGroup> dn_to_group;
|
||||
DeviceMeshGroup group;
|
||||
|
||||
void add(RPCExecution *rpc_exe);
|
||||
void resolve(RPCExecution *rpc_exe);
|
||||
void offload(std::string model_name);
|
||||
uint64_t total_mem_cost();
|
||||
};
|
||||
|
||||
class RPCInstance {
|
||||
public:
|
||||
RPC *rpc_ptr;
|
||||
int id;
|
||||
std::string name;
|
||||
std::vector<RPCInstance *> children;
|
||||
std::vector<RPCInstance *> parents;
|
||||
std::vector<RPCInstance *> tmp_children;
|
||||
std::vector<RPCInstance *> tmp_parents;
|
||||
std::vector<RPCInstance *> tmp_ris; // pointers to tmp rpc instances
|
||||
std::vector<RPCExecution *> tmp_exes; // pointers to tmp rpc executions
|
||||
|
||||
RPCExecution *rpc_exe_ptr = nullptr;
|
||||
RPCExecution *param_sync_rpc_exe_ptr = nullptr;
|
||||
bool param_sync = false;
|
||||
uint64_t param_sync_size = 0;
|
||||
bool offload = false;
|
||||
uint64_t offload_size = 0;
|
||||
|
||||
RPCInstance(RPC *rpc_ptr, int id, std::string name);
|
||||
|
||||
uint64_t ready_time = 0, start_time = 0, end_time = 0;
|
||||
|
||||
void remove_parent(RPCInstance *parent);
|
||||
void remove_child(RPCInstance *child);
|
||||
void add_parent(RPCInstance *parent);
|
||||
void add_child(RPCInstance *child);
|
||||
|
||||
void add_tmp_parent(RPCInstance *parent);
|
||||
void add_tmp_child(RPCInstance *child);
|
||||
void remove_tmp_parent(RPCInstance *parent);
|
||||
void remove_tmp_child(RPCInstance *child);
|
||||
|
||||
void resolve_parameter_sync(std::vector<RPCInstance *> tmp_graph,
|
||||
std::unordered_map<std::string, uint64_t> &cost_table);
|
||||
// void resolve_offload(std::vector<RPCInstance*> tmp_graph,
|
||||
// CommStats& comm_stats);
|
||||
};
|
||||
|
||||
uint64_t parameter_sync_cost(uint64_t param_size_bytes, RPCExecution *src, RPCExecution *dst,
|
||||
std::unordered_map<std::string, uint64_t> &cost_table);
|
||||
|
||||
uint64_t remote_param_sync_size(uint64_t size, RPCExecution *src, RPCExecution *dst);
|
||||
|
||||
// class ModelConfig {
|
||||
// std::string model_name;
|
||||
// uint64_t param_size_bytes;
|
||||
|
||||
// ModelConfig(std::string model_name, uint64_t param_size_bytes);
|
||||
// };
|
||||
|
||||
#endif
|
|
@ -1,827 +0,0 @@
|
|||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <algorithm>
|
||||
#include <rpc.hpp>
|
||||
#include <device_mesh.hpp>
|
||||
#include <simulate.hpp>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <chrono>
|
||||
#include <limits>
|
||||
#include <fstream>
|
||||
#include <cmath>
|
||||
#include <random>
|
||||
#include <functional>
|
||||
#include <thread>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
uint64_t VALID_COUNT_CAP = 25000000; // 25000000
|
||||
size_t MAX_EXE_PER_RPC = 1000;
|
||||
// std::unordered_map<std::string, DeviceMesh*> device_mesh_map;
|
||||
|
||||
void print_int_vector(std::vector<int> &vec) {
|
||||
std::cout << "[";
|
||||
for (int i = 0; i < static_cast<int>(vec.size()); i++) { std::cout << vec[i] << ", "; }
|
||||
std::cout << "] ";
|
||||
};
|
||||
|
||||
std::size_t vector_hash(const std::vector<int> &vec) {
|
||||
std::size_t seed = vec.size();
|
||||
for (const auto &i : vec) {
|
||||
seed ^= std::hash<int>{}(i) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
|
||||
}
|
||||
return seed;
|
||||
}
|
||||
|
||||
long check_memory_bytes() {
|
||||
std::ifstream statm_file("/proc/self/statm");
|
||||
long memory_usage_bytes = -1;
|
||||
if (!statm_file) {
|
||||
std::cerr << "Failed to open /proc/self/statm\n";
|
||||
} else {
|
||||
long size, resident, share, text, lib, data, dt;
|
||||
statm_file >> size >> resident >> share >> text >> lib >> data >> dt;
|
||||
|
||||
// size is in pages, to convert to bytes, multiply by the page size
|
||||
long page_size = sysconf(_SC_PAGESIZE);
|
||||
memory_usage_bytes = size * page_size;
|
||||
}
|
||||
return memory_usage_bytes;
|
||||
};
|
||||
|
||||
void make_rpc_exe_table(std::unordered_map<std::string, std::vector<RPCExecution *>> &rpc_exe_table,
|
||||
std::vector<RPCExecution *> &rpc_exes) {
|
||||
for (auto &rpc_exe : rpc_exes) {
|
||||
if (rpc_exe_table[rpc_exe->rpc_ptr->rpc_name].size() < MAX_EXE_PER_RPC)
|
||||
rpc_exe_table[rpc_exe->rpc_ptr->rpc_name].push_back(rpc_exe);
|
||||
}
|
||||
|
||||
for (auto &x : rpc_exe_table) {
|
||||
std::string rpc_name = x.first;
|
||||
std::vector<RPCExecution *> &rpc_exe_list = x.second;
|
||||
// sort first
|
||||
std::sort(rpc_exe_list.begin(), rpc_exe_list.end(),
|
||||
[](const RPCExecution *a, const RPCExecution *b) {
|
||||
if (a->time_cost == b->time_cost)
|
||||
return a->device_mesh.name < b->device_mesh.name;
|
||||
else
|
||||
return a->time_cost < b->time_cost;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void make_sorted_rpc_names(
|
||||
std::vector<std::string> &sorted_rpc_names,
|
||||
std::unordered_map<std::string, std::vector<RPCExecution *>> &rpc_exe_table) {
|
||||
std::vector<std::pair<std::string, uint64_t>> average_time_cost;
|
||||
for (auto &x : rpc_exe_table) {
|
||||
std::string rpc_name = x.first;
|
||||
std::vector<RPCExecution *> &rpc_exe_list = x.second;
|
||||
uint64_t total_time_cost = 0;
|
||||
int c = 0;
|
||||
for (auto &rpc_exe : rpc_exe_list) {
|
||||
total_time_cost += rpc_exe->time_cost;
|
||||
c += 1;
|
||||
if (c > 10) break;
|
||||
}
|
||||
average_time_cost.push_back(std::make_pair(rpc_name, total_time_cost / 10));
|
||||
}
|
||||
|
||||
std::sort(average_time_cost.begin(), average_time_cost.end(),
|
||||
[](const std::pair<std::string, float> &a, const std::pair<std::string, float> &b) {
|
||||
return a.second > b.second;
|
||||
});
|
||||
|
||||
for (auto &x : average_time_cost) { sorted_rpc_names.push_back(x.first); }
|
||||
}
|
||||
|
||||
void prepare(std::unordered_map<std::string, std::vector<RPCExecution *>> &rpc_exe_table,
|
||||
std::unordered_map<std::string, RPC *> &rpc_table,
|
||||
std::vector<std::string> &sorted_rpc_names,
|
||||
std::unordered_map<std::string, std::vector<RPCInstance *>> &ri_table,
|
||||
std::unordered_map<std::string, std::vector<RPCInstance *>> &model_name_ri_table,
|
||||
std::vector<RPC *> rpcs, std::vector<RPCExecution *> rpc_exes,
|
||||
std::vector<RPCInstance *> graph) {
|
||||
std::vector<std::pair<std::string, uint64_t>> average_time_cost;
|
||||
for (auto &rpc : rpcs) { rpc_table[rpc->rpc_name] = rpc; }
|
||||
|
||||
make_rpc_exe_table(rpc_exe_table, rpc_exes);
|
||||
make_sorted_rpc_names(sorted_rpc_names, rpc_exe_table);
|
||||
|
||||
for (auto &rpc_instance : graph) {
|
||||
ri_table[rpc_instance->rpc_ptr->rpc_name].push_back(rpc_instance);
|
||||
}
|
||||
|
||||
for (auto &rpc_instance : graph) {
|
||||
model_name_ri_table[rpc_instance->rpc_ptr->model_name].push_back(rpc_instance);
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<SimulateResult> mcmc_search(std::vector<RPC *> rpcs,
|
||||
std::vector<RPCExecution *> rpc_exes,
|
||||
std::vector<RPCInstance *> graph,
|
||||
std::unordered_map<std::string, uint64_t> &cost_table,
|
||||
std::unordered_map<std::string, uint64_t> model_sizes,
|
||||
double beta, double time_limit,
|
||||
MinEndTimeQueue &top_k_queue) {
|
||||
std::unordered_map<std::string, std::vector<RPCExecution *>> rpc_exe_table;
|
||||
std::unordered_map<std::string, RPC *> rpc_table;
|
||||
std::vector<std::string> sorted_rpc_names;
|
||||
std::unordered_map<std::string, std::vector<RPCInstance *>> ri_table;
|
||||
std::unordered_map<std::string, std::vector<RPCInstance *>> model_name_ri_table;
|
||||
std::chrono::duration<double> time_limit_duration(time_limit);
|
||||
std::vector<SimulateResult> time_cost_cache;
|
||||
|
||||
prepare(rpc_exe_table, rpc_table, sorted_rpc_names, ri_table, model_name_ri_table, rpcs, rpc_exes,
|
||||
graph);
|
||||
|
||||
std::vector<int> index;
|
||||
std::vector<int> min_index;
|
||||
std::vector<int> max_index;
|
||||
uint64_t min_index_mem = 0;
|
||||
uint64_t valid_count = 0;
|
||||
uint64_t oom_count = 0;
|
||||
int num_rpcs = static_cast<int>(sorted_rpc_names.size());
|
||||
uint64_t min_time_cost = std::numeric_limits<uint64_t>::max();
|
||||
uint64_t max_time_cost = 0;
|
||||
double avg = 0;
|
||||
|
||||
// [6, 2, 3, 2, 12, 15, ]
|
||||
// [0, 0, 0, 7, 8, 13, ]
|
||||
// index = { 23, 7, 30, 154, 173, 265 };
|
||||
// SimulateResult sr1 = simulate(graph, cost_table, model_sizes,
|
||||
// rpc_table, rpc_exe_table, ri_table,
|
||||
// model_name_ri_table, sorted_rpc_names,
|
||||
// index);
|
||||
// std::cout << "index 1 end time " << sr1.end_time << " mem cost " << sr1.mem_cost << std::endl;
|
||||
// std::cout << "************************" << std::endl;
|
||||
// index = { 1, 0, 0, 11, 20, 3 };
|
||||
// SimulateResult sr2 = simulate(graph, cost_table, model_sizes,
|
||||
// rpc_table, rpc_exe_table, ri_table,
|
||||
// model_name_ri_table, sorted_rpc_names,
|
||||
// index);
|
||||
// std::cout << "index 2 end time " << sr2.end_time << " mem cost " << sr2.mem_cost << std::endl;
|
||||
// exit(0);
|
||||
|
||||
// initial value
|
||||
index.resize(sorted_rpc_names.size(), 0);
|
||||
|
||||
// index 1
|
||||
SimulateResult first_sr = simulate(graph, cost_table, model_sizes, rpc_table, rpc_exe_table,
|
||||
ri_table, model_name_ri_table, sorted_rpc_names, index);
|
||||
SimulateResult final_sr = first_sr;
|
||||
uint64_t current_cost = first_sr.end_time;
|
||||
time_cost_cache.push_back(first_sr);
|
||||
|
||||
if (first_sr.oom)
|
||||
oom_count += 1;
|
||||
else
|
||||
top_k_queue.insert(first_sr);
|
||||
// std::cout << "initial cost " << current_cost << " oom "
|
||||
// << first_sr.oom << std::endl;
|
||||
// index = {0, 1, 3, 34, 4, 10};
|
||||
// std::unordered_map<std::size_t, uint64_t> time_cost_cache;
|
||||
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
// bool outer_loop_break_flag = false;
|
||||
while (valid_count < VALID_COUNT_CAP) {
|
||||
// only change one model execution in each iteration
|
||||
// std::vector<SimulateResult> sr_vector;
|
||||
std::unordered_map<int, std::pair<int, int>> flatten_to_pair;
|
||||
std::vector<double> weight;
|
||||
// double beta = 0.0075;
|
||||
int max_step_range = 10000;
|
||||
int current = 0;
|
||||
|
||||
std::vector<int> new_index(index);
|
||||
for (int i = 0; i < num_rpcs; i++) {
|
||||
std::string rpc_name = sorted_rpc_names[i];
|
||||
int c_i = index[i];
|
||||
int min_i = std::max(0, c_i - max_step_range);
|
||||
int max_i =
|
||||
std::min(static_cast<int>(rpc_exe_table[rpc_name].size()), c_i + max_step_range + 1);
|
||||
|
||||
for (int j = min_i; j < max_i; j++) {
|
||||
if (j == c_i) continue;
|
||||
// int tmp = new_index[i];
|
||||
// new_index[i] = j;
|
||||
// SimulateResult sr = simulate(graph, cost_table, model_sizes,
|
||||
// rpc_table, rpc_exe_table, ri_table,
|
||||
// model_name_ri_table, sorted_rpc_names,
|
||||
// new_index);
|
||||
// sr_vector.push_back(sr);
|
||||
// new_index[i] = tmp;
|
||||
flatten_to_pair[current] = std::make_pair(i, j);
|
||||
current++;
|
||||
}
|
||||
}
|
||||
|
||||
// if (time_cost_cache.size() > 10000000) {
|
||||
// time_cost_cache.clear();
|
||||
// }
|
||||
|
||||
// std::cout << "sr vector size " << sr_vector.size() << std::endl;
|
||||
// for (int i = 0; i < static_cast<int>(sr_vector.size()); i++) {
|
||||
// weight.push_back(std::exp(-beta * (sr_vector[i].end_time/100000)));
|
||||
// }
|
||||
// exit(0);
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
// std::discrete_distribution<int> d(weight.begin(), weight.end());
|
||||
std::uniform_int_distribution<int> d(0, static_cast<int>(flatten_to_pair.size() - 1));
|
||||
int selected = d(gen);
|
||||
|
||||
// assign new index
|
||||
int selected_i = flatten_to_pair[selected].first;
|
||||
int selected_j = flatten_to_pair[selected].second;
|
||||
new_index[selected_i] = selected_j;
|
||||
|
||||
SimulateResult selected_sr =
|
||||
simulate(graph, cost_table, model_sizes, rpc_table, rpc_exe_table, ri_table,
|
||||
model_name_ri_table, sorted_rpc_names, new_index);
|
||||
uint64_t selected_cost = selected_sr.end_time;
|
||||
// if (selected_sr.oom) {
|
||||
// std::cout << "oom max end time " << selected_cost << std::endl;
|
||||
// } else {
|
||||
// std::cout << "max end time " << selected_cost << std::endl;
|
||||
// }
|
||||
|
||||
if (selected_cost < std::numeric_limits<uint64_t>::max()) {
|
||||
bool accepted = true;
|
||||
if (current_cost < selected_cost) {
|
||||
double accept_prob = std::exp(-beta * ((selected_cost - current_cost) / 100000));
|
||||
// double accept_prob = ap * ap;
|
||||
std::bernoulli_distribution accept_dist(accept_prob);
|
||||
accepted = accept_dist(gen);
|
||||
}
|
||||
|
||||
if (accepted) {
|
||||
// std::cout << "accepted" << std::endl;
|
||||
index = new_index;
|
||||
current_cost = selected_cost;
|
||||
valid_count++;
|
||||
if (!selected_sr.oom) {
|
||||
// min_time_cost = std::min(min_time_cost, selected_cost);
|
||||
// max_time_cost = std::max(max_time_cost, selected_cost);
|
||||
avg = (selected_cost + avg * (valid_count - 1)) / valid_count;
|
||||
if (min_time_cost > selected_cost) {
|
||||
min_time_cost = selected_cost;
|
||||
min_index = index;
|
||||
min_index_mem = selected_sr.mem_cost;
|
||||
// final_sr = selected_sr;
|
||||
auto now = std::chrono::high_resolution_clock::now();
|
||||
double diff =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(now - start).count();
|
||||
selected_sr.used_time = diff;
|
||||
time_cost_cache.push_back(selected_sr);
|
||||
}
|
||||
if (min_time_cost == selected_cost) {
|
||||
if (min_index_mem > selected_sr.mem_cost) {
|
||||
min_index = index;
|
||||
min_index_mem = selected_sr.mem_cost;
|
||||
// final_sr = selected_sr;
|
||||
}
|
||||
}
|
||||
top_k_queue.insert(selected_sr);
|
||||
if (max_time_cost < selected_cost) {
|
||||
max_time_cost = selected_cost;
|
||||
max_index = index;
|
||||
}
|
||||
} else {
|
||||
oom_count += 1;
|
||||
}
|
||||
// if (min_time_cost <= 100000000) break; // DEBUG
|
||||
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
std::chrono::duration<double> diff = end - start;
|
||||
|
||||
if (valid_count % 1000 == 0) {
|
||||
std::cout << " valid_count " << valid_count << " oom count " << oom_count << " time "
|
||||
<< diff.count() << " min time cost " << min_time_cost << " min index mem cost "
|
||||
<< min_index_mem << " max time cost " << max_time_cost << " current cost "
|
||||
<< current_cost << " avg " << avg << " mem usage " << check_memory_bytes()
|
||||
<< std::endl;
|
||||
|
||||
std::cout << "min index : ";
|
||||
print_int_vector(min_index);
|
||||
std::cout << std::endl;
|
||||
std::cout << "max index : ";
|
||||
print_int_vector(max_index);
|
||||
std::cout << std::endl;
|
||||
std::cout << "current index : ";
|
||||
print_int_vector(index);
|
||||
std::cout << std::endl;
|
||||
}
|
||||
if (diff > time_limit_duration) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// std::cout << "break" << std::endl;
|
||||
|
||||
// int rpc_index = 0;
|
||||
// for (int index : final_sr.index) {
|
||||
// // std::cout << "index " << index << " rpc_index " << rpc_index << std::endl;
|
||||
// // std::cout << "rpc name " << sorted_rpc_names[rpc_index] << std::endl;
|
||||
// RPCExecution* re_ptr = rpc_exe_table[sorted_rpc_names[rpc_index]][index];
|
||||
// final_sr.rpc_exe_list.push_back(re_ptr);
|
||||
// rpc_index ++;
|
||||
// }
|
||||
// std::cout << "final_sr.rpc_exe_list size " << final_sr.rpc_exe_list.size() << std::endl;
|
||||
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
std::chrono::duration<double> diff = end - start;
|
||||
std::cout << "MCMC Search finished, beta " << beta << " time limit " << time_limit << " seconds"
|
||||
<< " valid_count " << valid_count << " oom_count " << oom_count << " time "
|
||||
<< diff.count() << " min time cost " << min_time_cost << " max time cost "
|
||||
<< max_time_cost << " avg " << avg << std::endl;
|
||||
std::cout << "RPCExecutions: " << std::endl;
|
||||
for (auto &re_ptr : final_sr.rpc_exe_list) { std::cout << re_ptr->to_string() << std::endl; }
|
||||
return time_cost_cache;
|
||||
// return final_sr;
|
||||
}
|
||||
|
||||
void multi_mcmc_search(std::vector<RPC *> rpcs, std::vector<RPCExecution *> rpc_exes,
|
||||
std::vector<RPCInstance *> graph,
|
||||
std::unordered_map<std::string, uint64_t> &cost_table,
|
||||
std::unordered_map<std::string, uint64_t> model_sizes, double beta_min,
|
||||
double beta_max, double beta_step, MinEndTimeQueue &res_queue,
|
||||
int top_k = 10, double time_limit = 60.0,
|
||||
int repeat = 1 // Remove the trailing comma here
|
||||
) {
|
||||
SimulateResult sr;
|
||||
std::vector<MinEndTimeQueue *> queues;
|
||||
// std::vector<std::thread> ts;
|
||||
for (int i = 0; i < repeat; i++) {
|
||||
for (double beta = beta_min; beta < beta_max; beta += beta_step) {
|
||||
MinEndTimeQueue *q = new MinEndTimeQueue(10);
|
||||
queues.push_back(q);
|
||||
|
||||
// Create a new thread to run mcmc_search
|
||||
std::vector<SimulateResult> r =
|
||||
mcmc_search(rpcs, rpc_exes, graph, cost_table, model_sizes, beta, time_limit, *q);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &q : queues) {
|
||||
mergeMinEndTimeQueues(res_queue, *q);
|
||||
// delete q;
|
||||
}
|
||||
|
||||
// std::cout << "Best result: " << sr.end_time << std::endl;
|
||||
// for (auto& re_ptr : sr.rpc_exe_list) {
|
||||
// std::cout << re_ptr -> to_string() << std::endl;
|
||||
// }
|
||||
// std::cout << "Index: ";
|
||||
// for (int i : sr.index) {
|
||||
// std::cout << i << " ";
|
||||
// }
|
||||
// std::cout << std::endl;
|
||||
// std::cout << "Time cost: " << sr.end_time
|
||||
// << ", mem cost: " << sr.mem_cost << std::endl;
|
||||
// std::cout << "sr.rpc_exe_list size " << sr.rpc_exe_list.size() << std::endl;
|
||||
// return sr;
|
||||
}
|
||||
|
||||
void input_check(std::vector<RPC *> rpcs, std::vector<RPCExecution *> rpc_exes,
|
||||
std::vector<RPCInstance *> graph, CommStats &comm_stats) {
|
||||
std::cout << "rpcs: " << rpcs.size() << std::endl;
|
||||
std::cout << "rpc_exes: " << rpc_exes.size() << std::endl;
|
||||
std::cout << "graph: " << graph.size() << std::endl;
|
||||
|
||||
for (auto &rpc_instance : graph) {
|
||||
std::cout << "==================" << std::endl;
|
||||
std::cout << "rpc instance: " << rpc_instance->name << std::endl;
|
||||
std::cout << "parents" << std::endl;
|
||||
for (auto &parent : rpc_instance->parents) { std::cout << parent->name << " "; }
|
||||
std::cout << std::endl;
|
||||
std::cout << "children" << std::endl;
|
||||
for (auto &child : rpc_instance->children) { std::cout << child->name << " "; }
|
||||
std::cout << std::endl;
|
||||
// std::cout << "parents: " << rpc_instance -> parents.size()
|
||||
// << " children: " << rpc_instance -> children.size() << std::endl;
|
||||
}
|
||||
|
||||
std::cout << "comm_stats: " << std::endl;
|
||||
std::cout << "local_send: " << comm_stats.local_send << std::endl;
|
||||
std::cout << "local_recv: " << comm_stats.local_recv << std::endl;
|
||||
std::cout << "remote_send: " << comm_stats.remote_send << std::endl;
|
||||
std::cout << "remote_recv: " << comm_stats.remote_recv << std::endl;
|
||||
std::cout << "offload_store: " << comm_stats.offload_store << std::endl;
|
||||
std::cout << "offload_load: " << comm_stats.offload_load << std::endl;
|
||||
}
|
||||
|
||||
RPC *cast_rpc(py::handle rpc_py) {
|
||||
return new RPC(py::str(rpc_py.attr("model_name")).cast<std::string>(),
|
||||
rpc_py.attr("name").cast<std::string>(),
|
||||
py::str(rpc_py.attr("interface_type")).cast<std::string>());
|
||||
}
|
||||
|
||||
DeviceMesh *cast_device_mesh(py::handle device_mesh_py,
|
||||
std::unordered_map<std::string, DeviceMesh *> &device_mesh_map) {
|
||||
std::string name = device_mesh_py.attr("name").cast<std::string>();
|
||||
if (device_mesh_map.find(name) == device_mesh_map.end()) {
|
||||
py::array_t<int32_t> mapping_array =
|
||||
device_mesh_py.attr("mapping").cast<pybind11::array_t<int32_t>>();
|
||||
py::buffer_info buf_info = mapping_array.request();
|
||||
|
||||
auto rows = buf_info.shape[0];
|
||||
auto cols = buf_info.shape[1];
|
||||
|
||||
std::vector<std::vector<int>> mapping(rows, std::vector<int>(cols));
|
||||
|
||||
// Get a pointer to the data
|
||||
int32_t *data = static_cast<int32_t *>(buf_info.ptr);
|
||||
|
||||
// Fill the 2D vector with data from the numpy array
|
||||
for (size_t i = 0; i < static_cast<size_t>(rows); ++i) {
|
||||
for (size_t j = 0; j < static_cast<size_t>(cols); ++j) { mapping[i][j] = data[i * cols + j]; }
|
||||
}
|
||||
|
||||
DeviceMesh *device_mesh =
|
||||
new DeviceMesh(device_mesh_py.attr("n_nodes").cast<int>(),
|
||||
device_mesh_py.attr("n_gpus_per_node").cast<int>(), mapping,
|
||||
device_mesh_py.attr("global_mesh_name").cast<std::string>(),
|
||||
device_mesh_py.attr("name").cast<std::string>());
|
||||
|
||||
device_mesh_map[name] = device_mesh;
|
||||
return device_mesh;
|
||||
} else {
|
||||
return device_mesh_map[name];
|
||||
}
|
||||
}
|
||||
|
||||
ModelParallelStrategy *cast_model_parallel_strategy(py::handle model_parallel_strategy_py) {
|
||||
return new ModelParallelStrategy(
|
||||
model_parallel_strategy_py.attr("pipeline_parallel_size").cast<int>(),
|
||||
model_parallel_strategy_py.attr("data_parallel_size").cast<int>(),
|
||||
model_parallel_strategy_py.attr("model_parallel_size").cast<int>());
|
||||
}
|
||||
|
||||
RPCExecution *cast_rpc_execution(py::handle rpc_exe_py, std::unordered_map<std::string, RPC *> &tmp,
|
||||
std::unordered_map<std::string, DeviceMesh *> &device_mesh_map) {
|
||||
DeviceMesh *device_mesh = cast_device_mesh(rpc_exe_py.attr("device_mesh"), device_mesh_map);
|
||||
ModelParallelStrategy *model_parallel_strategy =
|
||||
cast_model_parallel_strategy(rpc_exe_py.attr("parallel_strategy"));
|
||||
// RPC* rpc = cast_rpc(rpc_exe_py.attr("rpc"));
|
||||
|
||||
return new RPCExecution(
|
||||
tmp[rpc_exe_py.attr("rpc").attr("name").cast<std::string>()], *device_mesh,
|
||||
*model_parallel_strategy, rpc_exe_py.attr("time_cost").cast<uint64_t>(),
|
||||
rpc_exe_py.attr("mem").cast<uint64_t>(), rpc_exe_py.attr("static_mem").cast<uint64_t>());
|
||||
}
|
||||
|
||||
RPCInstance *cast_rpc_instance_wo_dependency(py::handle rpc_instance_py,
|
||||
std::unordered_map<std::string, RPC *> &tmp) {
|
||||
return new RPCInstance(tmp[rpc_instance_py.attr("rpc").attr("name").cast<std::string>()],
|
||||
rpc_instance_py.attr("iteration_id").cast<int>(),
|
||||
rpc_instance_py.attr("name").cast<std::string>());
|
||||
}
|
||||
|
||||
void cast_rpc_instance_dependency(py::handle rpc_instance_py, RPCInstance *ri_ptr,
|
||||
std::unordered_map<std::string, RPCInstance *> &tmp_graph) {
|
||||
for (py::handle parent_py : rpc_instance_py.attr("parents"))
|
||||
ri_ptr->parents.push_back(tmp_graph[parent_py.attr("name").cast<std::string>()]);
|
||||
for (py::handle child_py : rpc_instance_py.attr("children"))
|
||||
ri_ptr->children.push_back(tmp_graph[child_py.attr("name").cast<std::string>()]);
|
||||
}
|
||||
|
||||
py::list py_single_mcmc_search_time_profile(py::list rpcs_py, py::list rpc_exes_py,
|
||||
py::list graph_py, py::dict cost_table_py,
|
||||
py::dict model_sizes_py, py::object beta,
|
||||
py::object time_limit) {
|
||||
std::vector<RPC *> rpcs;
|
||||
std::unordered_map<std::string, RPC *> tmp;
|
||||
for (py::handle rpc_py : rpcs_py) {
|
||||
RPC *rpc_ptr = cast_rpc(rpc_py);
|
||||
rpcs.push_back(rpc_ptr);
|
||||
tmp[rpc_ptr->rpc_name] = rpc_ptr;
|
||||
}
|
||||
|
||||
std::vector<RPCExecution *> rpc_exes;
|
||||
std::unordered_map<std::string, DeviceMesh *> tmp_device_mesh;
|
||||
for (py::handle rpc_exe_py : rpc_exes_py) {
|
||||
RPCExecution *rpc_exe_ptr = cast_rpc_execution(rpc_exe_py, tmp, tmp_device_mesh);
|
||||
rpc_exes.push_back(rpc_exe_ptr);
|
||||
}
|
||||
|
||||
std::vector<RPCInstance *> graph;
|
||||
std::unordered_map<std::string, RPCInstance *> tmp_graph;
|
||||
for (py::handle ri_py : graph_py) {
|
||||
RPCInstance *ri_ptr = cast_rpc_instance_wo_dependency(ri_py, tmp);
|
||||
// std::cout << "cast " << ri_ptr -> name << std::endl;
|
||||
tmp_graph[ri_ptr->name] = ri_ptr;
|
||||
}
|
||||
// build dependecny
|
||||
for (py::handle ri_py : graph_py) {
|
||||
std::string ri_name = ri_py.attr("name").cast<std::string>();
|
||||
cast_rpc_instance_dependency(ri_py, tmp_graph[ri_name], tmp_graph);
|
||||
graph.push_back(tmp_graph[ri_name]);
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, uint64_t> cost_table =
|
||||
cost_table_py.cast<std::unordered_map<std::string, uint64_t>>();
|
||||
|
||||
std::unordered_map<std::string, uint64_t> model_sizes =
|
||||
model_sizes_py.cast<std::unordered_map<std::string, uint64_t>>();
|
||||
MinEndTimeQueue res_queue(10);
|
||||
std::vector<SimulateResult> rlist =
|
||||
mcmc_search(rpcs, rpc_exes, graph, cost_table, model_sizes, beta.cast<double>(),
|
||||
time_limit.cast<double>(), res_queue);
|
||||
|
||||
std::unordered_map<std::string, std::vector<RPCExecution *>> rpc_exe_table;
|
||||
std::vector<std::string> sorted_rpc_names;
|
||||
make_rpc_exe_table(rpc_exe_table, rpc_exes);
|
||||
make_sorted_rpc_names(sorted_rpc_names, rpc_exe_table);
|
||||
py::list result;
|
||||
std::cout << "rlist.size " << rlist.size() << std::endl;
|
||||
for (auto &r : rlist) {
|
||||
// SimulateResult r = res_queue.getQueue().top();
|
||||
// res_queue.getQueue().pop();
|
||||
|
||||
std::cout << "End time: " << r.end_time << std::endl;
|
||||
for (auto &re_ptr : r.rpc_exe_list) { std::cout << re_ptr->to_string() << std::endl; }
|
||||
std::cout << "Index: ";
|
||||
for (int i : r.index) { std::cout << i << " "; }
|
||||
std::cout << std::endl;
|
||||
std::cout << "Time cost: " << r.end_time << ", mem cost: " << r.mem_cost << std::endl;
|
||||
|
||||
int rpc_index = 0;
|
||||
for (int index : r.index) {
|
||||
// std::cout << "index " << index << " rpc_index " << rpc_index << std::endl;
|
||||
// std::cout << "rpc name " << sorted_rpc_names[rpc_index] << std::endl;
|
||||
RPCExecution *re_ptr = rpc_exe_table[sorted_rpc_names[rpc_index]][index];
|
||||
r.rpc_exe_list.push_back(re_ptr);
|
||||
rpc_index++;
|
||||
}
|
||||
|
||||
py::dict rdict;
|
||||
for (auto &re_ptr : r.rpc_exe_list) {
|
||||
py::dict rpc_exe_info;
|
||||
std::string rpc_name = re_ptr->rpc_ptr->rpc_name;
|
||||
py::object rpc_name_obj = py::str(rpc_name);
|
||||
// rpc_exe_info.append(re_ptr -> device_mesh.device_mesh_name);
|
||||
// rpc_exe_info.append(re_ptr -> model_parallel_strategy.num_dp);
|
||||
// rpc_exe_info.append(re_ptr -> model_parallel_strategy.num_mp);
|
||||
// rpc_exe_info.append(re_ptr -> model_parallel_strategy.num_pp);
|
||||
rpc_exe_info["device_mesh"] = re_ptr->device_mesh.name;
|
||||
rpc_exe_info["num_dp"] = re_ptr->model_parallel_strategy.num_dp;
|
||||
rpc_exe_info["num_mp"] = re_ptr->model_parallel_strategy.num_mp;
|
||||
rpc_exe_info["num_pp"] = re_ptr->model_parallel_strategy.num_pp;
|
||||
rdict[rpc_name_obj] = rpc_exe_info;
|
||||
// std::cout << "append key " << rpc_name_obj << std::endl;
|
||||
}
|
||||
rdict["end_time"] = r.end_time;
|
||||
rdict["mem_cost"] = r.mem_cost;
|
||||
rdict["used_time"] = r.used_time;
|
||||
result.append(rdict);
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
py::list py_multi_mcmc_search(py::list rpcs_py, py::list rpc_exes_py, py::list graph_py,
|
||||
py::dict cost_table_py, py::dict model_sizes_py,
|
||||
py::object beta_min_py, py::object beta_max_py,
|
||||
py::object beta_step_py, py::object time_limit_py,
|
||||
py::object repeat) {
|
||||
std::vector<RPC *> rpcs;
|
||||
std::unordered_map<std::string, RPC *> tmp;
|
||||
for (py::handle rpc_py : rpcs_py) {
|
||||
RPC *rpc_ptr = cast_rpc(rpc_py);
|
||||
rpcs.push_back(rpc_ptr);
|
||||
tmp[rpc_ptr->rpc_name] = rpc_ptr;
|
||||
}
|
||||
|
||||
std::vector<RPCExecution *> rpc_exes;
|
||||
std::unordered_map<std::string, DeviceMesh *> tmp_device_mesh;
|
||||
for (py::handle rpc_exe_py : rpc_exes_py) {
|
||||
RPCExecution *rpc_exe_ptr = cast_rpc_execution(rpc_exe_py, tmp, tmp_device_mesh);
|
||||
rpc_exes.push_back(rpc_exe_ptr);
|
||||
}
|
||||
|
||||
std::vector<RPCInstance *> graph;
|
||||
std::unordered_map<std::string, RPCInstance *> tmp_graph;
|
||||
for (py::handle ri_py : graph_py) {
|
||||
RPCInstance *ri_ptr = cast_rpc_instance_wo_dependency(ri_py, tmp);
|
||||
// std::cout << "cast " << ri_ptr -> name << std::endl;
|
||||
tmp_graph[ri_ptr->name] = ri_ptr;
|
||||
}
|
||||
// build dependecny
|
||||
for (py::handle ri_py : graph_py) {
|
||||
std::string ri_name = ri_py.attr("name").cast<std::string>();
|
||||
cast_rpc_instance_dependency(ri_py, tmp_graph[ri_name], tmp_graph);
|
||||
graph.push_back(tmp_graph[ri_name]);
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, uint64_t> cost_table =
|
||||
cost_table_py.cast<std::unordered_map<std::string, uint64_t>>();
|
||||
|
||||
std::unordered_map<std::string, uint64_t> model_sizes =
|
||||
model_sizes_py.cast<std::unordered_map<std::string, uint64_t>>();
|
||||
|
||||
double beta_min = beta_min_py.cast<double>();
|
||||
double beta_max = beta_max_py.cast<double>();
|
||||
double beta_step = beta_step_py.cast<double>();
|
||||
double time_limit = time_limit_py.cast<double>();
|
||||
int rp = repeat.cast<int>();
|
||||
|
||||
MinEndTimeQueue res_queue(10);
|
||||
multi_mcmc_search(rpcs, rpc_exes, graph, cost_table, model_sizes, beta_min, beta_max, beta_step,
|
||||
res_queue, 10, time_limit, rp);
|
||||
|
||||
std::unordered_map<std::string, std::vector<RPCExecution *>> rpc_exe_table;
|
||||
std::vector<std::string> sorted_rpc_names;
|
||||
make_rpc_exe_table(rpc_exe_table, rpc_exes);
|
||||
make_sorted_rpc_names(sorted_rpc_names, rpc_exe_table);
|
||||
|
||||
// std::cout << "r.rpc_exe_list size " << r.rpc_exe_list.size() << std::endl;
|
||||
// for (int rpc_index = 0; rpc_index < 6; rpc_index++) {
|
||||
// std::cout << "rpc name " << sorted_rpc_names[rpc_index] << std::endl;
|
||||
// int index = 0;
|
||||
// for (auto& re_ptr : rpc_exe_table[sorted_rpc_names[rpc_index]]) {
|
||||
// std::cout << "index " << index << " " << re_ptr -> to_string() << std::endl;
|
||||
// index ++;
|
||||
// }
|
||||
// }
|
||||
|
||||
py::list result;
|
||||
std::cout << "res_queue.getQueue().size " << res_queue.getQueue().size() << std::endl;
|
||||
while (!res_queue.getQueue().empty()) {
|
||||
SimulateResult r = res_queue.getQueue().top();
|
||||
res_queue.getQueue().pop();
|
||||
|
||||
std::cout << "End time: " << r.end_time << std::endl;
|
||||
for (auto &re_ptr : r.rpc_exe_list) { std::cout << re_ptr->to_string() << std::endl; }
|
||||
std::cout << "Index: ";
|
||||
for (int i : r.index) { std::cout << i << " "; }
|
||||
std::cout << std::endl;
|
||||
std::cout << "Time cost: " << r.end_time << ", mem cost: " << r.mem_cost << std::endl;
|
||||
|
||||
int rpc_index = 0;
|
||||
for (int index : r.index) {
|
||||
// std::cout << "index " << index << " rpc_index " << rpc_index << std::endl;
|
||||
// std::cout << "rpc name " << sorted_rpc_names[rpc_index] << std::endl;
|
||||
RPCExecution *re_ptr = rpc_exe_table[sorted_rpc_names[rpc_index]][index];
|
||||
r.rpc_exe_list.push_back(re_ptr);
|
||||
rpc_index++;
|
||||
}
|
||||
|
||||
py::dict rdict;
|
||||
for (auto &re_ptr : r.rpc_exe_list) {
|
||||
py::dict rpc_exe_info;
|
||||
std::string rpc_name = re_ptr->rpc_ptr->rpc_name;
|
||||
py::object rpc_name_obj = py::str(rpc_name);
|
||||
// convert device mesh mapping into py::array_t<int>
|
||||
std::vector<std::vector<int>> mapping = re_ptr->device_mesh.mapping;
|
||||
int rows = mapping.size();
|
||||
int cols = mapping[0].size();
|
||||
|
||||
py::array_t<int> numpy_array({rows, cols});
|
||||
for (int i = 0; i < rows; ++i) {
|
||||
for (int j = 0; j < cols; ++j) {
|
||||
*numpy_array.mutable_data(i, j) = mapping[i][j];
|
||||
// std::cout << i << j << mapping[i][j] << std::endl;
|
||||
}
|
||||
}
|
||||
// store in py::dict
|
||||
rpc_exe_info["device_mesh_mapping"] = numpy_array;
|
||||
rpc_exe_info["device_mesh_name"] = re_ptr->device_mesh.name;
|
||||
rpc_exe_info["num_dp"] = re_ptr->model_parallel_strategy.num_dp;
|
||||
rpc_exe_info["num_mp"] = re_ptr->model_parallel_strategy.num_mp;
|
||||
rpc_exe_info["num_pp"] = re_ptr->model_parallel_strategy.num_pp;
|
||||
rdict[rpc_name_obj] = rpc_exe_info;
|
||||
// std::cout << "append key " << rpc_name_obj << std::endl;
|
||||
}
|
||||
rdict["end_time"] = r.end_time;
|
||||
rdict["mem_cost"] = r.mem_cost;
|
||||
result.append(rdict);
|
||||
}
|
||||
|
||||
return result;
|
||||
};
|
||||
|
||||
PYBIND11_MODULE(mdm_search, m) {
|
||||
m.doc() = "model device mapping search module";
|
||||
|
||||
// for debug
|
||||
// m.def("mcmc_search", [](py::list rpcs_py, py::list rpc_exes_py,
|
||||
// py::list graph_py, py::object comm_stats_py,
|
||||
// py::dict model_sizes_py, py::object beta_py,
|
||||
// py::object time_limit_py) {
|
||||
// std::vector<RPC*> rpcs;
|
||||
// std::unordered_map<std::string, RPC*> tmp;
|
||||
// for (py::handle rpc_py : rpcs_py) {
|
||||
// RPC* rpc_ptr = cast_rpc(rpc_py);
|
||||
// rpcs.push_back(rpc_ptr);
|
||||
// tmp[rpc_ptr -> rpc_name] = rpc_ptr;
|
||||
// }
|
||||
|
||||
// std::vector<RPCExecution*> rpc_exes;
|
||||
// for (py::handle rpc_exe_py : rpc_exes_py) {
|
||||
// RPCExecution* rpc_exe_ptr = cast_rpc_execution(rpc_exe_py, tmp);
|
||||
// rpc_exes.push_back(rpc_exe_ptr);
|
||||
// }
|
||||
|
||||
// std::vector<RPCInstance*> graph;
|
||||
// std::unordered_map<std::string, RPCInstance*> tmp_graph;
|
||||
// for (py::handle ri_py : graph_py) {
|
||||
// RPCInstance* ri_ptr = cast_rpc_instance_wo_dependency(ri_py, tmp);
|
||||
// std::cout << "cast " << ri_ptr -> name << std::endl;
|
||||
// tmp_graph[ri_ptr -> name] = ri_ptr;
|
||||
// }
|
||||
// // build dependecny
|
||||
// for (py::handle ri_py : graph_py) {
|
||||
// std::string ri_name = ri_py.attr("name").cast<std::string>();
|
||||
// cast_rpc_instance_dependency(ri_py, tmp_graph[ri_name], tmp_graph);
|
||||
// graph.push_back(tmp_graph[ri_name]);
|
||||
// }
|
||||
|
||||
// CommStats comm_stats(
|
||||
// comm_stats_py.attr("local_send").cast<uint64_t>(),
|
||||
// comm_stats_py.attr("local_recv").cast<uint64_t>(),
|
||||
// comm_stats_py.attr("remote_send").cast<uint64_t>(),
|
||||
// comm_stats_py.attr("remote_recv").cast<uint64_t>(),
|
||||
// comm_stats_py.attr("offload_load").cast<uint64_t>(),
|
||||
// comm_stats_py.attr("offload_store").cast<uint64_t>()
|
||||
// );
|
||||
|
||||
// std::unordered_map<std::string, uint64_t> model_sizes
|
||||
// = model_sizes_py.cast<std::unordered_map<std::string, uint64_t>>();
|
||||
|
||||
// double beta = beta_py.cast<double>();
|
||||
// double time_limit = time_limit_py.cast<double>();
|
||||
|
||||
// mcmc_search(rpcs, rpc_exes, graph,
|
||||
// comm_stats, model_sizes, beta,
|
||||
// time_limit);
|
||||
// });
|
||||
|
||||
m.def("input_check",
|
||||
[](py::list rpcs_py, py::list rpc_exes_py, py::list graph_py, py::object comm_stats_py) {
|
||||
std::vector<RPC *> rpcs;
|
||||
std::unordered_map<std::string, RPC *> tmp;
|
||||
for (py::handle rpc_py : rpcs_py) {
|
||||
RPC *rpc_ptr = cast_rpc(rpc_py);
|
||||
rpcs.push_back(rpc_ptr);
|
||||
tmp[rpc_ptr->rpc_name] = rpc_ptr;
|
||||
}
|
||||
|
||||
std::vector<RPCExecution *> rpc_exes;
|
||||
std::unordered_map<std::string, DeviceMesh *> tmp_device_mesh;
|
||||
for (py::handle rpc_exe_py : rpc_exes_py) {
|
||||
RPCExecution *rpc_exe_ptr = cast_rpc_execution(rpc_exe_py, tmp, tmp_device_mesh);
|
||||
rpc_exes.push_back(rpc_exe_ptr);
|
||||
}
|
||||
|
||||
std::vector<RPCInstance *> graph;
|
||||
std::unordered_map<std::string, RPCInstance *> tmp_graph;
|
||||
for (py::handle ri_py : graph_py) {
|
||||
RPCInstance *ri_ptr = cast_rpc_instance_wo_dependency(ri_py, tmp);
|
||||
std::cout << "cast " << ri_ptr->name << std::endl;
|
||||
tmp_graph[ri_ptr->name] = ri_ptr;
|
||||
}
|
||||
// build dependecny
|
||||
for (py::handle ri_py : graph_py) {
|
||||
std::string ri_name = ri_py.attr("name").cast<std::string>();
|
||||
cast_rpc_instance_dependency(ri_py, tmp_graph[ri_name], tmp_graph);
|
||||
graph.push_back(tmp_graph[ri_name]);
|
||||
}
|
||||
|
||||
CommStats comm_stats(comm_stats_py.attr("local_send").cast<uint64_t>(),
|
||||
comm_stats_py.attr("local_recv").cast<uint64_t>(),
|
||||
comm_stats_py.attr("remote_send").cast<uint64_t>(),
|
||||
comm_stats_py.attr("remote_recv").cast<uint64_t>(),
|
||||
comm_stats_py.attr("offload_load").cast<uint64_t>(),
|
||||
comm_stats_py.attr("offload_store").cast<uint64_t>());
|
||||
|
||||
input_check(rpcs, rpc_exes, graph, comm_stats);
|
||||
});
|
||||
|
||||
// mcmc search to py result
|
||||
m.def("multi_mcmc_search", &py_multi_mcmc_search);
|
||||
|
||||
m.def("parameter_sync_cost", [](py::object rpcs_py, py::object param_size_bytes_py,
|
||||
py::dict cost_table_py, py::object src_py, py::object dst_py) {
|
||||
uint64_t param_size_bytes = param_size_bytes_py.cast<uint64_t>();
|
||||
std::unordered_map<std::string, uint64_t> cost_table =
|
||||
cost_table_py.cast<std::unordered_map<std::string, uint64_t>>();
|
||||
std::vector<RPC *> rpcs;
|
||||
std::unordered_map<std::string, RPC *> tmp;
|
||||
for (py::handle rpc_py : rpcs_py) {
|
||||
RPC *rpc_ptr = cast_rpc(rpc_py);
|
||||
rpcs.push_back(rpc_ptr);
|
||||
tmp[rpc_ptr->rpc_name] = rpc_ptr;
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, DeviceMesh *> tmp_device_mesh;
|
||||
RPCExecution *src = cast_rpc_execution(src_py, tmp, tmp_device_mesh);
|
||||
RPCExecution *dst = cast_rpc_execution(dst_py, tmp, tmp_device_mesh);
|
||||
|
||||
return parameter_sync_cost(param_size_bytes, src, dst, cost_table);
|
||||
});
|
||||
|
||||
m.def("mcmc_search_time_profile", &py_single_mcmc_search_time_profile);
|
||||
};
|
|
@ -1,244 +0,0 @@
|
|||
#include <rpc.hpp>
|
||||
#include <device_mesh.hpp>
|
||||
#include <simulate.hpp>
|
||||
#include <iostream>
|
||||
#include <queue>
|
||||
#include <algorithm>
|
||||
#include <iomanip>
|
||||
#include <limits>
|
||||
#include <chrono>
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
|
||||
uint64_t SOFT_GPU_MEM_CAP = 85899345920; // 80G
|
||||
|
||||
SimulateResult::SimulateResult()
|
||||
: end_time(std::numeric_limits<uint64_t>::max()), oom(true), mem_cost(0) {}
|
||||
|
||||
SimulateResult::SimulateResult(uint64_t end_time, bool oom, uint64_t mem_cost,
|
||||
std::vector<int> &index)
|
||||
: end_time(end_time), oom(oom), mem_cost(mem_cost), index(index) {}
|
||||
|
||||
SimulateResult simulate(
|
||||
std::vector<RPCInstance *> &graph, std::unordered_map<std::string, uint64_t> &cost_table,
|
||||
std::unordered_map<std::string, uint64_t> &model_sizes,
|
||||
std::unordered_map<std::string, RPC *> &rpc_table,
|
||||
std::unordered_map<std::string, std::vector<RPCExecution *>> &rpc_exe_table,
|
||||
std::unordered_map<std::string, std::vector<RPCInstance *>> &ri_table,
|
||||
std::unordered_map<std::string, std::vector<RPCInstance *>> &model_name_ri_table,
|
||||
std::vector<std::string> &sorted_rpc_names, std::vector<int> &index) {
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
GroupedRPCExecutions grouped_rpc_exe;
|
||||
// std::unordered_map<std::string, RPCExecution*> param_dst; // model_name -> rpc_exe_ptr
|
||||
std::unordered_set<std::string> offloaded;
|
||||
uint64_t oom_penalty = 3;
|
||||
int num_rpcs = static_cast<int>(sorted_rpc_names.size());
|
||||
|
||||
for (int i = 0; i < num_rpcs; i++) {
|
||||
std::string rpc_name = sorted_rpc_names[i];
|
||||
RPC *rpc = rpc_table[rpc_name];
|
||||
RPCExecution *rpc_exe = rpc_exe_table[rpc_name][index[i]];
|
||||
|
||||
for (auto &ri : ri_table[rpc_name]) {
|
||||
ri->rpc_exe_ptr = rpc_exe; // assign rpc_exe to rpc_instance
|
||||
}
|
||||
|
||||
// dirty implementation, check whether instance is hooked with param syncs
|
||||
for (auto &rpc_instance : model_name_ri_table[rpc->model_name]) {
|
||||
if (rpc_instance->rpc_ptr->interface_type == "ModelInterfaceType.TRAIN_STEP"
|
||||
&& rpc->interface_type != "ModelInterfaceType.TRAIN_STEP") {
|
||||
// param_dst[rpc -> model_name] = rpc_exe;
|
||||
rpc_instance->param_sync = true;
|
||||
rpc_instance->param_sync_size = model_sizes[rpc_instance->rpc_ptr->model_name];
|
||||
rpc_instance->param_sync_rpc_exe_ptr = rpc_exe;
|
||||
}
|
||||
}
|
||||
grouped_rpc_exe.add(rpc_exe);
|
||||
std::string model_name = rpc->model_name;
|
||||
}
|
||||
|
||||
// rpc_instances: list of rpc instances, graph
|
||||
std::priority_queue<RPCInstance *, std::vector<RPCInstance *>, CompareReadyTime> ready_queue;
|
||||
// std::vector<RPCInstance*> executed; // for debug, remove later
|
||||
std::unordered_map<std::string, size_t> parent_executed;
|
||||
std::unordered_set<DeviceMesh *> device_meshes;
|
||||
|
||||
// for offload and parameter sync RPC instances
|
||||
std::vector<RPCInstance *> tmp_graph;
|
||||
|
||||
// resolve parameter sync
|
||||
for (RPCInstance *node : graph) {
|
||||
tmp_graph.push_back(node);
|
||||
node->tmp_children = node->children;
|
||||
node->tmp_parents = node->parents;
|
||||
// std::cout << "Resolve parameter sync: " << node -> name
|
||||
// << " " << node -> param_sync << std::endl;
|
||||
node->resolve_parameter_sync(tmp_graph, cost_table);
|
||||
// node -> resolve_offload(tmp_graph, comm_stats);
|
||||
}
|
||||
|
||||
uint64_t max_end_time = 0;
|
||||
for (RPCInstance *node : tmp_graph) {
|
||||
// std::cout << "Node: " << node -> name << " parents: "
|
||||
// << node -> parents.size() << std::endl;
|
||||
if (node->parents.size() == 0) ready_queue.push(node);
|
||||
|
||||
// init device meshes
|
||||
RPCExecution *rpc_exe = node->rpc_exe_ptr;
|
||||
device_meshes.insert(&rpc_exe->device_mesh);
|
||||
}
|
||||
|
||||
std::vector<RPCInstance *> executed;
|
||||
|
||||
// simulate
|
||||
while (!ready_queue.empty()) {
|
||||
RPCInstance *t = ready_queue.top();
|
||||
RPCExecution *rpc_exe = t->rpc_exe_ptr;
|
||||
uint64_t exec_time = rpc_exe->time_cost;
|
||||
DeviceMesh *device_mesh = &rpc_exe->device_mesh;
|
||||
ready_queue.pop();
|
||||
|
||||
if (device_mesh->pre_task == nullptr) {
|
||||
t->start_time = t->ready_time;
|
||||
} else {
|
||||
t->start_time = MAX(t->ready_time, device_mesh->pre_task->end_time);
|
||||
}
|
||||
t->end_time = t->start_time + exec_time;
|
||||
max_end_time = MAX(t->end_time, max_end_time);
|
||||
|
||||
for (DeviceMesh *mesh : device_meshes) {
|
||||
if (device_mesh->overlap(*mesh)) {
|
||||
if (mesh->pre_task == nullptr || mesh->pre_task->end_time <= t->end_time) {
|
||||
mesh->pre_task = t;
|
||||
}
|
||||
// mesh -> pre_task = t;
|
||||
}
|
||||
}
|
||||
executed.push_back(t);
|
||||
|
||||
for (RPCInstance *child : t->tmp_children) {
|
||||
child->ready_time = MAX(t->end_time, child->ready_time);
|
||||
// std::cout << "parent: " << t -> name
|
||||
// << " child: " << child -> name << std::endl;
|
||||
parent_executed[child->name] += 1;
|
||||
// child -> remove_parent(t);
|
||||
if (child->tmp_parents.size() == parent_executed[child->name]) {
|
||||
ready_queue.push(child);
|
||||
// std::cout << "Ready: " << child -> name
|
||||
// << " ready time " << child -> ready_time << std::endl;
|
||||
}
|
||||
}
|
||||
// std::cout << "ready_queue size " << ready_queue.size()
|
||||
// << " executed size " << executed.size() << std::endl;
|
||||
}
|
||||
|
||||
// 110045999
|
||||
// if (max_end_time < 100000000) { // DEBUG
|
||||
// std::cout << "INDEX: [";
|
||||
// for (int i : index) {
|
||||
// std::cout << i << ", ";
|
||||
// }
|
||||
// std::cout << "]" << std::endl;
|
||||
|
||||
// for (auto& x : rpc_exe_table){
|
||||
// std::string rpc_name = x.first;
|
||||
// std::vector<RPCExecution*>& rpc_exe_list = x.second;
|
||||
// int count = 0;
|
||||
// for (RPCExecution* rpc_exe : rpc_exe_list) {
|
||||
// std::cout << "RPC: " << rpc_name
|
||||
// << " device mesh " << rpc_exe -> device_mesh.device_mesh_name
|
||||
// << " time cost " << rpc_exe -> time_cost << std::endl;
|
||||
// count ++;
|
||||
// if (count > 10) break;
|
||||
// }
|
||||
// }
|
||||
|
||||
// for (RPCInstance* ri : executed) {
|
||||
// for (RPCInstance* parent : ri -> tmp_parents) {
|
||||
// std::cout << "Parent: " << parent -> name << " of " << ri -> name
|
||||
// << " start time " << parent -> start_time
|
||||
// << " end time " << parent -> end_time << std::endl;
|
||||
// }
|
||||
|
||||
// std::cout << "Executed: " << ri -> name
|
||||
// << " start time " << ri -> start_time
|
||||
// << " end time " << ri -> end_time
|
||||
// << " rpc name " << ri -> rpc_ptr -> rpc_name
|
||||
// << " device mesh "
|
||||
// << ri -> rpc_exe_ptr -> device_mesh.device_mesh_name
|
||||
// << " rpc exe time cost " << ri -> rpc_exe_ptr -> time_cost
|
||||
// << std::endl;
|
||||
// }
|
||||
// }
|
||||
|
||||
// clear device mesh pre tasks
|
||||
for (DeviceMesh *mesh : device_meshes) { mesh->pre_task = nullptr; }
|
||||
// clear rpc instance times
|
||||
for (RPCInstance *node : graph) {
|
||||
node->tmp_children.clear();
|
||||
node->tmp_parents.clear();
|
||||
node->ready_time = 0;
|
||||
node->start_time = 0;
|
||||
node->end_time = 0;
|
||||
tmp_graph.clear();
|
||||
|
||||
for (RPCInstance *ptr : node->tmp_ris) { delete ptr; }
|
||||
node->tmp_ris.clear();
|
||||
|
||||
for (RPCExecution *ptr : node->tmp_exes) { delete ptr; }
|
||||
node->tmp_exes.clear();
|
||||
}
|
||||
|
||||
uint64_t current_mem = grouped_rpc_exe.total_mem_cost();
|
||||
if (current_mem > SOFT_GPU_MEM_CAP) { max_end_time *= oom_penalty; }
|
||||
// std::cout << "Max end time: " << max_end_time
|
||||
// << " executed size " << executed.size() << std::endl;
|
||||
std::chrono::duration<double> elapsed = std::chrono::high_resolution_clock::now() - start;
|
||||
// std::cout << "Elapsed time (micro seconds): "
|
||||
// << std::chrono::duration_cast<std::chrono::microseconds>(elapsed).count()
|
||||
// << std::endl;
|
||||
|
||||
for (OverlapGroup *ptr : grouped_rpc_exe.group.overlap_groups) { delete ptr; }
|
||||
grouped_rpc_exe.group.overlap_groups.clear();
|
||||
|
||||
return SimulateResult(max_end_time, current_mem > SOFT_GPU_MEM_CAP, current_mem, index);
|
||||
};
|
||||
|
||||
SimulateResult &SimulateResult::operator=(const SimulateResult &other) {
|
||||
if (this != &other) {
|
||||
end_time = other.end_time;
|
||||
oom = other.oom;
|
||||
mem_cost = other.mem_cost;
|
||||
index = other.index;
|
||||
rpc_exe_list = other.rpc_exe_list;
|
||||
}
|
||||
return *this;
|
||||
};
|
||||
|
||||
bool isPresent(MinEndTimeQueue &q, SimulateResult element) {
|
||||
std::priority_queue<SimulateResult, std::vector<SimulateResult>, CompareEndTime> pq =
|
||||
q.getQueue();
|
||||
std::queue<SimulateResult> tmp;
|
||||
while (!pq.empty()) {
|
||||
if (pq.top().end_time == element.end_time) { return true; }
|
||||
tmp.push(pq.top());
|
||||
pq.pop();
|
||||
}
|
||||
while (!tmp.empty()) {
|
||||
pq.push(tmp.front());
|
||||
tmp.pop();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void mergeMinEndTimeQueues(MinEndTimeQueue &target, MinEndTimeQueue &q1) {
|
||||
// Get the underlying priority queues
|
||||
std::priority_queue<SimulateResult, std::vector<SimulateResult>, CompareEndTime> pq1 =
|
||||
q1.getQueue();
|
||||
|
||||
// Insert all elements from q1 into the merged queue
|
||||
while (!pq1.empty()) {
|
||||
if (!isPresent(target, pq1.top())) target.insert(pq1.top());
|
||||
pq1.pop();
|
||||
}
|
||||
}
|
|
@ -1,75 +0,0 @@
|
|||
#ifndef SIMULATE_HPP
|
||||
#define SIMULATE_HPP
|
||||
|
||||
#include <rpc.hpp>
|
||||
#include <device_mesh.hpp>
|
||||
#include <queue>
|
||||
#include <iostream>
|
||||
|
||||
class SimulateResult {
|
||||
public:
|
||||
uint64_t end_time;
|
||||
bool oom;
|
||||
uint64_t mem_cost;
|
||||
std::vector<int> index;
|
||||
std::vector<RPCExecution *> rpc_exe_list;
|
||||
double used_time = 0;
|
||||
|
||||
SimulateResult();
|
||||
|
||||
SimulateResult(uint64_t end_time, bool oom, uint64_t mem_cost, std::vector<int> &index);
|
||||
|
||||
SimulateResult &operator=(const SimulateResult &other);
|
||||
};
|
||||
|
||||
SimulateResult simulate(
|
||||
std::vector<RPCInstance *> &graph, std::unordered_map<std::string, uint64_t> &cost_table,
|
||||
std::unordered_map<std::string, uint64_t> &model_sizes,
|
||||
std::unordered_map<std::string, RPC *> &rpc_table,
|
||||
std::unordered_map<std::string, std::vector<RPCExecution *>> &rpc_exe_table,
|
||||
std::unordered_map<std::string, std::vector<RPCInstance *>> &ri_table,
|
||||
std::unordered_map<std::string, std::vector<RPCInstance *>> &model_name_ri_table,
|
||||
std::vector<std::string> &sorted_rpc_names, std::vector<int> &index);
|
||||
|
||||
// Comparator for priority queue
|
||||
struct CompareEndTime {
|
||||
bool operator()(SimulateResult const &r1, SimulateResult const &r2) {
|
||||
// We want largest end_time at the top of the queue, so we reverse the comparison
|
||||
return r1.end_time < r2.end_time;
|
||||
}
|
||||
};
|
||||
|
||||
class MinEndTimeQueue {
|
||||
public:
|
||||
MinEndTimeQueue(int capacity) : k(capacity) {}
|
||||
|
||||
void insert(SimulateResult r) {
|
||||
if (queue.size() < k) {
|
||||
// std::cout << "push " << "end_time: " << r.end_time << " qsize " << queue.size() <<
|
||||
// std::endl;
|
||||
queue.push(r);
|
||||
} else if (r.end_time < queue.top().end_time) {
|
||||
// std::cout << "push " << "end_time: " << r.end_time << " qsize " << queue.size() <<
|
||||
// std::endl;
|
||||
queue.pop();
|
||||
queue.push(r);
|
||||
}
|
||||
}
|
||||
|
||||
std::priority_queue<SimulateResult, std::vector<SimulateResult>, CompareEndTime> &getQueue() {
|
||||
return queue;
|
||||
}
|
||||
|
||||
private:
|
||||
std::priority_queue<SimulateResult, std::vector<SimulateResult>, CompareEndTime> queue;
|
||||
int k;
|
||||
};
|
||||
|
||||
void mergeMinEndTimeQueues(MinEndTimeQueue &target, MinEndTimeQueue &q1);
|
||||
|
||||
class CompareReadyTime {
|
||||
public:
|
||||
bool operator()(RPCInstance *r1, RPCInstance *r2) { return r1->ready_time > r2->ready_time; }
|
||||
};
|
||||
|
||||
#endif // SIMULATE_HPP
|
|
@ -83,7 +83,7 @@ async def async_invoke_function(
|
|||
url: str,
|
||||
timeout: aiohttp.ClientTimeout,
|
||||
payload: Dict[str, Any] = None,
|
||||
max_retries: int = 100,
|
||||
max_retries: int = 2,
|
||||
initial_retry_interval: float = 0.5,
|
||||
max_retry_interval: float = 10.0,
|
||||
):
|
||||
|
@ -137,7 +137,7 @@ async def async_invoke_function(
|
|||
)
|
||||
|
||||
retries += 1
|
||||
if retries > max_retries:
|
||||
if retries >= max_retries:
|
||||
return {
|
||||
"uid": payload.get("uid", ""),
|
||||
"success": False,
|
||||
|
@ -189,6 +189,7 @@ async def batch_function_call_async(payload_list, url, timeout, concurrency=1500
|
|||
data_list.append(data)
|
||||
elapsed_times.append(elapsed)
|
||||
|
||||
if len(elapsed_times) > 0:
|
||||
p50 = median(elapsed_times)
|
||||
p90 = calculate_percentile(elapsed_times, 90)
|
||||
p99 = calculate_percentile(elapsed_times, 99)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import ast
|
||||
import faulthandler
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
|
||||
# to run the solution files we're using a timing based approach
|
||||
|
@ -38,6 +39,14 @@ def truncatefn(s, length=300):
|
|||
return s[: length // 2] + "...(truncated) ..." + s[-length // 2 :]
|
||||
|
||||
|
||||
def load_from_path(l):
|
||||
outputs = []
|
||||
for x in l:
|
||||
with open(x) as f:
|
||||
outputs.append(f.read())
|
||||
return outputs
|
||||
|
||||
|
||||
class CODE_TYPE(Enum):
|
||||
call_based = 0
|
||||
standard_input = 1
|
||||
|
@ -105,6 +114,13 @@ def run_test(sample, test=None, debug=False, timeout=6):
|
|||
|
||||
which_type = CODE_TYPE.call_based
|
||||
if in_outs:
|
||||
assert "inputs" in in_outs
|
||||
assert "outputs" in in_outs
|
||||
if os.path.isfile(in_outs["inputs"][0]):
|
||||
assert os.path.isfile(in_outs["outputs"][0])
|
||||
in_outs["inputs"] = load_from_path(in_outs["inputs"])
|
||||
in_outs["outputs"] = load_from_path(in_outs["outputs"])
|
||||
|
||||
if in_outs.get("fn_name", "") == "":
|
||||
which_type = CODE_TYPE.standard_input # Standard input
|
||||
method_name = None
|
||||
|
|
|
@ -59,9 +59,10 @@ def call_verify(problem, generation, debug, timeout=SINGLE_CASE_EXEC_TIMEOUT):
|
|||
shell=True,
|
||||
preexec_fn=os.setsid,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
try:
|
||||
pro.wait(600)
|
||||
pro.wait(200)
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
|
|
|
@ -9,7 +9,7 @@ from functioncall.base.utils import construct_uid, load_jsonl, logger
|
|||
|
||||
SINGLE_CASE_EXEC_TIMEOUT = 6
|
||||
TEST_CASE_BATCH_SIZE = 1
|
||||
FUNCTIONCALL_TIMEOUT = 1000
|
||||
FUNCTIONCALL_TIMEOUT = 100
|
||||
|
||||
|
||||
def round_up_memory(memory):
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
|
||||
index 174656b2..33fe0a5f 100644
|
||||
--- a/python/sglang/srt/managers/io_struct.py
|
||||
+++ b/python/sglang/srt/managers/io_struct.py
|
||||
@@ -687,10 +687,21 @@ class FlushCacheReqOutput:
|
||||
success: bool
|
||||
|
||||
|
||||
+@dataclass
|
||||
+class InterruptAllReqInput:
|
||||
+ pass
|
||||
+
|
||||
+
|
||||
+@dataclass
|
||||
+class InterruptAllReqOutput:
|
||||
+ num_interrupted_requests: int
|
||||
+
|
||||
+
|
||||
@dataclass
|
||||
class UpdateWeightFromDiskReqInput:
|
||||
# The model path with the new weights
|
||||
model_path: str
|
||||
+ allow_interrupt: bool = False
|
||||
# The format to load the weights
|
||||
load_format: Optional[str] = None
|
||||
|
||||
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
||||
index 8891115c..843a8a82 100644
|
||||
--- a/python/sglang/srt/managers/scheduler.py
|
||||
+++ b/python/sglang/srt/managers/scheduler.py
|
||||
@@ -70,6 +70,8 @@ from sglang.srt.managers.io_struct import (
|
||||
HealthCheckOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
+ InterruptAllReqInput,
|
||||
+ InterruptAllReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -419,6 +421,7 @@ class Scheduler(
|
||||
# Init request dispatcher
|
||||
self._request_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
+ (InterruptAllReqInput, self.interrupt_all_requests),
|
||||
(TokenizedGenerateReqInput, self.handle_generate_request),
|
||||
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
||||
(FlushCacheReqInput, self.flush_cache_wrapped),
|
||||
@@ -1938,6 +1941,15 @@ class Scheduler(
|
||||
def _pause_engine(self) -> Tuple[List[Req], int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
+ def interrupt_all_requests(self, recv_req: InterruptAllReqInput):
|
||||
+ num = len(self.waiting_queue) + len(self.running_batch.reqs)
|
||||
+ for req in self.waiting_queue:
|
||||
+ req.sampling_params.max_new_tokens = 0
|
||||
+ for req in self.running_batch.reqs:
|
||||
+ req.sampling_params.max_new_tokens = len(req.output_ids)
|
||||
+ logger.info(f"Interrupt {num} requests.")
|
||||
+ return InterruptAllReqOutput(num)
|
||||
+
|
||||
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
||||
"""In-place update of the weights from disk."""
|
||||
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
||||
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
|
||||
index 82709b09..bfab3ce7 100644
|
||||
--- a/python/sglang/srt/managers/tokenizer_manager.py
|
||||
+++ b/python/sglang/srt/managers/tokenizer_manager.py
|
||||
@@ -76,6 +76,8 @@ from sglang.srt.managers.io_struct import (
|
||||
HealthCheckOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
+ InterruptAllReqInput,
|
||||
+ InterruptAllReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -265,6 +267,9 @@ class TokenizerManager:
|
||||
self.resume_memory_occupation_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
+ self.interrupt_requests_communicator = _Communicator(
|
||||
+ self.send_to_scheduler, server_args.dp_size
|
||||
+ )
|
||||
self.flush_cache_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
@@ -294,6 +299,10 @@ class TokenizerManager:
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
self._handle_update_weights_from_disk_req_output,
|
||||
),
|
||||
+ (
|
||||
+ InterruptAllReqOutput,
|
||||
+ self.interrupt_requests_communicator.handle_recv,
|
||||
+ ),
|
||||
(
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
self.init_weights_update_group_communicator.handle_recv,
|
||||
@@ -767,6 +776,13 @@ class TokenizerManager:
|
||||
) -> Tuple[bool, str]:
|
||||
self.auto_create_handle_loop()
|
||||
|
||||
+ if obj.allow_interrupt:
|
||||
+ num_interrupted_requests = await self.interrupt_all_requests(
|
||||
+ InterruptAllReqInput()
|
||||
+ )
|
||||
+ # Set a break point to wait for the interrupt to finish
|
||||
+ await asyncio.sleep(0.1)
|
||||
+
|
||||
# default the load format to the server_args
|
||||
if obj.load_format is None:
|
||||
obj.load_format = self.server_args.load_format
|
||||
@@ -776,7 +792,12 @@ class TokenizerManager:
|
||||
# Hold the lock if it is not async. This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
async with self.model_update_lock.writer_lock:
|
||||
- return await self._wait_for_model_update_from_disk(obj)
|
||||
+ success, message, n_paused = (
|
||||
+ await self._wait_for_model_update_from_disk(obj)
|
||||
+ )
|
||||
+ if obj.allow_interrupt:
|
||||
+ return success, message, num_interrupted_requests
|
||||
+ return success, message, n_paused
|
||||
|
||||
async def _wait_for_model_update_from_disk(
|
||||
self, obj: UpdateWeightFromDiskReqInput
|
||||
@@ -849,6 +870,18 @@ class TokenizerManager:
|
||||
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
+ async def interrupt_all_requests(
|
||||
+ self,
|
||||
+ obj: InterruptAllReqInput,
|
||||
+ request: Optional[fastapi.Request] = None,
|
||||
+ ) -> Tuple[bool, str]:
|
||||
+ self.auto_create_handle_loop()
|
||||
+ result = await self.interrupt_requests_communicator(obj)
|
||||
+ if self.server_args.dp_size == 1:
|
||||
+ return result[0].num_interrupted_requests
|
||||
+ else:
|
||||
+ return [r.num_interrupted_requests for r in result]
|
||||
+
|
||||
async def get_weights_by_name(
|
||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
|
@ -0,0 +1,144 @@
|
|||
diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py
|
||||
index 5390668c..db370d19 100644
|
||||
--- a/python/sglang/srt/managers/io_struct.py
|
||||
+++ b/python/sglang/srt/managers/io_struct.py
|
||||
@@ -687,10 +687,21 @@ class FlushCacheReqOutput:
|
||||
success: bool
|
||||
|
||||
|
||||
+@dataclass
|
||||
+class InterruptAllReqInput:
|
||||
+ pass
|
||||
+
|
||||
+
|
||||
+@dataclass
|
||||
+class InterruptAllReqOutput:
|
||||
+ num_interrupted_requests: int
|
||||
+
|
||||
+
|
||||
@dataclass
|
||||
class UpdateWeightFromDiskReqInput:
|
||||
# The model path with the new weights
|
||||
model_path: str
|
||||
+ allow_interrupt: bool = False
|
||||
# The format to load the weights
|
||||
load_format: Optional[str] = None
|
||||
|
||||
diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py
|
||||
index 1178eec5..318dee33 100644
|
||||
--- a/python/sglang/srt/managers/scheduler.py
|
||||
+++ b/python/sglang/srt/managers/scheduler.py
|
||||
@@ -73,6 +73,8 @@ from sglang.srt.managers.io_struct import (
|
||||
HealthCheckOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
+ InterruptAllReqInput,
|
||||
+ InterruptAllReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -427,6 +429,7 @@ class Scheduler(
|
||||
# Init request dispatcher
|
||||
self._request_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
+ (InterruptAllReqInput, self.interrupt_all_requests),
|
||||
(TokenizedGenerateReqInput, self.handle_generate_request),
|
||||
(TokenizedEmbeddingReqInput, self.handle_embedding_request),
|
||||
(FlushCacheReqInput, self.flush_cache_wrapped),
|
||||
@@ -1971,6 +1974,15 @@ class Scheduler(
|
||||
def _pause_engine(self) -> Tuple[List[Req], int]:
|
||||
raise NotImplementedError()
|
||||
|
||||
+ def interrupt_all_requests(self, recv_req: InterruptAllReqInput):
|
||||
+ num = len(self.waiting_queue) + len(self.running_batch.reqs)
|
||||
+ for req in self.waiting_queue:
|
||||
+ req.sampling_params.max_new_tokens = 0
|
||||
+ for req in self.running_batch.reqs:
|
||||
+ req.sampling_params.max_new_tokens = len(req.output_ids)
|
||||
+ logger.info(f"Interrupt {num} requests.")
|
||||
+ return InterruptAllReqOutput(num)
|
||||
+
|
||||
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
||||
"""In-place update of the weights from disk."""
|
||||
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
||||
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
|
||||
index b646fae1..c668728b 100644
|
||||
--- a/python/sglang/srt/managers/tokenizer_manager.py
|
||||
+++ b/python/sglang/srt/managers/tokenizer_manager.py
|
||||
@@ -80,6 +80,8 @@ from sglang.srt.managers.io_struct import (
|
||||
HealthCheckOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
+ InterruptAllReqInput,
|
||||
+ InterruptAllReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -279,6 +281,9 @@ class TokenizerManager:
|
||||
self.slow_down_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
+ self.interrupt_requests_communicator = _Communicator(
|
||||
+ self.send_to_scheduler, server_args.dp_size
|
||||
+ )
|
||||
self.flush_cache_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
@@ -309,6 +314,10 @@ class TokenizerManager:
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
self._handle_update_weights_from_disk_req_output,
|
||||
),
|
||||
+ (
|
||||
+ InterruptAllReqOutput,
|
||||
+ self.interrupt_requests_communicator.handle_recv,
|
||||
+ ),
|
||||
(
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
self.init_weights_update_group_communicator.handle_recv,
|
||||
@@ -799,6 +808,13 @@ class TokenizerManager:
|
||||
) -> Tuple[bool, str]:
|
||||
self.auto_create_handle_loop()
|
||||
|
||||
+ if obj.allow_interrupt:
|
||||
+ num_interrupted_requests = await self.interrupt_all_requests(
|
||||
+ InterruptAllReqInput()
|
||||
+ )
|
||||
+ # Set a break point to wait for the interrupt to finish
|
||||
+ await asyncio.sleep(0.1)
|
||||
+
|
||||
# default the load format to the server_args
|
||||
if obj.load_format is None:
|
||||
obj.load_format = self.server_args.load_format
|
||||
@@ -808,7 +824,12 @@ class TokenizerManager:
|
||||
# Hold the lock if it is not async. This means that weight sync
|
||||
# cannot run while requests are in progress.
|
||||
async with self.model_update_lock.writer_lock:
|
||||
- return await self._wait_for_model_update_from_disk(obj)
|
||||
+ success, message, n_paused = (
|
||||
+ await self._wait_for_model_update_from_disk(obj)
|
||||
+ )
|
||||
+ if obj.allow_interrupt:
|
||||
+ return success, message, num_interrupted_requests
|
||||
+ return success, message, n_paused
|
||||
|
||||
async def _wait_for_model_update_from_disk(
|
||||
self, obj: UpdateWeightFromDiskReqInput
|
||||
@@ -881,6 +902,18 @@ class TokenizerManager:
|
||||
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
+ async def interrupt_all_requests(
|
||||
+ self,
|
||||
+ obj: InterruptAllReqInput,
|
||||
+ request: Optional[fastapi.Request] = None,
|
||||
+ ) -> Tuple[bool, str]:
|
||||
+ self.auto_create_handle_loop()
|
||||
+ result = await self.interrupt_requests_communicator(obj)
|
||||
+ if self.server_args.dp_size == 1:
|
||||
+ return result[0].num_interrupted_requests
|
||||
+ else:
|
||||
+ return [r.num_interrupted_requests for r in result]
|
||||
+
|
||||
async def get_weights_by_name(
|
||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
|
@ -97,8 +97,7 @@ class PromptOnlyDatasetConfig:
|
|||
class ModelFamily:
|
||||
"""Identifier for HuggingFace model types (e.g., llama, gpt2).
|
||||
|
||||
Used for model registration and allocation. The size parameter is specifically
|
||||
relevant for the 'search' allocation mode.
|
||||
Used for model registration and allocation.
|
||||
"""
|
||||
|
||||
_class: str = field(
|
||||
|
@ -107,12 +106,6 @@ class ModelFamily:
|
|||
"`realhf/api/from_hf` for supported models.",
|
||||
}
|
||||
)
|
||||
size: int = field(
|
||||
default=0,
|
||||
metadata={
|
||||
"help": "Model size parameter. Only used by 'search' allocation mode, ignored otherwise",
|
||||
},
|
||||
)
|
||||
is_critic: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
|
@ -121,8 +114,8 @@ class ModelFamily:
|
|||
)
|
||||
|
||||
def __repr__(self):
|
||||
"""Returns formatted string representation: '{class}-{size}[-critic]'."""
|
||||
s = f"{self._class}-{self.size}"
|
||||
"""Returns formatted string representation: '{class}[-critic]'."""
|
||||
s = f"{self._class}"
|
||||
if self.is_critic:
|
||||
s += "-critic"
|
||||
return s
|
||||
|
@ -136,7 +129,7 @@ class ParallelismConfig:
|
|||
Sequence parallelism is only used in combination with tensor-model parallelism.
|
||||
"""
|
||||
|
||||
model_parallel_size: int = field(
|
||||
tensor_parallel_size: int = field(
|
||||
default=1, metadata={"help": "Size of tensor-model parallelism"}
|
||||
)
|
||||
pipeline_parallel_size: int = field(
|
||||
|
@ -155,7 +148,7 @@ class ParallelismConfig:
|
|||
def __str__(self):
|
||||
"""Returns compact string representation: 'Parallel(mp=X,pp=Y,dp=Z)'."""
|
||||
return (
|
||||
f"Parallel(mp={self.model_parallel_size},"
|
||||
f"Parallel(mp={self.tensor_parallel_size},"
|
||||
f"pp={self.pipeline_parallel_size},"
|
||||
f"dp={self.data_parallel_size})"
|
||||
)
|
||||
|
@ -168,7 +161,7 @@ class ParallelismConfig:
|
|||
Implemented as static method to avoid OmegaConf compatibility issues.
|
||||
"""
|
||||
return (
|
||||
(this.model_parallel_size == other.model_parallel_size)
|
||||
(this.tensor_parallel_size == other.tensor_parallel_size)
|
||||
and (this.pipeline_parallel_size == other.pipeline_parallel_size)
|
||||
and (this.data_parallel_size == other.data_parallel_size)
|
||||
)
|
||||
|
@ -186,7 +179,7 @@ class OptimizerConfig:
|
|||
default="adam",
|
||||
metadata={"help": "Optimizer type", "choices": ["adam", "empty"]},
|
||||
)
|
||||
lr: float = field(default=1e-5, metadata={"help": "Learning rate"})
|
||||
lr: float = field(default=2e-5, metadata={"help": "Learning rate"})
|
||||
weight_decay: float = field(default=0.05, metadata={"help": "Weight decay"})
|
||||
beta1: float = field(default=0.9, metadata={"help": "Adam beta1 parameter"})
|
||||
beta2: float = field(default=0.95, metadata={"help": "Adam beta2 parameter"})
|
||||
|
@ -198,14 +191,14 @@ class OptimizerConfig:
|
|||
},
|
||||
)
|
||||
lr_scheduler_type: str = field(
|
||||
default="cosine",
|
||||
default="constant",
|
||||
metadata={
|
||||
"help": "Learning rate scheduler type",
|
||||
"choices": ["linear", "cosine", "constant"],
|
||||
},
|
||||
)
|
||||
warmup_steps_proportion: float = field(
|
||||
default=0.02,
|
||||
default=0.001,
|
||||
metadata={
|
||||
"help": "Proportion of training steps for warmup",
|
||||
},
|
||||
|
@ -237,6 +230,7 @@ class vLLMConfig:
|
|||
"""
|
||||
|
||||
max_num_seqs: int = 256
|
||||
dtype: str = "float16"
|
||||
kv_cache_type: str = "auto"
|
||||
num_scheduler_steps: int = 1
|
||||
multi_step_stream_outputs: bool = True
|
||||
|
@ -278,7 +272,6 @@ class SGLangConfig:
|
|||
enable_nccl_nvls: bool = False
|
||||
disable_outlines_disk_cache: bool = False
|
||||
disable_custom_all_reduce: bool = False
|
||||
disable_mla: bool = False
|
||||
disable_overlap_schedule: bool = False
|
||||
enable_mixed_chunk: bool = False
|
||||
enable_dp_attention: bool = False
|
||||
|
@ -296,7 +289,7 @@ class SGLangConfig:
|
|||
enable_memory_saver: bool = False
|
||||
allow_auto_truncate: bool = False
|
||||
# NOTE: to avoid the illegal memory access error
|
||||
attention_backend: Optional[str] = "triton"
|
||||
attention_backend: Optional[str] = "flashinfer"
|
||||
sampling_backend: Optional[str] = None
|
||||
context_length: Optional[int] = 32768
|
||||
mem_fraction_static: Optional[float] = 0.9
|
||||
|
@ -309,15 +302,19 @@ class SGLangConfig:
|
|||
schedule_conservativeness: float = 1.0
|
||||
cpu_offload_gb: int = 0
|
||||
hybrid_train: bool = False
|
||||
dtype: str = "float16"
|
||||
kv_cache_dtype: str = "auto"
|
||||
|
||||
# logging
|
||||
log_level: str = "info"
|
||||
log_level: str = "warning"
|
||||
log_level_http: Optional[str] = "warning"
|
||||
log_requests: bool = False
|
||||
log_requests_level: int = 0
|
||||
show_time_cost: bool = False
|
||||
enable_metrics: bool = True # Exports Prometheus-like metrics
|
||||
decode_log_interval: int = 1000 # How often (in tokens) to log decode progress.
|
||||
# The interval (in decoding iterations) to log throughput
|
||||
# and update prometheus metrics
|
||||
decode_log_interval: int = 1
|
||||
|
||||
# Use staticmethod to make OmegaConf happy.
|
||||
@staticmethod
|
||||
|
@ -327,6 +324,7 @@ class SGLangConfig:
|
|||
tp_size,
|
||||
server_index,
|
||||
base_gpu_id,
|
||||
dist_init_addr: Optional[str] = None,
|
||||
):
|
||||
from realhf.base import constants, network, pkg_version, seeding
|
||||
from realhf.experiments.common.utils import asdict as conf_as_dict
|
||||
|
@ -345,7 +343,6 @@ class SGLangConfig:
|
|||
tokenizer_mode="auto",
|
||||
load_format="auto",
|
||||
trust_remote_code=True,
|
||||
kv_cache_dtype="auto",
|
||||
device="cuda",
|
||||
served_model_name=f"{constants.experiment_name()}/{constants.trial_name()}/{model_path}",
|
||||
is_embedding=False,
|
||||
|
@ -365,6 +362,7 @@ class SGLangConfig:
|
|||
ep_size=1, # TODO: check
|
||||
nnodes=1,
|
||||
node_rank=0,
|
||||
dist_init_addr=dist_init_addr,
|
||||
**args,
|
||||
)
|
||||
|
||||
|
@ -385,6 +383,10 @@ class SGLangConfig:
|
|||
if v is True:
|
||||
flags.append(f"--{k.replace('_','-')} ")
|
||||
continue
|
||||
if isinstance(v, list):
|
||||
values = " ".join(map(str, v))
|
||||
flags.append(f"--{k.replace('_','-')} {values}")
|
||||
continue
|
||||
flags.append(f"--{k.replace('_','-')} {v}")
|
||||
flags = " ".join(flags)
|
||||
return f"python3 -m sglang.launch_server {flags}"
|
||||
|
@ -444,7 +446,7 @@ class ModelTrainEvalConfig:
|
|||
|
||||
# Model Architecture Configuration
|
||||
type: ModelFamily = field(
|
||||
default=ModelFamily("llama", 7, False),
|
||||
default=ModelFamily("llama", False),
|
||||
metadata={"help": "Model family specification"},
|
||||
)
|
||||
path: str = field(default="", metadata={"help": "Path to HuggingFace checkpoint"})
|
||||
|
@ -679,13 +681,13 @@ class PPOHyperparameters:
|
|||
value_norm_eps: float = field(
|
||||
default=1e-5, metadata={"help": "Epsilon term for numerical stability"}
|
||||
)
|
||||
|
||||
# Experimental Features
|
||||
recompute_logprob: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Recompute log probabilities after generation. Used mainly for debugging purposes"
|
||||
},
|
||||
metadata={"help": "Recompute logp and replace the logp returned by inference."},
|
||||
)
|
||||
use_decoupled_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use the decoupled loss. recompute_logprob must be True."},
|
||||
)
|
||||
|
||||
|
||||
|
@ -772,6 +774,13 @@ class ExperimentSaveEvalControl:
|
|||
"For benchmarking purposes only. None indicates normal training."
|
||||
},
|
||||
)
|
||||
benchmark_n_seqs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Terminate training after consuming this number of samples. "
|
||||
"For benchmarking purposes only. None indicates normal training."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -847,7 +856,7 @@ class BaseExperimentConfig:
|
|||
|
||||
Note:
|
||||
- Recovery modes: auto, fault, resume, disabled
|
||||
- Allocation modes: manual, search, heuristic, or pattern-based
|
||||
- Allocation modes: manual, heuristic, or pattern-based
|
||||
"""
|
||||
|
||||
experiment_name: str = field(
|
||||
|
@ -919,13 +928,9 @@ class BaseExperimentConfig:
|
|||
default="",
|
||||
metadata={
|
||||
"help": "GPU parallel strategy allocation mode. "
|
||||
"Options: manual/search/heuristic or pattern-based."
|
||||
"Options: manual/heuristic or pattern-based."
|
||||
},
|
||||
)
|
||||
allocation_use_cache: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use allocation search cache (search mode only)."},
|
||||
)
|
||||
n_nodes: int = field(
|
||||
default=1, metadata={"help": "Number of nodes for experiment."}
|
||||
)
|
||||
|
@ -998,9 +1003,17 @@ class BaseExperimentConfig:
|
|||
|
||||
@dataclass
|
||||
class AsyncRLOptions:
|
||||
schedule_policy: str = field(
|
||||
default="round_robin",
|
||||
metadata={
|
||||
"help": "The request schedule policy during generation. Available options: [round_robin]."
|
||||
},
|
||||
)
|
||||
new_tokens_per_chunk: int = field(
|
||||
default=1024,
|
||||
metadata={"help": "The lenght of chunked generation."},
|
||||
default=int(1e10),
|
||||
metadata={
|
||||
"help": "The length of chunked generation. Only valid if inference can't be interrupted."
|
||||
},
|
||||
)
|
||||
max_head_offpolicyness: int = field(
|
||||
default=0,
|
||||
|
@ -1013,9 +1026,11 @@ class AsyncRLOptions:
|
|||
"help": "Number of rollout workers. None defaults to train world size."
|
||||
},
|
||||
)
|
||||
max_concurrent_rollouts: int = field(
|
||||
default=1024,
|
||||
metadata={"help": "Max concurrent rollout jobs in each worker."},
|
||||
max_concurrent_rollouts: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Max concurrent rollouts globally. Defaults to train batch size."
|
||||
},
|
||||
)
|
||||
flush_request_timeout: int = field(
|
||||
default=120,
|
||||
|
@ -1225,6 +1240,12 @@ class PPOMATHExperimentOptions:
|
|||
},
|
||||
)
|
||||
|
||||
# testing only
|
||||
no_training: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Run without training. Test-only."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MathCodeEvalOptions:
|
||||
|
|
|
@ -100,8 +100,8 @@ class ModelShardID:
|
|||
:type model_name: ModelName
|
||||
:param dp_rank: The data parallel rank.
|
||||
:type dp_rank: int
|
||||
:param mp_rank: The tensor-model parallel rank.
|
||||
:type mp_rank: int
|
||||
:param tp_rank: The tensor-model parallel rank.
|
||||
:type tp_rank: int
|
||||
:param pp_rank: The pipeline-model parallel rank.
|
||||
:type pp_rank: int
|
||||
:param topo: The 3D parallelism topology of this model.
|
||||
|
@ -110,22 +110,22 @@ class ModelShardID:
|
|||
|
||||
model_name: ModelName
|
||||
dp_rank: int
|
||||
mp_rank: int
|
||||
tp_rank: int
|
||||
pp_rank: int
|
||||
topo: topology.ProcessTopology
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.dp_rank >= 0 and self.mp_rank >= 0 and self.pp_rank >= 0
|
||||
assert self.dp_rank >= 0 and self.tp_rank >= 0 and self.pp_rank >= 0
|
||||
if "@" in self.model_name.role:
|
||||
raise ValueError("model_name cannot contain @")
|
||||
assert self.dp_rank < self.topo.get_dim("data")
|
||||
assert self.mp_rank < self.topo.get_dim("model")
|
||||
assert self.tp_rank < self.topo.get_dim("tensor")
|
||||
assert self.pp_rank < self.topo.get_dim("pipe")
|
||||
|
||||
@property
|
||||
def parallelism_rank(self):
|
||||
return self.topo.get_rank(
|
||||
data=self.dp_rank, model=self.mp_rank, pipe=self.pp_rank
|
||||
data=self.dp_rank, tensor=self.tp_rank, pipe=self.pp_rank
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -134,14 +134,14 @@ class ModelShardID:
|
|||
return cls(
|
||||
model_name=model_name,
|
||||
dp_rank=c.data,
|
||||
mp_rank=c.model,
|
||||
tp_rank=c.tensor,
|
||||
pp_rank=c.pipe,
|
||||
topo=topo,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
n = cluster.spec.suffix_n_digits
|
||||
return f"{self.model_name}@pp{self.pp_rank:0{n}d}@mp{self.mp_rank:0{n}d}@dp{self.dp_rank:0{n}d}"
|
||||
return f"{self.model_name}@pp{self.pp_rank:0{n}d}@tp{self.tp_rank:0{n}d}@dp{self.dp_rank:0{n}d}"
|
||||
|
||||
def __hash__(self):
|
||||
return hash(str(self))
|
||||
|
@ -152,7 +152,7 @@ class ModelShardID:
|
|||
return (
|
||||
self.model_name == other.model_name
|
||||
and self.dp_rank == other.dp_rank
|
||||
and self.mp_rank == other.mp_rank
|
||||
and self.tp_rank == other.tp_rank
|
||||
and self.pp_rank == other.pp_rank
|
||||
)
|
||||
return False
|
||||
|
|
|
@ -547,6 +547,7 @@ class SequenceSample:
|
|||
return [[seqlen] for seqlen in seqlens]
|
||||
elif key in [
|
||||
"packed_logprobs",
|
||||
"prox_logp",
|
||||
"logprobs",
|
||||
"packed_ref_logprobs",
|
||||
"ref_logprobs",
|
||||
|
|
|
@ -10,7 +10,6 @@ import matplotlib.pyplot as plt
|
|||
import networkx as nx
|
||||
|
||||
import realhf.base.logging as logging
|
||||
from realhf.api.cli_args import ModelFamily
|
||||
from realhf.api.core.config import (
|
||||
ModelInterfaceAbstraction,
|
||||
ModelInterfaceType,
|
||||
|
@ -94,13 +93,6 @@ class MFCDef:
|
|||
:type min_n_seqs_per_pass: int
|
||||
:param log_return_value: Whether to log the return value of the interface implementation.
|
||||
:type log_return_value: bool
|
||||
:param model_type: The specification of the LLM, e.g., LLaMA-7B. Used by the profiler and
|
||||
search engine to produce an optimal execution plan. Can be omitted if the search engine
|
||||
is not used.
|
||||
:type model_type: Optional[ModelFamily]
|
||||
:param model_path: The path to the model file. Used to get the config for the search engine.
|
||||
Can be omitted if the search engine is not used.
|
||||
:type model_path: Optional[str]
|
||||
"""
|
||||
|
||||
# The unique identifier of this model function call.
|
||||
|
@ -126,10 +118,6 @@ class MFCDef:
|
|||
min_n_seqs_per_pass: int | float = 1
|
||||
log_return_value: bool = False
|
||||
|
||||
# Only used by search.
|
||||
model_type: Optional[Any | ModelFamily] = None
|
||||
model_path: Optional[str] = None
|
||||
|
||||
# Reserved dataclasses.fields. Should not be set by the user.
|
||||
_G: nx.DiGraph = None
|
||||
_pre_hooks: List[RPCHook] = dataclasses.field(default_factory=lambda: [])
|
||||
|
|
|
@ -11,6 +11,7 @@ from typing import Any, Callable, Dict, Hashable, List, Literal, Optional, Tuple
|
|||
import aiohttp
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.utils.data
|
||||
import transformers
|
||||
|
||||
|
@ -24,6 +25,7 @@ from realhf.api.core.config import (
|
|||
ModelWrapperAbstraction,
|
||||
)
|
||||
from realhf.api.core.data_api import MicroBatchSpec, SequenceSample, load_hf_tokenizer
|
||||
from realhf.base.datapack import flat2d
|
||||
from realhf.base.recover import StepInfo
|
||||
|
||||
logger = logging.getLogger("model_api")
|
||||
|
@ -37,15 +39,19 @@ class ZeroTotalLossWeightException(Exception):
|
|||
class GenRespMeta:
|
||||
qid: str
|
||||
accepted: bool
|
||||
n_tokens: int
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GenReqMeta:
|
||||
## Meta info used to schedule the request. ##
|
||||
qid: Hashable
|
||||
prompt_len: int
|
||||
group_size: int
|
||||
new_token_budget: int
|
||||
predicted_new_tokens: int | None
|
||||
previous_server_url: str = ""
|
||||
previous_version: int = -1
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -120,6 +126,7 @@ class APIGenerateOutput:
|
|||
|
||||
@staticmethod
|
||||
def concat(outputs: List["APIGenerateOutput"]):
|
||||
assert len(set([o.qid for o in outputs])) == 1
|
||||
return APIGenerateOutput(
|
||||
qid=outputs[0].qid,
|
||||
prompt_ids=outputs[0].prompt_ids,
|
||||
|
@ -436,6 +443,8 @@ class ReaLModelConfig:
|
|||
rotary_special_impl: Optional[str] = None
|
||||
# for gemma
|
||||
normalize_embed: bool = False
|
||||
# for qwen3
|
||||
qk_layernorm: bool = False
|
||||
# for opt, it's 2
|
||||
abs_position_embedding_offset: int = 0
|
||||
do_layernorm_before: bool = True
|
||||
|
@ -798,7 +807,7 @@ class ModelInterface(abc.ABC):
|
|||
model: Model,
|
||||
data: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> Dict:
|
||||
) -> Dict | List[Dict]:
|
||||
raise NotImplementedError()
|
||||
|
||||
# Mock methods for creating data and profiling an individual MFC.
|
||||
|
@ -860,7 +869,17 @@ class NullInterface(ModelInterface):
|
|||
|
||||
def train_step(
|
||||
self, model: Model, data: SequenceSample, mb_spec: MicroBatchSpec
|
||||
) -> Dict:
|
||||
) -> Dict | List[Dict]:
|
||||
from realhf.base import constants
|
||||
|
||||
n_tokens = sum(flat2d(data.seqlens[data._get_split_key()]))
|
||||
n_tokens = torch.tensor(
|
||||
n_tokens, dtype=torch.long, device=constants.current_device()
|
||||
)
|
||||
dist.all_reduce(n_tokens, group=constants.data_parallel_group())
|
||||
if constants.parallelism_rank() == 0:
|
||||
logger.info(f"Number of tokens in NullInterface training: {int(n_tokens)}")
|
||||
model.inc_version()
|
||||
return {}
|
||||
|
||||
def save(self, model: Model, save_dir: str):
|
||||
|
|
|
@ -462,8 +462,8 @@ class ExperimentConfig:
|
|||
)
|
||||
self_topo = model_topos[rpc.model_name]
|
||||
if (
|
||||
self_topo.get_dim("model") % other_topo.get_dim("model") != 0
|
||||
and other_topo.get_dim("model") % self_topo.get_dim("model") != 0
|
||||
self_topo.get_dim("tensor") % other_topo.get_dim("tensor") != 0
|
||||
and other_topo.get_dim("tensor") % self_topo.get_dim("tensor") != 0
|
||||
):
|
||||
raise ValueError(
|
||||
"To synchronize parameters between two models, "
|
||||
|
|
|
@ -0,0 +1,252 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
from typing import *
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from realhf.api.core.model_api import ReaLModelConfig, register_hf_family
|
||||
from realhf.base.testing import (
|
||||
TESTING_MODEL_HEAD_DIM,
|
||||
TESTING_MODEL_HIDDEN_SIZE,
|
||||
TESTING_MODEL_INTERMEDIATE_SIZE,
|
||||
TESTING_MODEL_N_HEADS,
|
||||
TESTING_MODEL_N_LAYERS,
|
||||
TESTING_MODEL_N_POSITIONS,
|
||||
TESTING_MODEL_VOCAB_SIZE,
|
||||
)
|
||||
|
||||
from .llama import (
|
||||
convert_state_dict_llama,
|
||||
llama_embedding_layer_names,
|
||||
llama_output_head_param_name,
|
||||
to_llama_state_dict,
|
||||
)
|
||||
|
||||
|
||||
class Qwen3Config(PretrainedConfig):
|
||||
|
||||
model_type = "qwen3"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
# Default tensor parallel plan for base model `Qwen3`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=151936,
|
||||
hidden_size=4096,
|
||||
intermediate_size=22016,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=32,
|
||||
head_dim=128,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=28,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = (
|
||||
sliding_window # we check `use_sliding_window` in the modeling code
|
||||
)
|
||||
self.max_window_layers = max_window_layers
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def convert_config_qwen3(
|
||||
hf_config: Qwen3Config,
|
||||
) -> ReaLModelConfig:
|
||||
return ReaLModelConfig(
|
||||
n_layers=hf_config.num_hidden_layers,
|
||||
n_kv_heads=hf_config.num_key_value_heads,
|
||||
hidden_dim=hf_config.hidden_size,
|
||||
n_q_heads=hf_config.num_attention_heads,
|
||||
head_dim=getattr(
|
||||
hf_config,
|
||||
"head_dim",
|
||||
hf_config.hidden_size // hf_config.num_attention_heads,
|
||||
),
|
||||
intermediate_dim=hf_config.intermediate_size,
|
||||
vocab_size=hf_config.vocab_size,
|
||||
n_positions=hf_config.max_position_embeddings,
|
||||
embd_pdrop=0.0,
|
||||
attn_pdrop=(
|
||||
hf_config.attention_dropout
|
||||
if hasattr(hf_config, "attention_dropout")
|
||||
else 0.1
|
||||
),
|
||||
layer_norm_epsilon=hf_config.rms_norm_eps,
|
||||
activation_function=hf_config.hidden_act,
|
||||
use_attention_bias=False,
|
||||
use_attn_proj_bias=False,
|
||||
scale_attn_by_inverse_layer_idx=False,
|
||||
layer_norm_type="rms",
|
||||
qk_layernorm=True,
|
||||
mlp_type="llama",
|
||||
apply_rotary=True,
|
||||
rotary_base=hf_config.rope_theta,
|
||||
rotary_interleaved=False,
|
||||
tied_embedding=hf_config.tie_word_embeddings,
|
||||
)
|
||||
|
||||
|
||||
def convert_config_back_qwen3(
|
||||
config: ReaLModelConfig,
|
||||
) -> Qwen3Config:
|
||||
return Qwen3Config(
|
||||
vocab_size=config.vocab_size,
|
||||
hidden_size=config.hidden_dim,
|
||||
intermediate_size=config.intermediate_dim,
|
||||
num_hidden_layers=config.n_layers,
|
||||
num_key_value_heads=config.n_kv_heads,
|
||||
num_attention_heads=config.n_q_heads,
|
||||
head_dim=config.head_dim,
|
||||
max_position_embeddings=config.n_positions,
|
||||
rms_norm_eps=config.layer_norm_epsilon,
|
||||
hidden_act=config.activation_function,
|
||||
attention_dropout=config.attn_pdrop,
|
||||
rope_theta=config.rotary_base,
|
||||
architectures=["Qwen3ForCausalLM"], # ["Qwen3ForCausalLM"],
|
||||
tie_word_embeddings=config.tied_embedding,
|
||||
)
|
||||
|
||||
|
||||
def qwen3_config_maker():
|
||||
hf_config = Qwen3Config(
|
||||
vocab_size=TESTING_MODEL_VOCAB_SIZE,
|
||||
max_position_embeddings=TESTING_MODEL_N_POSITIONS,
|
||||
hidden_size=TESTING_MODEL_HIDDEN_SIZE,
|
||||
intermediate_size=TESTING_MODEL_INTERMEDIATE_SIZE,
|
||||
num_hidden_layers=TESTING_MODEL_N_LAYERS,
|
||||
num_attention_heads=TESTING_MODEL_N_HEADS,
|
||||
head_dim=TESTING_MODEL_HEAD_DIM,
|
||||
num_key_value_heads=8,
|
||||
hidden_act="silu",
|
||||
rms_norm_eps=1e-5,
|
||||
)
|
||||
return convert_config_qwen3(hf_config)
|
||||
|
||||
|
||||
def convert_state_dict_qwen3(state_dict: Dict, config: ReaLModelConfig) -> Dict:
|
||||
llama_state_dict = convert_state_dict_llama(state_dict, config)
|
||||
# model.layers.0.self_attn.k_norm.weight -> 1.attn.k_ln.weight
|
||||
new_state_dict = {}
|
||||
for k, v in llama_state_dict.items():
|
||||
if "k_norm" in k:
|
||||
k = k.replace("k_norm", "k_ln")
|
||||
if "q_norm" in k:
|
||||
k = k.replace("q_norm", "q_ln")
|
||||
new_state_dict[k] = v
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def convert_state_dict_back_qwen3(state_dict: Dict, config: ReaLModelConfig) -> Dict:
|
||||
new_sd = to_llama_state_dict(state_dict, config)
|
||||
layer_indices = list(set([int(k.split(".")[0]) for k in state_dict.keys()]))
|
||||
for i in layer_indices:
|
||||
if i == 0 or i == config.n_layers + 1:
|
||||
continue
|
||||
new_sd[f"model.layers.{i - 1}.self_attn.k_norm.weight"] = state_dict[
|
||||
f"{i}.attn.k_ln.weight"
|
||||
]
|
||||
new_sd[f"model.layers.{i - 1}.self_attn.q_norm.weight"] = state_dict[
|
||||
f"{i}.attn.q_ln.weight"
|
||||
]
|
||||
return new_sd
|
||||
|
||||
|
||||
def qwen3_transformer_block_param_name(config: ReaLModelConfig, idx: int) -> List[str]:
|
||||
names = []
|
||||
for k in ["weight", "bias"]:
|
||||
names += [
|
||||
f"model.layers.{idx}.input_layernorm.{k}",
|
||||
f"model.layers.{idx}.mlp.down_proj.{k}",
|
||||
f"model.layers.{idx}.mlp.gate_proj.{k}",
|
||||
f"model.layers.{idx}.mlp.up_proj.{k}",
|
||||
f"model.layers.{idx}.post_attention_layernorm.{k}",
|
||||
f"model.layers.{idx}.self_attn.k_proj.{k}",
|
||||
f"model.layers.{idx}.self_attn.o_proj.{k}",
|
||||
f"model.layers.{idx}.self_attn.q_proj.{k}",
|
||||
# f"model.layers.{idx}.self_attn.rotary_emb.inv_freq",
|
||||
f"model.layers.{idx}.self_attn.v_proj.{k}",
|
||||
]
|
||||
if idx == config.n_layers - 1:
|
||||
names += [f"model.norm.{k}"]
|
||||
# Qwen3
|
||||
if config.qk_layernorm:
|
||||
names += [
|
||||
f"model.layers.{idx}.self_attn.q_norm.weight",
|
||||
f"model.layers.{idx}.self_attn.k_norm.weight",
|
||||
]
|
||||
return names
|
||||
|
||||
|
||||
register_hf_family(
|
||||
name="qwen3",
|
||||
hf_cls_name="Qwen3ForCausalLM", # "Qwen3ForCausalLM"
|
||||
config_from_hf_converter=convert_config_qwen3,
|
||||
config_to_hf_converter=convert_config_back_qwen3,
|
||||
sd_from_hf_converter=convert_state_dict_qwen3,
|
||||
sd_to_hf_converter=convert_state_dict_back_qwen3,
|
||||
embedding_param_names=llama_embedding_layer_names,
|
||||
tblock_param_names=qwen3_transformer_block_param_name,
|
||||
head_param_names=llama_output_head_param_name,
|
||||
real_config_maker=qwen3_config_maker,
|
||||
)
|
|
@ -224,18 +224,18 @@ def find_parallel_strategies(
|
|||
) -> List[ParallelismConfig]:
|
||||
n_gpus = np.sum(device_mesh.mapping)
|
||||
res = []
|
||||
for num_mp in [1, 2, 4, 8]:
|
||||
if n_gpus >= num_mp:
|
||||
assert n_gpus % num_mp == 0
|
||||
num_dp_pp = n_gpus // num_mp
|
||||
for num_tp in [1, 2, 4, 8]:
|
||||
if n_gpus >= num_tp:
|
||||
assert n_gpus % num_tp == 0
|
||||
num_dp_pp = n_gpus // num_tp
|
||||
num_pp = 1
|
||||
while num_pp <= num_dp_pp:
|
||||
num_dp_mp = n_gpus // num_pp
|
||||
num_dp_tp = n_gpus // num_pp
|
||||
valid = (
|
||||
num_dp_mp in [1, 2, 4, 8] or num_dp_mp % 8 == 0
|
||||
num_dp_tp in [1, 2, 4, 8] or num_dp_tp % 8 == 0
|
||||
) and num_dp_pp % num_pp == 0
|
||||
if valid:
|
||||
res.append(ParallelismConfig(num_pp, num_mp, num_dp_pp // num_pp))
|
||||
res.append(ParallelismConfig(num_pp, num_tp, num_dp_pp // num_pp))
|
||||
num_pp += 1
|
||||
return res
|
||||
|
||||
|
@ -248,7 +248,7 @@ class RPCAllocation:
|
|||
|
||||
def __post_init__(self):
|
||||
world_size = (
|
||||
self.parallel.model_parallel_size
|
||||
self.parallel.tensor_parallel_size
|
||||
* self.parallel.pipeline_parallel_size
|
||||
* self.parallel.data_parallel_size
|
||||
)
|
||||
|
|
|
@ -8,12 +8,10 @@ import functools
|
|||
import inspect
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable
|
||||
|
||||
import hydra
|
||||
import omegaconf
|
||||
import yaml
|
||||
from hydra.core.config_store import ConfigStore
|
||||
from omegaconf import MISSING, OmegaConf
|
||||
|
||||
|
@ -29,6 +27,9 @@ def kind_reminder(config_name, logger, args):
|
|||
logger.info(
|
||||
f"Logs will be dumped to {os.path.join(LOG_ROOT, args.experiment_name, args.trial_name)}"
|
||||
)
|
||||
logger.info(
|
||||
f"Experiment configs will be dumped to {os.path.join(LOG_ROOT, args.experiment_name, args.trial_name, 'config.yaml')}"
|
||||
)
|
||||
logger.info(
|
||||
f"Model checkpoints will be saved to {os.path.join(MODEL_SAVE_ROOT, args.experiment_name, args.trial_name)}"
|
||||
)
|
||||
|
@ -69,7 +70,7 @@ def register_quickstart_exp(config_name: str, exp_cls: Callable):
|
|||
|
||||
logger = logging.getLogger("quickstart", "colored")
|
||||
|
||||
print_runtime_helper(OmegaConf.to_object(args))
|
||||
# print_runtime_helper(OmegaConf.to_object(args))
|
||||
|
||||
exp_name = args.experiment_name
|
||||
if args.trial_name == MISSING:
|
||||
|
@ -80,6 +81,17 @@ def register_quickstart_exp(config_name: str, exp_cls: Callable):
|
|||
trial_name = args.trial_name
|
||||
from realhf.apps.main import main_start, main_stop
|
||||
|
||||
config_save_path = os.path.join(
|
||||
LOG_ROOT, args.experiment_name, args.trial_name, "config.yaml"
|
||||
)
|
||||
with open(config_save_path, "w") as f:
|
||||
yaml.dump(
|
||||
dataclasses.asdict(OmegaConf.to_object(args)),
|
||||
f,
|
||||
default_flow_style=False,
|
||||
sort_keys=False,
|
||||
)
|
||||
|
||||
kind_reminder(config_name, logger, args)
|
||||
|
||||
exp_fn = functools.partial(exp_cls, **args)
|
||||
|
|
|
@ -87,7 +87,7 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0):
|
|||
job_group_id = str(uuid.uuid4())
|
||||
logger.info(f"AReaL Version: {get_full_version_with_dirty_description()}")
|
||||
logger.info(f"AReaL Job Group ID: {job_group_id}")
|
||||
logger.info(f"AReaL Job Group Index: {recover_count}")
|
||||
logger.info(f"AReaL Job Group Index (recover count): {recover_count}")
|
||||
if recover_count == 0:
|
||||
constants.set_experiment_trial_names(args.experiment_name, args.trial_name)
|
||||
experiment = config_package.make_experiment(args.experiment_name)
|
||||
|
@ -110,10 +110,6 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0):
|
|||
assert (
|
||||
args.recover_mode == "disabled"
|
||||
), "Recover mode is not supported for local runs!"
|
||||
# Use search cache for recover runs
|
||||
force_allocation_use_cache = (
|
||||
recover_count > 1 or args.recover_mode == "resume"
|
||||
) and args.allocation_mode == "search"
|
||||
# handle args
|
||||
args.ignore_worker_error = (
|
||||
args.ignore_worker_error and args.recover_mode == "disabled"
|
||||
|
@ -174,12 +170,8 @@ def main_start(args, job_group_id: str = "", recover_count: int = 0):
|
|||
)
|
||||
for k, v in BASE_ENVIRONS.items():
|
||||
os.environ[k] = v
|
||||
os.environ["REAL_IS_REMOTE"] = "0" if not force_allocation_use_cache else "1"
|
||||
|
||||
# setup experiments
|
||||
if args.allocation_mode == "search":
|
||||
experiment._search()
|
||||
|
||||
sched = sched_client.make(
|
||||
mode=scheduler_mode(args.mode),
|
||||
expr_name=expr_name,
|
||||
|
@ -324,80 +316,6 @@ def main_find_config(args):
|
|||
print(exp_name)
|
||||
|
||||
|
||||
def main_profile_layers(args):
|
||||
from realhf.api.cli_args import ModelFamily
|
||||
|
||||
_main_profile_layers(
|
||||
ModelFamily(args.model_class, args.model_size, args.is_critic),
|
||||
args.model_path,
|
||||
)
|
||||
|
||||
|
||||
def _main_profile_layers(model_family, model_path):
|
||||
from realhf.api.cli_args import ModelFamily
|
||||
from realhf.base.slurm_utils import check_slurm_availability
|
||||
from realhf.base.testing import clear_name_resolve
|
||||
|
||||
expr_name = trial_name = "profile"
|
||||
cmd = (
|
||||
f"python3 -m realhf.apps.profile_layers --expr_name {expr_name} --trial_name {trial_name} "
|
||||
f"--model_path {model_path} --model_name {model_family} "
|
||||
)
|
||||
|
||||
if check_slurm_availability():
|
||||
if not os.environ.get("CLUSTER_SPEC_PATH", ""):
|
||||
raise ValueError(
|
||||
"Environment variable CLUSTER_SPEC_PATH must be set for slurm mode! "
|
||||
"See example/cluster_config.json for a template."
|
||||
)
|
||||
BASE_ENVIRONS = constants.get_env_vars(
|
||||
REAL_MODE="slurm",
|
||||
CLUSTER_SPEC_PATH=os.environ.get("CLUSTER_SPEC_PATH", ""),
|
||||
)
|
||||
clear_name_resolve(expr_name, trial_name)
|
||||
sched = sched_client.make(
|
||||
mode="slurm", expr_name=expr_name, trial_name=trial_name
|
||||
)
|
||||
print(
|
||||
f"Profiling {model_family} layers, model path {model_path}, " f"cmd {cmd}"
|
||||
)
|
||||
sched.submit_array(
|
||||
worker_type="profile_layer",
|
||||
cmd=cmd,
|
||||
count=1,
|
||||
cpu=64,
|
||||
gpu=8,
|
||||
gpu_type="tesla",
|
||||
mem=500000,
|
||||
env_vars=BASE_ENVIRONS,
|
||||
container_image=config_package._LLM_GPU_IMAGE,
|
||||
)
|
||||
|
||||
try:
|
||||
sched.wait(timeout=None)
|
||||
except (
|
||||
KeyboardInterrupt,
|
||||
sched_client.JobException,
|
||||
TimeoutError,
|
||||
) as e:
|
||||
sched.stop_all()
|
||||
raise e
|
||||
else:
|
||||
try:
|
||||
print(
|
||||
f"Profiling {model_family} layers, model path {model_path}, "
|
||||
f"cmd {cmd}"
|
||||
)
|
||||
clear_name_resolve(expr_name, trial_name)
|
||||
os.system(cmd)
|
||||
except (
|
||||
KeyboardInterrupt,
|
||||
sched_client.JobException,
|
||||
TimeoutError,
|
||||
) as e:
|
||||
raise e
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(prog="ReaLHF")
|
||||
subparsers = parser.add_subparsers(dest="cmd", help="sub-command help")
|
||||
|
@ -482,7 +400,7 @@ def main():
|
|||
type=str,
|
||||
required=False,
|
||||
default="pipe_model",
|
||||
choices=["manual", "search", "heuristic", "pipe_model", "pipe_data"],
|
||||
choices=["manual", "heuristic", "pipe_model", "pipe_data"],
|
||||
help="Mode of GPU resource/model parallel strategy allocation.",
|
||||
)
|
||||
subparser.set_defaults(ignore_worker_error=False)
|
||||
|
@ -514,15 +432,6 @@ def main():
|
|||
subparser.add_argument("--regex", "-r", type=str, required=True)
|
||||
subparser.set_defaults(func=main_find_config)
|
||||
|
||||
subparser = subparsers.add_parser(
|
||||
"profile_layers", help="profile layers of a model."
|
||||
)
|
||||
subparser.add_argument("--model_class", type=str, required=True)
|
||||
subparser.add_argument("--model_size", type=int, required=True)
|
||||
subparser.add_argument("--is_critic", action="store_true")
|
||||
subparser.add_argument("--model_path", type=str, required=True)
|
||||
subparser.set_defaults(func=main_profile_layers)
|
||||
|
||||
args = parser.parse_args()
|
||||
args.func(args)
|
||||
|
||||
|
|
|
@ -1,91 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import time
|
||||
|
||||
import realhf.base.testing as testing
|
||||
|
||||
BATCH_SIZE_RANGE = [1, 2, 4, 8, 16, 32, 64, 128]
|
||||
SEQ_LEN_RANGE = [128, 256, 512]
|
||||
|
||||
|
||||
def profile_layer_func(
|
||||
world_size,
|
||||
model_path,
|
||||
model_name,
|
||||
warm_up_rounds,
|
||||
profile_rounds,
|
||||
batch_size_range,
|
||||
seq_len_range,
|
||||
use_sequence_parallel=False,
|
||||
use_gradient_checkpointing=False,
|
||||
):
|
||||
# FIXME: use_sequence_parallel=True and use_gradient_checkpointing=True will cause bugs
|
||||
import torch
|
||||
|
||||
import realhf.base.constants as constants
|
||||
|
||||
testing.init_global_constants(
|
||||
1, world_size, 1, sequence_parallel=False, gradient_checkpointing=False
|
||||
)
|
||||
device = torch.device("cuda")
|
||||
with constants.model_scope(testing.MODEL_NAME):
|
||||
from realhf.search_engine.layers import make_profile_layers
|
||||
|
||||
profile_layers = make_profile_layers(device, model_path, model_name)
|
||||
|
||||
st = time.monotonic_ns()
|
||||
for i in range(warm_up_rounds + profile_rounds):
|
||||
for bs, seq_len in itertools.product(batch_size_range, seq_len_range):
|
||||
profile_layers.fwd_gen(bs, seq_len)
|
||||
profile_layers.fwd_bwd_opt(bs, seq_len)
|
||||
|
||||
if i < warm_up_rounds:
|
||||
profile_layers.reset_stats()
|
||||
profile_layers.make_dataframe_and_print()
|
||||
profile_layers.dump_stats(world_size)
|
||||
t = (time.monotonic_ns() - st) / int(1e9)
|
||||
print(f"profile world size {world_size} cost {t:4f} seconds")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
st = time.monotonic_ns()
|
||||
parser = argparse.ArgumentParser(prog="profile_layers")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument("--expr_name", type=str, default="profile")
|
||||
parser.add_argument("--trial_name", type=str, default="profile")
|
||||
parser.add_argument("--model_name", type=str, default="Llama-2-70b")
|
||||
parser.add_argument("--warm_up_rounds", type=int, default=1)
|
||||
parser.add_argument("--profile_rounds", type=int, default=3)
|
||||
# parser.add_argument("--use_sequence_parallel", action="store_true")
|
||||
# parser.add_argument("--use_gradient_checkpointing", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
world_sizes = [1, 2, 4, 8]
|
||||
|
||||
for world_size in world_sizes:
|
||||
testing.clear_name_resolve(args.expr_name, args.trial_name)
|
||||
mp = testing.LocalMultiProcessTest(
|
||||
world_size,
|
||||
profile_layer_func,
|
||||
world_size,
|
||||
args.model_path,
|
||||
args.model_name,
|
||||
args.warm_up_rounds,
|
||||
args.profile_rounds,
|
||||
BATCH_SIZE_RANGE,
|
||||
SEQ_LEN_RANGE,
|
||||
expr_name=args.expr_name,
|
||||
trial_name=args.trial_name,
|
||||
)
|
||||
mp.launch()
|
||||
|
||||
t = (time.monotonic_ns() - st) / int(1e9)
|
||||
print(f"profile model {args.model_name} time cost {t:4f} seconds")
|
|
@ -72,6 +72,7 @@ MODEL_SAVE_ROOT = f"{cluster_spec.fileroot}/checkpoints/{getpass.getuser()}"
|
|||
LOG_ROOT = f"{cluster_spec.fileroot}/logs/{getpass.getuser()}"
|
||||
RECOVER_ROOT = f"{cluster_spec.fileroot}/recover/{getpass.getuser()}"
|
||||
SLURM_LOCK_FILE_NAME = f"{cluster_spec.fileroot}/logs/slurm_scheduler.lock"
|
||||
PORT_LOCK_FILE_ROOT = f"{cluster_spec.fileroot}/.cache/{getpass.getuser()}/ports"
|
||||
PYTORCH_KERNEL_CACHE_PATH = (
|
||||
f"{LOCAL_CACHE_DIR}/.cache/{getpass.getuser()}/torch/kernels"
|
||||
)
|
||||
|
@ -120,6 +121,9 @@ BASE_ENVIRONS = {
|
|||
"REAL_GPU_MEMORY_KILL_THRESHOLD": os.getenv(
|
||||
"REAL_GPU_MEMORY_KILL_THRESHOLD", "0.95"
|
||||
),
|
||||
"LC_ALL": "C",
|
||||
"LANG": "C",
|
||||
"NCCL_DEBUG": "WARN",
|
||||
}
|
||||
|
||||
# Set PPU-specific environment variables for stable training.
|
||||
|
@ -146,7 +150,6 @@ elif cluster_spec.name == "na132":
|
|||
"NCCL_IB_SL": "5",
|
||||
"NCCL_IB_TC": "136",
|
||||
"NCCL_IB_HCA": "mlx5_bond",
|
||||
"NCCL_DEBUG": "WARN",
|
||||
"NCCL_IB_QPS_PER_CONNECTION": "8",
|
||||
"NCCL_SET_THREAD_NAME": "1",
|
||||
"NCCL_DEBUG_SUBSYS": "INIT,TUNING,GRAPH",
|
||||
|
@ -165,6 +168,7 @@ os.makedirs(DATASET_CACHE_PATH, exist_ok=True)
|
|||
os.makedirs(PROFILER_CACHE_PATH, exist_ok=True)
|
||||
os.makedirs(TORCH_EXTENSIONS_DIR, exist_ok=True)
|
||||
os.makedirs(QUICKSTART_EXPR_CACHE_PATH, exist_ok=True)
|
||||
os.makedirs(PORT_LOCK_FILE_ROOT, exist_ok=True)
|
||||
os.makedirs(SGLANG_CACHE_PATH, exist_ok=True)
|
||||
|
||||
# _model_name will be changed in the model_scope context manager
|
||||
|
@ -186,16 +190,12 @@ _self_group = None
|
|||
_rank_mapping: Dict["ModelName", Dict["ModelShardID", int]] = {}
|
||||
_global_memory_buffer: GlobalMemoryBuffer = GlobalMemoryBuffer()
|
||||
|
||||
# used only in scripts and tests
|
||||
_fake_mp_world_size = None
|
||||
_fake_mp_rank = None
|
||||
|
||||
|
||||
# TODO: As in Megatron, we can set NCCL group options. Is it necessary?
|
||||
|
||||
|
||||
def reset_run():
|
||||
global _model_name, _grids, _pgroups, _pgroup_ranks, _self_group, _rank_mapping, _global_memory_buffer, _fake_mp_world_size, _fake_mp_rank
|
||||
global _model_name, _grids, _pgroups, _pgroup_ranks, _self_group, _rank_mapping, _global_memory_buffer
|
||||
_model_name = None
|
||||
_grids = {}
|
||||
_pgroups = {}
|
||||
|
@ -203,8 +203,6 @@ def reset_run():
|
|||
_self_group = None
|
||||
_rank_mapping = {}
|
||||
_global_memory_buffer = GlobalMemoryBuffer()
|
||||
_fake_mp_world_size = None
|
||||
_fake_mp_rank = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
@ -284,7 +282,7 @@ def set_rank_mapping(
|
|||
else:
|
||||
msid2mwid = {k: v for k, v in msid2mwid.items() if k.model_name == model_name}
|
||||
_rank_mapping[model_name] = {
|
||||
topo.get_rank(data=s.dp_rank, model=s.mp_rank, pipe=s.pp_rank): mw_id
|
||||
topo.get_rank(data=s.dp_rank, tensor=s.tp_rank, pipe=s.pp_rank): mw_id
|
||||
for s, mw_id in msid2mwid.items()
|
||||
}
|
||||
|
||||
|
@ -408,7 +406,7 @@ def parallelism_group_ranks():
|
|||
|
||||
def parallelism_group_size() -> int:
|
||||
"""The 3D parallelism group size of a specific model, normally dp_size *
|
||||
pp_size * mp_size."""
|
||||
pp_size * tp_size."""
|
||||
import torch.distributed as dist
|
||||
|
||||
return dist.get_world_size(group=parallelism_group())
|
||||
|
@ -470,37 +468,25 @@ def prev_pipe_stage():
|
|||
|
||||
|
||||
def is_dp_head():
|
||||
return is_last_pipe_stage() and model_parallel_rank() == 0
|
||||
return is_last_pipe_stage() and tensor_parallel_rank() == 0
|
||||
|
||||
|
||||
def model_parallel_rank() -> int:
|
||||
def tensor_parallel_rank() -> int:
|
||||
"""Return the rank inside the tensor parallelism group."""
|
||||
try:
|
||||
return grid().get_tensor_model_parallel_rank()
|
||||
except RuntimeError as e: # used only in scripts and tests
|
||||
if _fake_mp_rank is not None:
|
||||
return _fake_mp_rank
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
def model_parallel_world_size() -> int:
|
||||
def tensor_parallel_world_size() -> int:
|
||||
"""Return the world size of the tensor parallelism group."""
|
||||
try:
|
||||
return grid().get_tensor_model_parallel_world_size()
|
||||
except RuntimeError as e: # used only in scripts and tests
|
||||
if _fake_mp_world_size is not None:
|
||||
return _fake_mp_world_size
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
def model_parallel_group():
|
||||
def tensor_parallel_group():
|
||||
"""Return the NCCL tensor parallelism process group."""
|
||||
return grid().get_tensor_model_parallel_group()
|
||||
|
||||
|
||||
def model_parallel_cpu_group():
|
||||
def tensor_parallel_cpu_group():
|
||||
"""Return the GLOO tensor parallelism process group."""
|
||||
return grid().get_tensor_model_parallel_cpu_group()
|
||||
|
||||
|
@ -536,26 +522,6 @@ def data_parallel_group():
|
|||
return grid().get_data_parallel_group()
|
||||
|
||||
|
||||
def set_fake_mp_world_size(world_size):
|
||||
# used only in scripts and tests
|
||||
global _fake_mp_world_size
|
||||
_fake_mp_world_size = world_size
|
||||
|
||||
|
||||
def set_fake_mp_rank(rank):
|
||||
# used only in scripts and tests
|
||||
global _fake_mp_rank
|
||||
_fake_mp_rank = rank
|
||||
|
||||
|
||||
def set_fake_grid(model_name, rank, topo):
|
||||
# used only in scripts and tests
|
||||
from realhf.base.topology import FakeGrid
|
||||
|
||||
global _grids
|
||||
_grids[model_name] = FakeGrid(rank=rank, topo=topo)
|
||||
|
||||
|
||||
def get_global_memory_buffer():
|
||||
global _global_memory_buffer
|
||||
assert _global_memory_buffer is not None, "global memory buffer is not set"
|
||||
|
|
|
@ -62,7 +62,7 @@ def reveal_pg_identity(expr_name, trial_name, worker_index):
|
|||
master_group_name = names.distributed_peer(
|
||||
expr_name, trial_name, GLOBAL_PROCESS_GROUP_NAME
|
||||
)
|
||||
name_resolve.add_subentry(master_group_name, str(worker_index), keepalive_ttl=300)
|
||||
name_resolve.add_subentry(master_group_name, str(worker_index))
|
||||
|
||||
|
||||
def isolate_cuda_device(
|
||||
|
@ -100,12 +100,10 @@ def isolate_cuda_device(
|
|||
name_resolve_identifier,
|
||||
),
|
||||
rank,
|
||||
keepalive_ttl=60,
|
||||
)
|
||||
name_resolve.add_subentry(
|
||||
names.distributed_peer(experiment_name, trial_name, name_resolve_identifier),
|
||||
rank,
|
||||
keepalive_ttl=30,
|
||||
)
|
||||
logger.debug(
|
||||
f"Worker type {worker_type} rank {rank} waiting for peers, world size {world_size}..."
|
||||
|
|
|
@ -141,9 +141,18 @@ def getLogger(
|
|||
return logging.getLogger(name)
|
||||
|
||||
|
||||
_LATEST_WANDB_STEP = 0
|
||||
|
||||
|
||||
def log_wandb_tensorboard(data, step=None, summary_writer=None):
|
||||
import wandb
|
||||
|
||||
global _LATEST_WANDB_STEP
|
||||
if step is None:
|
||||
step = _LATEST_WANDB_STEP
|
||||
else:
|
||||
_LATEST_WANDB_STEP = max(_LATEST_WANDB_STEP, step)
|
||||
|
||||
wandb.log(data, step=step)
|
||||
if summary_writer is not None:
|
||||
for key, val in data.items():
|
||||
|
|
|
@ -618,7 +618,7 @@ class Etcd3NameRecordRepository(NameRecordRepository):
|
|||
|
||||
self._to_delete = set()
|
||||
|
||||
logger.info(f"Connected to etcd3 at {self._host}:{self._port}")
|
||||
logger.debug(f"Connected to etcd3 at {self._host}:{self._port}")
|
||||
|
||||
def __del__(self):
|
||||
"""Clean up resources when the object is deleted."""
|
||||
|
@ -945,12 +945,13 @@ def make_repository(type_="nfs", **kwargs):
|
|||
|
||||
# DEFAULT_REPOSITORY_TYPE = "redis" if socket.gethostname().startswith("frl") else "nfs"
|
||||
DEFAULT_REPOSITORY_TYPE = "nfs"
|
||||
if (
|
||||
etcd3 is not None
|
||||
and cluster.spec.name in ["wa180", "na132", "su18"]
|
||||
and os.getenv("REAL_ETCD_ADDR", "")
|
||||
):
|
||||
if etcd3 is not None and os.getenv("REAL_ETCD_ADDR", ""):
|
||||
DEFAULT_REPOSITORY_TYPE = "etcd3"
|
||||
if os.getenv("REAL_ETCD_ADDR", "") and etcd3 is None:
|
||||
logger.warning(
|
||||
f"Detected REAL_ETCD_ADDR but etcd3 client is not available. "
|
||||
"Please run `pip install -r requirements.txt` if you want to use etcd name resolve."
|
||||
)
|
||||
DEFAULT_REPOSITORY = make_repository(DEFAULT_REPOSITORY_TYPE)
|
||||
add = DEFAULT_REPOSITORY.add
|
||||
add_subentry = DEFAULT_REPOSITORY.add_subentry
|
||||
|
|
|
@ -93,5 +93,9 @@ def gen_servers(experiment_name, trial_name):
|
|||
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_servers"
|
||||
|
||||
|
||||
def used_ports(experiment_name, trial_name, host_name):
|
||||
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/{host_name}/"
|
||||
|
||||
|
||||
def gen_server_manager(experiment_name, trial_name):
|
||||
return f"{USER_NAMESPACE}/{experiment_name}/{trial_name}/gen_server_manager"
|
||||
|
|
|
@ -2,31 +2,16 @@
|
|||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import fcntl
|
||||
import os
|
||||
import socket
|
||||
import time
|
||||
from contextlib import closing
|
||||
from functools import wraps
|
||||
|
||||
from realhf.base import constants, logging, name_resolve, names
|
||||
|
||||
def find_free_port(low=1, high=65536, exclude_ports=None):
|
||||
"""Find a free port within the specified range, excluding certain ports."""
|
||||
if exclude_ports is None:
|
||||
exclude_ports = set()
|
||||
|
||||
while True:
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
||||
s.bind(("", 0))
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
port = s.getsockname()[1]
|
||||
if low <= port <= high and port not in exclude_ports:
|
||||
return port
|
||||
|
||||
|
||||
def find_multiple_free_ports(count, low=1, high=65536):
|
||||
"""Find multiple mutually exclusive free ports."""
|
||||
free_ports = set()
|
||||
for _ in range(count):
|
||||
port = find_free_port(low, high, exclude_ports=free_ports)
|
||||
free_ports.add(port)
|
||||
return list(free_ports)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def gethostname():
|
||||
|
@ -35,3 +20,54 @@ def gethostname():
|
|||
|
||||
def gethostip():
|
||||
return socket.gethostbyname(socket.gethostname())
|
||||
|
||||
|
||||
def find_free_port(
|
||||
low=1, high=65536, exclude_ports=None, experiment_name="port", trial_name="port"
|
||||
):
|
||||
"""Find a free port within the specified range, excluding certain ports."""
|
||||
|
||||
ports_name = names.used_ports(experiment_name, trial_name, gethostip())
|
||||
used_ports = list(map(int, name_resolve.get_subtree(ports_name)))
|
||||
if exclude_ports is None:
|
||||
exclude_ports = set(used_ports)
|
||||
else:
|
||||
exclude_ports = exclude_ports.union(set(used_ports))
|
||||
|
||||
free_port = None
|
||||
lockfile = os.path.join(constants.PORT_LOCK_FILE_ROOT, gethostip())
|
||||
while True:
|
||||
with open(lockfile, "w") as fd:
|
||||
# This will block until lock is acquired
|
||||
fcntl.flock(fd, fcntl.LOCK_EX)
|
||||
try:
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
||||
s.bind(("", 0))
|
||||
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
port = s.getsockname()[1]
|
||||
if low <= port <= high and port not in exclude_ports:
|
||||
name_resolve.add_subentry(ports_name, str(port))
|
||||
logger.info(f"Found free port {port}")
|
||||
free_port = port
|
||||
break
|
||||
finally:
|
||||
fcntl.flock(fd, fcntl.LOCK_UN)
|
||||
time.sleep(0.05)
|
||||
return free_port
|
||||
|
||||
|
||||
def find_multiple_free_ports(
|
||||
count, low=1, high=65536, experiment_name="port", trial_name="port"
|
||||
):
|
||||
"""Find multiple mutually exclusive free ports."""
|
||||
free_ports = set()
|
||||
for _ in range(count):
|
||||
port = find_free_port(
|
||||
low=low,
|
||||
high=high,
|
||||
exclude_ports=free_ports,
|
||||
experiment_name=experiment_name,
|
||||
trial_name=trial_name,
|
||||
)
|
||||
free_ports.add(port)
|
||||
return list(free_ports)
|
||||
|
|
|
@ -171,6 +171,9 @@ class DistributedStatsTracker:
|
|||
else:
|
||||
raise ValueError(f"Unknown reduce type: {reduce_type}")
|
||||
|
||||
keys_to_pop = [k for k, v in result.items() if v is None]
|
||||
for k in keys_to_pop:
|
||||
result.pop(k)
|
||||
return result
|
||||
|
||||
def _sum_of(self, key, reduce_group):
|
||||
|
@ -209,7 +212,7 @@ class DistributedStatsTracker:
|
|||
dist.all_reduce(x, group=reduce_group)
|
||||
dist.all_reduce(d, group=reduce_group)
|
||||
if d == 0:
|
||||
return 0
|
||||
return None
|
||||
return x / d
|
||||
|
||||
def _min_of(self, key, reduce_group):
|
||||
|
@ -224,7 +227,7 @@ class DistributedStatsTracker:
|
|||
if reduce_group is not None:
|
||||
dist.all_reduce(x, group=reduce_group, op=dist.ReduceOp.MIN)
|
||||
if torch.isinf(x):
|
||||
return float("nan")
|
||||
return None
|
||||
return float(x)
|
||||
|
||||
def _max_of(self, key, reduce_group):
|
||||
|
@ -239,7 +242,7 @@ class DistributedStatsTracker:
|
|||
if reduce_group is not None:
|
||||
dist.all_reduce(x, group=reduce_group, op=dist.ReduceOp.MAX)
|
||||
if torch.isinf(x):
|
||||
return float("nan")
|
||||
return None
|
||||
return float(x)
|
||||
|
||||
|
||||
|
|
|
@ -22,9 +22,9 @@ import torch.utils.data
|
|||
from realhf.api.core.data_api import SequenceSample
|
||||
from realhf.base import constants, gpu_utils, logging, name_resolve, names, topology
|
||||
from realhf.base.topology import (
|
||||
DataPipeModelParallelTopology,
|
||||
DataPipeTensorParallelTopology,
|
||||
ParallelGrid,
|
||||
PipeDataModelParallelTopology,
|
||||
PipeDataTensorParallelTopology,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("testing")
|
||||
|
@ -106,9 +106,6 @@ class StandaloneTestingProcess(mp.Process):
|
|||
self.expr_name, self.trial_name, self.rank, backend=self.dist_backend
|
||||
)
|
||||
|
||||
# setup some useful constants
|
||||
constants.set_experiment_trial_names(self.expr_name, self.trial_name)
|
||||
|
||||
# misc setup
|
||||
if constants.use_cuda():
|
||||
pynvml.nvmlInit()
|
||||
|
@ -207,7 +204,7 @@ class LocalMultiProcessTest:
|
|||
|
||||
def init_global_constants(
|
||||
num_dp=1,
|
||||
num_mp=1,
|
||||
num_tp=1,
|
||||
num_pp=1,
|
||||
topo=None,
|
||||
model_name=None,
|
||||
|
@ -227,9 +224,9 @@ def init_global_constants(
|
|||
|
||||
if topo is None:
|
||||
if is_train:
|
||||
topo = PipeDataModelParallelTopology(
|
||||
topo = PipeDataTensorParallelTopology(
|
||||
num_dp=num_dp,
|
||||
num_mp=num_mp,
|
||||
num_tp=num_tp,
|
||||
num_pp=num_pp,
|
||||
sequence_parallel=sequence_parallel,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
|
@ -237,13 +234,13 @@ def init_global_constants(
|
|||
max_prompt_len=max_prompt_len,
|
||||
)
|
||||
else:
|
||||
topo = DataPipeModelParallelTopology(
|
||||
topo = DataPipeTensorParallelTopology(
|
||||
num_dp=num_dp,
|
||||
num_mp=num_mp,
|
||||
num_tp=num_tp,
|
||||
num_pp=num_pp,
|
||||
sequence_parallel=sequence_parallel,
|
||||
)
|
||||
ws = num_dp * num_mp * num_pp
|
||||
ws = num_dp * num_tp * num_pp
|
||||
else:
|
||||
ws = topo.world_size()
|
||||
|
||||
|
|
|
@ -65,22 +65,22 @@ def decompose_to_three_factors(n: int) -> List[Tuple[int, int, int]]:
|
|||
return factors
|
||||
|
||||
|
||||
class PipeDataModelrocessCoord(NamedTuple):
|
||||
class PipeDataTensorProcessCoord(NamedTuple):
|
||||
pipe: int
|
||||
data: int
|
||||
model: int
|
||||
tensor: int
|
||||
|
||||
|
||||
class DataPipeModelrocessCoord(NamedTuple):
|
||||
class DataPipeTensorProcessCoord(NamedTuple):
|
||||
data: int
|
||||
pipe: int
|
||||
model: int
|
||||
tensor: int
|
||||
|
||||
|
||||
# Explicitly define these class to allow pickling.
|
||||
PROCESS_COORD_REGISTRY = {
|
||||
"pipe#data#model": PipeDataModelrocessCoord,
|
||||
"data#pipe#model": DataPipeModelrocessCoord,
|
||||
"pipe#data#tensor": PipeDataTensorProcessCoord,
|
||||
"data#pipe#tensor": DataPipeTensorProcessCoord,
|
||||
}
|
||||
|
||||
|
||||
|
@ -327,20 +327,20 @@ def _prime_factors(N):
|
|||
return primes
|
||||
|
||||
|
||||
class PipeDataModelParallelTopology(ProcessTopology):
|
||||
class PipeDataTensorParallelTopology(ProcessTopology):
|
||||
"""A topology for hybrid pipeline, model, and data parallelism."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_pp: int,
|
||||
num_mp: int,
|
||||
num_tp: int,
|
||||
num_dp: int,
|
||||
sequence_parallel: bool,
|
||||
gradient_checkpointing: bool,
|
||||
gradient_accumulation_fusion: bool,
|
||||
max_prompt_len: Optional[int] = None,
|
||||
):
|
||||
super().__init__(axes=["pipe", "data", "model"], dims=[num_pp, num_dp, num_mp])
|
||||
super().__init__(axes=["pipe", "data", "tensor"], dims=[num_pp, num_dp, num_tp])
|
||||
|
||||
self.sequence_parallel = sequence_parallel
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
@ -348,7 +348,7 @@ class PipeDataModelParallelTopology(ProcessTopology):
|
|||
self.gradient_accumulation_fusion = gradient_accumulation_fusion
|
||||
|
||||
|
||||
class DataPipeModelParallelTopology(ProcessTopology):
|
||||
class DataPipeTensorParallelTopology(ProcessTopology):
|
||||
"""A topology for hybrid data, pipeline, and tensor parallelism.
|
||||
|
||||
Note that DP is the most outer dimension. Used for inference only.
|
||||
|
@ -357,12 +357,12 @@ class DataPipeModelParallelTopology(ProcessTopology):
|
|||
def __init__(
|
||||
self,
|
||||
num_pp: int,
|
||||
num_mp: int,
|
||||
num_tp: int,
|
||||
num_dp: int,
|
||||
sequence_parallel: bool,
|
||||
max_prompt_len: Optional[int] = None,
|
||||
):
|
||||
super().__init__(axes=["data", "pipe", "model"], dims=[num_dp, num_pp, num_mp])
|
||||
super().__init__(axes=["data", "pipe", "tensor"], dims=[num_dp, num_pp, num_tp])
|
||||
self.sequence_parallel = sequence_parallel
|
||||
self.max_prompt_len = max_prompt_len
|
||||
|
||||
|
@ -414,7 +414,7 @@ class ParallelGrid:
|
|||
|
||||
self.data_parallel_size = max(self._topo.get_dim("data"), 1)
|
||||
self.pipe_parallel_size = max(self._topo.get_dim("pipe"), 1)
|
||||
self.model_parallel_size = max(self._topo.get_dim("model"), 1)
|
||||
self.model_parallel_size = max(self._topo.get_dim("tensor"), 1)
|
||||
self.slice_parallel_size = self.model_parallel_size
|
||||
assert self._is_grid_valid(), (
|
||||
"Invalid Grid",
|
||||
|
@ -520,7 +520,7 @@ class ParallelGrid:
|
|||
self.slice_group = None
|
||||
self.slice_proc_group = self.slice_proc_group_gloo = None
|
||||
self.mp_group = []
|
||||
self.model_groups = self._topo.get_axis_comm_lists("model")
|
||||
self.model_groups = self._topo.get_axis_comm_lists("tensor")
|
||||
for g in self.model_groups:
|
||||
proc_group = new_or_get_group(ranks=[rank_mapping[x] for x in g])
|
||||
# NOTE: We must create the GLOO group for vLLM's usage.
|
||||
|
@ -634,8 +634,8 @@ class ParallelGrid:
|
|||
def get_tensor_model_parallel_rank(self):
|
||||
if self.global_rank == -1:
|
||||
return -1
|
||||
if "model" in self._topo.get_axis_names():
|
||||
return self._topo.get_coord(rank=self.global_rank).model
|
||||
if "tensor" in self._topo.get_axis_names():
|
||||
return self._topo.get_coord(rank=self.global_rank).tensor
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
@ -662,12 +662,12 @@ class FakeGrid:
|
|||
|
||||
self.data_parallel_size = max(self._topo.get_dim("data"), 1)
|
||||
self.pipe_parallel_size = max(self._topo.get_dim("pipe"), 1)
|
||||
self.model_parallel_size = max(self._topo.get_dim("model"), 1)
|
||||
self.model_parallel_size = max(self._topo.get_dim("tensor"), 1)
|
||||
|
||||
self.coord: ProcessCoord = self._topo.get_coord(self.rank)
|
||||
self.coord = self._topo.get_coord(self.rank)
|
||||
self.dp_id = self.coord.data
|
||||
self.pp_id = self.coord.pipe
|
||||
self.mp_id = self.coord.model
|
||||
self.mp_id = self.coord.tensor
|
||||
|
||||
self.world_size = (
|
||||
self.data_parallel_size * self.pipe_parallel_size * self.model_parallel_size
|
||||
|
|
|
@ -6,7 +6,11 @@ from typing import Any, Dict, List, Tuple
|
|||
|
||||
import realhf.base.logging as logging
|
||||
from realhf.api.cli_args import ModelTrainEvalConfig, PPOMATHExperimentOptions
|
||||
from realhf.api.core.config import AgentAbstraction, EnvServiceAbstraction
|
||||
from realhf.api.core.config import (
|
||||
AgentAbstraction,
|
||||
EnvServiceAbstraction,
|
||||
ModelInterfaceAbstraction,
|
||||
)
|
||||
from realhf.api.core.model_api import GenerationHyperparameters
|
||||
from realhf.api.quickstart.entrypoint import register_quickstart_exp
|
||||
from realhf.experiments.async_exp.async_rl_exp import AsyncRLExperimentConfig
|
||||
|
@ -36,7 +40,7 @@ class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig):
|
|||
@property
|
||||
def env(self) -> EnvServiceAbstraction:
|
||||
return EnvServiceAbstraction(
|
||||
"math-single-step", args=dict(dataset_path=self.dataset.path)
|
||||
"math-code-single-step", args=dict(dataset_path=self.dataset.path)
|
||||
)
|
||||
|
||||
@property
|
||||
|
@ -71,6 +75,11 @@ class AsyncPPOMATHConfig(AsyncRLExperimentConfig, PPOMATHConfig):
|
|||
rpcs["ref_inf"].output_keys = ("packed_ref_logprobs",)
|
||||
if "rew_inf" in rpcs:
|
||||
rpcs.pop("rew_inf")
|
||||
if self.no_training:
|
||||
rpcs["actor_train"].interface_impl = ModelInterfaceAbstraction("null")
|
||||
rpcs["actor_gen"].interface_impl = ModelInterfaceAbstraction("null")
|
||||
if "actor_inf" in rpcs:
|
||||
rpcs["actor_inf"].interface_impl = ModelInterfaceAbstraction("null")
|
||||
return rpcs
|
||||
|
||||
@property
|
||||
|
|
|
@ -60,7 +60,6 @@ GEN_WORKER_DEFAULT_CAPACITY = 512
|
|||
|
||||
@dataclasses.dataclass
|
||||
class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
|
||||
|
||||
@property
|
||||
def generation_config(self) -> GenerationHyperparameters:
|
||||
raise NotImplementedError()
|
||||
|
@ -203,16 +202,17 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
|
|||
"config_from_hf_converter"
|
||||
](hf_config)
|
||||
if (
|
||||
model_config.n_kv_heads % rpc_alloc.parallel.model_parallel_size
|
||||
model_config.n_kv_heads % rpc_alloc.parallel.tensor_parallel_size
|
||||
!= 0
|
||||
) or (
|
||||
model_config.n_q_heads % rpc_alloc.parallel.model_parallel_size != 0
|
||||
model_config.n_q_heads % rpc_alloc.parallel.tensor_parallel_size
|
||||
!= 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"The number of KV heads {model_config.n_kv_heads} or "
|
||||
f"Q heads {model_config.n_q_heads} is not"
|
||||
f" divisible by the configured TP size "
|
||||
f"({rpc_alloc.parallel.model_parallel_size}). "
|
||||
f"({rpc_alloc.parallel.tensor_parallel_size}). "
|
||||
f"Please decrease TP size."
|
||||
)
|
||||
mapping = rpc_alloc.device_mesh.mapping
|
||||
|
@ -250,7 +250,7 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
|
|||
topo=topo,
|
||||
dp_rank=topo.get_coord(shard_idx).data,
|
||||
pp_rank=topo.get_coord(shard_idx).pipe,
|
||||
mp_rank=topo.get_coord(shard_idx).model,
|
||||
tp_rank=topo.get_coord(shard_idx).tensor,
|
||||
),
|
||||
model=model,
|
||||
backend=backend,
|
||||
|
@ -308,15 +308,18 @@ class AsyncRLExperimentConfig(CommonExperimentConfig, AsyncRLOptions):
|
|||
model_name = gen_rpc_alloc.rpc.model_name
|
||||
train_rpcs = [alloc.rpc for alloc in rpc_allocs if alloc.rpc.is_train()]
|
||||
assert all(rpc.n_seqs == train_rpcs[0].n_seqs for rpc in train_rpcs)
|
||||
max_concurrent_rollouts = self.max_concurrent_rollouts
|
||||
if max_concurrent_rollouts is None:
|
||||
max_concurrent_rollouts = train_rpcs[0].n_seqs
|
||||
return [
|
||||
GserverManager(
|
||||
model_name=model_name,
|
||||
flush_request_timeout=self.flush_request_timeout,
|
||||
n_servers=gen_world_size // gen_tp_size,
|
||||
schedule_policy="round_robin",
|
||||
schedule_policy=self.schedule_policy,
|
||||
max_head_offpolicyness=self.max_head_offpolicyness,
|
||||
train_batch_size=train_rpcs[0].n_seqs,
|
||||
max_concurrent_rollouts=self.max_concurrent_rollouts,
|
||||
max_concurrent_rollouts=max_concurrent_rollouts,
|
||||
)
|
||||
]
|
||||
|
||||
|
|
|
@ -37,21 +37,21 @@ def default_parallel_config(n_gpus: int) -> List[Dict[str, Any]]:
|
|||
x = [
|
||||
{
|
||||
"data_parallel_size": dp,
|
||||
"model_parallel_size": mp,
|
||||
"tensor_parallel_size": tp,
|
||||
"pipeline_parallel_size": pp,
|
||||
"use_sequence_parallel": mp > 1,
|
||||
"use_sequence_parallel": tp > 1,
|
||||
}
|
||||
for dp, mp, pp in factors
|
||||
for dp, tp, pp in factors
|
||||
]
|
||||
x += [
|
||||
{
|
||||
"data_parallel_size": dp,
|
||||
"model_parallel_size": mp,
|
||||
"tensor_parallel_size": tp,
|
||||
"pipeline_parallel_size": pp,
|
||||
"use_sequence_parallel": False,
|
||||
}
|
||||
for dp, mp, pp in factors
|
||||
if mp > 1
|
||||
for dp, tp, pp in factors
|
||||
if tp > 1
|
||||
]
|
||||
return x
|
||||
|
||||
|
@ -122,7 +122,7 @@ class ProfileConfig(CommonExperimentConfig):
|
|||
k
|
||||
in [
|
||||
"data_parallel_size",
|
||||
"model_parallel_size",
|
||||
"tensor_parallel_size",
|
||||
"pipeline_parallel_size",
|
||||
"use_sequence_parallel",
|
||||
]
|
||||
|
@ -130,7 +130,7 @@ class ProfileConfig(CommonExperimentConfig):
|
|||
), pcfg.keys()
|
||||
assert (self.n_nodes * self.n_gpus_per_node) == (
|
||||
pcfg.get("data_parallel_size", 1)
|
||||
* pcfg.get("model_parallel_size", 1)
|
||||
* pcfg.get("tensor_parallel_size", 1)
|
||||
* pcfg.get("pipeline_parallel_size", 1)
|
||||
)
|
||||
|
||||
|
@ -246,8 +246,6 @@ class ProfileConfig(CommonExperimentConfig):
|
|||
model_name="default",
|
||||
input_keys=["packed_prompts"],
|
||||
log_return_value=False,
|
||||
model_type=self._tmp_model.type,
|
||||
model_path=self._tmp_model.path,
|
||||
balanced_dp=True,
|
||||
)
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ def check_valid_parallel_batch_size(rpc_alloc: RPCAllocation):
|
|||
mb_spec = rpc.mb_spec
|
||||
|
||||
dp_size = rpc_alloc.parallel.data_parallel_size
|
||||
tp_size = rpc_alloc.parallel.model_parallel_size
|
||||
tp_size = rpc_alloc.parallel.tensor_parallel_size
|
||||
pp_size = rpc_alloc.parallel.pipeline_parallel_size
|
||||
|
||||
factor = 1
|
||||
|
|
|
@ -44,7 +44,6 @@ from realhf.api.quickstart.device_mesh import (
|
|||
)
|
||||
from realhf.base.cluster import spec as cluster_spec
|
||||
from realhf.experiments.common.check import (
|
||||
check_is_realhf_native_model_interface,
|
||||
check_valid_model_and_path,
|
||||
check_valid_optimizer,
|
||||
check_valid_parallel_batch_size,
|
||||
|
@ -61,7 +60,6 @@ from realhf.experiments.common.utils import (
|
|||
resolve_replica_ids,
|
||||
resolve_rpc_hooks,
|
||||
)
|
||||
from realhf.search_engine.search import search_rpc_allocations
|
||||
|
||||
# Register all HF models
|
||||
import realhf.api.from_hf # isort:skip
|
||||
|
@ -144,10 +142,6 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
|
|||
"""
|
||||
return None
|
||||
|
||||
@property
|
||||
def search_kwargs(self) -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def global_device_mesh(self) -> DeviceMesh:
|
||||
return DeviceMesh(
|
||||
|
@ -161,20 +155,6 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
|
|||
f"_heuristic_rpc_allocation is not implemented in {self.__class__}"
|
||||
)
|
||||
|
||||
def _search(self):
|
||||
# called in both api.main and controller
|
||||
gradient_checkpointing = any(
|
||||
model.gradient_checkpointing for model in self.models.values()
|
||||
)
|
||||
rpc_allocs: List[RPCAllocation] = search_rpc_allocations(
|
||||
device_mesh=self.global_device_mesh,
|
||||
rpcs=list(self.rpcs.values()),
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
use_cache=self.allocation_use_cache,
|
||||
**self.search_kwargs,
|
||||
)
|
||||
return rpc_allocs
|
||||
|
||||
def scheduling_setup(self) -> ExperimentScheduling:
|
||||
"""The resourced occupied by each worker.
|
||||
|
||||
|
@ -221,24 +201,11 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
|
|||
self._check_legal_allocation_options()
|
||||
|
||||
rpcs = self.rpcs
|
||||
if self.allocation_mode == "search":
|
||||
# assert self.mode == "slurm"
|
||||
# assumes gradient checkpointing for all training RPCs if one is enabled
|
||||
# for the simplicity of search configurations
|
||||
rpc_allocs = self._search()
|
||||
for rpc_alloc in rpc_allocs:
|
||||
assert isinstance(rpc_alloc.rpc, str)
|
||||
for rpc in rpcs.values():
|
||||
if rpc.name == rpc_alloc.rpc:
|
||||
rpc_alloc.rpc = rpc
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"RPC {rpc_alloc.rpc} not found in rpcs.")
|
||||
elif self._allocation_mode.is_decoupled():
|
||||
if self._allocation_mode.is_decoupled():
|
||||
paras = self._allocation_mode.parallel_strat
|
||||
|
||||
gdp, gpp, gmp = paras["gen"]["d"], paras["gen"]["p"], paras["gen"]["m"]
|
||||
gen_world_size = gdp * gpp * gmp
|
||||
gdp, gpp, gtp = paras["gen"]["d"], paras["gen"]["p"], paras["gen"]["m"]
|
||||
gen_world_size = gdp * gpp * gtp
|
||||
assert (
|
||||
gen_world_size < self.n_gpus_per_node
|
||||
or gen_world_size % self.n_gpus_per_node == 0
|
||||
|
@ -268,7 +235,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
|
|||
parallel=ParallelismConfig(
|
||||
data_parallel_size=gdp,
|
||||
pipeline_parallel_size=gpp,
|
||||
model_parallel_size=gmp,
|
||||
tensor_parallel_size=gtp,
|
||||
use_sequence_parallel=False,
|
||||
),
|
||||
)
|
||||
|
@ -276,7 +243,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
|
|||
else:
|
||||
rpc_name = rpc.name
|
||||
if rpc_name in paras:
|
||||
dp, pp, mp = (
|
||||
dp, pp, tp = (
|
||||
paras[rpc_name]["d"],
|
||||
paras[rpc_name]["p"],
|
||||
paras[rpc_name]["m"],
|
||||
|
@ -287,9 +254,9 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
|
|||
f"RPC {rpc_name} parallel strategy not given, "
|
||||
"expect a `*` to specify the default parallel strategy."
|
||||
)
|
||||
dp, pp, mp = paras["*"]["d"], paras["*"]["p"], paras["*"]["m"]
|
||||
dp, pp, tp = paras["*"]["d"], paras["*"]["p"], paras["*"]["m"]
|
||||
if (
|
||||
dp * pp * mp + gdp * gpp * gmp
|
||||
dp * pp * tp + gdp * gpp * gtp
|
||||
!= self.n_nodes * self.n_gpus_per_node
|
||||
):
|
||||
raise ValueError(
|
||||
|
@ -297,7 +264,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
|
|||
"does not equal to the number of gpus. "
|
||||
"Note that the device mesh of vLLM/SGLang should be disjoint from the device mesh of other MFCs, "
|
||||
"so their summation should be equal to the total number of gpus. "
|
||||
f"dp={dp}, pp={pp}, mp={mp}, gen.dp={gdp}, gen.pp={gpp}, gen.mp={gmp}, "
|
||||
f"dp={dp}, pp={pp}, mp={tp}, gen.dp={gdp}, gen.pp={gpp}, gen.mp={gtp}, "
|
||||
f"n_nodes={self.n_nodes}, n_gpus_per_node={self.n_gpus_per_node}"
|
||||
)
|
||||
alloc = RPCAllocation(
|
||||
|
@ -306,10 +273,10 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
|
|||
parallel=ParallelismConfig(
|
||||
data_parallel_size=dp,
|
||||
pipeline_parallel_size=pp,
|
||||
model_parallel_size=mp,
|
||||
tensor_parallel_size=tp,
|
||||
use_sequence_parallel=(
|
||||
rpc.interface_type == ModelInterfaceType.TRAIN_STEP
|
||||
and mp > 1
|
||||
and tp > 1
|
||||
),
|
||||
),
|
||||
)
|
||||
|
@ -323,7 +290,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
|
|||
rpc_allocs = []
|
||||
for rpc_name, rpc in self.rpcs.items():
|
||||
if rpc_name in paras:
|
||||
dp, pp, mp = (
|
||||
dp, pp, tp = (
|
||||
paras[rpc_name]["d"],
|
||||
paras[rpc_name]["p"],
|
||||
paras[rpc_name]["m"],
|
||||
|
@ -334,18 +301,18 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
|
|||
f"RPC {rpc_name} parallel strategy not given, "
|
||||
"expect a `*` to specify the default parallel strategy."
|
||||
)
|
||||
dp, pp, mp = paras["*"]["d"], paras["*"]["p"], paras["*"]["m"]
|
||||
assert dp * pp * mp == self.n_nodes * self.n_gpus_per_node
|
||||
dp, pp, tp = paras["*"]["d"], paras["*"]["p"], paras["*"]["m"]
|
||||
assert dp * pp * tp == self.n_nodes * self.n_gpus_per_node
|
||||
alloc = RPCAllocation(
|
||||
rpc=rpc,
|
||||
device_mesh=self.global_device_mesh,
|
||||
parallel=ParallelismConfig(
|
||||
data_parallel_size=dp,
|
||||
pipeline_parallel_size=pp,
|
||||
model_parallel_size=mp,
|
||||
tensor_parallel_size=tp,
|
||||
use_sequence_parallel=(
|
||||
rpc.interface_type == ModelInterfaceType.TRAIN_STEP
|
||||
and mp > 1
|
||||
and tp > 1
|
||||
),
|
||||
),
|
||||
)
|
||||
|
@ -455,7 +422,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
|
|||
topo=topo,
|
||||
dp_rank=topo.get_coord(shard_idx).data,
|
||||
pp_rank=topo.get_coord(shard_idx).pipe,
|
||||
mp_rank=topo.get_coord(shard_idx).model,
|
||||
tp_rank=topo.get_coord(shard_idx).tensor,
|
||||
),
|
||||
model=ModelAbstraction(
|
||||
"tokenizer", args=dict(tokenizer_path=model_cfg.path)
|
||||
|
@ -464,7 +431,6 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
|
|||
gen_backend_name,
|
||||
args=dict(
|
||||
model_path=model_cfg.path,
|
||||
dtype="bfloat16" if model_cfg.bf16 else "float16",
|
||||
**dict_args,
|
||||
),
|
||||
),
|
||||
|
@ -503,16 +469,17 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
|
|||
"config_from_hf_converter"
|
||||
](hf_config)
|
||||
if (
|
||||
model_config.n_kv_heads % rpc_alloc.parallel.model_parallel_size
|
||||
model_config.n_kv_heads % rpc_alloc.parallel.tensor_parallel_size
|
||||
!= 0
|
||||
) or (
|
||||
model_config.n_q_heads % rpc_alloc.parallel.model_parallel_size != 0
|
||||
model_config.n_q_heads % rpc_alloc.parallel.tensor_parallel_size
|
||||
!= 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"The number of KV heads {model_config.n_kv_heads} or "
|
||||
f"Q heads {model_config.n_q_heads} is not"
|
||||
f" divisible by the configured TP size "
|
||||
f"({rpc_alloc.parallel.model_parallel_size}). "
|
||||
f"({rpc_alloc.parallel.tensor_parallel_size}). "
|
||||
f"Please decrease TP size."
|
||||
)
|
||||
mapping = rpc_alloc.device_mesh.mapping
|
||||
|
@ -572,7 +539,7 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
|
|||
topo=topo,
|
||||
dp_rank=topo.get_coord(shard_idx).data,
|
||||
pp_rank=topo.get_coord(shard_idx).pipe,
|
||||
mp_rank=topo.get_coord(shard_idx).model,
|
||||
tp_rank=topo.get_coord(shard_idx).tensor,
|
||||
),
|
||||
model=model,
|
||||
backend=backend,
|
||||
|
@ -612,12 +579,9 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
|
|||
"please setup slurm for distributed runs."
|
||||
)
|
||||
|
||||
if self.n_gpus_per_node != 8 and self.allocation_mode in [
|
||||
"search",
|
||||
"heuristic",
|
||||
]:
|
||||
if self.n_gpus_per_node != 8 and self.allocation_mode == "heuristic":
|
||||
raise ValueError(
|
||||
f"Cannot run search or heuristic allocation with "
|
||||
f"Cannot run heuristic allocation with "
|
||||
f"n_gpus_per_node {self.n_gpus_per_node}, "
|
||||
"please set n_gpus_per_node to 8."
|
||||
)
|
||||
|
@ -627,13 +591,6 @@ class CommonExperimentConfig(BaseExperimentConfig, Experiment):
|
|||
raise KeyError(
|
||||
f"RPC name {rpc_name} does not match the name in the MFCDef object {rpc.name}."
|
||||
)
|
||||
if not check_is_realhf_native_model_interface(
|
||||
rpc.interface_impl.type_
|
||||
) and self.allocation_mode in ["search"]:
|
||||
raise ValueError(
|
||||
f"RPC {rpc.name} interface is not a realhf native implementation. "
|
||||
f"The search allocation mode are not available."
|
||||
)
|
||||
if self.allocation_mode == "manual" and rpc_name not in self.allocations:
|
||||
if rpc_name not in self.allocations:
|
||||
raise ValueError(
|
||||
|
|
|
@ -65,8 +65,6 @@ class MathCodeEvalConfig(MathCodeEvalOptions, CommonExperimentConfig):
|
|||
model_name="actor",
|
||||
mb_spec=self.actor_gen.mb_spec,
|
||||
interface_type=ModelInterfaceType.GENERATE,
|
||||
model_type=self.actor.type,
|
||||
model_path=self.actor.path,
|
||||
interface_impl=actor_interface,
|
||||
input_keys=("packed_prompts", "task_ids"),
|
||||
output_keys=("packed_input_ids",),
|
||||
|
@ -79,8 +77,6 @@ class MathCodeEvalConfig(MathCodeEvalOptions, CommonExperimentConfig):
|
|||
mb_spec=self.rew_inf.mb_spec,
|
||||
interface_type=ModelInterfaceType.INFERENCE,
|
||||
interface_impl=rw_interface,
|
||||
model_type=self.rew.type,
|
||||
model_path=self.rew.path,
|
||||
min_n_seqs_per_pass=1 / self.group_size,
|
||||
input_keys=("packed_input_ids", "packed_prompts", "task_ids"),
|
||||
output_keys=("rewards",),
|
||||
|
|
|
@ -39,8 +39,6 @@ class NullSFTConfig(CommonExperimentConfig, SFTExperimentOptions):
|
|||
model_name="default",
|
||||
input_keys=("packed_input_ids", "prompt_mask"),
|
||||
log_return_value=True,
|
||||
model_type=self.model.type,
|
||||
model_path=self.model.path,
|
||||
)
|
||||
return {"trainDefault": rpc}
|
||||
|
||||
|
@ -88,8 +86,6 @@ class NullPPOConfig(CommonExperimentConfig, NullPPOExperimentOptions):
|
|||
model_name="default",
|
||||
input_keys=("packed_prompts",),
|
||||
output_keys=("rewards",),
|
||||
model_type=self.model.type,
|
||||
model_path=self.model.path,
|
||||
)
|
||||
rpc = MFCDef(
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
|
@ -100,8 +96,6 @@ class NullPPOConfig(CommonExperimentConfig, NullPPOExperimentOptions):
|
|||
model_name="default",
|
||||
input_keys=("packed_prompts", "rewards"),
|
||||
log_return_value=True,
|
||||
model_type=self.model.type,
|
||||
model_path=self.model.path,
|
||||
)
|
||||
return {"trainDefault": rpc, "reward": rw}
|
||||
|
||||
|
|
|
@ -148,32 +148,31 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
|
|||
"packed_logprobs",
|
||||
"prompt_mask",
|
||||
]
|
||||
if self.ppo.recompute_logprob:
|
||||
if self.ppo.recompute_logprob and not self.ppo.use_decoupled_loss:
|
||||
rollout_output_keys.remove("packed_logprobs")
|
||||
rollout = MFCDef(
|
||||
name="actor_gen",
|
||||
model_name="actor",
|
||||
mb_spec=self.actor_gen.mb_spec,
|
||||
interface_type=ModelInterfaceType.GENERATE,
|
||||
model_type=self.actor.type,
|
||||
model_path=self.actor.path,
|
||||
interface_impl=actor_interface,
|
||||
input_keys=("packed_prompts", "task_ids"),
|
||||
output_keys=tuple(rollout_output_keys),
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
||||
actor_inf_outputs = ("packed_logprobs",)
|
||||
if self.ppo.use_decoupled_loss:
|
||||
actor_inf_outputs = ("proximal_logprobs",)
|
||||
actor_inf = MFCDef(
|
||||
name="actor_inf",
|
||||
model_name="actor",
|
||||
mb_spec=self.actor_inf.mb_spec,
|
||||
interface_type=ModelInterfaceType.INFERENCE,
|
||||
model_type=self.actor.type,
|
||||
model_path=self.actor.path,
|
||||
interface_impl=actor_interface,
|
||||
input_keys=("packed_input_ids",),
|
||||
output_keys=("packed_logprobs",),
|
||||
output_key_remap=dict(logprobs="packed_logprobs"),
|
||||
output_keys=actor_inf_outputs,
|
||||
output_key_remap=dict(logprobs=actor_inf_outputs[0]),
|
||||
n_seqs=self.dataset.train_bs_n_seqs,
|
||||
)
|
||||
|
||||
|
@ -200,8 +199,6 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
|
|||
model_name="ref",
|
||||
mb_spec=self.ref_inf.mb_spec,
|
||||
interface_type=ModelInterfaceType.INFERENCE,
|
||||
model_type=self.ref.type,
|
||||
model_path=self.ref.path,
|
||||
interface_impl=ref_interface,
|
||||
min_n_seqs_per_pass=1 / self.group_size,
|
||||
input_keys=tuple(inf_ref_inputs),
|
||||
|
@ -216,8 +213,6 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
|
|||
mb_spec=self.critic_inf.mb_spec,
|
||||
interface_type=ModelInterfaceType.INFERENCE,
|
||||
interface_impl=critic_interface,
|
||||
model_type=self.critic.type,
|
||||
model_path=self.critic.path,
|
||||
min_n_seqs_per_pass=1 / self.group_size,
|
||||
input_keys=("packed_input_ids", "seq_no_eos_mask"),
|
||||
output_keys=("values",),
|
||||
|
@ -238,13 +233,13 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
|
|||
train_actor_inputs.remove("values")
|
||||
if self.ppo.kl_ctl == 0:
|
||||
train_actor_inputs.remove("packed_ref_logprobs")
|
||||
if self.ppo.use_decoupled_loss:
|
||||
train_actor_inputs.append("proximal_logprobs")
|
||||
train_actor = MFCDef(
|
||||
name="actor_train",
|
||||
model_name="actor",
|
||||
mb_spec=self.actor_train.mb_spec,
|
||||
interface_type=ModelInterfaceType.TRAIN_STEP,
|
||||
model_type=self.actor.type,
|
||||
model_path=self.actor.path,
|
||||
interface_impl=actor_interface,
|
||||
input_keys=tuple(train_actor_inputs),
|
||||
log_return_value=True,
|
||||
|
@ -269,8 +264,6 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
|
|||
mb_spec=self.critic_train.mb_spec,
|
||||
interface_type=ModelInterfaceType.TRAIN_STEP,
|
||||
interface_impl=critic_interface,
|
||||
model_type=self.critic.type,
|
||||
model_path=self.critic.path,
|
||||
input_keys=tuple(train_critic_inputs),
|
||||
log_return_value=True,
|
||||
min_n_seqs_per_pass=self.ppo.ppo_n_minibatches / self.group_size,
|
||||
|
@ -289,7 +282,7 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
|
|||
if self.ppo.disable_value:
|
||||
rpcs.pop("critic_inf")
|
||||
rpcs.pop("critic_train")
|
||||
if not self.ppo.recompute_logprob:
|
||||
if not self.ppo.recompute_logprob and not self.ppo.use_decoupled_loss:
|
||||
rpcs.pop("actor_inf")
|
||||
if self.ppo.kl_ctl == 0:
|
||||
rpcs.pop("ref_inf")
|
||||
|
@ -311,7 +304,7 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
|
|||
if self.ppo.disable_value:
|
||||
allocs.pop("critic_inf")
|
||||
allocs.pop("critic_train")
|
||||
if not self.ppo.recompute_logprob:
|
||||
if not self.ppo.recompute_logprob and not self.ppo.use_decoupled_loss:
|
||||
allocs.pop("actor_inf")
|
||||
if self.ppo.kl_ctl == 0:
|
||||
allocs.pop("ref_inf")
|
||||
|
@ -337,14 +330,6 @@ class PPOMATHConfig(CommonExperimentConfig, PPOMATHExperimentOptions):
|
|||
def tokenizer_name_or_path(self) -> str:
|
||||
return self.actor.path
|
||||
|
||||
@property
|
||||
def search_kwargs(self):
|
||||
return {
|
||||
"num_gen_tokens": self.ppo.gen.max_new_tokens,
|
||||
"n_ppo_minibatches": self.ppo.ppo_n_minibatches,
|
||||
"seq_len": self.dataset.max_prompt_len,
|
||||
}
|
||||
|
||||
@property
|
||||
def max_prompt_len(self):
|
||||
return self.dataset.max_prompt_len
|
||||
|
|
|
@ -36,8 +36,6 @@ class SFTConfig(CommonExperimentConfig, SFTExperimentOptions):
|
|||
model_name="default",
|
||||
input_keys=("packed_input_ids", "prompt_mask"),
|
||||
log_return_value=True,
|
||||
model_type=self.model.type,
|
||||
model_path=self.model.path,
|
||||
)
|
||||
return {"trainDefault": rpc}
|
||||
|
||||
|
|
|
@ -34,8 +34,8 @@ from realhf.api.core.dfg import OffloadHook, ParamReallocHook
|
|||
from realhf.api.quickstart.device_mesh import RPCAllocation
|
||||
from realhf.base import logging
|
||||
from realhf.base.topology import (
|
||||
DataPipeModelParallelTopology,
|
||||
PipeDataModelParallelTopology,
|
||||
DataPipeTensorParallelTopology,
|
||||
PipeDataTensorParallelTopology,
|
||||
ProcessTopology,
|
||||
)
|
||||
|
||||
|
@ -73,8 +73,8 @@ def get_topo(
|
|||
max_prompt_len: Optional[int] = None,
|
||||
) -> ProcessTopology:
|
||||
if is_train:
|
||||
return PipeDataModelParallelTopology(
|
||||
num_mp=parallel.model_parallel_size,
|
||||
return PipeDataTensorParallelTopology(
|
||||
num_tp=parallel.tensor_parallel_size,
|
||||
num_pp=parallel.pipeline_parallel_size,
|
||||
num_dp=parallel.data_parallel_size,
|
||||
sequence_parallel=parallel.use_sequence_parallel,
|
||||
|
@ -82,8 +82,8 @@ def get_topo(
|
|||
max_prompt_len=max_prompt_len,
|
||||
gradient_accumulation_fusion=gradient_accumulation_fusion,
|
||||
)
|
||||
return DataPipeModelParallelTopology(
|
||||
num_mp=parallel.model_parallel_size,
|
||||
return DataPipeTensorParallelTopology(
|
||||
num_tp=parallel.tensor_parallel_size,
|
||||
num_pp=parallel.pipeline_parallel_size,
|
||||
num_dp=parallel.data_parallel_size,
|
||||
sequence_parallel=parallel.use_sequence_parallel,
|
||||
|
@ -93,7 +93,7 @@ def get_topo(
|
|||
|
||||
def get_world_size(parallel: ParallelismConfig) -> int:
|
||||
return (
|
||||
parallel.model_parallel_size
|
||||
parallel.tensor_parallel_size
|
||||
* parallel.pipeline_parallel_size
|
||||
* parallel.data_parallel_size
|
||||
)
|
||||
|
@ -247,9 +247,8 @@ class AllocationType(enum.Enum):
|
|||
GLOBAL_HYBRID = 2
|
||||
MANUAL = 3
|
||||
HEURISTIC = 4
|
||||
SEARCH = 5
|
||||
DECOUPLED_SGLANG = 6
|
||||
DECOUPLED_MOCK = 7
|
||||
DECOUPLED_SGLANG = 5
|
||||
DECOUPLED_MOCK = 6
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
@ -293,8 +292,6 @@ class AllocationMode:
|
|||
return cls(AllocationType.MANUAL, None)
|
||||
if allocation_mode == "heuristic":
|
||||
return cls(AllocationType.HEURISTIC, None)
|
||||
if allocation_mode == "search":
|
||||
return cls(AllocationType.SEARCH, None)
|
||||
|
||||
alloc_3d = AllocationMode.extract_3d_alloc(allocation_mode)
|
||||
alloc_hybrid = AllocationMode.extract_key_value_alloc(allocation_mode)
|
||||
|
|
|
@ -1 +1 @@
|
|||
import realhf.impl.environment.math_single_step_env
|
||||
import realhf.impl.environment.math_code_single_step_env
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
|
||||
from functioncall.code.local_verify import code_verify as local_code_verify
|
||||
from functioncall.code.verify import code_verify
|
||||
from functioncall.math.verify import math_verify
|
||||
from realhf.api.core.env_api import EnvironmentService, register_environment
|
||||
from realhf.base import logging
|
||||
from realhf.impl.dataset.math_code_dataset import load_metadata
|
||||
from realhf.impl.dataset.math_parser import parse_lines_in_parallel
|
||||
|
||||
ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False
|
||||
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel
|
||||
code_verify_call = code_verify if ENABLE_FUNCTION_CALL else local_code_verify
|
||||
|
||||
logger = logging.getLogger("Math Single Step Environment")
|
||||
|
||||
|
||||
def extract_code(text, min_length=20):
|
||||
code_pattern = r"(?i)```(?:python|py|cpp|CPP)?\s*\n?(.*?)\n?```"
|
||||
code_blocks = re.findall(code_pattern, text, re.DOTALL)
|
||||
valid_blocks = []
|
||||
for block in code_blocks:
|
||||
clean_block = block.strip()
|
||||
if len(clean_block) < min_length:
|
||||
continue
|
||||
|
||||
valid_blocks.append(clean_block)
|
||||
|
||||
if not valid_blocks:
|
||||
# logger.warning(f"failed to extract python code from {text}")
|
||||
return None
|
||||
# return the last code block
|
||||
return valid_blocks[-1]
|
||||
|
||||
|
||||
class MathCodeSingleStepEnv(EnvironmentService):
|
||||
def __init__(self, dataset_path: str):
|
||||
self.id2info, _ = load_metadata(dataset_path)
|
||||
|
||||
async def reset(self, seed=None, options=None):
|
||||
return None, {}
|
||||
|
||||
async def step(self, action: Tuple[str, List[str]]):
|
||||
qid, answers = action
|
||||
group_size = len(answers)
|
||||
qid = qid.split("@")[0]
|
||||
cur_task = self.id2info[qid]["task"]
|
||||
|
||||
if cur_task == "math":
|
||||
format_rewards = await asyncio.to_thread(
|
||||
math_verify_call,
|
||||
self.id2info,
|
||||
answers,
|
||||
[qid for _ in range(group_size)],
|
||||
)
|
||||
elif cur_task == "code":
|
||||
answers = [extract_code(x) for x in answers]
|
||||
format_rewards = await asyncio.to_thread(
|
||||
code_verify_call,
|
||||
self.id2info,
|
||||
answers,
|
||||
[qid for _ in range(group_size)],
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return None, format_rewards, True, False, {}
|
||||
|
||||
|
||||
register_environment("math-code-single-step", MathCodeSingleStepEnv)
|
|
@ -1,38 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
from functioncall.math.verify import math_verify
|
||||
from realhf.api.core.env_api import EnvironmentService, register_environment
|
||||
from realhf.base import logging
|
||||
from realhf.impl.dataset.math_code_dataset import load_metadata
|
||||
from realhf.impl.dataset.math_parser import parse_lines_in_parallel
|
||||
|
||||
ENABLE_FUNCTION_CALL = True if os.getenv("FUNCTIONCALL_SERVICE_DOMAIN", "") else False
|
||||
math_verify_call = math_verify if ENABLE_FUNCTION_CALL else parse_lines_in_parallel
|
||||
|
||||
logger = logging.getLogger("Math Single Step Environment")
|
||||
|
||||
|
||||
class MathSingleStepEnv(EnvironmentService):
|
||||
def __init__(self, dataset_path: str):
|
||||
self.id2info, _ = load_metadata(dataset_path)
|
||||
|
||||
async def reset(self, seed=None, options=None):
|
||||
return None, {}
|
||||
|
||||
async def step(self, action: Tuple[str, List[str]]):
|
||||
qid, answers = action
|
||||
group_size = len(answers)
|
||||
format_rewards = await asyncio.to_thread(
|
||||
math_verify_call,
|
||||
self.id2info,
|
||||
answers,
|
||||
[qid for _ in range(group_size)],
|
||||
)
|
||||
return None, format_rewards, True, False, {}
|
||||
|
||||
|
||||
register_environment("math-single-step", MathSingleStepEnv)
|
|
@ -13,7 +13,7 @@ import realhf.api.from_hf
|
|||
import realhf.base.logging as logging
|
||||
from realhf.api.core.model_api import HF_MODEL_FAMILY_REGISTRY
|
||||
from realhf.base.importing import import_module
|
||||
from realhf.base.pkg_version import is_version_less
|
||||
from realhf.base.pkg_version import is_available, is_version_less
|
||||
from realhf.impl.model.conversion.hf_registry import HFModelRegistry
|
||||
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
||||
|
||||
|
@ -27,8 +27,9 @@ import_module(os.path.join(_filepath, "nn"), _p)
|
|||
|
||||
# NOTE: skip importing vLLM for now to avoid an
|
||||
# "invalid device context" issue for the 25.01 image
|
||||
if is_version_less("vllm", "0.6.4"):
|
||||
if is_available("vllm") and is_version_less("vllm", "0.6.4"):
|
||||
import realhf.impl.model.backend.vllm
|
||||
|
||||
import realhf.impl.model.backend.inference
|
||||
import realhf.impl.model.backend.megatron
|
||||
import realhf.impl.model.backend.mock_train
|
||||
|
|
|
@ -62,7 +62,7 @@ class PipelinableInferenceEngine(model_api.PipelinableEngine):
|
|||
f"num_layers(this stage)={self.module.num_layers} "
|
||||
f"pp_size={constants.pipe_parallel_world_size()} "
|
||||
f"dp_size={constants.data_parallel_world_size()} "
|
||||
f"mp_size={constants.model_parallel_world_size()} "
|
||||
f"tp_size={constants.tensor_parallel_world_size()} "
|
||||
)
|
||||
if constants.data_parallel_rank() == 0:
|
||||
logger.info(
|
||||
|
|
|
@ -126,7 +126,7 @@ def megatron_ctx():
|
|||
# Build the tensor model-parallel groups.
|
||||
parallel_state._TENSOR_MODEL_PARALLEL_GROUP = g
|
||||
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
|
||||
g = constants.model_parallel_group()
|
||||
g = constants.tensor_parallel_group()
|
||||
parallel_state._TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = (
|
||||
dist.get_process_group_ranks(g)
|
||||
)
|
||||
|
@ -155,7 +155,7 @@ def megatron_ctx():
|
|||
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
|
||||
# Build the tensor + context parallel groups
|
||||
parallel_state._TENSOR_AND_CONTEXT_PARALLEL_GROUP = (
|
||||
constants.model_parallel_group()
|
||||
constants.tensor_parallel_group()
|
||||
)
|
||||
|
||||
# Build expert parallel groups.
|
||||
|
@ -173,7 +173,7 @@ def megatron_ctx():
|
|||
)
|
||||
else:
|
||||
parallel_state._TENSOR_AND_EXPERT_PARALLEL_GROUP = (
|
||||
constants.model_parallel_group()
|
||||
constants.tensor_parallel_group()
|
||||
)
|
||||
parallel_state._DATA_MODULO_EXPERT_PARALLEL_GROUP = (
|
||||
constants.data_parallel_group()
|
||||
|
@ -227,7 +227,7 @@ class MegatronEngine:
|
|||
|
||||
def _all_reduce_layernorm_grads(self):
|
||||
if not (
|
||||
constants.sequence_parallel() and constants.model_parallel_world_size() > 1
|
||||
constants.sequence_parallel() and constants.tensor_parallel_world_size() > 1
|
||||
):
|
||||
return
|
||||
real_model: ReaLModel = self.ddp.module
|
||||
|
@ -255,7 +255,7 @@ class MegatronEngine:
|
|||
|
||||
assert all(x is not None for x in grads)
|
||||
coalesced = _flatten_dense_tensors(grads)
|
||||
dist.all_reduce(coalesced, group=constants.model_parallel_group())
|
||||
dist.all_reduce(coalesced, group=constants.tensor_parallel_group())
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
|
||||
|
@ -362,7 +362,10 @@ class PipeTrainInstrSetForMegatron(PipeTrainInstrSet):
|
|||
)
|
||||
dist.all_reduce(grad_norm, group=constants.tp_and_pp_group())
|
||||
grad_norm /= constants.tp_and_pp_world_size()
|
||||
if constants.data_parallel_rank() == 0 and constants.model_parallel_rank() == 0:
|
||||
if (
|
||||
constants.data_parallel_rank() == 0
|
||||
and constants.tensor_parallel_rank() == 0
|
||||
):
|
||||
logger.info(
|
||||
f"Model name {constants.model_name()}, "
|
||||
f"Pipeline rank {constants.pipe_parallel_rank()}. "
|
||||
|
@ -539,7 +542,10 @@ class ReaLMegatronEngine(model_api.PipelinableEngine):
|
|||
)
|
||||
dist.all_reduce(grad_norm, group=constants.tp_and_pp_group())
|
||||
grad_norm /= constants.tp_and_pp_world_size()
|
||||
if constants.data_parallel_rank() == 0 and constants.model_parallel_rank() == 0:
|
||||
if (
|
||||
constants.data_parallel_rank() == 0
|
||||
and constants.tensor_parallel_rank() == 0
|
||||
):
|
||||
logger.info(
|
||||
f"Megatron backend update success? {update_successful}. "
|
||||
f"Grad Norm: {grad_norm}. "
|
||||
|
@ -700,6 +706,7 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
|
|||
# Deleting models directly will not release the memory.
|
||||
# We must disable hooks at first.
|
||||
if pkg_version.is_version_greater_or_equal("megatron.core", "0.11.0"):
|
||||
if self.ddp.use_distributed_optimizer and self.ddp.overlap_param_gather:
|
||||
model.module.engine.ddp.disable_forward_pre_hook()
|
||||
else:
|
||||
optimizer = model.module.engine.optim
|
||||
|
@ -726,7 +733,7 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
|
|||
sd = optimizer.state_dict()
|
||||
dp = constants.data_parallel_rank()
|
||||
pp = constants.pipe_parallel_rank()
|
||||
tp = constants.model_parallel_rank()
|
||||
tp = constants.tensor_parallel_rank()
|
||||
# HACK: (bowei) I'm not sure whether there's duplicated information.
|
||||
torch.save(
|
||||
sd, pathlib.Path(save_dir) / f"megatron_optim_sd_d{dp}p{pp}t{tp}.mckpt"
|
||||
|
@ -742,7 +749,7 @@ class MegatronTrainBackend(model_api.ModelBackend, MegatronConfig):
|
|||
|
||||
dp = constants.data_parallel_rank()
|
||||
pp = constants.pipe_parallel_rank()
|
||||
tp = constants.model_parallel_rank()
|
||||
tp = constants.tensor_parallel_rank()
|
||||
|
||||
sd = torch.load(
|
||||
pathlib.Path(load_dir) / f"megatron_optim_sd_d{dp}p{pp}t{tp}.mckpt"
|
||||
|
|
|
@ -82,7 +82,7 @@ def _split_and_prefill_pipe_input(
|
|||
raise PipelineError(
|
||||
"Partitioned seqlens are not equal across pipeline parallel ranks. "
|
||||
f"Current rank (dp={constants.data_parallel_rank()},"
|
||||
f"tp={constants.model_parallel_rank()},pp={constants.pipe_parallel_rank()}), "
|
||||
f"tp={constants.tensor_parallel_rank()},pp={constants.pipe_parallel_rank()}), "
|
||||
f"gathered batch seqlens={_batch_seqlen_all_gathered}, "
|
||||
f"Have you ensured that the order of dataset across ranks is the same?",
|
||||
)
|
||||
|
@ -118,7 +118,7 @@ def _split_and_prefill_pipe_input(
|
|||
total_len = (
|
||||
packed_input_ids.shape[0]
|
||||
if not constants.sequence_parallel()
|
||||
else packed_input_ids.shape[0] // constants.model_parallel_world_size()
|
||||
else packed_input_ids.shape[0] // constants.tensor_parallel_world_size()
|
||||
)
|
||||
mb_seq_lens.append(total_len)
|
||||
return (x, ys)
|
||||
|
@ -569,7 +569,7 @@ class PipeGenInstrSet:
|
|||
"batch_lengths", micro_batch_id, remove=False
|
||||
)
|
||||
batch_length = (
|
||||
batch_length // constants.model_parallel_world_size()
|
||||
batch_length // constants.tensor_parallel_world_size()
|
||||
if constants.sequence_parallel()
|
||||
else batch_length
|
||||
)
|
||||
|
|
|
@ -198,8 +198,8 @@ class SGLangGenerationEngine(PipelinableEngine):
|
|||
hybrid_train: bool,
|
||||
request_timeout: int = 1800,
|
||||
):
|
||||
if constants.model_parallel_rank() != 0:
|
||||
dist.barrier(group=constants.model_parallel_cpu_group())
|
||||
if constants.tensor_parallel_rank() != 0:
|
||||
dist.barrier(group=constants.tensor_parallel_cpu_group())
|
||||
return
|
||||
# Start the serving process
|
||||
self.server_proc = mp.Process(
|
||||
|
@ -224,8 +224,8 @@ class SGLangGenerationEngine(PipelinableEngine):
|
|||
if server_args_dict["enable_metrics"]:
|
||||
dp_rank = constants.data_parallel_rank()
|
||||
pp_rank = constants.pipe_parallel_rank()
|
||||
mp_rank = constants.model_parallel_rank()
|
||||
metric_server_name = f"d{dp_rank}p{pp_rank}m{mp_rank}"
|
||||
tp_rank = constants.tensor_parallel_rank()
|
||||
metric_server_name = f"d{dp_rank}p{pp_rank}t{tp_rank}"
|
||||
key = names.metric_server(
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
|
@ -243,7 +243,7 @@ class SGLangGenerationEngine(PipelinableEngine):
|
|||
# offload weights/cache
|
||||
self.hybrid_train = hybrid_train
|
||||
|
||||
dist.barrier(group=constants.model_parallel_cpu_group())
|
||||
dist.barrier(group=constants.tensor_parallel_cpu_group())
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "server_proc"):
|
||||
|
@ -381,8 +381,8 @@ class SGLangGenerationEngine(PipelinableEngine):
|
|||
"NOTE: passing in an arbitrary `min_new_tokens` will lead to a bug for SGLang v0.4.3 "
|
||||
"because we force to skip_tokenizer_init."
|
||||
)
|
||||
if constants.model_parallel_rank() != 0:
|
||||
dist.barrier(group=constants.model_parallel_cpu_group())
|
||||
if constants.tensor_parallel_rank() != 0:
|
||||
dist.barrier(group=constants.tensor_parallel_cpu_group())
|
||||
return None, None, None
|
||||
|
||||
results = asyncio.run(
|
||||
|
@ -393,12 +393,12 @@ class SGLangGenerationEngine(PipelinableEngine):
|
|||
gconfig=gconfig,
|
||||
)
|
||||
)
|
||||
dist.barrier(group=constants.model_parallel_cpu_group())
|
||||
dist.barrier(group=constants.tensor_parallel_cpu_group())
|
||||
return results
|
||||
|
||||
def update_weights_from_disk(self, path):
|
||||
if constants.model_parallel_rank() != 0:
|
||||
dist.barrier(group=constants.model_parallel_cpu_group())
|
||||
if constants.tensor_parallel_rank() != 0:
|
||||
dist.barrier(group=constants.tensor_parallel_cpu_group())
|
||||
return
|
||||
|
||||
async def _fn():
|
||||
|
@ -409,18 +409,17 @@ class SGLangGenerationEngine(PipelinableEngine):
|
|||
await client.async_update_weights_from_disk(path)
|
||||
|
||||
asyncio.run(_fn())
|
||||
dist.barrier(group=constants.model_parallel_cpu_group())
|
||||
dist.barrier(group=constants.tensor_parallel_cpu_group())
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SGLangGenerationBackend(ModelBackend, SGLangConfig):
|
||||
model_path: str = ""
|
||||
dtype: str = "float16"
|
||||
|
||||
def _initialize(self, model: Model, spec: FinetuneSpec) -> Model:
|
||||
if constants.pipe_parallel_world_size() != 1:
|
||||
raise RuntimeError("SGLang does not support pipe parallel size > 1.")
|
||||
if constants.model_parallel_world_size() > cluster.spec.n_gpus_per_node:
|
||||
if constants.tensor_parallel_world_size() > cluster.spec.n_gpus_per_node:
|
||||
raise RuntimeError(
|
||||
"AReaL's SGLang integration does not support model parallel size > n_gpus_per_node."
|
||||
)
|
||||
|
@ -436,7 +435,13 @@ class SGLangGenerationBackend(ModelBackend, SGLangConfig):
|
|||
) != len(datapack.flat2d(ports)):
|
||||
dist.all_gather_object(
|
||||
ports,
|
||||
network.find_multiple_free_ports(2, low=20000, high=40000),
|
||||
network.find_multiple_free_ports(
|
||||
2,
|
||||
low=10000,
|
||||
high=60000,
|
||||
experiment_name=constants.experiment_name(),
|
||||
trial_name=constants.trial_name(),
|
||||
),
|
||||
group=constants.data_parallel_group(),
|
||||
)
|
||||
api_server_port, dist_port = ports[constants.data_parallel_rank()]
|
||||
|
@ -450,13 +455,12 @@ class SGLangGenerationBackend(ModelBackend, SGLangConfig):
|
|||
tokenizer_mode="auto",
|
||||
load_format="auto",
|
||||
trust_remote_code=True,
|
||||
kv_cache_dtype="auto",
|
||||
device="cuda",
|
||||
served_model_name=f"{constants.experiment_name()}/{constants.trial_name()}/{constants.model_name().role}",
|
||||
is_embedding=False,
|
||||
skip_tokenizer_init=True,
|
||||
# Other runtime options
|
||||
tp_size=constants.model_parallel_world_size(),
|
||||
tp_size=constants.tensor_parallel_world_size(),
|
||||
# Because we have set CUDA_VISIBLE_DEVICES to a single GPU in each process
|
||||
base_gpu_id=int(os.environ["CUDA_VISIBLE_DEVICES"]),
|
||||
file_storage_path=os.path.join(
|
||||
|
|
|
@ -36,7 +36,7 @@ def _vllm_group_rank(group_type: _vLLMGroupType):
|
|||
if group_type == _vLLMGroupType.WORLD:
|
||||
return constants.tp_and_pp_rank()
|
||||
elif group_type == _vLLMGroupType.TP:
|
||||
return constants.model_parallel_rank()
|
||||
return constants.tensor_parallel_rank()
|
||||
elif group_type == _vLLMGroupType.PP:
|
||||
return constants.pipe_parallel_rank()
|
||||
|
||||
|
@ -45,7 +45,7 @@ def _vllm_group_size(group_type: _vLLMGroupType):
|
|||
if group_type == _vLLMGroupType.WORLD:
|
||||
return constants.tp_and_pp_world_size()
|
||||
elif group_type == _vLLMGroupType.TP:
|
||||
return constants.model_parallel_world_size()
|
||||
return constants.tensor_parallel_world_size()
|
||||
elif group_type == _vLLMGroupType.PP:
|
||||
return constants.pipe_parallel_world_size()
|
||||
|
||||
|
@ -54,7 +54,7 @@ def _vllm_parallel_group(group_type: _vLLMGroupType):
|
|||
if group_type == _vLLMGroupType.WORLD:
|
||||
return constants.tp_and_pp_group()
|
||||
elif group_type == _vLLMGroupType.TP:
|
||||
return constants.model_parallel_group()
|
||||
return constants.tensor_parallel_group()
|
||||
elif group_type == _vLLMGroupType.PP:
|
||||
return constants.pipe_parallel_group()
|
||||
|
||||
|
|
|
@ -213,7 +213,7 @@ class GPUExecutor_(GPUExecutor):
|
|||
tok = time.perf_counter()
|
||||
after_mem = float(pynvml.nvmlDeviceGetMemoryInfo(handle).used)
|
||||
is_dp_head = (
|
||||
constants.is_last_pipe_stage() and constants.model_parallel_rank() == 0
|
||||
constants.is_last_pipe_stage() and constants.tensor_parallel_rank() == 0
|
||||
)
|
||||
if is_dp_head:
|
||||
logger.info(
|
||||
|
@ -241,7 +241,7 @@ class GPUExecutor_(GPUExecutor):
|
|||
tok = time.perf_counter()
|
||||
after_mem = float(pynvml.nvmlDeviceGetMemoryInfo(handle).used)
|
||||
is_dp_head = (
|
||||
constants.is_last_pipe_stage() and constants.model_parallel_rank() == 0
|
||||
constants.is_last_pipe_stage() and constants.tensor_parallel_rank() == 0
|
||||
)
|
||||
if is_dp_head:
|
||||
logger.info(
|
||||
|
|
|
@ -166,7 +166,6 @@ class vLLMGenerationEngine(model_api.PipelinableEngine, LLM):
|
|||
@dataclasses.dataclass
|
||||
class vLLMGenerationBackend(vLLMConfig, model_api.ModelBackend):
|
||||
model_path: str = ""
|
||||
dtype: str = "bfloat16"
|
||||
|
||||
def _initialize(
|
||||
self, model: model_api.Model, spec: model_api.FinetuneSpec
|
||||
|
@ -192,7 +191,7 @@ class vLLMGenerationBackend(vLLMConfig, model_api.ModelBackend):
|
|||
kv_cache_dtype=self.kv_cache_type,
|
||||
device=constants.current_device(),
|
||||
# Parallelism.
|
||||
tensor_parallel_size=constants.model_parallel_world_size(),
|
||||
tensor_parallel_size=constants.tensor_parallel_world_size(),
|
||||
pipeline_parallel_size=constants.pipe_parallel_world_size(),
|
||||
# KV cahce and scheduling.
|
||||
num_scheduler_steps=self.num_scheduler_steps,
|
||||
|
|
|
@ -100,7 +100,7 @@ def setup_global_comm(
|
|||
|
||||
if worker_index == 0:
|
||||
host_ip = socket.gethostbyname(socket.gethostname())
|
||||
port = network.find_free_port()
|
||||
port = network.find_free_port(experiment_name=expr_name, trial_name=trial_name)
|
||||
pg_init_addr = f"tcp://{host_ip}:{port}"
|
||||
name_resolve.add(pg_master_name, pg_init_addr, keepalive_ttl=300)
|
||||
else:
|
||||
|
|
|
@ -43,10 +43,10 @@ def is_trainable(model_name: ModelName) -> bool:
|
|||
class ParamReallocPair:
|
||||
src: ModelName
|
||||
src_dp_rank: int
|
||||
src_mp_rank: int
|
||||
src_tp_rank: int
|
||||
src_pp_rank: int
|
||||
dst: ModelName
|
||||
dst_mp_rank: int
|
||||
dst_tp_rank: int
|
||||
dst_pp_rank: int
|
||||
|
||||
|
||||
|
@ -171,32 +171,34 @@ def _create_param_realloc_groups(
|
|||
range(from_topo.get_dim("pipe")), range(to_topo.get_dim("pipe"))
|
||||
):
|
||||
# create tensor reshard groups
|
||||
src_mp_size = from_topo.get_dim("model")
|
||||
dst_mp_size = to_topo.get_dim("model")
|
||||
src_tp_size = from_topo.get_dim("tensor")
|
||||
dst_tp_size = to_topo.get_dim("tensor")
|
||||
|
||||
for mp_j in range(dst_mp_size):
|
||||
for tp_j in range(dst_tp_size):
|
||||
_all_dst_ranks = filter_match_mwids(
|
||||
dst, to_topo, msid2mwid, pipe=pp_j, model=mp_j
|
||||
dst, to_topo, msid2mwid, pipe=pp_j, tensor=tp_j
|
||||
)
|
||||
if src_mp_size > dst_mp_size:
|
||||
factor = src_mp_size // dst_mp_size
|
||||
mp_is = list(range(factor * mp_j, factor * (mp_j + 1)))
|
||||
if src_tp_size > dst_tp_size:
|
||||
factor = src_tp_size // dst_tp_size
|
||||
tp_is = list(range(factor * tp_j, factor * (tp_j + 1)))
|
||||
_all_src_ranks = [
|
||||
filter_match_mwids(src, from_topo, msid2mwid, model=mp_i, pipe=pp_i)
|
||||
for mp_i in mp_is
|
||||
filter_match_mwids(
|
||||
src, from_topo, msid2mwid, tensor=tp_i, pipe=pp_i
|
||||
)
|
||||
for tp_i in tp_is
|
||||
]
|
||||
else:
|
||||
factor = dst_mp_size // src_mp_size
|
||||
factor = dst_tp_size // src_tp_size
|
||||
_all_src_ranks = [
|
||||
filter_match_mwids(
|
||||
src,
|
||||
from_topo,
|
||||
msid2mwid,
|
||||
model=mp_j // factor,
|
||||
tensor=tp_j // factor,
|
||||
pipe=pp_i,
|
||||
)
|
||||
]
|
||||
# All GPUs in _src_ranks have the data required by (pp_j, mp_j)
|
||||
# All GPUs in _src_ranks have the data required by (pp_j, tp_j)
|
||||
for _src_ranks in _all_src_ranks:
|
||||
# NOTE: inter-node communication cost is significantly larger than intra-node communication cost.
|
||||
# We only select one sender per host/node to prevent multiple senders occupying the same network bandwidth.
|
||||
|
@ -209,42 +211,42 @@ def _create_param_realloc_groups(
|
|||
)
|
||||
_idle_src_ranks = [r for r in _src_ranks if r not in assignment]
|
||||
for _src_rank in _idle_src_ranks:
|
||||
dp_i, mp_i = (
|
||||
dp_i, tp_i = (
|
||||
from_topo.get_coord(
|
||||
mwid2msid[_src_rank][src].parallelism_rank
|
||||
).data,
|
||||
from_topo.get_coord(
|
||||
mwid2msid[_src_rank][src].parallelism_rank
|
||||
).model,
|
||||
).tensor,
|
||||
)
|
||||
key = ParamReallocPair(
|
||||
src=src,
|
||||
src_dp_rank=dp_i,
|
||||
src_mp_rank=mp_i,
|
||||
src_tp_rank=tp_i,
|
||||
src_pp_rank=pp_i,
|
||||
dst=dst,
|
||||
dst_mp_rank=mp_j,
|
||||
dst_tp_rank=tp_j,
|
||||
dst_pp_rank=pp_j,
|
||||
)
|
||||
param_realloc_dst_ranks[key] = []
|
||||
param_realloc_groups[key] = None
|
||||
param_realloc_src_ranks[key] = _src_rank
|
||||
for _src_rank, _dst_ranks in assignment.items():
|
||||
dp_i, mp_i = (
|
||||
dp_i, tp_i = (
|
||||
from_topo.get_coord(
|
||||
mwid2msid[_src_rank][src].parallelism_rank
|
||||
).data,
|
||||
from_topo.get_coord(
|
||||
mwid2msid[_src_rank][src].parallelism_rank
|
||||
).model,
|
||||
).tensor,
|
||||
)
|
||||
key = ParamReallocPair(
|
||||
src=src,
|
||||
src_dp_rank=dp_i,
|
||||
src_mp_rank=mp_i,
|
||||
src_tp_rank=tp_i,
|
||||
src_pp_rank=pp_i,
|
||||
dst=dst,
|
||||
dst_mp_rank=mp_j,
|
||||
dst_tp_rank=tp_j,
|
||||
dst_pp_rank=pp_j,
|
||||
)
|
||||
param_realloc_dst_ranks[key] = _dst_ranks
|
||||
|
@ -315,8 +317,8 @@ def setup_param_realloc(
|
|||
@dataclasses.dataclass
|
||||
class ReparallelizeSenderStep:
|
||||
rank: int
|
||||
sender_mp_portion_id: int
|
||||
receiver_mp_portion_id: int
|
||||
sender_tp_portion_id: int
|
||||
receiver_tp_portion_id: int
|
||||
param_keys: List[str]
|
||||
param_intervals_cpu: List[Tuple[int, int]]
|
||||
param_intervals_cuda: torch.Tensor
|
||||
|
@ -330,8 +332,8 @@ class ReparallelizeSenderStep:
|
|||
@dataclasses.dataclass
|
||||
class ReparallelizeReceiverStep:
|
||||
rank: int
|
||||
sender_mp_portion_id: int
|
||||
receiver_mp_portion_id: int
|
||||
sender_tp_portion_id: int
|
||||
receiver_tp_portion_id: int
|
||||
sender_param_intervals_cpu: List[Tuple[int, int]]
|
||||
sender_param_intervals_cuda: torch.Tensor
|
||||
sender_max_interval_size: int
|
||||
|
@ -356,9 +358,9 @@ def _derive_reparallelize_comm_plan(
|
|||
pg_info: ParamReallocInfo,
|
||||
dtype: Optional[torch.dtype] = torch.float16,
|
||||
) -> List[ReparallelizeReceiverStep | ReparallelizeSenderStep]:
|
||||
src_mp_size = from_topo.get_dim("model")
|
||||
dst_mp_size = to_topo.get_dim("model")
|
||||
assert src_mp_size % dst_mp_size == 0 or dst_mp_size % src_mp_size == 0
|
||||
src_tp_size = from_topo.get_dim("tensor")
|
||||
dst_tp_size = to_topo.get_dim("tensor")
|
||||
assert src_tp_size % dst_tp_size == 0 or dst_tp_size % src_tp_size == 0
|
||||
for k, v in dataclasses.asdict(to_model_config).items():
|
||||
if k not in ["is_critic"] and v != getattr(from_model_config, k):
|
||||
raise ValueError(
|
||||
|
@ -366,8 +368,8 @@ def _derive_reparallelize_comm_plan(
|
|||
f"value in checkpoint is `{v}`, current value is `{getattr(from_model_config, k)}`)."
|
||||
)
|
||||
if from_model_config.n_kv_heads > 1 and (
|
||||
from_model_config.n_kv_heads % src_mp_size == 0
|
||||
) != (from_model_config.n_kv_heads % dst_mp_size == 0):
|
||||
from_model_config.n_kv_heads % src_tp_size == 0
|
||||
) != (from_model_config.n_kv_heads % dst_tp_size == 0):
|
||||
raise ValueError("Whether to partition kv heads should remain the same.")
|
||||
|
||||
from_layer_mapping = partition_pipeline_layers(
|
||||
|
@ -400,7 +402,7 @@ def _derive_reparallelize_comm_plan(
|
|||
from_model_param_specs, _ = build_param_spec(
|
||||
from_layer_indices,
|
||||
from_model_config,
|
||||
mp_size=from_topo.get_dim("model"),
|
||||
tp_size=from_topo.get_dim("tensor"),
|
||||
dp_size=from_topo.get_dim("data"),
|
||||
pp_size=from_topo.get_dim("pipe"),
|
||||
head_param_point_to_embedding=from_model_head_param_point_to_embedding,
|
||||
|
@ -411,7 +413,7 @@ def _derive_reparallelize_comm_plan(
|
|||
to_model_param_specs, _ = build_param_spec(
|
||||
to_layer_indices,
|
||||
to_model_config,
|
||||
mp_size=to_topo.get_dim("model"),
|
||||
tp_size=to_topo.get_dim("tensor"),
|
||||
pp_size=to_topo.get_dim("pipe"),
|
||||
dp_size=to_topo.get_dim("data"),
|
||||
head_param_point_to_embedding=to_model_head_param_point_to_embedding,
|
||||
|
@ -428,25 +430,25 @@ def _derive_reparallelize_comm_plan(
|
|||
if len(layer_indices) == 0:
|
||||
continue
|
||||
|
||||
for mp_i in range(src_mp_size):
|
||||
if dst_mp_size > src_mp_size:
|
||||
factor = dst_mp_size // src_mp_size
|
||||
mp_js = [i + factor * mp_i for i in range(factor)]
|
||||
receiver_mp_portion_id = 0
|
||||
for tp_i in range(src_tp_size):
|
||||
if dst_tp_size > src_tp_size:
|
||||
factor = dst_tp_size // src_tp_size
|
||||
tp_js = [i + factor * tp_i for i in range(factor)]
|
||||
receiver_tp_portion_id = 0
|
||||
else:
|
||||
factor = src_mp_size // dst_mp_size
|
||||
mp_js = [mp_i // factor]
|
||||
receiver_mp_portion_id = mp_i % factor
|
||||
for sender_mp_portion_id, mp_j in enumerate(mp_js):
|
||||
factor = src_tp_size // dst_tp_size
|
||||
tp_js = [tp_i // factor]
|
||||
receiver_tp_portion_id = tp_i % factor
|
||||
for sender_tp_portion_id, tp_j in enumerate(tp_js):
|
||||
|
||||
for dp_i in range(src_dp_size):
|
||||
key = ParamReallocPair(
|
||||
src=from_model_name,
|
||||
src_dp_rank=dp_i,
|
||||
src_mp_rank=mp_i,
|
||||
src_tp_rank=tp_i,
|
||||
src_pp_rank=pp_i,
|
||||
dst=to_model_name,
|
||||
dst_mp_rank=mp_j,
|
||||
dst_tp_rank=tp_j,
|
||||
dst_pp_rank=pp_j,
|
||||
)
|
||||
src = pg_info.param_realloc_src_ranks[key]
|
||||
|
@ -462,10 +464,10 @@ def _derive_reparallelize_comm_plan(
|
|||
)
|
||||
param_size = param_size_from_keys(
|
||||
config=from_model_config,
|
||||
src_mp_size=src_mp_size,
|
||||
src_tp_size=src_tp_size,
|
||||
sd_keys=param_keys,
|
||||
src2dst_tp_size=max(dst_mp_size // src_mp_size, 1),
|
||||
src2dst_tp_rank=sender_mp_portion_id,
|
||||
src2dst_tp_size=max(dst_tp_size // src_tp_size, 1),
|
||||
src2dst_tp_rank=sender_tp_portion_id,
|
||||
head_param_point_to_embedding=from_model_head_param_point_to_embedding,
|
||||
)
|
||||
if torch.distributed.is_initialized():
|
||||
|
@ -474,11 +476,11 @@ def _derive_reparallelize_comm_plan(
|
|||
param_intervals_cpu = param_intervals_from_keys(
|
||||
model_name=from_model_name,
|
||||
config=from_model_config,
|
||||
mp_size=src_mp_size,
|
||||
tp_size=src_tp_size,
|
||||
param_spec=from_model_param_specs,
|
||||
sd_keys=param_keys,
|
||||
portion_size=max(dst_mp_size // src_mp_size, 1),
|
||||
portion_rank=sender_mp_portion_id,
|
||||
portion_size=max(dst_tp_size // src_tp_size, 1),
|
||||
portion_rank=sender_tp_portion_id,
|
||||
head_param_point_to_embedding=from_model_head_param_point_to_embedding,
|
||||
)
|
||||
param_intervals_cuda = torch.tensor(
|
||||
|
@ -493,11 +495,11 @@ def _derive_reparallelize_comm_plan(
|
|||
receiver_param_intervals_cpu = param_intervals_from_keys(
|
||||
model_name=to_model_name,
|
||||
config=to_model_config,
|
||||
mp_size=dst_mp_size,
|
||||
tp_size=dst_tp_size,
|
||||
param_spec=to_model_param_specs,
|
||||
sd_keys=param_keys,
|
||||
portion_size=max(src_mp_size // dst_mp_size, 1),
|
||||
portion_rank=receiver_mp_portion_id,
|
||||
portion_size=max(src_tp_size // dst_tp_size, 1),
|
||||
portion_rank=receiver_tp_portion_id,
|
||||
head_param_point_to_embedding=to_model_head_param_point_to_embedding,
|
||||
)
|
||||
receiver_param_intervals_cuda = torch.tensor(
|
||||
|
@ -513,8 +515,8 @@ def _derive_reparallelize_comm_plan(
|
|||
comm_plan.append(
|
||||
ReparallelizeReceiverStep(
|
||||
rank=dst_rank,
|
||||
sender_mp_portion_id=sender_mp_portion_id,
|
||||
receiver_mp_portion_id=receiver_mp_portion_id,
|
||||
sender_tp_portion_id=sender_tp_portion_id,
|
||||
receiver_tp_portion_id=receiver_tp_portion_id,
|
||||
param_keys=param_keys,
|
||||
sender_param_intervals_cpu=param_intervals_cpu,
|
||||
sender_param_intervals_cuda=param_intervals_cuda,
|
||||
|
@ -532,8 +534,8 @@ def _derive_reparallelize_comm_plan(
|
|||
comm_plan.append(
|
||||
ReparallelizeSenderStep(
|
||||
rank=src,
|
||||
sender_mp_portion_id=sender_mp_portion_id,
|
||||
receiver_mp_portion_id=receiver_mp_portion_id,
|
||||
sender_tp_portion_id=sender_tp_portion_id,
|
||||
receiver_tp_portion_id=receiver_tp_portion_id,
|
||||
param_keys=param_keys,
|
||||
param_intervals_cpu=param_intervals_cpu,
|
||||
param_intervals_cuda=param_intervals_cuda,
|
||||
|
|
|
@ -22,8 +22,8 @@ from realhf.base.saveload_utils import (
|
|||
)
|
||||
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
||||
from realhf.impl.model.nn.real_llm_parallel import (
|
||||
mp_merge_key,
|
||||
mp_partition_real_model_state_dict,
|
||||
tp_merge_key,
|
||||
tp_partition_real_model_state_dict,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("HF Registry")
|
||||
|
@ -141,11 +141,11 @@ class HFModelRegistry:
|
|||
partition_tik = time.perf_counter()
|
||||
sd = {k: v for k, v in sd.items() if k in required_hf_sd_names}
|
||||
sd = self.sd_from_hf_converter(sd, model.config)
|
||||
psd = mp_partition_real_model_state_dict(
|
||||
psd = tp_partition_real_model_state_dict(
|
||||
sd,
|
||||
model.config,
|
||||
constants.model_parallel_world_size(),
|
||||
constants.model_parallel_rank(),
|
||||
constants.tensor_parallel_world_size(),
|
||||
constants.tensor_parallel_rank(),
|
||||
)
|
||||
return psd, partition_tik - load_tik, time.perf_counter() - partition_tik
|
||||
|
||||
|
@ -222,8 +222,8 @@ class HFModelRegistry:
|
|||
|
||||
dp_rank = constants.data_parallel_rank()
|
||||
pp_rank = constants.pipe_parallel_rank()
|
||||
mp_rank = constants.model_parallel_rank()
|
||||
mp_size = constants.model_parallel_world_size()
|
||||
tp_rank = constants.tensor_parallel_rank()
|
||||
tp_size = constants.tensor_parallel_world_size()
|
||||
pp_size = constants.pipe_parallel_world_size()
|
||||
dp_size = constants.data_parallel_world_size()
|
||||
|
||||
|
@ -234,7 +234,7 @@ class HFModelRegistry:
|
|||
# of each pipeline stage into smaller shards.
|
||||
approx_param_size = (
|
||||
sum(v.numel() * v.element_size() for v in model.state_dict().values())
|
||||
* mp_size
|
||||
* tp_size
|
||||
)
|
||||
|
||||
# By default a shard is at most 1GB. A small size enables parallel saving during training.
|
||||
|
@ -274,9 +274,9 @@ class HFModelRegistry:
|
|||
and k == f"{model.config.n_layers + 1}.weight"
|
||||
):
|
||||
continue
|
||||
gather_list = [torch.zeros_like(v) for _ in range(mp_size)]
|
||||
dist.all_gather(gather_list, v, group=constants.model_parallel_group())
|
||||
gathered = mp_merge_key(k, gather_list, model.config)
|
||||
gather_list = [torch.zeros_like(v) for _ in range(tp_size)]
|
||||
dist.all_gather(gather_list, v, group=constants.tensor_parallel_group())
|
||||
gathered = tp_merge_key(k, gather_list, model.config)
|
||||
cpu_sd[k] = gathered.cpu()
|
||||
|
||||
t2 = time.perf_counter()
|
||||
|
@ -299,7 +299,7 @@ class HFModelRegistry:
|
|||
param_size = param_size.item()
|
||||
|
||||
# Save tokenizer and huggingface model config.
|
||||
if pp_rank == 0 and dp_rank == 0 and mp_rank == 0:
|
||||
if pp_rank == 0 and dp_rank == 0 and tp_rank == 0:
|
||||
hf_config.save_pretrained(save_dir)
|
||||
if tokenizer is not None:
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
|
@ -307,7 +307,7 @@ class HFModelRegistry:
|
|||
# Dump parameters to disk.
|
||||
if len(pp_stage_n_shards) == 1 and pp_stage_n_shards[0] == 1:
|
||||
fn = "pytorch_model.bin"
|
||||
if pp_rank == 0 and dp_rank == 0 and mp_rank == 0:
|
||||
if pp_rank == 0 and dp_rank == 0 and tp_rank == 0:
|
||||
torch.save(hf_sd, os.path.join(save_dir, fn))
|
||||
else:
|
||||
output_fn = (
|
||||
|
@ -326,8 +326,8 @@ class HFModelRegistry:
|
|||
bin_index["weight_map"] = {}
|
||||
weight_map = {}
|
||||
|
||||
mesh_size = dp_size * mp_size
|
||||
mesh_idx = dp_rank * mp_size + mp_rank
|
||||
mesh_size = dp_size * tp_size
|
||||
mesh_idx = dp_rank * tp_size + tp_rank
|
||||
n_shards_per_gpu = (n_shards + mesh_size - 1) // mesh_size
|
||||
if mesh_idx < len(range(0, n_shards, n_shards_per_gpu)):
|
||||
s = list(range(0, n_shards, n_shards_per_gpu))[mesh_idx]
|
||||
|
@ -357,7 +357,7 @@ class HFModelRegistry:
|
|||
for wm in weight_map_list:
|
||||
bin_index["weight_map"].update(wm)
|
||||
|
||||
if pp_rank == 0 and dp_rank == 0 and mp_rank == 0:
|
||||
if pp_rank == 0 and dp_rank == 0 and tp_rank == 0:
|
||||
with open(
|
||||
os.path.join(save_dir, "pytorch_model.bin.index.json"), "w"
|
||||
) as f:
|
||||
|
|
|
@ -240,7 +240,7 @@ class MultiTaskRewardInterface(model_api.ModelInterface):
|
|||
return data
|
||||
local_rank = constants.grid().topo.get_rank(
|
||||
data=constants.data_parallel_rank(),
|
||||
model=0,
|
||||
tensor=0,
|
||||
pipe=constants.pipe_parallel_world_size() - 1,
|
||||
)
|
||||
dst = constants.to_global_pg_rank(local_rank)
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import dataclasses
|
||||
from typing import Dict, Literal, Optional
|
||||
from typing import Dict, List, Literal, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -86,6 +86,7 @@ def _ppo_actor_loss_from_model_outputs(
|
|||
eps_clip=eps_clip,
|
||||
loss_mask=ppo_loss_mask,
|
||||
c_clip=c_clip,
|
||||
proximal_logprobs=input_.data.get("prox_logp", None),
|
||||
)
|
||||
|
||||
# Log training statistics
|
||||
|
@ -106,13 +107,20 @@ def _ppo_actor_loss_from_model_outputs(
|
|||
dual_clip_ratio=ppo_stat["dual_clip_mask"].float(),
|
||||
denominator="n_valid_tokens",
|
||||
)
|
||||
if "behave_imp_weight" in ppo_stat:
|
||||
stats_tracker.denominator(unclipped_behave_tokens=ppo_stat["behave_mask"])
|
||||
stats_tracker.stat(
|
||||
behave_imp_weight=ppo_stat["behave_imp_weight"],
|
||||
behave_approx_kl=ppo_stat["behave_approx_kl"],
|
||||
denominator="unclipped_behave_tokens",
|
||||
)
|
||||
vocab_min_logits = logits.detach().min(-1).values.float()
|
||||
vocab_max_logits = logits.detach().max(-1).values.float()
|
||||
dist.all_reduce(
|
||||
vocab_min_logits, group=constants.model_parallel_group(), op=dist.ReduceOp.MIN
|
||||
vocab_min_logits, group=constants.tensor_parallel_group(), op=dist.ReduceOp.MIN
|
||||
)
|
||||
dist.all_reduce(
|
||||
vocab_max_logits, group=constants.model_parallel_group(), op=dist.ReduceOp.MAX
|
||||
vocab_max_logits, group=constants.tensor_parallel_group(), op=dist.ReduceOp.MAX
|
||||
)
|
||||
stats_tracker.stat(
|
||||
vocab_min_logits=vocab_min_logits,
|
||||
|
@ -505,7 +513,7 @@ class PPOActorInterface(model_api.ModelInterface):
|
|||
model: model_api.Model,
|
||||
input_: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> Dict:
|
||||
) -> Dict | List[Dict]:
|
||||
module = model.module
|
||||
# We call module.eval() because dropout causes the computation of incorrect of log probs.
|
||||
module.eval()
|
||||
|
@ -656,15 +664,20 @@ class PPOActorInterface(model_api.ModelInterface):
|
|||
advantages = torch.cat(adv_list, 0)
|
||||
|
||||
# Prepare data to be splitted into mini-batches.
|
||||
flat_input = SequenceSample.from_default(
|
||||
ids=list(range(input_.bs * self.group_size)),
|
||||
data=dict(
|
||||
flat_data = dict(
|
||||
advantages=advantages,
|
||||
old_logp=old_logp,
|
||||
ppo_loss_mask=loss_mask,
|
||||
packed_input_ids=input_.data["packed_input_ids"],
|
||||
kl_rewards=kl_rewards,
|
||||
),
|
||||
)
|
||||
use_prox_logp = "proximal_logprobs" in input_.data
|
||||
if use_prox_logp:
|
||||
flat_data["prox_logp"] = input_.data["proximal_logprobs"].float()
|
||||
|
||||
flat_input = SequenceSample.from_default(
|
||||
ids=list(range(input_.bs * self.group_size)),
|
||||
data=flat_data,
|
||||
seqlens=[int(x) for x in input_lens.cpu().numpy().tolist()],
|
||||
)
|
||||
|
||||
|
@ -672,6 +685,7 @@ class PPOActorInterface(model_api.ModelInterface):
|
|||
dense_reward_score = dense_reward_score[shift_one_indices]
|
||||
|
||||
### Logging code starts. ###
|
||||
all_stats = []
|
||||
with stats_tracker.scope("ppo_actor"):
|
||||
assert (
|
||||
task_ids.shape == reward_score.shape
|
||||
|
@ -682,12 +696,13 @@ class PPOActorInterface(model_api.ModelInterface):
|
|||
for idx, task in enumerate(RL_TASKS)
|
||||
}
|
||||
|
||||
stats_tracker.denominator(
|
||||
global_denominators = dict(
|
||||
n_seqs=torch.ones_like(reward_score, dtype=torch.bool),
|
||||
n_tokens=torch.ones_like(prompt_mask, dtype=torch.bool),
|
||||
n_valid_tokens=loss_mask.bool(),
|
||||
**task_denominators,
|
||||
)
|
||||
stats_tracker.denominator(**global_denominators)
|
||||
|
||||
for task in RL_TASKS:
|
||||
stats_tracker.stat(
|
||||
|
@ -721,6 +736,22 @@ class PPOActorInterface(model_api.ModelInterface):
|
|||
**seq_stats,
|
||||
denominator="n_seqs",
|
||||
)
|
||||
scalars = dict(
|
||||
disable_value=self.disable_value,
|
||||
mask_no_eos_with_zero=self.mask_no_eos_with_zero,
|
||||
eps_clip=self.eps_clip,
|
||||
use_prox_logp=use_prox_logp,
|
||||
)
|
||||
if self.c_clip is not None:
|
||||
scalars["c_clip"] = self.c_clip
|
||||
scalars["use_dual_clip"] = 1
|
||||
else:
|
||||
scalars["use_dual_clip"] = 0
|
||||
stats_tracker.scalar(**scalars)
|
||||
|
||||
global_stats = stats_tracker.export()
|
||||
for k in global_denominators:
|
||||
global_stats.pop(f"ppo_actor/{k}")
|
||||
|
||||
# Run mini-batched PPO training!
|
||||
def _loss_fn(logits, input_):
|
||||
|
@ -736,7 +767,6 @@ class PPOActorInterface(model_api.ModelInterface):
|
|||
)
|
||||
|
||||
for reuse in range(self.sample_reuse):
|
||||
with stats_tracker.scope(f"reuse{reuse}"):
|
||||
# NOTE: We split PPO minibatches in terms of #seqs instead of #tokens.
|
||||
flat_input = SequenceSample.shuffled(flat_input)
|
||||
bs = flat_input.bs
|
||||
|
@ -751,7 +781,6 @@ class PPOActorInterface(model_api.ModelInterface):
|
|||
f"#tokens: {[sum([sum(lens) for lens in s.seqlens[s._get_split_key()]]) for s in datas]}"
|
||||
)
|
||||
for mb_i, data in enumerate(datas):
|
||||
with stats_tracker.scope(f"mb{mb_i}"):
|
||||
train_stat = module.train_batch(
|
||||
input_=data,
|
||||
mb_spec=mb_spec,
|
||||
|
@ -763,16 +792,12 @@ class PPOActorInterface(model_api.ModelInterface):
|
|||
token_normalize_scope=self.token_normalize_scope,
|
||||
)
|
||||
stats_tracker.scalar(**train_stat)
|
||||
all_stats.append(stats_tracker.export())
|
||||
|
||||
stats_tracker.scalar(
|
||||
disable_value=self.disable_value,
|
||||
mask_no_eos_with_zero=self.mask_no_eos_with_zero,
|
||||
c_clip=self.c_clip if self.c_clip is not None else float("nan"),
|
||||
eps_clip=self.eps_clip,
|
||||
)
|
||||
model.inc_version()
|
||||
all_stats[0].update(global_stats)
|
||||
|
||||
return stats_tracker.export()
|
||||
return all_stats
|
||||
|
||||
# Mock methods for profiling only.
|
||||
def _mock_inference(
|
||||
|
@ -1033,7 +1058,7 @@ class PPOCriticInterface(model_api.ModelInterface):
|
|||
model: model_api.Model,
|
||||
input_: SequenceSample,
|
||||
mb_spec: MicroBatchSpec,
|
||||
) -> Dict:
|
||||
) -> Dict | List[Dict]:
|
||||
assert model.module.module.config.is_critic
|
||||
|
||||
if self.disable_value:
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import dataclasses
|
||||
from typing import Dict, Literal
|
||||
from typing import Dict, List, Literal
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -68,10 +68,10 @@ def compute_packed_sft_loss(
|
|||
vocab_min_logits = logits.detach().min(-1).values.float()
|
||||
vocab_max_logits = logits.detach().max(-1).values.float()
|
||||
dist.all_reduce(
|
||||
vocab_min_logits, group=constants.model_parallel_group(), op=dist.ReduceOp.MIN
|
||||
vocab_min_logits, group=constants.tensor_parallel_group(), op=dist.ReduceOp.MIN
|
||||
)
|
||||
dist.all_reduce(
|
||||
vocab_max_logits, group=constants.model_parallel_group(), op=dist.ReduceOp.MAX
|
||||
vocab_max_logits, group=constants.tensor_parallel_group(), op=dist.ReduceOp.MAX
|
||||
)
|
||||
stats_tracker.stat(
|
||||
vocab_min_logits=vocab_min_logits,
|
||||
|
@ -88,7 +88,7 @@ class SFTInterface(model_api.ModelInterface):
|
|||
|
||||
def train_step(
|
||||
self, model: model_api.Model, data: SequenceSample, mb_spec: MicroBatchSpec
|
||||
) -> Dict:
|
||||
) -> Dict | List[Dict]:
|
||||
module = model.module
|
||||
|
||||
module.train()
|
||||
|
|
|
@ -10,14 +10,14 @@ import torch.utils.checkpoint
|
|||
|
||||
import realhf.base.constants as constants
|
||||
import realhf.base.logging as logging
|
||||
from realhf.impl.model.parallelism.model_parallel.modules import RowParallelLinear
|
||||
from realhf.impl.model.parallelism.tensor_parallel.modules import RowParallelLinear
|
||||
from realhf.impl.model.utils.functional import (
|
||||
apply_rotary_varlen,
|
||||
compute_varlen_position_indices,
|
||||
torch_attn_func,
|
||||
)
|
||||
|
||||
from .mlp import LayerNormQKVLinear
|
||||
from .mlp import GemmaRMSNorm, LayerNormQKVLinear, LlamaRMSNorm
|
||||
from .rotary import RotaryEmbedding
|
||||
|
||||
try:
|
||||
|
@ -53,6 +53,8 @@ class CausalSelfAttentionLayer(nn.Module):
|
|||
layer_norm_type: Optional[str] = None,
|
||||
# opt applies layer norm after attn
|
||||
do_layernorm_before: bool = True,
|
||||
# qk layer norm (Qwen3)
|
||||
qk_layernorm: bool = False,
|
||||
# rotary embedding
|
||||
apply_rotary: bool = False,
|
||||
rotary_base: float = 10000.0,
|
||||
|
@ -67,7 +69,7 @@ class CausalSelfAttentionLayer(nn.Module):
|
|||
super().__init__()
|
||||
if dtype is None:
|
||||
dtype = torch.float16
|
||||
assert hidden_dim % head_dim == 0
|
||||
assert hidden_dim % head_dim == 0, (hidden_dim, head_dim)
|
||||
self.c_attn = LayerNormQKVLinear(
|
||||
input_dim=hidden_dim,
|
||||
head_dim=head_dim,
|
||||
|
@ -82,7 +84,7 @@ class CausalSelfAttentionLayer(nn.Module):
|
|||
layer_index=layer_index,
|
||||
)
|
||||
|
||||
if constants.model_parallel_world_size() > 1:
|
||||
if constants.tensor_parallel_world_size() > 1:
|
||||
self.c_proj = RowParallelLinear(
|
||||
n_q_heads * head_dim,
|
||||
hidden_dim,
|
||||
|
@ -100,6 +102,21 @@ class CausalSelfAttentionLayer(nn.Module):
|
|||
device=device,
|
||||
)
|
||||
|
||||
self.qk_layernorm = qk_layernorm
|
||||
if qk_layernorm:
|
||||
if layer_norm_type is None:
|
||||
layer_norm_fn = nn.LayerNorm
|
||||
elif layer_norm_type == "rms":
|
||||
layer_norm_fn = LlamaRMSNorm
|
||||
elif layer_norm_type == "gemma":
|
||||
layer_norm_fn = GemmaRMSNorm
|
||||
self.q_ln = layer_norm_fn(
|
||||
head_dim, eps=layer_norm_epsilon, dtype=dtype, device=device
|
||||
)
|
||||
self.k_ln = layer_norm_fn(
|
||||
head_dim, eps=layer_norm_epsilon, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
self.resid_dropout = nn.Dropout(resid_pdrop)
|
||||
|
||||
self.attn_pdrop = attn_pdrop
|
||||
|
@ -173,6 +190,10 @@ class CausalSelfAttentionLayer(nn.Module):
|
|||
|
||||
q, k, v = self.c_attn(hidden_states)
|
||||
|
||||
if self.qk_layernorm:
|
||||
q = self.q_ln(q)
|
||||
k = self.k_ln(k)
|
||||
|
||||
if self.apply_rotary and (k_cache is None or str(q.device) == "cpu"):
|
||||
# otherwise, we input rotary cos/sin directly into flash_attn_with_kvcache
|
||||
rotary_cache_len = max_seqlen
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch.nn import init
|
||||
|
||||
from realhf.impl.model.parallelism.model_parallel.modules import ParallelEmbedding
|
||||
from realhf.impl.model.parallelism.tensor_parallel.modules import ParallelEmbedding
|
||||
|
||||
|
||||
class OffsetPositionalEmbedding(nn.Embedding):
|
||||
|
|
|
@ -15,7 +15,7 @@ from transformers.activations import ACT2FN
|
|||
|
||||
import realhf.base.constants as constants
|
||||
import realhf.base.logging as logging
|
||||
from realhf.impl.model.parallelism.model_parallel.modules import (
|
||||
from realhf.impl.model.parallelism.tensor_parallel.modules import (
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
merged_linear_with_grad_accumulation_and_async_allreduce,
|
||||
|
@ -49,10 +49,10 @@ class LayerNormQKVLinear(nn.Module):
|
|||
layer_index=None,
|
||||
):
|
||||
super().__init__()
|
||||
model_parallel = constants.model_parallel_world_size() > 1
|
||||
tensor_parallel = constants.tensor_parallel_world_size() > 1
|
||||
sequence_parallel = constants.sequence_parallel()
|
||||
gradient_accumulation_fusion = constants.gradient_accumulation_fusion()
|
||||
if not model_parallel and (sequence_parallel or gradient_accumulation_fusion):
|
||||
if not tensor_parallel and (sequence_parallel or gradient_accumulation_fusion):
|
||||
global SEQUENCE_PARALLEL_WARNED
|
||||
if not SEQUENCE_PARALLEL_WARNED:
|
||||
logger.warning(
|
||||
|
@ -73,16 +73,16 @@ class LayerNormQKVLinear(nn.Module):
|
|||
input_dim, eps=layer_norm_epsilon, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
self.model_parallel = model_parallel
|
||||
self.tensor_parallel = tensor_parallel
|
||||
self.layer_index = layer_index
|
||||
self.mp_worldsize = constants.model_parallel_world_size()
|
||||
assert n_q_heads % self.mp_worldsize == 0, (
|
||||
self.tp_worldsize = constants.tensor_parallel_world_size()
|
||||
assert n_q_heads % self.tp_worldsize == 0, (
|
||||
f"n_q_heads {n_q_heads} must be divisible by "
|
||||
f"mp_worldsize {self.mp_worldsize}"
|
||||
f"tp_worldsize {self.tp_worldsize}"
|
||||
)
|
||||
assert n_kv_heads % self.mp_worldsize == 0, (
|
||||
assert n_kv_heads % self.tp_worldsize == 0, (
|
||||
f"n_kv_heads {n_kv_heads} must be divisible by "
|
||||
f"mp_worldsize {self.mp_worldsize}"
|
||||
f"tp_worldsize {self.tp_worldsize}"
|
||||
)
|
||||
hidden_dim = input_dim
|
||||
# TODO: we can fuse the forward of qkv attention
|
||||
|
@ -141,9 +141,9 @@ class LayerNormQKVLinear(nn.Module):
|
|||
self.v_attn.weight,
|
||||
self.v_attn.bias,
|
||||
)
|
||||
q = q.view(*q.shape[:-1], self.nq // self.mp_worldsize, self.d)
|
||||
k = k.view(*k.shape[:-1], self.nkv // self.mp_worldsize, self.d)
|
||||
v = v.view(*v.shape[:-1], self.nkv // self.mp_worldsize, self.d)
|
||||
q = q.view(*q.shape[:-1], self.nq // self.tp_worldsize, self.d)
|
||||
k = k.view(*k.shape[:-1], self.nkv // self.tp_worldsize, self.d)
|
||||
v = v.view(*v.shape[:-1], self.nkv // self.tp_worldsize, self.d)
|
||||
return q, k, v
|
||||
|
||||
|
||||
|
@ -163,10 +163,10 @@ class LayerNormMLP(nn.Module):
|
|||
device: Optional[Union[str, torch.device]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
model_parallel = constants.model_parallel_world_size() > 1
|
||||
tensor_parallel = constants.tensor_parallel_world_size() > 1
|
||||
sequence_parallel = constants.sequence_parallel()
|
||||
gradient_accumulation_fusion = constants.gradient_accumulation_fusion()
|
||||
if not model_parallel and (sequence_parallel or gradient_accumulation_fusion):
|
||||
if not tensor_parallel and (sequence_parallel or gradient_accumulation_fusion):
|
||||
global SEQUENCE_PARALLEL_WARNED
|
||||
if not SEQUENCE_PARALLEL_WARNED:
|
||||
logger.warning(
|
||||
|
@ -228,12 +228,12 @@ class LlamaLayerNormMLP(nn.Module):
|
|||
device: Optional[Union[str, torch.device]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_parallel = constants.model_parallel_world_size() > 1
|
||||
self.tensor_parallel = constants.tensor_parallel_world_size() > 1
|
||||
gradient_accumulation_fusion = constants.gradient_accumulation_fusion()
|
||||
self.is_expert = is_expert
|
||||
# when used as experts the MLP always compute without sequence parallel
|
||||
sequence_parallel = constants.sequence_parallel() and not is_expert
|
||||
if not self.model_parallel and (
|
||||
if not self.tensor_parallel and (
|
||||
sequence_parallel or gradient_accumulation_fusion
|
||||
):
|
||||
global SEQUENCE_PARALLEL_WARNED
|
||||
|
@ -418,13 +418,13 @@ if constants.use_te_impl():
|
|||
eps=layer_norm_epsilon,
|
||||
sequence_parallel=constants.sequence_parallel(),
|
||||
return_bias=False,
|
||||
tp_group=constants.model_parallel_group(),
|
||||
tp_size=constants.model_parallel_world_size(),
|
||||
tp_group=constants.tensor_parallel_group(),
|
||||
tp_size=constants.tensor_parallel_world_size(),
|
||||
bias=False,
|
||||
normalization="RMSNorm",
|
||||
activation="swiglu",
|
||||
fuse_wgrad_accumulation=constants.gradient_accumulation_fusion(),
|
||||
params_dtype=dtype,
|
||||
set_parallel_mode=constants.model_parallel_world_size() > 1,
|
||||
set_parallel_mode=constants.tensor_parallel_world_size() > 1,
|
||||
device=device,
|
||||
)
|
||||
|
|
|
@ -10,11 +10,11 @@ from torch.nn.parameter import Parameter
|
|||
import realhf.base.constants as constants
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.impl.model.modules.mlp import LlamaLayerNormMLP, get_activation_fn
|
||||
from realhf.impl.model.parallelism.model_parallel.mappings import (
|
||||
from realhf.impl.model.parallelism.tensor_parallel.mappings import (
|
||||
copy_to_tensor_model_parallel_region,
|
||||
reduce_from_tensor_model_parallel_region,
|
||||
)
|
||||
from realhf.impl.model.parallelism.model_parallel.utils import divide
|
||||
from realhf.impl.model.parallelism.tensor_parallel.utils import divide
|
||||
from realhf.impl.model.utils.random import _initialize_affine_weight_gpu
|
||||
|
||||
try:
|
||||
|
@ -125,7 +125,7 @@ class GroupedMLP(torch.nn.Module):
|
|||
self.activation_func = get_activation_fn(self.config.activation_function)
|
||||
|
||||
# How many feature each rank holds for fc1 and fc2, respectively.
|
||||
tp_size = constants.model_parallel_world_size()
|
||||
tp_size = constants.tensor_parallel_world_size()
|
||||
intermediate_dim_per_partition = divide(self.config.intermediate_dim, tp_size)
|
||||
|
||||
# Note: The current kernel implementations of grouped_gemm
|
||||
|
@ -186,7 +186,7 @@ class GroupedMLP(torch.nn.Module):
|
|||
):
|
||||
tokens_per_expert = tokens_per_expert.cpu()
|
||||
if permuted_local_hidden_states.nelement() != 0:
|
||||
if constants.model_parallel_world_size() > 1:
|
||||
if constants.tensor_parallel_world_size() > 1:
|
||||
permuted_local_hidden_states = copy_to_tensor_model_parallel_region(
|
||||
permuted_local_hidden_states
|
||||
)
|
||||
|
@ -208,7 +208,7 @@ class GroupedMLP(torch.nn.Module):
|
|||
output = grouped_gemm.ops.gmm(
|
||||
inter, self.grouped_down_proj, tokens_per_expert, trans_b=False
|
||||
)
|
||||
if constants.model_parallel_world_size() > 1:
|
||||
if constants.tensor_parallel_world_size() > 1:
|
||||
output = reduce_from_tensor_model_parallel_region(output)
|
||||
else:
|
||||
# No token is allocated for local experts.
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch.nn.init as init
|
|||
|
||||
import realhf.base.constants as constants
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.impl.model.parallelism.model_parallel.mappings import (
|
||||
from realhf.impl.model.parallelism.tensor_parallel.mappings import (
|
||||
gather_from_sequence_parallel_region,
|
||||
)
|
||||
from realhf.impl.model.utils.moe import (
|
||||
|
@ -117,10 +117,10 @@ class TopKRouter(torch.nn.Module):
|
|||
torch.Tensor: The activation tensor with the attached gradient function.
|
||||
"""
|
||||
moe_aux_loss_coeff = self.config.moe.aux_loss_coeff
|
||||
moe_aux_loss_coeff /= constants.model_parallel_world_size()
|
||||
moe_aux_loss_coeff /= constants.tensor_parallel_world_size()
|
||||
scale_for_logging = 1.0
|
||||
if constants.sequence_parallel():
|
||||
scale_for_logging *= constants.model_parallel_world_size()
|
||||
scale_for_logging *= constants.tensor_parallel_world_size()
|
||||
|
||||
aux_loss = switch_load_balancing_loss_func(
|
||||
probs,
|
||||
|
@ -128,7 +128,7 @@ class TopKRouter(torch.nn.Module):
|
|||
self.config.moe.top_k,
|
||||
moe_aux_loss_coeff,
|
||||
sequence_partition_group=(
|
||||
constants.model_parallel_group()
|
||||
constants.tensor_parallel_group()
|
||||
if constants.sequence_parallel()
|
||||
else None
|
||||
),
|
||||
|
@ -155,7 +155,7 @@ class TopKRouter(torch.nn.Module):
|
|||
"""
|
||||
if self.config.moe.z_loss_coeff > 0:
|
||||
moe_z_loss_coeff = (
|
||||
self.config.moe.z_loss_coeff / constants.model_parallel_world_size()
|
||||
self.config.moe.z_loss_coeff / constants.tensor_parallel_world_size()
|
||||
)
|
||||
z_loss = z_loss_func(logits, moe_z_loss_coeff)
|
||||
logits = MoEAuxLossAutoScaler.apply(logits, z_loss)
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
|||
|
||||
import realhf.base.constants as constants
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.impl.model.parallelism.model_parallel.mappings import (
|
||||
from realhf.impl.model.parallelism.tensor_parallel.mappings import (
|
||||
gather_from_sequence_parallel_region,
|
||||
scatter_to_sequence_parallel_region,
|
||||
)
|
||||
|
|
|
@ -23,8 +23,8 @@ from .real_llm_base import ReaLModelParamKeys
|
|||
from .real_llm_parallel import (
|
||||
get_real_model_param_shape,
|
||||
intervals_partition_fn,
|
||||
mp_partition_key,
|
||||
shape_partition_fn,
|
||||
tp_partition_key,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -188,7 +188,7 @@ def set_intervals(
|
|||
|
||||
def param_size_from_keys(
|
||||
config: model_api.ReaLModelConfig,
|
||||
src_mp_size: int,
|
||||
src_tp_size: int,
|
||||
sd_keys: List[str],
|
||||
src2dst_tp_size: int,
|
||||
src2dst_tp_rank: int,
|
||||
|
@ -202,9 +202,9 @@ def param_size_from_keys(
|
|||
and "0.wte.weight" in sd_keys
|
||||
):
|
||||
continue
|
||||
new_shape = mp_partition_key(
|
||||
new_shape = tp_partition_key(
|
||||
k,
|
||||
get_real_model_param_shape(k, config, src_mp_size),
|
||||
get_real_model_param_shape(k, config, src_tp_size),
|
||||
src2dst_tp_rank,
|
||||
src2dst_tp_size,
|
||||
config,
|
||||
|
@ -218,7 +218,7 @@ def build_param_spec(
|
|||
layer_indices: List[int],
|
||||
config: model_api.ReaLModelConfig,
|
||||
dp_size: int,
|
||||
mp_size: int,
|
||||
tp_size: int,
|
||||
pp_size: int,
|
||||
head_param_point_to_embedding: bool,
|
||||
bucket_size: int = 40000000,
|
||||
|
@ -273,7 +273,7 @@ def build_param_spec(
|
|||
if head_param_point_to_embedding and k == f"{config.n_layers + 1}.weight":
|
||||
continue
|
||||
|
||||
shape = get_real_model_param_shape(k, config, mp_size)
|
||||
shape = get_real_model_param_shape(k, config, tp_size)
|
||||
numel = int(np.prod(shape))
|
||||
data_end_index = data_start_index + numel
|
||||
|
||||
|
@ -307,14 +307,14 @@ def param_intervals_from_keys(
|
|||
config: model_api.ReaLModelConfig,
|
||||
head_param_point_to_embedding: bool,
|
||||
param_spec: Dict[str, ContiguousParamSpec],
|
||||
mp_size: int,
|
||||
tp_size: int,
|
||||
sd_keys: List[str],
|
||||
portion_size: int,
|
||||
portion_rank: int,
|
||||
) -> List[int]:
|
||||
param_size = param_size_from_keys(
|
||||
config=config,
|
||||
src_mp_size=mp_size,
|
||||
src_tp_size=tp_size,
|
||||
sd_keys=sd_keys,
|
||||
src2dst_tp_size=portion_size,
|
||||
src2dst_tp_rank=portion_rank,
|
||||
|
@ -333,13 +333,13 @@ def param_intervals_from_keys(
|
|||
if (
|
||||
model_name,
|
||||
k.split(".", 1)[1],
|
||||
mp_size,
|
||||
tp_size,
|
||||
portion_rank,
|
||||
portion_size,
|
||||
) not in _FLAT_PARAM_INDICES_CACHE:
|
||||
zero_start_intervals = mp_partition_key(
|
||||
zero_start_intervals = tp_partition_key(
|
||||
k,
|
||||
get_real_model_param_shape(k, config, mp_size),
|
||||
get_real_model_param_shape(k, config, tp_size),
|
||||
portion_rank,
|
||||
portion_size,
|
||||
config,
|
||||
|
@ -349,7 +349,7 @@ def param_intervals_from_keys(
|
|||
(
|
||||
model_name,
|
||||
k.split(".", 1)[1],
|
||||
mp_size,
|
||||
tp_size,
|
||||
portion_rank,
|
||||
portion_size,
|
||||
)
|
||||
|
@ -359,7 +359,7 @@ def param_intervals_from_keys(
|
|||
(
|
||||
model_name,
|
||||
k.split(".", 1)[1],
|
||||
mp_size,
|
||||
tp_size,
|
||||
portion_rank,
|
||||
portion_size,
|
||||
)
|
||||
|
|
|
@ -167,7 +167,7 @@ class ReaLModel(nn.Module):
|
|||
self._param_spec, self._param_size = build_param_spec(
|
||||
list(range(self.layer_idx_start, self.layer_idx_end)),
|
||||
self.config,
|
||||
mp_size=constants.model_parallel_world_size(),
|
||||
tp_size=constants.tensor_parallel_world_size(),
|
||||
pp_size=constants.pipe_parallel_world_size(),
|
||||
dp_size=constants.data_parallel_world_size(),
|
||||
head_param_point_to_embedding=self.head_param_point_to_embedding,
|
||||
|
@ -282,7 +282,7 @@ class ReaLModel(nn.Module):
|
|||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
elif not config.is_critic and constants.model_parallel_world_size() > 1:
|
||||
elif not config.is_critic and constants.tensor_parallel_world_size() > 1:
|
||||
l = ParallelActorHead(
|
||||
config.hidden_dim,
|
||||
config.vocab_size,
|
||||
|
@ -428,14 +428,14 @@ class ReaLModel(nn.Module):
|
|||
x.cu_seqlens = x.cu_seqlens.int()
|
||||
|
||||
# Copy input tensor to a pinned buffer.
|
||||
mp_size = constants.model_parallel_world_size()
|
||||
tp_size = constants.tensor_parallel_world_size()
|
||||
batch_length = None
|
||||
if ys[0].packed_input_ids is not None:
|
||||
batch_length = ys[0].packed_input_ids.shape[0]
|
||||
if x.pp_input is not None:
|
||||
batch_length = x.pp_input.shape[0]
|
||||
assert batch_length is not None
|
||||
padded_batch_length = (batch_length + mp_size - 1) // mp_size * mp_size
|
||||
padded_batch_length = (batch_length + tp_size - 1) // tp_size * tp_size
|
||||
pad_size = padded_batch_length - batch_length
|
||||
|
||||
if (
|
||||
|
@ -609,7 +609,7 @@ class ReaLModel(nn.Module):
|
|||
to_param_spec, to_param_size = build_param_spec(
|
||||
to_layer_indices,
|
||||
to_model_config,
|
||||
mp_size=to_topo.get_dim("model"),
|
||||
tp_size=to_topo.get_dim("tensor"),
|
||||
dp_size=to_topo.get_dim("data"),
|
||||
pp_size=to_topo.get_dim("pipe"),
|
||||
head_param_point_to_embedding=to_model_head_param_point_to_embedding,
|
||||
|
|
|
@ -17,7 +17,7 @@ import transformers
|
|||
|
||||
import realhf.base.constants as constants
|
||||
import realhf.base.logging as logging
|
||||
import realhf.impl.model.parallelism.model_parallel.mappings as tensor_parallel
|
||||
import realhf.impl.model.parallelism.tensor_parallel.mappings as tensor_parallel
|
||||
from realhf.api.core import model_api
|
||||
from realhf.impl.model.modules import (
|
||||
CausalSelfAttentionLayer,
|
||||
|
@ -28,9 +28,8 @@ from realhf.impl.model.modules import (
|
|||
LlamaRMSNorm,
|
||||
OffsetParallelPositionalEmbedding,
|
||||
OffsetPositionalEmbedding,
|
||||
scatter_to_sequence_parallel_region,
|
||||
)
|
||||
from realhf.impl.model.parallelism.model_parallel.modules import (
|
||||
from realhf.impl.model.parallelism.tensor_parallel.modules import (
|
||||
ColumnParallelLinear,
|
||||
ParallelEmbedding,
|
||||
gather_from_sequence_parallel_region,
|
||||
|
@ -139,6 +138,7 @@ class ReaLModelBlock(nn.Module):
|
|||
use_attention_bias=config.use_attention_bias,
|
||||
use_attn_proj_bias=config.use_attn_proj_bias,
|
||||
do_layernorm_before=config.do_layernorm_before,
|
||||
qk_layernorm=config.qk_layernorm,
|
||||
apply_rotary=config.apply_rotary,
|
||||
rotary_base=config.rotary_base,
|
||||
rotary_interleaved=config.rotary_interleaved,
|
||||
|
@ -281,8 +281,8 @@ class VocabPositionEmbedding(nn.Module):
|
|||
self.n_positions = config.n_positions
|
||||
self.hidden_dim = config.hidden_dim
|
||||
|
||||
model_parallel = constants.model_parallel_world_size() > 1
|
||||
if model_parallel:
|
||||
tensor_parallel = constants.tensor_parallel_world_size() > 1
|
||||
if tensor_parallel:
|
||||
embed_cls = ParallelEmbedding
|
||||
else:
|
||||
embed_cls = nn.Embedding
|
||||
|
@ -295,7 +295,7 @@ class VocabPositionEmbedding(nn.Module):
|
|||
if self.apply_abs_pos_embed:
|
||||
p_embed_cls = (
|
||||
OffsetParallelPositionalEmbedding
|
||||
if model_parallel
|
||||
if tensor_parallel
|
||||
else OffsetPositionalEmbedding
|
||||
)
|
||||
self.wpe = p_embed_cls(
|
||||
|
@ -416,7 +416,7 @@ class ParallelActorHead(ColumnParallelLinear):
|
|||
def _forward(self, x: torch.Tensor):
|
||||
weight = self.weight
|
||||
if self._norm_head:
|
||||
from realhf.impl.model.parallelism.model_parallel.mappings import (
|
||||
from realhf.impl.model.parallelism.tensor_parallel.mappings import (
|
||||
gather_from_sequence_parallel_region,
|
||||
)
|
||||
|
||||
|
@ -431,7 +431,7 @@ class ParallelActorHead(ColumnParallelLinear):
|
|||
).transpose(1, 0)
|
||||
head_norm = unnormed_head.norm(dim=0, keepdim=True, p=2)
|
||||
normed_head = unnormed_head / (head_norm + 1e-7)
|
||||
weight = scatter_to_sequence_parallel_region(normed_head)
|
||||
weight = tensor_parallel.scatter_to_sequence_parallel_region(normed_head)
|
||||
|
||||
output = parallel_lm_logits(
|
||||
x,
|
||||
|
@ -486,6 +486,12 @@ class ReaLModelParamKeys:
|
|||
keys += [f"{idx + 1}.attn.c_proj.weight"]
|
||||
if config.use_attn_proj_bias:
|
||||
keys += [f"{idx + 1}.attn.c_proj.bias"]
|
||||
if config.qk_layernorm:
|
||||
keys += [f"{idx + 1}.attn.q_ln.weight"]
|
||||
keys += [f"{idx + 1}.attn.k_ln.weight"]
|
||||
if config.layer_norm_type is None:
|
||||
keys += [f"{idx + 1}.attn.q_ln.bias"]
|
||||
keys += [f"{idx + 1}.attn.k_ln.bias"]
|
||||
keys += [f"{idx + 1}.mlp.ln.weight"]
|
||||
if config.layer_norm_type is None:
|
||||
keys += [f"{idx + 1}.mlp.ln.bias"]
|
||||
|
|
|
@ -55,8 +55,8 @@ def genstep(
|
|||
unfinished_sequences: Bool tensor indicator of whether a sequence is finished.
|
||||
Shape [bs].
|
||||
"""
|
||||
if constants.model_parallel_world_size() > 1:
|
||||
from realhf.impl.model.parallelism.model_parallel.mappings import (
|
||||
if constants.tensor_parallel_world_size() > 1:
|
||||
from realhf.impl.model.parallelism.tensor_parallel.mappings import (
|
||||
gather_from_tensor_model_parallel_region,
|
||||
)
|
||||
|
||||
|
@ -95,20 +95,20 @@ def genstep(
|
|||
next_tokens = distrb.mode if gconfig.greedy else distrb.sample()
|
||||
logprob = distrb.log_prob(next_tokens)
|
||||
|
||||
if constants.model_parallel_world_size() > 1:
|
||||
if constants.model_parallel_rank() > 0:
|
||||
if constants.tensor_parallel_world_size() > 1:
|
||||
if constants.tensor_parallel_rank() > 0:
|
||||
logprob[:] = 0
|
||||
next_tokens[:] = 0
|
||||
handle = torch.distributed.all_reduce(
|
||||
logprob,
|
||||
torch.distributed.ReduceOp.SUM,
|
||||
async_op=True,
|
||||
group=constants.model_parallel_group(),
|
||||
group=constants.tensor_parallel_group(),
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
next_tokens,
|
||||
torch.distributed.ReduceOp.SUM,
|
||||
group=constants.model_parallel_group(),
|
||||
group=constants.tensor_parallel_group(),
|
||||
)
|
||||
|
||||
if tokenizer.eos_token_id is not None:
|
||||
|
@ -139,7 +139,7 @@ def genstep(
|
|||
if not logits_mask.any():
|
||||
logits_mask = None
|
||||
|
||||
if constants.model_parallel_world_size() > 1:
|
||||
if constants.tensor_parallel_world_size() > 1:
|
||||
handle.wait()
|
||||
|
||||
return next_tokens, logprob, logits_mask, terminate, unfinished_sequences
|
||||
|
|
|
@ -42,27 +42,27 @@ if constants.use_te_impl():
|
|||
|
||||
def tensor_slice_partition_fn(
|
||||
tensor: torch.Tensor,
|
||||
mp_rank: Optional[int],
|
||||
mp_world_size: int,
|
||||
tp_rank: Optional[int],
|
||||
tp_world_size: int,
|
||||
dim: Optional[int],
|
||||
) -> Union[List[torch.Tensor], torch.Tensor]:
|
||||
"""Partition a tensor by slicing along a dimension for tensor-model
|
||||
parallelism."""
|
||||
if dim is None:
|
||||
splits = [tensor for _ in range(mp_world_size)]
|
||||
splits = [tensor for _ in range(tp_world_size)]
|
||||
else:
|
||||
assert tensor.shape[dim] % mp_world_size == 0
|
||||
splits = torch.split(tensor, tensor.shape[dim] // mp_world_size, dim=dim)
|
||||
if mp_rank is None:
|
||||
assert tensor.shape[dim] % tp_world_size == 0
|
||||
splits = torch.split(tensor, tensor.shape[dim] // tp_world_size, dim=dim)
|
||||
if tp_rank is None:
|
||||
return [s.contiguous() for s in splits]
|
||||
else:
|
||||
return splits[mp_rank].contiguous()
|
||||
return splits[tp_rank].contiguous()
|
||||
|
||||
|
||||
def intervals_partition_fn(
|
||||
shape: torch.Size,
|
||||
mp_rank: Optional[int],
|
||||
mp_world_size: int,
|
||||
tp_rank: Optional[int],
|
||||
tp_world_size: int,
|
||||
dim: Optional[int],
|
||||
) -> Union[List[torch.Tensor], torch.Tensor]:
|
||||
"""Get the intervals of a MP-partitioned tensor in the flatten view.
|
||||
|
@ -74,34 +74,34 @@ def intervals_partition_fn(
|
|||
Used by parameter reallocation. Return a numpy array of shape [N,
|
||||
2], where N is the number of intervals.
|
||||
"""
|
||||
assert mp_rank is not None
|
||||
assert tp_rank is not None
|
||||
param_size = int(np.prod(shape))
|
||||
if dim is None:
|
||||
return np.array([(0, param_size)], dtype=np.int64)
|
||||
|
||||
if dim < 0:
|
||||
dim = len(shape) + dim
|
||||
assert shape[dim] % mp_world_size == 0
|
||||
assert shape[dim] % tp_world_size == 0
|
||||
|
||||
if len(shape) == 1:
|
||||
assert dim == 0
|
||||
partition_size = shape[0] // mp_world_size
|
||||
partition_size = shape[0] // tp_world_size
|
||||
return np.array(
|
||||
[(partition_size * mp_rank, partition_size * (mp_rank + 1))],
|
||||
[(partition_size * tp_rank, partition_size * (tp_rank + 1))],
|
||||
dtype=np.int64,
|
||||
)
|
||||
else:
|
||||
assert len(shape) == 2, shape
|
||||
if dim == 0:
|
||||
row_start = mp_rank * shape[0] // mp_world_size
|
||||
row_end = (mp_rank + 1) * shape[0] // mp_world_size
|
||||
row_start = tp_rank * shape[0] // tp_world_size
|
||||
row_end = (tp_rank + 1) * shape[0] // tp_world_size
|
||||
return np.array(
|
||||
[(row_start * shape[1], row_end * shape[1])], dtype=np.int64
|
||||
)
|
||||
else:
|
||||
assert dim == 1
|
||||
col_start = mp_rank * shape[1] // mp_world_size
|
||||
col_end = (mp_rank + 1) * shape[1] // mp_world_size
|
||||
col_start = tp_rank * shape[1] // tp_world_size
|
||||
col_end = (tp_rank + 1) * shape[1] // tp_world_size
|
||||
return np.arange(shape[0], dtype=np.int64)[:, None] * shape[1] + np.array(
|
||||
[(col_start, col_end)], dtype=np.int64
|
||||
)
|
||||
|
@ -109,32 +109,32 @@ def intervals_partition_fn(
|
|||
|
||||
def shape_partition_fn(
|
||||
shape: torch.Size,
|
||||
mp_rank: Optional[int],
|
||||
mp_world_size: int,
|
||||
tp_rank: Optional[int],
|
||||
tp_world_size: int,
|
||||
dim: Optional[int],
|
||||
):
|
||||
"""Get the partitioned shape of a tensor for tensor-model parallelism."""
|
||||
if dim is None:
|
||||
splits = [shape for _ in range(mp_world_size)]
|
||||
splits = [shape for _ in range(tp_world_size)]
|
||||
else:
|
||||
if dim < 0:
|
||||
dim = len(shape) + dim
|
||||
assert shape[dim] % mp_world_size == 0
|
||||
assert shape[dim] % tp_world_size == 0
|
||||
splits = [
|
||||
(*shape[:dim], shape[dim] // mp_world_size, *shape[dim + 1 :])
|
||||
for _ in range(mp_world_size)
|
||||
(*shape[:dim], shape[dim] // tp_world_size, *shape[dim + 1 :])
|
||||
for _ in range(tp_world_size)
|
||||
]
|
||||
if mp_rank is None:
|
||||
if tp_rank is None:
|
||||
return [s for s in splits]
|
||||
else:
|
||||
return splits[mp_rank]
|
||||
return splits[tp_rank]
|
||||
|
||||
|
||||
def mp_partition_key(
|
||||
def tp_partition_key(
|
||||
key: str,
|
||||
tensor_or_shape: torch.Tensor | torch.Size,
|
||||
mp_rank: Optional[int],
|
||||
mp_size: Optional[int],
|
||||
tp_rank: Optional[int],
|
||||
tp_size: Optional[int],
|
||||
config: model_api.ReaLModelConfig,
|
||||
partition_fn: Callable[
|
||||
[torch.Tensor, Optional[int], int, Optional[int]],
|
||||
|
@ -149,7 +149,7 @@ def mp_partition_key(
|
|||
|
||||
if any([ek in key for ek in EMBEDDING_KEYS]):
|
||||
assert "weight" in key
|
||||
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=0)
|
||||
return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=0)
|
||||
elif key == f"{config.n_layers + 1}.weight": # output head
|
||||
if (
|
||||
isinstance(tensor_or_shape, torch.Tensor) and tensor_or_shape.shape[0] == 1
|
||||
|
@ -157,88 +157,90 @@ def mp_partition_key(
|
|||
not isinstance(tensor_or_shape, torch.Tensor) and tensor_or_shape[0] == 1
|
||||
):
|
||||
assert config.is_critic
|
||||
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=None)
|
||||
return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=None)
|
||||
else:
|
||||
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=0)
|
||||
return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=0)
|
||||
elif any([ck in key for ck in COLUMN_LINEAR_KEYS]):
|
||||
if (
|
||||
("k_attn" in key) or ("v_attn" in key)
|
||||
) and config.n_kv_heads % mp_size != 0:
|
||||
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=None)
|
||||
) and config.n_kv_heads % tp_size != 0:
|
||||
return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=None)
|
||||
# partition both weight and bias
|
||||
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=0)
|
||||
return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=0)
|
||||
elif any([rk in key for rk in ROW_LINEAR_KEYS]):
|
||||
# only partition weight
|
||||
if "weight" in key:
|
||||
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=1)
|
||||
return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=1)
|
||||
else:
|
||||
assert "bias" in key, key
|
||||
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=None)
|
||||
return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=None)
|
||||
else:
|
||||
return partition_fn(tensor_or_shape, mp_rank, mp_size, dim=None)
|
||||
return partition_fn(tensor_or_shape, tp_rank, tp_size, dim=None)
|
||||
|
||||
|
||||
def mp_partition_real_model_state_dict(
|
||||
def tp_partition_real_model_state_dict(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
config: model_api.ReaLModelConfig,
|
||||
mp_size: int,
|
||||
mp_rank: Optional[int] = None,
|
||||
tp_size: int,
|
||||
tp_rank: Optional[int] = None,
|
||||
) -> Union[Dict, List[Dict]]:
|
||||
"""A helper function to partition a state dict using `mp_partition_key`."""
|
||||
if mp_size == 1:
|
||||
if mp_rank is None:
|
||||
"""A helper function to partition a state dict using `tp_partition_key`."""
|
||||
if tp_size == 1:
|
||||
if tp_rank is None:
|
||||
return [state_dict]
|
||||
else:
|
||||
return state_dict
|
||||
|
||||
new_state_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
new_state_dict[k] = mp_partition_key(k, v, mp_rank, mp_size, config)
|
||||
new_state_dict[k] = tp_partition_key(k, v, tp_rank, tp_size, config)
|
||||
|
||||
if mp_rank is None:
|
||||
if tp_rank is None:
|
||||
return [
|
||||
{k: v[mp_rank] for k, v in new_state_dict.items()}
|
||||
for mp_rank in range(mp_size)
|
||||
{k: v[tp_rank] for k, v in new_state_dict.items()}
|
||||
for tp_rank in range(tp_size)
|
||||
]
|
||||
else:
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def get_real_model_param_shape(
|
||||
k: str, config: model_api.ReaLModelConfig, mp_size: int
|
||||
k: str, config: model_api.ReaLModelConfig, tp_size: int
|
||||
) -> Tuple:
|
||||
if "wte.weight" in k:
|
||||
assert config.vocab_size % mp_size == 0
|
||||
return (config.vocab_size // mp_size, config.hidden_dim)
|
||||
assert config.vocab_size % tp_size == 0
|
||||
return (config.vocab_size // tp_size, config.hidden_dim)
|
||||
elif "wpe.weight" in k:
|
||||
assert config.n_positions % mp_size == 0
|
||||
if (config.n_positions + config.abs_position_embedding_offset) % mp_size != 0:
|
||||
assert config.n_positions % tp_size == 0
|
||||
if (config.n_positions + config.abs_position_embedding_offset) % tp_size != 0:
|
||||
raise ValueError(
|
||||
f"The dimenstion of position embedding "
|
||||
f"({config.n_positions} + offset {config.abs_position_embedding_offset}) "
|
||||
f"is not divisible by mp_size ({mp_size}). "
|
||||
f"is not divisible by tp_size ({tp_size}). "
|
||||
"Models like this (e.g. OPT-350m) inherently do not support tensor parallelism."
|
||||
)
|
||||
return (
|
||||
(config.n_positions + config.abs_position_embedding_offset) // mp_size,
|
||||
(config.n_positions + config.abs_position_embedding_offset) // tp_size,
|
||||
config.hidden_dim,
|
||||
)
|
||||
elif ".ln." in k or ".ln_f." in k:
|
||||
return (config.hidden_dim,)
|
||||
elif ".q_ln." in k or ".k_ln." in k:
|
||||
return (config.head_dim,)
|
||||
elif k == f"{config.n_layers + 1}.weight": # output head
|
||||
if config.is_critic:
|
||||
return (1, config.hidden_dim)
|
||||
elif mp_size > 1:
|
||||
assert config.vocab_size % mp_size == 0
|
||||
return (config.vocab_size // mp_size, config.hidden_dim)
|
||||
elif tp_size > 1:
|
||||
assert config.vocab_size % tp_size == 0
|
||||
return (config.vocab_size // tp_size, config.hidden_dim)
|
||||
else:
|
||||
return (config.vocab_size, config.hidden_dim)
|
||||
elif any([ck in k for ck in COLUMN_LINEAR_KEYS]):
|
||||
if "k_attn" in k or "v_attn" in k:
|
||||
if "weight" in k:
|
||||
if config.n_kv_heads % mp_size == 0:
|
||||
if config.n_kv_heads % tp_size == 0:
|
||||
return (
|
||||
config.head_dim * config.n_kv_heads // mp_size,
|
||||
config.head_dim * config.n_kv_heads // tp_size,
|
||||
config.hidden_dim,
|
||||
)
|
||||
else:
|
||||
|
@ -248,27 +250,27 @@ def get_real_model_param_shape(
|
|||
)
|
||||
else:
|
||||
assert "bias" in k
|
||||
if config.n_kv_heads % mp_size == 0:
|
||||
return (config.head_dim * config.n_kv_heads // mp_size,)
|
||||
if config.n_kv_heads % tp_size == 0:
|
||||
return (config.head_dim * config.n_kv_heads // tp_size,)
|
||||
else:
|
||||
return (config.head_dim * config.n_kv_heads,)
|
||||
if "mlp" in k:
|
||||
if "weight" in k:
|
||||
return (config.intermediate_dim // mp_size, config.hidden_dim)
|
||||
return (config.intermediate_dim // tp_size, config.hidden_dim)
|
||||
else:
|
||||
assert "bias" in k
|
||||
return (config.intermediate_dim // mp_size,)
|
||||
return (config.intermediate_dim // tp_size,)
|
||||
if "weight" in k:
|
||||
assert config.n_q_heads % mp_size == 0
|
||||
return (config.n_q_heads * config.head_dim // mp_size, config.hidden_dim)
|
||||
assert config.n_q_heads % tp_size == 0
|
||||
return (config.n_q_heads * config.head_dim // tp_size, config.hidden_dim)
|
||||
else:
|
||||
assert "bias" in k
|
||||
return (config.n_q_heads * config.head_dim // mp_size,)
|
||||
return (config.n_q_heads * config.head_dim // tp_size,)
|
||||
elif any([rk in k for rk in ROW_LINEAR_KEYS]):
|
||||
if "mlp" in k and "weight" in k:
|
||||
return (config.hidden_dim, config.intermediate_dim // mp_size)
|
||||
return (config.hidden_dim, config.intermediate_dim // tp_size)
|
||||
elif "attn" in k and "weight" in k:
|
||||
return (config.hidden_dim, config.n_q_heads * config.head_dim // mp_size)
|
||||
return (config.hidden_dim, config.n_q_heads * config.head_dim // tp_size)
|
||||
elif "bias" in k:
|
||||
return (config.hidden_dim,)
|
||||
else:
|
||||
|
@ -280,7 +282,7 @@ def get_real_model_param_shape(
|
|||
raise NotImplementedError(f"unkown shape of key {k}.")
|
||||
|
||||
|
||||
def mp_merge_key(
|
||||
def tp_merge_key(
|
||||
k: str,
|
||||
tensors: List[torch.Tensor],
|
||||
config: model_api.ReaLModelConfig,
|
||||
|
@ -297,17 +299,17 @@ def mp_merge_key(
|
|||
return tensors[0]
|
||||
|
||||
|
||||
def mp_merge_real_model_state_dict(
|
||||
def tp_merge_real_model_state_dict(
|
||||
state_dicts: List[Dict[str, torch.Tensor]],
|
||||
config: model_api.ReaLModelConfig,
|
||||
) -> Dict:
|
||||
mp_size = len(state_dicts)
|
||||
if mp_size == 1:
|
||||
tp_size = len(state_dicts)
|
||||
if tp_size == 1:
|
||||
return state_dicts[0]
|
||||
|
||||
new_state_dict = {}
|
||||
for k in state_dicts[0].keys():
|
||||
new_state_dict[k] = mp_merge_key(k, [sd[k] for sd in state_dicts], config)
|
||||
new_state_dict[k] = tp_merge_key(k, [sd[k] for sd in state_dicts], config)
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
@ -317,37 +319,37 @@ class ReaLModelParamCount:
|
|||
|
||||
@staticmethod
|
||||
def _derive_count_from_keys(
|
||||
keys: List[str], config: model_api.ReaLModelConfig, mp_size: int
|
||||
keys: List[str], config: model_api.ReaLModelConfig, tp_size: int
|
||||
) -> int:
|
||||
count = 0
|
||||
for k in keys:
|
||||
count += np.prod(get_real_model_param_shape(k, config, mp_size))
|
||||
count += np.prod(get_real_model_param_shape(k, config, tp_size))
|
||||
return int(count)
|
||||
|
||||
@staticmethod
|
||||
def embed(config: model_api.ReaLModelConfig, mp_size: int) -> int:
|
||||
def embed(config: model_api.ReaLModelConfig, tp_size: int) -> int:
|
||||
return ReaLModelParamCount._derive_count_from_keys(
|
||||
ReaLModelParamKeys.embed(config), config, mp_size
|
||||
ReaLModelParamKeys.embed(config), config, tp_size
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def tblock(config: model_api.ReaLModelConfig, idx: int, mp_size: int) -> int:
|
||||
def tblock(config: model_api.ReaLModelConfig, idx: int, tp_size: int) -> int:
|
||||
return ReaLModelParamCount._derive_count_from_keys(
|
||||
ReaLModelParamKeys.tblock(config, idx), config, mp_size
|
||||
ReaLModelParamKeys.tblock(config, idx), config, tp_size
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def head(config: model_api.ReaLModelConfig, mp_size: int) -> int:
|
||||
def head(config: model_api.ReaLModelConfig, tp_size: int) -> int:
|
||||
return ReaLModelParamCount._derive_count_from_keys(
|
||||
ReaLModelParamKeys.head(config), config, mp_size
|
||||
ReaLModelParamKeys.head(config), config, tp_size
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def total(config: model_api.ReaLModelConfig, idx: int, mp_size: int) -> int:
|
||||
def total(config: model_api.ReaLModelConfig, idx: int, tp_size: int) -> int:
|
||||
return (
|
||||
config.n_layers * ReaLModelParamCount.tblock(config, idx, mp_size)
|
||||
+ ReaLModelParamCount.head(config, mp_size)
|
||||
+ ReaLModelParamCount.embed(config, mp_size)
|
||||
config.n_layers * ReaLModelParamCount.tblock(config, idx, tp_size)
|
||||
+ ReaLModelParamCount.head(config, tp_size)
|
||||
+ ReaLModelParamCount.embed(config, tp_size)
|
||||
)
|
||||
|
||||
|
||||
|
@ -356,7 +358,7 @@ def partition_pipeline_layers(
|
|||
num_stages: int,
|
||||
method: str = "parameters",
|
||||
) -> Dict[int, Tuple[int, int]]:
|
||||
# Ignoring mp_size in param count because tensor parallel equally partitions parameters.
|
||||
# Ignoring tp_size in param count because tensor parallel equally partitions parameters.
|
||||
# It is irrelevant to how we partition pipeline stages.
|
||||
param_counts = (
|
||||
[ReaLModelParamCount.embed(config, 1)]
|
||||
|
|
|
@ -13,11 +13,11 @@ def _reduce(input_):
|
|||
"""All-reduce the input tensor across model parallel group."""
|
||||
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if constants.model_parallel_world_size() == 1:
|
||||
if constants.tensor_parallel_world_size() == 1:
|
||||
return input_
|
||||
|
||||
# All-reduce.
|
||||
torch.distributed.all_reduce(input_, group=constants.model_parallel_group())
|
||||
torch.distributed.all_reduce(input_, group=constants.tensor_parallel_group())
|
||||
return input_
|
||||
|
||||
|
||||
|
@ -25,7 +25,7 @@ def _split_along_last_dim(input_):
|
|||
"""Split the tensor along its last dimension and keep the corresponding
|
||||
slice."""
|
||||
|
||||
world_size = constants.model_parallel_world_size()
|
||||
world_size = constants.tensor_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
@ -34,7 +34,7 @@ def _split_along_last_dim(input_):
|
|||
input_list = split_tensor_along_last_dim(input_, world_size)
|
||||
|
||||
# Note: torch.split does not create contiguous tensors by default.
|
||||
rank = constants.model_parallel_rank()
|
||||
rank = constants.tensor_parallel_rank()
|
||||
output = input_list[rank].contiguous()
|
||||
|
||||
return output
|
||||
|
@ -44,7 +44,7 @@ def _split_along_first_dim(input_):
|
|||
"""Split the tensor along its first dimension and keep the corresponding
|
||||
slice."""
|
||||
|
||||
world_size = constants.model_parallel_world_size()
|
||||
world_size = constants.tensor_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
@ -55,7 +55,7 @@ def _split_along_first_dim(input_):
|
|||
dim_size % world_size == 0
|
||||
), "First dimension of the tensor should be divisible by tensor parallel size"
|
||||
local_dim_size = dim_size // world_size
|
||||
rank = constants.model_parallel_rank()
|
||||
rank = constants.tensor_parallel_rank()
|
||||
dim_offset = rank * local_dim_size
|
||||
|
||||
output = input_[dim_offset : dim_offset + local_dim_size].contiguous()
|
||||
|
@ -66,19 +66,19 @@ def _split_along_first_dim(input_):
|
|||
def _gather_along_last_dim(input_):
|
||||
"""Gather tensors and concatinate along the last dimension."""
|
||||
|
||||
world_size = constants.model_parallel_world_size()
|
||||
world_size = constants.tensor_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# Size and dimension.
|
||||
last_dim = input_.dim() - 1
|
||||
rank = constants.model_parallel_rank()
|
||||
rank = constants.tensor_parallel_rank()
|
||||
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
tensor_list[rank] = input_
|
||||
torch.distributed.all_gather(
|
||||
tensor_list, input_, group=constants.model_parallel_group()
|
||||
tensor_list, input_, group=constants.tensor_parallel_group()
|
||||
)
|
||||
|
||||
# Note: torch.cat already creates a contiguous tensor.
|
||||
|
@ -90,7 +90,7 @@ def _gather_along_last_dim(input_):
|
|||
def _gather_along_first_dim(input_):
|
||||
"""Gather tensors and concatinate along the first dimension."""
|
||||
|
||||
world_size = constants.model_parallel_world_size()
|
||||
world_size = constants.tensor_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
@ -102,7 +102,7 @@ def _gather_along_first_dim(input_):
|
|||
dim_size, dtype=input_.dtype, device=constants.current_device()
|
||||
)
|
||||
torch.distributed._all_gather_base(
|
||||
output, input_.contiguous(), group=constants.model_parallel_group()
|
||||
output, input_.contiguous(), group=constants.tensor_parallel_group()
|
||||
)
|
||||
|
||||
return output
|
||||
|
@ -110,7 +110,7 @@ def _gather_along_first_dim(input_):
|
|||
|
||||
def _reduce_scatter_along_first_dim(input_):
|
||||
"""Reduce-scatter the input tensor across model parallel group."""
|
||||
world_size = constants.model_parallel_world_size()
|
||||
world_size = constants.tensor_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
@ -128,7 +128,7 @@ def _reduce_scatter_along_first_dim(input_):
|
|||
dim_size, dtype=input_.dtype, device=constants.current_device()
|
||||
)
|
||||
torch.distributed._reduce_scatter_base(
|
||||
output, input_.contiguous(), group=constants.model_parallel_group()
|
||||
output, input_.contiguous(), group=constants.tensor_parallel_group()
|
||||
)
|
||||
return output
|
||||
|
|
@ -44,7 +44,7 @@ except ImportError:
|
|||
|
||||
import realhf.base.logging as logging
|
||||
|
||||
logger = logging.getLogger("model_parallel.modules")
|
||||
logger = logging.getLogger("tensor_parallel.modules")
|
||||
|
||||
|
||||
def get_activation_fn(activation_function: str) -> Callable:
|
||||
|
@ -95,12 +95,12 @@ class ParallelEmbedding(torch.nn.Module):
|
|||
self.scale_grad_by_freq = False
|
||||
self.sparse = False
|
||||
self._weight = None
|
||||
self.tensor_model_parallel_size = constants.model_parallel_world_size()
|
||||
self.tensor_model_parallel_size = constants.tensor_parallel_world_size()
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
self.vocab_start_index, self.vocab_end_index = (
|
||||
VocabUtility.vocab_range_from_global_vocab_size(
|
||||
self.num_embeddings,
|
||||
constants.model_parallel_rank(),
|
||||
constants.tensor_parallel_rank(),
|
||||
self.tensor_model_parallel_size,
|
||||
)
|
||||
)
|
||||
|
@ -110,7 +110,7 @@ class ParallelEmbedding(torch.nn.Module):
|
|||
|
||||
logger.debug(
|
||||
f"ParallelEmbedding: num_embeddings={num_embeddings}, per_partition={self.num_embeddings_per_partition}, embedding_dim={embedding_dim},"
|
||||
f"tp_rank={constants.model_parallel_rank()},tp_world_size={constants.model_parallel_world_size()}"
|
||||
f"tp_rank={constants.tensor_parallel_rank()},tp_world_size={constants.tensor_parallel_world_size()}"
|
||||
)
|
||||
# Allocate weights and initialize.
|
||||
self.weight = Parameter(
|
||||
|
@ -264,7 +264,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
|
|||
assert (
|
||||
not ctx.async_grad_allreduce
|
||||
), "async_grad_allreduce and sequence_parallel can not be both True"
|
||||
world_size = constants.model_parallel_world_size()
|
||||
world_size = constants.tensor_parallel_world_size()
|
||||
dim_size = list(input.size())
|
||||
dim_size[0] = dim_size[0] * world_size
|
||||
|
||||
|
@ -272,7 +272,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
|
|||
dim_size, input.dtype, "mpu"
|
||||
)
|
||||
torch.distributed._all_gather_base(
|
||||
all_gather_buffer, input, group=constants.model_parallel_group()
|
||||
all_gather_buffer, input, group=constants.tensor_parallel_group()
|
||||
)
|
||||
total_input = all_gather_buffer
|
||||
else:
|
||||
|
@ -290,7 +290,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
|
|||
use_bias = ctx.use_bias
|
||||
|
||||
if ctx.sequence_parallel:
|
||||
world_size = constants.model_parallel_world_size()
|
||||
world_size = constants.tensor_parallel_world_size()
|
||||
dim_size = list(input.size())
|
||||
dim_size[0] = dim_size[0] * world_size
|
||||
|
||||
|
@ -300,7 +300,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
|
|||
handle = torch.distributed._all_gather_base(
|
||||
all_gather_buffer,
|
||||
input,
|
||||
group=constants.model_parallel_group(),
|
||||
group=constants.tensor_parallel_group(),
|
||||
async_op=True,
|
||||
)
|
||||
|
||||
|
@ -327,7 +327,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
|
|||
# Asynchronous all-reduce
|
||||
handle = torch.distributed.all_reduce(
|
||||
grad_input,
|
||||
group=constants.model_parallel_group(),
|
||||
group=constants.tensor_parallel_group(),
|
||||
async_op=True,
|
||||
)
|
||||
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
|
||||
|
@ -346,7 +346,7 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
|
|||
handle = torch.distributed._reduce_scatter_base(
|
||||
sub_grad_input,
|
||||
grad_input,
|
||||
group=constants.model_parallel_group(),
|
||||
group=constants.tensor_parallel_group(),
|
||||
async_op=True,
|
||||
)
|
||||
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
|
||||
|
@ -525,7 +525,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct
|
|||
assert (
|
||||
not ctx.async_grad_allreduce
|
||||
), "async_grad_allreduce and sequence_parallel can not be both True"
|
||||
world_size = constants.model_parallel_world_size()
|
||||
world_size = constants.tensor_parallel_world_size()
|
||||
dim_size = list(input.size())
|
||||
dim_size[0] = dim_size[0] * world_size
|
||||
|
||||
|
@ -533,7 +533,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct
|
|||
dim_size, input.dtype, "mpu"
|
||||
)
|
||||
torch.distributed._all_gather_base(
|
||||
all_gather_buffer, input, group=constants.model_parallel_group()
|
||||
all_gather_buffer, input, group=constants.tensor_parallel_group()
|
||||
)
|
||||
total_input = all_gather_buffer
|
||||
else:
|
||||
|
@ -557,7 +557,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct
|
|||
is_w_parallel = ctx.is_w_parallel
|
||||
|
||||
if ctx.sequence_parallel:
|
||||
world_size = constants.model_parallel_world_size()
|
||||
world_size = constants.tensor_parallel_world_size()
|
||||
dim_size = list(input.size())
|
||||
dim_size[0] = dim_size[0] * world_size
|
||||
|
||||
|
@ -567,7 +567,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct
|
|||
handle = torch.distributed._all_gather_base(
|
||||
all_gather_buffer,
|
||||
input,
|
||||
group=constants.model_parallel_group(),
|
||||
group=constants.tensor_parallel_group(),
|
||||
async_op=True,
|
||||
)
|
||||
|
||||
|
@ -578,7 +578,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct
|
|||
total_input = input
|
||||
grad_input = 0
|
||||
for w, is_parallel, grad in zip(weights, is_w_parallel, grads):
|
||||
if is_parallel or constants.model_parallel_rank() == 0:
|
||||
if is_parallel or constants.tensor_parallel_rank() == 0:
|
||||
grad_input = grad_input + grad.matmul(w)
|
||||
|
||||
if ctx.sequence_parallel:
|
||||
|
@ -597,7 +597,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct
|
|||
# Asynchronous all-reduce
|
||||
handle = torch.distributed.all_reduce(
|
||||
grad_input,
|
||||
group=constants.model_parallel_group(),
|
||||
group=constants.tensor_parallel_group(),
|
||||
async_op=True,
|
||||
)
|
||||
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
|
||||
|
@ -616,7 +616,7 @@ class MergedLinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Funct
|
|||
handle = torch.distributed._reduce_scatter_base(
|
||||
sub_grad_input,
|
||||
grad_input,
|
||||
group=constants.model_parallel_group(),
|
||||
group=constants.tensor_parallel_group(),
|
||||
async_op=True,
|
||||
)
|
||||
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
|
||||
|
@ -785,7 +785,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||
self.output_size = output_size
|
||||
self.gather_output = gather_output
|
||||
# Divide the weight matrix along the last dimension.
|
||||
world_size = constants.model_parallel_world_size()
|
||||
world_size = constants.tensor_parallel_world_size()
|
||||
self.output_size_per_partition = divide(output_size, world_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.is_expert = is_expert
|
||||
|
@ -852,7 +852,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||
# in expert MLPs always behave as sequence parallel is not enabled.
|
||||
sequence_parallel = constants.sequence_parallel() and not self.is_expert
|
||||
async_tensor_model_parallel_allreduce = (
|
||||
constants.model_parallel_world_size() > 1 and not sequence_parallel
|
||||
constants.tensor_parallel_world_size() > 1 and not sequence_parallel
|
||||
)
|
||||
|
||||
if sequence_parallel:
|
||||
|
@ -942,7 +942,7 @@ class RowParallelLinear(torch.nn.Module):
|
|||
self.output_size = output_size
|
||||
self.input_is_parallel = input_is_parallel
|
||||
# Divide the weight matrix along the last dimension.
|
||||
world_size = constants.model_parallel_world_size()
|
||||
world_size = constants.tensor_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, world_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.gradient_accumulation_fusion = gradient_accumulation_fusion
|
||||
|
@ -1030,9 +1030,9 @@ def parallel_lm_logits(
|
|||
bias=None,
|
||||
):
|
||||
"""LM logits using word embedding weights."""
|
||||
model_parallel = constants.model_parallel_world_size() > 1
|
||||
tensor_parallel = constants.tensor_parallel_world_size() > 1
|
||||
sequence_parallel = constants.sequence_parallel()
|
||||
async_grad_allreduce = not sequence_parallel and model_parallel
|
||||
async_grad_allreduce = not sequence_parallel and tensor_parallel
|
||||
# Parallel logits.
|
||||
if sequence_parallel:
|
||||
input_parallel = input_
|
||||
|
@ -1066,7 +1066,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
|
|||
torch.distributed.all_reduce(
|
||||
logits_max,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=constants.model_parallel_group(),
|
||||
group=constants.tensor_parallel_group(),
|
||||
)
|
||||
# Subtract the maximum value.
|
||||
vocab_parallel_logits = vocab_parallel_logits - logits_max.unsqueeze(dim=-1)
|
||||
|
@ -1074,8 +1074,8 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
|
|||
# Get the partition's vocab indecies
|
||||
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
|
||||
partition_vocab_size = vocab_parallel_logits.size()[-1]
|
||||
rank = constants.model_parallel_rank()
|
||||
world_size = constants.model_parallel_world_size()
|
||||
rank = constants.tensor_parallel_rank()
|
||||
world_size = constants.tensor_parallel_world_size()
|
||||
vocab_start_index, vocab_end_index = get_vocab_range(
|
||||
partition_vocab_size, rank, world_size
|
||||
)
|
||||
|
@ -1101,7 +1101,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
|
|||
torch.distributed.all_reduce(
|
||||
predicted_logits,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=constants.model_parallel_group(),
|
||||
group=constants.tensor_parallel_group(),
|
||||
)
|
||||
|
||||
# Sum of exponential of logits along vocab dimension across all GPUs.
|
||||
|
@ -1111,7 +1111,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
|
|||
torch.distributed.all_reduce(
|
||||
sum_exp_logits,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=constants.model_parallel_group(),
|
||||
group=constants.tensor_parallel_group(),
|
||||
)
|
||||
|
||||
# Loss = log(sum(exp(logits))) - predicted-logit.
|
|
@ -6,11 +6,7 @@ from typing import List, Sequence
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from realhf.base.constants import (
|
||||
model_parallel_group,
|
||||
model_parallel_rank,
|
||||
model_parallel_world_size,
|
||||
)
|
||||
import realhf.base.constants as constants
|
||||
|
||||
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
|
||||
"tensor_model_parallel": False,
|
||||
|
@ -22,7 +18,7 @@ _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {
|
|||
def param_is_not_model_parallel_duplicate(param):
|
||||
return (
|
||||
hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel
|
||||
) or (model_parallel_rank() == 0)
|
||||
) or (constants.tensor_parallel_rank() == 0)
|
||||
|
||||
|
||||
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
|
||||
|
@ -110,8 +106,8 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
|
|||
If False, returns a view into the existing Tensor.
|
||||
Default is False
|
||||
"""
|
||||
partition_size = torch.numel(tensor) // model_parallel_world_size()
|
||||
start_index = partition_size * model_parallel_rank()
|
||||
partition_size = torch.numel(tensor) // constants.tensor_parallel_world_size()
|
||||
start_index = partition_size * constants.tensor_parallel_rank()
|
||||
end_index = start_index + partition_size
|
||||
if new_buffer:
|
||||
data = torch.empty(
|
||||
|
@ -135,7 +131,7 @@ def gather_split_1d_tensor(tensor):
|
|||
Arguments:
|
||||
tensor: A Tensor or view of this rank's portion of the data.
|
||||
"""
|
||||
numel_gathered = torch.numel(tensor) * model_parallel_world_size()
|
||||
numel_gathered = torch.numel(tensor) * constants.tensor_parallel_world_size()
|
||||
gathered = torch.empty(
|
||||
numel_gathered,
|
||||
dtype=tensor.dtype,
|
||||
|
@ -147,7 +143,9 @@ def gather_split_1d_tensor(tensor):
|
|||
# as opposed to torch.distributed.all_gather for efficiency reasons.
|
||||
# This API calls directly NCCL all-gather versus the former does
|
||||
# internal copies and can potentially cause slow down.
|
||||
torch.distributed._all_gather_base(gathered, tensor, group=model_parallel_group())
|
||||
torch.distributed._all_gather_base(
|
||||
gathered, tensor, group=constants.tensor_parallel_group()
|
||||
)
|
||||
return gathered
|
||||
|
||||
|
|
@ -210,11 +210,11 @@ def gather_packed_shifted_log_probs(
|
|||
"""
|
||||
labels = torch.nn.functional.pad(labels[1:], (0, 1), value=0)
|
||||
leave_one_indices = build_leave_one_indices(logits, cu_seqlens)
|
||||
if constants.model_parallel_world_size() > 1:
|
||||
if constants.tensor_parallel_world_size() > 1:
|
||||
# NOTE: logprobs is freaking sensitive to input_ids. If the input sequence is a natural sequence, everything will be fine.
|
||||
# However, if we input random token IDs, parallel cross entropy can produce VERY different results than the normal
|
||||
# torch.gather based version (e.g., the maximum absolute different can reach ~50).
|
||||
from realhf.impl.model.parallelism.model_parallel.modules import (
|
||||
from realhf.impl.model.parallelism.tensor_parallel.modules import (
|
||||
vocab_parallel_cross_entropy,
|
||||
)
|
||||
|
||||
|
@ -239,14 +239,16 @@ def gather_packed_shifted_log_probs(
|
|||
|
||||
|
||||
def apply_logits_mask(logits: torch.HalfTensor, mask: torch.BoolTensor):
|
||||
assert mask.shape[-1] == logits.shape[-1] * constants.model_parallel_world_size(), (
|
||||
constants.model_parallel_world_size(),
|
||||
assert (
|
||||
mask.shape[-1] == logits.shape[-1] * constants.tensor_parallel_world_size()
|
||||
), (
|
||||
constants.tensor_parallel_world_size(),
|
||||
logits.shape,
|
||||
mask.shape,
|
||||
)
|
||||
parallel_vocab_size = logits.shape[-1]
|
||||
mp_rank = constants.model_parallel_rank()
|
||||
mask = mask[:, mp_rank * parallel_vocab_size : (mp_rank + 1) * parallel_vocab_size]
|
||||
tp_rank = constants.tensor_parallel_rank()
|
||||
mask = mask[:, tp_rank * parallel_vocab_size : (tp_rank + 1) * parallel_vocab_size]
|
||||
logits.masked_fill_(mask, torch.finfo(logits.dtype).min)
|
||||
|
||||
|
||||
|
|
|
@ -250,7 +250,7 @@ def pad_sequence_parallel_input(
|
|||
):
|
||||
"""Sequence parallel requires packed_input_ids has a shape of 1 dimension
|
||||
[total_seq_len], and total_seq_len should be divisible by
|
||||
model_parallel_world_size. This function is used to pad packed_input_ids to
|
||||
tensor_parallel_world_size. This function is used to pad packed_input_ids to
|
||||
suitable length with an empty sequence, and return new packed_input_ids,
|
||||
cu_seqlens and max_seqlen.
|
||||
|
||||
|
@ -262,10 +262,10 @@ def pad_sequence_parallel_input(
|
|||
Returns:
|
||||
(torch.Tensor, torch.Tensor, int, int): padded (packed_input_ids, cu_seqlens, max_seqlen, pad_size)
|
||||
"""
|
||||
mp_world_size = constants.model_parallel_world_size()
|
||||
tp_world_size = constants.tensor_parallel_world_size()
|
||||
pad_size = 0
|
||||
if len(packed_input_ids) % mp_world_size != 0:
|
||||
pad_size = mp_world_size - len(packed_input_ids) % mp_world_size
|
||||
if len(packed_input_ids) % tp_world_size != 0:
|
||||
pad_size = tp_world_size - len(packed_input_ids) % tp_world_size
|
||||
packed_input_ids = torch.nn.functional.pad(
|
||||
packed_input_ids, (0, pad_size), value=1
|
||||
)
|
||||
|
@ -281,9 +281,9 @@ def pad_sequence_parallel_generate_input(
|
|||
):
|
||||
"""Only for pipeline generate input when model+seq parallel is enabled. To
|
||||
make sure inputs for seq parallel model have a shape with first dimension
|
||||
divisible by model_parallel_world_size, the packed_input_ids should have
|
||||
length divisible by model_parallel_world_size, and contains number of
|
||||
sequences divisible by model_parallel_world_size.
|
||||
divisible by tensor_parallel_world_size, the packed_input_ids should have
|
||||
length divisible by tensor_parallel_world_size, and contains number of
|
||||
sequences divisible by tensor_parallel_world_size.
|
||||
|
||||
Args:
|
||||
packed_input_ids (torch.Tensor): unpadded packed_input_ids
|
||||
|
@ -293,16 +293,16 @@ def pad_sequence_parallel_generate_input(
|
|||
Returns:
|
||||
(torch.Tensor, torch.Tensor, int, int, int): padded (packed_input_ids, cu_seqlens, max_seqlen, pad_size, pad_seq_size)
|
||||
"""
|
||||
mp_world_size = constants.model_parallel_world_size()
|
||||
tp_world_size = constants.tensor_parallel_world_size()
|
||||
pad_size, pad_seq_size = 0, 0
|
||||
if (
|
||||
len(packed_input_ids) % mp_world_size != 0
|
||||
or (len(cu_seqlens) - 1) % mp_world_size != 0
|
||||
len(packed_input_ids) % tp_world_size != 0
|
||||
or (len(cu_seqlens) - 1) % tp_world_size != 0
|
||||
):
|
||||
pad_size = mp_world_size - len(packed_input_ids) % mp_world_size
|
||||
pad_seq_size = mp_world_size - (len(cu_seqlens) - 1) % mp_world_size
|
||||
pad_size = tp_world_size - len(packed_input_ids) % tp_world_size
|
||||
pad_seq_size = tp_world_size - (len(cu_seqlens) - 1) % tp_world_size
|
||||
if pad_size < pad_seq_size:
|
||||
pad_size += mp_world_size
|
||||
pad_size += tp_world_size
|
||||
pad_cu_seqlens = torch.tensor(list(range(1, pad_seq_size)) + [pad_size]) + len(
|
||||
packed_input_ids
|
||||
)
|
||||
|
|
|
@ -55,6 +55,7 @@ def actor_loss_fn(
|
|||
eps_clip: float,
|
||||
loss_mask: Optional[torch.BoolTensor] = None,
|
||||
c_clip: Optional[float] = None,
|
||||
proximal_logprobs: Optional[torch.FloatTensor] = None,
|
||||
) -> Tuple[torch.Tensor, Dict]:
|
||||
"""Compute PPO actor loss function.
|
||||
|
||||
|
@ -83,13 +84,22 @@ def actor_loss_fn(
|
|||
old_logprobs = old_logprobs.clone()
|
||||
if advantages.is_inference():
|
||||
advantages = advantages.clone()
|
||||
if proximal_logprobs is not None:
|
||||
assert proximal_logprobs.dtype == torch.float32
|
||||
if proximal_logprobs.is_inference():
|
||||
proximal_logprobs = proximal_logprobs.clone()
|
||||
denorm_logprobs = proximal_logprobs
|
||||
else:
|
||||
denorm_logprobs = old_logprobs
|
||||
|
||||
# create mask
|
||||
if loss_mask is None:
|
||||
loss_mask = torch.ones_like(logprobs, dtype=torch.bool)
|
||||
loss_mask: torch.BoolTensor
|
||||
|
||||
if loss_mask is not None:
|
||||
loss_mask_count = loss_mask.count_nonzero() or 1
|
||||
# For numerical stability.
|
||||
ratio = torch.where(loss_mask, torch.exp(logprobs - old_logprobs), 0)
|
||||
else:
|
||||
ratio = torch.exp(logprobs - old_logprobs)
|
||||
ratio = torch.where(loss_mask, torch.exp(logprobs - denorm_logprobs), 0)
|
||||
|
||||
clipped_ratio = torch.clamp(ratio, 1.0 - eps_clip, 1.0 + eps_clip)
|
||||
pg_loss1 = -advantages * ratio
|
||||
|
@ -104,24 +114,34 @@ def actor_loss_fn(
|
|||
pg_loss = torch.min(pg_loss, pg_loss3)
|
||||
else:
|
||||
dual_clip_mask = torch.zeros_like(clip_mask)
|
||||
if proximal_logprobs is not None:
|
||||
behav_kl = proximal_logprobs - old_logprobs
|
||||
behav_imp_weight = behav_kl.exp()
|
||||
if c_clip is not None:
|
||||
behav_mask = (behav_imp_weight <= c_clip).logical_and(loss_mask)
|
||||
else:
|
||||
behav_mask = loss_mask
|
||||
behav_kl = torch.where(behav_mask, behav_kl, 0.0)
|
||||
behav_imp_weight = torch.where(behav_mask, behav_imp_weight, 0.0)
|
||||
pg_loss = pg_loss * behav_imp_weight
|
||||
|
||||
logging_loss = pg_loss.detach()
|
||||
if loss_mask is not None:
|
||||
pg_loss = torch.where(loss_mask, pg_loss, 0).sum() / loss_mask_count
|
||||
else:
|
||||
pg_loss = pg_loss.mean()
|
||||
|
||||
if loss_mask is not None:
|
||||
clip_mask.logical_and_(loss_mask)
|
||||
dual_clip_mask.logical_and_(loss_mask)
|
||||
# Remain torch.CudaTensor here for all-reduce after train step.
|
||||
stat = dict(
|
||||
loss=logging_loss,
|
||||
importance_weight=ratio.detach(),
|
||||
approx_kl=(logprobs - old_logprobs).detach(),
|
||||
approx_kl=(logprobs - denorm_logprobs).detach(),
|
||||
clip_mask=clip_mask,
|
||||
dual_clip_mask=dual_clip_mask,
|
||||
)
|
||||
if proximal_logprobs is not None:
|
||||
stat["behave_imp_weight"] = behav_imp_weight
|
||||
stat["behave_approx_kl"] = behav_kl
|
||||
stat["behave_mask"] = behav_mask
|
||||
|
||||
return pg_loss, stat
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from torch.cuda import device as device_ctx_manager
|
|||
from torch.utils.checkpoint import detach_variable
|
||||
|
||||
import realhf.base.constants as constants
|
||||
from realhf.impl.model.parallelism.model_parallel.utils import (
|
||||
from realhf.impl.model.parallelism.tensor_parallel.utils import (
|
||||
divide,
|
||||
gather_split_1d_tensor,
|
||||
safely_set_viewless_tensor_data,
|
||||
|
@ -169,10 +169,10 @@ def model_parallel_cuda_manual_seed(seed):
|
|||
tensor-model-parallel state: This state is different among a set of model parallel GPUs, but the same across data parallel groups. This is used for example for dropout in model parallel regions.
|
||||
"""
|
||||
# 2718 is just for fun and any POSITIVE value will work.
|
||||
model_parallel_rank = constants.model_parallel_rank()
|
||||
tensor_parallel_rank = constants.tensor_parallel_rank()
|
||||
expert_parallel_rank = 0
|
||||
offset = seed + 2718
|
||||
tensor_model_parallel_seed = offset + model_parallel_rank
|
||||
tensor_model_parallel_seed = offset + tensor_parallel_rank
|
||||
# Data parallel gets the original seed.
|
||||
data_parallel_seed = seed
|
||||
|
||||
|
@ -187,7 +187,7 @@ def model_parallel_cuda_manual_seed(seed):
|
|||
)
|
||||
|
||||
expert_parallel_seed = (
|
||||
seed + 1024 + 100 * expert_parallel_rank + model_parallel_rank
|
||||
seed + 1024 + 100 * expert_parallel_rank + tensor_parallel_rank
|
||||
)
|
||||
_CUDA_RNG_STATE_TRACKER.add(_EXPERT_PARALLEL_RNG_TRACKER_NAME, expert_parallel_seed)
|
||||
|
||||
|
@ -331,8 +331,8 @@ def _initialize_affine_weight_cpu(
|
|||
weight_list = torch.split(
|
||||
master_weight, per_partition_per_stride_size, dim=partition_dim
|
||||
)
|
||||
rank = constants.model_parallel_rank()
|
||||
world_size = constants.model_parallel_world_size()
|
||||
rank = constants.tensor_parallel_rank()
|
||||
world_size = constants.tensor_parallel_world_size()
|
||||
my_weight_list = weight_list[rank::world_size]
|
||||
|
||||
with torch.no_grad():
|
||||
|
|
|
@ -5,13 +5,18 @@
|
|||
import fcntl
|
||||
import os
|
||||
import re
|
||||
import select
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Literal, Optional, Tuple
|
||||
|
||||
import colorama
|
||||
|
||||
import realhf.base.logging as logging
|
||||
from realhf.base.cluster import spec as cluster_spec
|
||||
from realhf.base.constants import LOG_ROOT
|
||||
from realhf.base.constants import SLURM_LOCK_FILE_NAME as LOCK_FILE_NAME
|
||||
from realhf.scheduler.client import JobException, JobInfo, JobState, SchedulerClient
|
||||
from realhf.scheduler.evaluator import AutomaticEvaluator
|
||||
|
@ -29,6 +34,49 @@ SCHEDULING_TIMEOUT_MAX_SECONDS = 3600 * 24
|
|||
SCHEDULER_WAIT_CHECK_TIME_INTERVAL = 5
|
||||
|
||||
|
||||
def monitor_log(
|
||||
job_name: str, log_path: str, output_file: str, stop_event: threading.Event
|
||||
):
|
||||
"""Monitor a log file and write its contents to the output file with job name prefix."""
|
||||
# Wait for log file to be created
|
||||
while not os.path.exists(log_path) and not stop_event.is_set():
|
||||
time.sleep(0.1)
|
||||
|
||||
if stop_event.is_set():
|
||||
return
|
||||
|
||||
# Open the log file and follow it
|
||||
with open(log_path, "r") as log_file, open(output_file, "a") as out_file:
|
||||
# Store last position
|
||||
position = 0
|
||||
line_pos = 0
|
||||
|
||||
while not stop_event.is_set():
|
||||
log_file.seek(position)
|
||||
try:
|
||||
new_lines = log_file.readlines()
|
||||
except UnicodeDecodeError:
|
||||
time.sleep(0.5)
|
||||
continue
|
||||
|
||||
if new_lines:
|
||||
# Update position
|
||||
position = log_file.tell()
|
||||
|
||||
worker_type = job_name.split(":")[1]
|
||||
# Write new lines to output file with job name prefix
|
||||
for line in new_lines:
|
||||
if line.strip(): # Skip empty lines
|
||||
out_file.write(
|
||||
f"{colorama.Fore.YELLOW + colorama.Style.DIM}({worker_type} Line {line_pos}){colorama.Style.RESET_ALL} {line}"
|
||||
)
|
||||
line_pos += 1
|
||||
out_file.flush()
|
||||
|
||||
# Sleep briefly to avoid CPU spinning
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
class SlurmSchedulerClient(SchedulerClient):
|
||||
"""Uses Slurm (https://slurm.schedmd.com/overview.html)."""
|
||||
|
||||
|
@ -248,6 +296,26 @@ class SlurmSchedulerClient(SchedulerClient):
|
|||
# before wait, commit all remaining pending jobs
|
||||
# TODO: grab global file lock to avoid multi-experiment deadlocks
|
||||
self.__allocate_and_commit_pending_jobs()
|
||||
# Start monitoring threads
|
||||
threads = []
|
||||
stop_events = []
|
||||
|
||||
merged_log_path = os.path.join(
|
||||
LOG_ROOT, self.expr_name, self.trial_name, "main.log"
|
||||
)
|
||||
|
||||
for job_name, launch_info in self.__committed_jobs.items():
|
||||
stop_event = threading.Event()
|
||||
stop_events.append(stop_event)
|
||||
|
||||
# Thread for monitoring the log file
|
||||
log_thread = threading.Thread(
|
||||
target=monitor_log,
|
||||
args=(job_name, launch_info.log_path, merged_log_path, stop_event),
|
||||
)
|
||||
threads.append(log_thread)
|
||||
log_thread.start()
|
||||
|
||||
# begin wait
|
||||
deadline = None if timeout is None else time.time() + timeout
|
||||
left = set(self.__committed_jobs.keys())
|
||||
|
@ -256,6 +324,11 @@ class SlurmSchedulerClient(SchedulerClient):
|
|||
f"Waiting for {num_jobs_left} jobs. Jobs IDs: "
|
||||
f"{','.join(sorted([x.job_info.slurm_id for x in self.__committed_jobs.values()]))}."
|
||||
)
|
||||
logger.info(
|
||||
f"All slurm logs will be merged. To check the real-time output, "
|
||||
f"run\n\t`tail -f {merged_log_path}`."
|
||||
)
|
||||
try:
|
||||
while len(left) > 0:
|
||||
if len(left) < num_jobs_left:
|
||||
num_jobs_left = len(left)
|
||||
|
@ -294,6 +367,9 @@ class SlurmSchedulerClient(SchedulerClient):
|
|||
if update:
|
||||
self.__committed_jobs.pop(job_slurm_name)
|
||||
time.sleep(SCHEDULER_WAIT_CHECK_TIME_INTERVAL)
|
||||
finally:
|
||||
[s.set() for s in stop_events]
|
||||
[t.join() for t in threads]
|
||||
|
||||
def __update_all(self):
|
||||
states = []
|
||||
|
|
|
@ -294,21 +294,29 @@ class SlurmLaunchInfo:
|
|||
|
||||
@property
|
||||
def multiprog_path(self) -> str:
|
||||
return os.path.join(
|
||||
path = os.path.join(
|
||||
LOG_ROOT,
|
||||
self.exper_name,
|
||||
self.trial_name,
|
||||
"slurm",
|
||||
"multiprog",
|
||||
f"{self.worker_type}-{self.worker_submission_idx}.multiprog",
|
||||
)
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def hostfile_path(self) -> str:
|
||||
return os.path.join(
|
||||
path = os.path.join(
|
||||
LOG_ROOT,
|
||||
self.exper_name,
|
||||
self.trial_name,
|
||||
"slurm",
|
||||
"hostfile",
|
||||
f"{self.worker_type}-{self.worker_submission_idx}.hostfile",
|
||||
)
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
return path
|
||||
|
||||
def show_log(self):
|
||||
try:
|
||||
|
|
|
@ -1,12 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
|
||||
def import_profiler_registers():
|
||||
import realhf.search_engine.enumerate
|
||||
import realhf.search_engine.estimate
|
||||
import realhf.search_engine.layers
|
||||
import realhf.search_engine.param_realloc
|
||||
import realhf.search_engine.search
|
||||
import realhf.search_engine.utils
|
|
@ -1,158 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
from typing import List
|
||||
|
||||
from realhf.api.core.dfg import MFCDef, ModelInterfaceType
|
||||
from realhf.api.core.dfg import build_graph as build_dfg
|
||||
from realhf.api.quickstart.device_mesh import DeviceMesh, find_parallel_strategies
|
||||
from realhf.api.quickstart.search import MFCDef, RPCExecution, RPCInstance
|
||||
from realhf.search_engine.estimate import (
|
||||
estimate_rpc_memory_cost,
|
||||
estimate_rpc_time_cost,
|
||||
)
|
||||
|
||||
MEM_INDEX = 1.0 # heuristic value to scale estimated memory
|
||||
|
||||
|
||||
def enumerate_rpc_executions(
|
||||
rpc: MFCDef,
|
||||
device_mesh: DeviceMesh,
|
||||
seq_len: int,
|
||||
num_gen_tokens: int,
|
||||
n_ppo_minibatches: int,
|
||||
gradient_checkpointing: bool,
|
||||
) -> List[RPCExecution]:
|
||||
sub_device_meshes = device_mesh.sub_device_meshes()
|
||||
import pprint
|
||||
|
||||
feasible = []
|
||||
for sub_device_mesh in sub_device_meshes:
|
||||
ps = find_parallel_strategies(sub_device_mesh)
|
||||
for parallel in ps:
|
||||
num_dp = parallel.data_parallel_size
|
||||
num_pp = parallel.pipeline_parallel_size
|
||||
num_mp = parallel.model_parallel_size
|
||||
bs = rpc.n_seqs
|
||||
# seq_len = seq_len
|
||||
min_bs = (
|
||||
2 * num_dp * num_pp * n_ppo_minibatches
|
||||
if rpc.interface_type == ModelInterfaceType.TRAIN_STEP
|
||||
else num_dp * num_pp
|
||||
)
|
||||
if min_bs > bs:
|
||||
# batch size too small
|
||||
continue
|
||||
# heuristic to filter out inherent slow configurations
|
||||
if (
|
||||
num_mp * num_dp > device_mesh.n_gpus_per_node
|
||||
and rpc.interface_type == ModelInterfaceType.TRAIN_STEP
|
||||
):
|
||||
continue
|
||||
if num_mp > 8:
|
||||
continue
|
||||
if num_pp > max(device_mesh.n_nodes, 8):
|
||||
continue
|
||||
# memory and time estimation
|
||||
mem_cost, static_mem = estimate_rpc_memory_cost(
|
||||
rpc,
|
||||
parallel,
|
||||
bs,
|
||||
seq_len,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
n_ppo_minibatches=n_ppo_minibatches,
|
||||
num_gen_tokens=num_gen_tokens,
|
||||
offload=rpc.model_name.role in ["ref", "reward"],
|
||||
)
|
||||
mem_cost = int(mem_cost * MEM_INDEX)
|
||||
static_mem = int(static_mem * MEM_INDEX)
|
||||
time_cost = estimate_rpc_time_cost(
|
||||
rpc,
|
||||
parallel,
|
||||
bs=bs,
|
||||
seq_len=seq_len,
|
||||
num_gen_tokens=num_gen_tokens,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
n_ppo_minibatches=n_ppo_minibatches,
|
||||
)
|
||||
time_cost = int(time_cost)
|
||||
if mem_cost < device_mesh.gpu_memory_capacity:
|
||||
feasible.append(
|
||||
RPCExecution(
|
||||
rpc,
|
||||
sub_device_mesh,
|
||||
parallel,
|
||||
time_cost,
|
||||
mem_cost,
|
||||
static_mem,
|
||||
)
|
||||
)
|
||||
return feasible
|
||||
|
||||
|
||||
def build_graph(
|
||||
rpcs: List[MFCDef],
|
||||
num_epoch: int = 5,
|
||||
epoch_dependency_interval: int = 1,
|
||||
if_print=False,
|
||||
) -> List[RPCInstance]:
|
||||
"""Build model function call graph of multiple training epochs,
|
||||
|
||||
args:
|
||||
exp: ProfileExperiment, the experiment object
|
||||
num_epoch: int, number of training epochs
|
||||
epoch_dependency_interval: int, the interval of epoch dependency,
|
||||
e.g. if epoch_dependency_interval = 2, then the graph will have
|
||||
edges between epoch i and epoch i+2, i+4, ...
|
||||
|
||||
returns:
|
||||
rpc_instances: List[RPCInstance], the list of RPCInstance objects
|
||||
"""
|
||||
# one epoch dependency graph
|
||||
rpcs, edges = build_dfg(rpcs)
|
||||
rpc_names_mapping = {rpc.name: rpc for rpc in rpcs}
|
||||
rpc_instances = []
|
||||
|
||||
# multi epoch graph
|
||||
for epoch_id in range(num_epoch):
|
||||
for rpc in rpcs:
|
||||
children = []
|
||||
parents = []
|
||||
if rpc.is_src and epoch_id >= epoch_dependency_interval:
|
||||
for other in rpcs:
|
||||
if other.is_dst and other.model_name.role == rpc.model_name.role:
|
||||
parents.append(
|
||||
RPCInstance(
|
||||
rpc,
|
||||
epoch_id - epoch_dependency_interval,
|
||||
[],
|
||||
[],
|
||||
)
|
||||
)
|
||||
if rpc.is_dst and rpc.model_name.role == rpc.model_name.role:
|
||||
for other in rpcs:
|
||||
if (
|
||||
other.is_src
|
||||
and epoch_id + epoch_dependency_interval < num_epoch
|
||||
):
|
||||
children.append(
|
||||
RPCInstance(
|
||||
rpc,
|
||||
epoch_id + epoch_dependency_interval,
|
||||
[],
|
||||
[],
|
||||
)
|
||||
)
|
||||
for parent in rpc.parents:
|
||||
p = rpc_names_mapping[parent]
|
||||
parents.append(RPCInstance(p, epoch_id, [], []))
|
||||
for child in rpc.children:
|
||||
c = rpc_names_mapping[child]
|
||||
children.append(RPCInstance(c, epoch_id, [], []))
|
||||
rpc_instance = RPCInstance(rpc, epoch_id, parents, children)
|
||||
rpc_instances.append(rpc_instance)
|
||||
if if_print:
|
||||
for ri in rpc_instances:
|
||||
print(ri)
|
||||
return rpc_instances
|
|
@ -1,499 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
# Estimate a fucntion-call level execution time for device mesh enumerate pruning
|
||||
# assume one batch of data passes through all rpcs once
|
||||
import argparse
|
||||
import getpass
|
||||
import itertools
|
||||
import os
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import realhf.base.cluster
|
||||
import realhf.base.constants as constants
|
||||
import realhf.base.logging as logging
|
||||
from realhf.api.cli_args import ParallelismConfig
|
||||
from realhf.api.core.dfg import MFCDef, ModelFamily, ModelInterfaceType
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.search_engine.param_realloc import estimate_param_realloc_time_cost
|
||||
from realhf.search_engine.utils import load_model_config
|
||||
|
||||
logger = logging.getLogger("estimate", "benchmark")
|
||||
|
||||
PROFILE_RESULT_PATH = os.path.join(
|
||||
realhf.base.cluster.spec.fileroot,
|
||||
"logs",
|
||||
getpass.getuser(),
|
||||
"profile",
|
||||
"profile",
|
||||
"layer_stats",
|
||||
)
|
||||
|
||||
|
||||
def get_param_realloc_stats(
|
||||
model_family: ModelFamily,
|
||||
model_path: str,
|
||||
n_nodes: int,
|
||||
use_cache: bool = True,
|
||||
):
|
||||
non_critic = ModelFamily(model_family._class, model_family.size, False)
|
||||
table_path = os.path.join(
|
||||
constants.PROFILER_CACHE_PATH,
|
||||
"param_realloc",
|
||||
f"prtc_{non_critic}_n{n_nodes}.pkl",
|
||||
)
|
||||
if not os.path.exists(table_path):
|
||||
print(
|
||||
f"Calculating estimation of param realloc time cost for {model_family} at {model_path}"
|
||||
)
|
||||
estimate_param_realloc_time_cost(n_nodes, {non_critic: model_path})
|
||||
|
||||
print(f"Loading param realloc stats from {table_path}")
|
||||
return pickle.load(open(table_path, "rb"))
|
||||
|
||||
|
||||
def get_organized_op_stats(
|
||||
model_family: ModelFamily, model_path: str, use_cache: bool = True
|
||||
):
|
||||
non_critic = ModelFamily(model_family._class, model_family.size, False)
|
||||
# parse raw stats into list of OpInfo used for estimation
|
||||
cache_path = os.path.join(
|
||||
constants.PROFILER_CACHE_PATH, "organized_stats", f"{non_critic}.pkl"
|
||||
)
|
||||
if use_cache and os.path.exists(cache_path):
|
||||
with open(cache_path, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
raw_result_path = os.path.join(
|
||||
constants.PROFILER_CACHE_PATH, "layer_stats", str(non_critic)
|
||||
)
|
||||
if not os.path.exists(raw_result_path):
|
||||
from realhf.apps.main import _main_profile_layers
|
||||
|
||||
_main_profile_layers(non_critic, model_path)
|
||||
|
||||
raw_stats_list = []
|
||||
for fn in os.listdir(raw_result_path):
|
||||
if not fn.startswith("layer-stats"):
|
||||
continue
|
||||
num_mp = int(fn.replace(".pkl", "").split("_")[1])
|
||||
rank = int(fn.replace(".pkl", "").split("_")[2])
|
||||
with open(os.path.join(raw_result_path, fn), "rb") as f:
|
||||
stats = pickle.load(f)
|
||||
|
||||
if isinstance(stats, dict):
|
||||
stats = pd.DataFrame(stats)
|
||||
elif isinstance(stats, pd.DataFrame):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type {type(stats)}")
|
||||
|
||||
stats["num_mp"] = num_mp
|
||||
stats["rank"] = rank
|
||||
raw_stats_list.append(stats)
|
||||
|
||||
raw_stats = pd.concat(raw_stats_list)
|
||||
bs_list = raw_stats["bs"].unique()
|
||||
seq_len_list = raw_stats["seq_len"].unique()
|
||||
op_name_list = raw_stats["op_name"].unique()
|
||||
layer_name_list = raw_stats["layer_name"].unique()
|
||||
num_mp_list = raw_stats["num_mp"].unique()
|
||||
organized_stats = defaultdict(list)
|
||||
|
||||
for op_name, bs, seq_len, layer_name, num_mp in itertools.product(
|
||||
op_name_list, bs_list, seq_len_list, layer_name_list, num_mp_list
|
||||
):
|
||||
filter_cond = (
|
||||
(raw_stats["op_name"] == op_name)
|
||||
& (raw_stats["bs"] == bs)
|
||||
& (raw_stats["seq_len"] == seq_len)
|
||||
& (raw_stats["layer_name"] == layer_name)
|
||||
& (raw_stats["num_mp"] == num_mp)
|
||||
)
|
||||
avg_time_ns = raw_stats[filter_cond]["time_ns"].mean()
|
||||
x = int(bs) if op_name == "fwd_gen_1" else int(bs * seq_len)
|
||||
|
||||
organized_stats["op_name"].append(op_name)
|
||||
organized_stats["layer_name"].append(layer_name)
|
||||
organized_stats["bs"].append(bs)
|
||||
organized_stats["seq_len"].append(seq_len)
|
||||
organized_stats["num_mp"].append(num_mp)
|
||||
organized_stats["avg_time_ns"].append(avg_time_ns)
|
||||
organized_stats["x"].append(x)
|
||||
|
||||
df = pd.DataFrame(organized_stats)
|
||||
if use_cache:
|
||||
os.makedirs(os.path.dirname(cache_path), exist_ok=True)
|
||||
with open(cache_path, "wb") as f:
|
||||
pickle.dump(df, f)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def computation_instruction_time_cost(
|
||||
op_stats: pd.DataFrame,
|
||||
op_name: str,
|
||||
num_layers: int,
|
||||
parallel_strategy: ParallelismConfig,
|
||||
bs: int,
|
||||
seqlen: int,
|
||||
):
|
||||
# inst cost unit: ns
|
||||
layer_names = ["embedding_layer", "block_0", "head"]
|
||||
num_pp = parallel_strategy.pipeline_parallel_size
|
||||
num_mp = parallel_strategy.model_parallel_size
|
||||
op_stats = op_stats[
|
||||
(op_stats["op_name"] == op_name) & (op_stats["num_mp"] == num_mp)
|
||||
]
|
||||
|
||||
op_cost = {}
|
||||
embed_stats = op_stats[op_stats["layer_name"] == "embedding_layer"]
|
||||
if embed_stats[
|
||||
(embed_stats["bs"] == bs) & (embed_stats["seq_len"] == seqlen)
|
||||
].empty:
|
||||
# do linear interpolation for data points that does not exist
|
||||
for layer_name in layer_names:
|
||||
layer_stats = op_stats[op_stats["layer_name"] == layer_name]
|
||||
assert layer_stats[
|
||||
(layer_stats["bs"] == bs) & (layer_stats["seq_len"] == seqlen)
|
||||
].empty
|
||||
assert not layer_stats.empty, (
|
||||
layer_name,
|
||||
op_name,
|
||||
num_mp,
|
||||
op_stats,
|
||||
)
|
||||
xs = layer_stats["x"]
|
||||
ys = layer_stats["avg_time_ns"]
|
||||
x = int(bs) if op_name == "fwd_gen_1" else int(bs * seqlen)
|
||||
y = np.interp(x, xs, ys)
|
||||
if max(xs) < x or min(xs) > x:
|
||||
logger.warning(
|
||||
f"Interpolated value outside profiling range, "
|
||||
f"parallel strategy {parallel_strategy}: "
|
||||
f"{x} in {sorted(list(set(xs)))}"
|
||||
)
|
||||
# estimate using largest or smallest value
|
||||
if max(xs) < x:
|
||||
y = ys.max() * (x / xs[ys.idxmax()])
|
||||
else:
|
||||
y = ys.min() * (x / xs[ys.idxmin()])
|
||||
op_cost[layer_name] = y
|
||||
else:
|
||||
for layer_name in layer_names:
|
||||
assert not op_stats[op_stats["layer_name"] == layer_name].empty
|
||||
required_stats = op_stats[
|
||||
(op_stats["layer_name"] == layer_name)
|
||||
& (op_stats["bs"] == bs)
|
||||
& (op_stats["seq_len"] == seqlen)
|
||||
]
|
||||
assert required_stats.shape[0] == 1
|
||||
op_cost[layer_name] = required_stats["avg_time_ns"].values[0]
|
||||
|
||||
embedding_layer_cost = op_cost["embedding_layer"]
|
||||
block_0_cost = op_cost["block_0"]
|
||||
head_cost = op_cost["head"]
|
||||
cost = (embedding_layer_cost + num_layers * block_0_cost + head_cost) / num_pp
|
||||
return cost
|
||||
|
||||
|
||||
def communication_instruction_time_cost(comm_stats, size, comm_type):
|
||||
return size / comm_stats[comm_type] # unit: ns
|
||||
|
||||
|
||||
def estimate_instruction_time_costs(
|
||||
model_family: ModelFamily,
|
||||
model_path: str,
|
||||
num_layers: int, # model configuration, num transformers layer
|
||||
parallel_strategy: ParallelismConfig,
|
||||
hidden_dim: int,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
n_ppo_minibatches: int = 1,
|
||||
):
|
||||
comm_stats = default_communication_stats()
|
||||
op_stats = get_organized_op_stats(model_family, model_path, use_cache=True)
|
||||
|
||||
num_mp = parallel_strategy.model_parallel_size
|
||||
num_pp = parallel_strategy.pipeline_parallel_size
|
||||
num_dp = parallel_strategy.data_parallel_size
|
||||
num_gpus = num_dp * num_mp * num_pp
|
||||
|
||||
train_mbs = (
|
||||
batch_size / (2 * num_pp * num_dp * n_ppo_minibatches)
|
||||
if num_pp > 1
|
||||
else batch_size / (num_dp * n_ppo_minibatches)
|
||||
)
|
||||
gen_mbs = batch_size / (num_pp * num_dp)
|
||||
|
||||
# pprint.pprint(op_cost, indent=4)
|
||||
inst_keys = [
|
||||
"gen_fwd_0",
|
||||
"gen_fwd_1",
|
||||
"inf_fwd",
|
||||
"train_fwd",
|
||||
"train_bwd",
|
||||
"train_opt",
|
||||
]
|
||||
op_names = ["fwd_gen_0", "fwd_gen_1", "fwd", "fwd", "bwd", "opt"]
|
||||
inst_stats = {}
|
||||
|
||||
for inst_key, op_name in zip(inst_keys, op_names):
|
||||
mbs = train_mbs if "train" in inst_key else gen_mbs
|
||||
inst_stats[inst_key] = computation_instruction_time_cost(
|
||||
op_stats, op_name, num_layers, parallel_strategy, mbs, seq_len
|
||||
)
|
||||
|
||||
comm_type = "remote_send" if num_gpus // num_pp >= 8 else "local_send"
|
||||
|
||||
inst_stats["act_p2p"] = communication_instruction_time_cost(
|
||||
comm_stats, 2 * hidden_dim * train_mbs * seq_len, comm_type
|
||||
)
|
||||
inst_stats["grad_p2p"] = communication_instruction_time_cost(
|
||||
comm_stats, 2 * hidden_dim * train_mbs * seq_len, comm_type
|
||||
)
|
||||
inst_stats["gen_act_p2p"] = communication_instruction_time_cost(
|
||||
comm_stats, 2 * hidden_dim * gen_mbs, comm_type
|
||||
)
|
||||
return inst_stats
|
||||
|
||||
|
||||
def _estimate_rpc_time_cost(
|
||||
inst_stats,
|
||||
parallel_strategy: ParallelismConfig,
|
||||
model_interface_type: ModelInterfaceType,
|
||||
# model function call args
|
||||
num_gen_tokens: int,
|
||||
gradient_checkpointing: bool,
|
||||
n_ppo_minibatches: int = 1,
|
||||
):
|
||||
# TODO: improve/remove heuristic
|
||||
num_pp = parallel_strategy.pipeline_parallel_size
|
||||
num_dp = parallel_strategy.data_parallel_size
|
||||
num_mp = parallel_strategy.model_parallel_size
|
||||
if model_interface_type == ModelInterfaceType.INFERENCE:
|
||||
num_micro_batches = num_pp
|
||||
compute_cost = inst_stats["inf_fwd"] * (num_pp + num_micro_batches - 1)
|
||||
comm_cost = inst_stats["act_p2p"] * (num_pp + num_micro_batches - 2) * 2
|
||||
if num_mp > 1: # and num_pp * num_dp > 1:
|
||||
compute_cost = compute_cost * (1 - num_mp * 0.03)
|
||||
elif model_interface_type == ModelInterfaceType.TRAIN_STEP:
|
||||
# TODO: add reduce grads, add ppo micro batches
|
||||
num_micro_batches = num_pp * 2 if num_pp > 1 else 1
|
||||
compute_cost = (inst_stats["train_fwd"] + inst_stats["train_bwd"]) * (
|
||||
num_pp + num_micro_batches - 1
|
||||
) + inst_stats["train_opt"]
|
||||
if gradient_checkpointing:
|
||||
compute_cost += inst_stats["train_fwd"] * (num_pp + num_micro_batches - 1)
|
||||
comm_cost = (
|
||||
(inst_stats["grad_p2p"] + inst_stats["act_p2p"])
|
||||
* (num_pp + num_micro_batches - 2)
|
||||
* 2
|
||||
)
|
||||
compute_cost = compute_cost * n_ppo_minibatches
|
||||
comm_cost = comm_cost * n_ppo_minibatches
|
||||
if num_pp * num_dp <= 1:
|
||||
compute_cost = compute_cost * (1 - num_mp * 0.04)
|
||||
if num_mp > 1: # and num_pp * num_dp > 1:
|
||||
compute_cost = compute_cost * (1 - num_mp * 0.03)
|
||||
elif model_interface_type == ModelInterfaceType.GENERATE:
|
||||
num_micro_batches = num_pp
|
||||
num_gen_tokens = num_gen_tokens
|
||||
compute_cost = (
|
||||
inst_stats["gen_fwd_0"] * (num_pp + num_micro_batches - 1)
|
||||
+ inst_stats["gen_fwd_1"] * (num_gen_tokens - 1) * num_micro_batches
|
||||
)
|
||||
|
||||
if num_dp * num_mp > 1:
|
||||
compute_cost = compute_cost * (1 - min(num_dp * num_mp, 8) * 0.03)
|
||||
comm_cost = 0
|
||||
|
||||
# dirty heuristic
|
||||
if num_pp > 8:
|
||||
compute_cost = compute_cost * (1 + num_pp * 0.01)
|
||||
|
||||
# FIXME: disable comm cost for its not accurate
|
||||
comm_cost = 0
|
||||
|
||||
return compute_cost + comm_cost
|
||||
|
||||
|
||||
def estimate_rpc_time_cost(
|
||||
rpc: MFCDef,
|
||||
parallel_strategy: ParallelismConfig,
|
||||
bs: int,
|
||||
seq_len: int,
|
||||
num_gen_tokens: int = 256,
|
||||
gradient_checkpointing: bool = False,
|
||||
n_ppo_minibatches: int = 1,
|
||||
):
|
||||
# time unit: miliseconds
|
||||
# FIXME: n_ppo_minibatches > 1 will result in bad estimation
|
||||
# when batch size is large enough, n_ppo_minibatches > 1 will not affect the estimation
|
||||
n_ppo_minibatches = 1
|
||||
model_type = rpc.model_type
|
||||
model_path = rpc.model_path
|
||||
model_config = load_model_config(rpc.model_type._class, rpc.model_path)
|
||||
|
||||
inst_cost = estimate_instruction_time_costs(
|
||||
model_type,
|
||||
model_path,
|
||||
model_config.n_layers,
|
||||
parallel_strategy,
|
||||
model_config.hidden_dim,
|
||||
bs,
|
||||
seq_len,
|
||||
n_ppo_minibatches=n_ppo_minibatches,
|
||||
)
|
||||
return (
|
||||
_estimate_rpc_time_cost(
|
||||
inst_cost,
|
||||
parallel_strategy,
|
||||
rpc.interface_type,
|
||||
num_gen_tokens=num_gen_tokens,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
n_ppo_minibatches=n_ppo_minibatches,
|
||||
)
|
||||
) / 1e6
|
||||
|
||||
|
||||
def default_communication_stats(if_print=False):
|
||||
# use default comm stats of cluster
|
||||
r = dict(
|
||||
# between GPUs on the same node
|
||||
local_send=170, # unit: GB/s
|
||||
local_recv=170,
|
||||
# IB between GPUs on different nodes
|
||||
remote_send=20, # unit: GB/s
|
||||
remote_recv=20,
|
||||
)
|
||||
if if_print:
|
||||
print(r)
|
||||
return r
|
||||
|
||||
|
||||
def estimate_model_size(model_config: ReaLModelConfig):
|
||||
h = model_config.hidden_dim
|
||||
i = model_config.intermediate_dim
|
||||
v = model_config.vocab_size
|
||||
L = model_config.n_layers
|
||||
# for llama actor only
|
||||
n_params = 3 * v * h + (3 * h * i + 4 * h * h) * L
|
||||
return 2 * n_params
|
||||
|
||||
|
||||
def estimate_rpc_memory_cost(
|
||||
rpc: MFCDef,
|
||||
parallel_strategy: ParallelismConfig,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
offload: bool = False,
|
||||
gradient_checkpointing: bool = False,
|
||||
offload_optimizer: bool = False,
|
||||
n_ppo_minibatches: int = 1,
|
||||
num_gen_tokens: int = 128,
|
||||
):
|
||||
# TODO: improve heuristic
|
||||
interface_type = rpc.interface_type
|
||||
model_config = load_model_config(rpc.model_type._class, rpc.model_path)
|
||||
|
||||
h = model_config.hidden_dim
|
||||
i = model_config.intermediate_dim
|
||||
v = model_config.vocab_size
|
||||
s = seq_len
|
||||
gs = num_gen_tokens
|
||||
b = batch_size
|
||||
L = model_config.n_layers
|
||||
# for llama actor only
|
||||
n_params = 3 * v * h + (3 * h * i + 4 * h * h) * L
|
||||
param_mem = 2 * n_params
|
||||
grad_mem = 2 * n_params
|
||||
optimizer_mem = 20 * n_params if not offload_optimizer else 0
|
||||
|
||||
num_pp = parallel_strategy.pipeline_parallel_size
|
||||
num_mp = parallel_strategy.model_parallel_size
|
||||
num_dp = parallel_strategy.data_parallel_size
|
||||
# zero1, pp and mp divide evenly
|
||||
# enable sequence parallel
|
||||
if interface_type == ModelInterfaceType.TRAIN_STEP:
|
||||
# gradient checkpointing is always enabled for flash attn
|
||||
static_mem = (param_mem + grad_mem) // (num_pp * num_mp) + optimizer_mem // (
|
||||
num_pp * num_dp * num_mp
|
||||
)
|
||||
micro_bs = b // (2 * num_pp * num_dp) if num_pp > 0 else b // (num_dp)
|
||||
if gradient_checkpointing:
|
||||
active_mem = (micro_bs * s * h * num_pp * 2) // (num_pp * num_mp)
|
||||
else:
|
||||
# FIXME: calculate other memory entries
|
||||
active_mem = (micro_bs * s * h * num_pp * 2) * 2 * L // (num_pp * num_mp)
|
||||
return static_mem + active_mem, static_mem
|
||||
elif interface_type == ModelInterfaceType.INFERENCE:
|
||||
static_mem = int(2 * param_mem // (num_pp * num_mp))
|
||||
# if num_dp > 4:
|
||||
# static_mem = static_mem * 1.25
|
||||
if offload:
|
||||
return static_mem, 0 # assume offload
|
||||
else:
|
||||
return static_mem, static_mem
|
||||
elif interface_type == ModelInterfaceType.GENERATE:
|
||||
static_mem = int(2 * param_mem // (num_pp * num_mp))
|
||||
if num_dp > 4 and num_dp * num_mp * num_pp <= 16:
|
||||
static_mem = static_mem * 1.25
|
||||
if num_mp == 0 and num_pp == 0:
|
||||
static_mem = static_mem * 1.25
|
||||
active_mem = (
|
||||
2 * (2 * b * (gs + s) * h) * L // (num_pp * num_mp * num_dp)
|
||||
) # kv cache
|
||||
return static_mem + active_mem, static_mem
|
||||
|
||||
|
||||
def example(rpcs):
|
||||
if args.model_size == 7:
|
||||
n_nodes = 1
|
||||
elif args.model_size == 13:
|
||||
n_nodes = 2
|
||||
elif args.model_size == 34:
|
||||
n_nodes = 4
|
||||
elif args.model_size == 70:
|
||||
n_nodes = 8
|
||||
|
||||
expr_name = f"profile-s{args.model_size}p{n_nodes}m1d8"
|
||||
|
||||
rollout, inf, train = rpcs
|
||||
|
||||
bs = 128
|
||||
seq_len = 1024
|
||||
|
||||
p1 = ParallelismConfig(
|
||||
pipeline_parallel_size=1, model_parallel_size=4, data_parallel_size=8
|
||||
)
|
||||
rpc_cost = estimate_rpc_time_cost(
|
||||
train,
|
||||
p1,
|
||||
gradient_checkpointing=True,
|
||||
num_gen_tokens=896,
|
||||
bs=bs,
|
||||
seq_len=seq_len,
|
||||
n_ppo_minibatches=4,
|
||||
)
|
||||
mem_cost, static_mem = estimate_rpc_memory_cost(rollout, p1, bs, seq_len)
|
||||
print(f"{p1} rpc cost {rpc_cost:.2f} seconds mem cost {mem_cost/(1024**3):.2f} GB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run a profiling experiment.")
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--model_size",
|
||||
type=int,
|
||||
default=7,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
example(args)
|
|
@ -1,310 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
|
||||
import realhf.api.core.model_api as model_api
|
||||
import realhf.api.core.system_api as config_package
|
||||
import realhf.base.constants as constants
|
||||
import realhf.base.logging as logging
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.impl.model.utils.padding import unpad_input
|
||||
|
||||
logger = logging.getLogger("profile layers", "system")
|
||||
|
||||
|
||||
def make_layers(config: ReaLModelConfig, dtype, device):
|
||||
from realhf.impl.model.nn.real_llm_base import (
|
||||
OutputHead,
|
||||
ReaLModelBlock,
|
||||
VocabPositionEmbedding,
|
||||
)
|
||||
|
||||
embedding_layer = VocabPositionEmbedding(
|
||||
config,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
real_model_blocks = [
|
||||
ReaLModelBlock(
|
||||
config,
|
||||
layer_index=i,
|
||||
output_layernorm=(i == 1),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
for i in range(1)
|
||||
]
|
||||
head = OutputHead(
|
||||
config.hidden_dim,
|
||||
1 if config.is_critic else config.vocab_size,
|
||||
bias=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
layer_names = ["embedding_layer", "block_0", "head"]
|
||||
return [embedding_layer] + real_model_blocks + [head], layer_names
|
||||
|
||||
|
||||
class ProfileLayers:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
config: ReaLModelConfig,
|
||||
tokenizer: transformers.PreTrainedTokenizerFast = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.config = config
|
||||
self.backend_config = config_package.ModelBackend(
|
||||
type_="deepspeed",
|
||||
args=dict(
|
||||
optimizer_name="adam",
|
||||
optimizer_config=dict(lr=1e-5, weight_decay=0.0, betas=(0.9, 0.95)),
|
||||
warmup_steps_proportion=0.0,
|
||||
min_lr_ratio=0.0,
|
||||
zero_stage=1,
|
||||
bf16=False,
|
||||
),
|
||||
)
|
||||
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.layers, self.layer_names = make_layers(config, dtype, device)
|
||||
self.hidden_dim = config.hidden_dim
|
||||
self.head_dim = config.head_dim
|
||||
self.max_new_tokens = 128 # only useful in kv cache memory alloc
|
||||
self.min_new_tokens = 128
|
||||
|
||||
self.stats = defaultdict(list)
|
||||
self.num_layers = len(self.layers)
|
||||
|
||||
self.layers = [
|
||||
model_api.Model(name, layer, tokenizer, device=device, dtype=dtype)
|
||||
for layer, name in zip(self.layers, self.layer_names)
|
||||
]
|
||||
self.backend = model_api.make_backend(self.backend_config)
|
||||
ft_spec = model_api.FinetuneSpec(10, 100, 10)
|
||||
self.layers = [self.backend.initialize(layer, ft_spec) for layer in self.layers]
|
||||
self.stats = defaultdict(list)
|
||||
|
||||
def reset_stats(self):
|
||||
self.stats = defaultdict(list)
|
||||
|
||||
def insert_data_point(self, layer_name, name, bs, seq_len, time_ns):
|
||||
self.stats["layer_name"].append(layer_name)
|
||||
self.stats["op_name"].append(name)
|
||||
self.stats["bs"].append(bs)
|
||||
self.stats["seq_len"].append(seq_len)
|
||||
self.stats["time_ns"].append(time_ns)
|
||||
|
||||
@torch.no_grad()
|
||||
def fwd_gen(self, bs, seq_len):
|
||||
from realhf.impl.model.nn.real_llm_base import PipeCacheData, PipeTransferData
|
||||
|
||||
input_ids = torch.randint(
|
||||
0,
|
||||
self.config.vocab_size,
|
||||
(bs, seq_len),
|
||||
dtype=torch.long,
|
||||
device=self.device,
|
||||
)
|
||||
attention_mask = torch.ones_like(input_ids, device=self.device)
|
||||
# fwd_gen_0
|
||||
packed_input_ids, _, cu_seqlens, max_seqlen = unpad_input(
|
||||
input_ids, attention_mask
|
||||
)
|
||||
cu_seqlens = cu_seqlens.to(device=self.device)
|
||||
packed_input_ids = packed_input_ids.to(device=self.device)
|
||||
x = PipeTransferData(
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=int(max_seqlen),
|
||||
store_kv_cache=True,
|
||||
)
|
||||
ys = [PipeCacheData() for _ in range(self.num_layers)]
|
||||
ys[0].packed_input_ids = packed_input_ids
|
||||
|
||||
for layer_name, layer, y in zip(self.layer_names, self.layers, ys):
|
||||
st = time.monotonic_ns()
|
||||
x: PipeTransferData = layer.module(x, y)
|
||||
x.pp_input = x.pp_output
|
||||
torch.cuda.synchronize()
|
||||
self.insert_data_point(
|
||||
layer_name, "fwd_gen_0", bs, seq_len, time.monotonic_ns() - st
|
||||
)
|
||||
|
||||
prompt_logits = x.pp_output
|
||||
logits = prompt_logits[cu_seqlens[1:] - 1]
|
||||
input_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
cache_seqlens = input_lens.clone().to(dtype=torch.int32)
|
||||
layer_indices = range(len(ys))
|
||||
|
||||
for y, layer_idx in zip(ys[1:-1], layer_indices[1:-1]):
|
||||
assert (
|
||||
y.k_cache is not None
|
||||
and y.v_cache is not None
|
||||
and y.cache_seqlens is not None
|
||||
)
|
||||
kvcache_seqlen = max(
|
||||
max_seqlen + self.max_new_tokens,
|
||||
self.hidden_dim // self.head_dim + 10,
|
||||
)
|
||||
# fix of a flash attention bug
|
||||
k_cache = torch.zeros(
|
||||
(bs, kvcache_seqlen, *y.k_cache.shape[1:]),
|
||||
dtype=y.k_cache.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
v_cache = torch.zeros_like(k_cache)
|
||||
indices = (
|
||||
torch.arange(
|
||||
kvcache_seqlen,
|
||||
device=constants.current_device(),
|
||||
dtype=torch.long,
|
||||
)[None, :]
|
||||
< input_lens[:, None]
|
||||
)
|
||||
k_cache[indices] = y.k_cache
|
||||
v_cache[indices] = y.v_cache
|
||||
y.k_cache = k_cache
|
||||
y.v_cache = v_cache
|
||||
y.cache_seqlens = cache_seqlens
|
||||
x = PipeTransferData(store_kv_cache=True)
|
||||
ys[0].cache_seqlens = cache_seqlens
|
||||
|
||||
# fwd_gen_1
|
||||
new_tokens = torch.randint(
|
||||
0,
|
||||
self.config.vocab_size,
|
||||
(bs,),
|
||||
dtype=torch.long,
|
||||
device=self.device,
|
||||
)
|
||||
ys[0].packed_input_ids = new_tokens
|
||||
ys[0].packed_position_ids = None
|
||||
x.cu_seqlens = torch.arange(bs + 1, dtype=torch.int32, device=self.device)
|
||||
x.max_seqlen = 1
|
||||
for layer_name, layer, y in zip(self.layer_names, self.layers, ys):
|
||||
st = time.monotonic_ns()
|
||||
x = layer.module(x, y)
|
||||
x.pp_input = x.pp_output
|
||||
torch.cuda.synchronize()
|
||||
self.insert_data_point(
|
||||
layer_name, "fwd_gen_1", bs, seq_len, time.monotonic_ns() - st
|
||||
)
|
||||
|
||||
def fwd_bwd_opt(self, bs, seq_len):
|
||||
from realhf.impl.model.nn.real_llm_base import PipeCacheData, PipeTransferData
|
||||
|
||||
input_ids = torch.randint(
|
||||
0,
|
||||
self.config.vocab_size,
|
||||
(bs, seq_len),
|
||||
dtype=torch.long,
|
||||
device=self.device,
|
||||
)
|
||||
attention_mask = torch.ones_like(input_ids, device=self.device)
|
||||
packed_input_ids, _, cu_seqlens, max_seqlen = unpad_input(
|
||||
input_ids, attention_mask
|
||||
)
|
||||
cu_seqlens = cu_seqlens.to(device=self.device)
|
||||
packed_input_ids = packed_input_ids.to(device=self.device)
|
||||
x = PipeTransferData(
|
||||
cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, store_kv_cache=False
|
||||
)
|
||||
ys = [PipeCacheData() for _ in range(self.num_layers)]
|
||||
ys[0].packed_input_ids = packed_input_ids
|
||||
|
||||
for layer_name, layer, y in zip(self.layer_names, self.layers, ys):
|
||||
# fwd
|
||||
st = time.monotonic_ns()
|
||||
x: PipeTransferData = layer.module(x, y)
|
||||
torch.cuda.synchronize()
|
||||
self.insert_data_point(
|
||||
layer_name, "fwd", bs, seq_len, time.monotonic_ns() - st
|
||||
)
|
||||
# bwd
|
||||
r = torch.rand(
|
||||
*x.pp_output.shape,
|
||||
device=x.pp_output.device,
|
||||
dtype=x.pp_output.dtype,
|
||||
)
|
||||
loss = torch.max(x.pp_output * r)
|
||||
st = time.monotonic_ns()
|
||||
layer.module.backward(loss)
|
||||
torch.cuda.synchronize()
|
||||
self.insert_data_point(
|
||||
layer_name, "bwd", bs, seq_len, time.monotonic_ns() - st
|
||||
)
|
||||
# opt
|
||||
st = time.monotonic_ns()
|
||||
layer.module.step()
|
||||
torch.cuda.synchronize()
|
||||
self.insert_data_point(
|
||||
layer_name, "opt", bs, seq_len, time.monotonic_ns() - st
|
||||
)
|
||||
x.pp_input = x.pp_output.clone().detach()
|
||||
|
||||
def make_dataframe_and_print(self):
|
||||
df = pd.DataFrame(self.stats)
|
||||
logger.info(f"Current Stats: \nstr{df}")
|
||||
|
||||
def dump_stats(self, world_size):
|
||||
rank = dist.get_rank()
|
||||
# dump full stats
|
||||
dump_dir = os.path.join(
|
||||
constants.PROFILER_CACHE_PATH,
|
||||
"layer_stats",
|
||||
)
|
||||
dump_path = os.path.join(
|
||||
dump_dir, self.model_name, f"layer-stats_{world_size}_{rank}.pkl"
|
||||
)
|
||||
os.makedirs(os.path.dirname(dump_path), exist_ok=True)
|
||||
|
||||
with open(dump_path, "wb") as f:
|
||||
df = pd.DataFrame(self.stats)
|
||||
pickle.dump(df, f)
|
||||
|
||||
|
||||
def make_profile_layers(
|
||||
device: torch.device,
|
||||
model_path: str,
|
||||
model_name: str,
|
||||
dtype: Optional[str] = None,
|
||||
hf_model_type: str = "llama",
|
||||
):
|
||||
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
||||
|
||||
if dtype == "fp16" or dtype == None:
|
||||
dtype = torch.float16
|
||||
elif dtype == "bf16":
|
||||
dtype = torch.bfloat16
|
||||
elif dtype == "fp32":
|
||||
dtype == torch.float32
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported dtype {dtype}")
|
||||
tokenizer = None
|
||||
config: ReaLModelConfig = getattr(ReaLModel, f"config_from_{hf_model_type}")(
|
||||
model_path=model_path,
|
||||
)
|
||||
if tokenizer is None:
|
||||
tokenizer = model_api.load_hf_tokenizer(model_path)
|
||||
|
||||
profile_layers = ProfileLayers(
|
||||
model_name, config, tokenizer=tokenizer, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
return profile_layers
|
|
@ -1,272 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import dataclasses
|
||||
import itertools
|
||||
import json
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import *
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import realhf.api.core.system_api as system_api
|
||||
import realhf.base.constants as constants
|
||||
import realhf.base.topology as topology
|
||||
from realhf.api.core.config import ModelFamily, ModelName
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
from realhf.base.topology import decompose_to_three_factors
|
||||
|
||||
|
||||
def bcast_cost(
|
||||
param_size: float, bw: float, src: int, dsts: List[int], n_nodes_per_gpu=8
|
||||
):
|
||||
src_node = src // n_nodes_per_gpu
|
||||
dst_nodes = [dst // n_nodes_per_gpu for dst in dsts]
|
||||
if src_node == dst_nodes[0] and all(
|
||||
dst_node == dst_nodes[0] for dst_node in dst_nodes
|
||||
):
|
||||
return param_size * 2 * 8 / (1800 * 1024**3)
|
||||
# return 0.0
|
||||
else:
|
||||
# param size is in float16, bw is in Gbps
|
||||
return param_size * 2 * 8 / (bw * 1024**3) * len(set(dst_nodes))
|
||||
|
||||
|
||||
def compute_cost(
|
||||
world_size: int,
|
||||
from_model_name: ModelName,
|
||||
to_model_name: ModelName,
|
||||
from_topo: topology.ProcessTopology,
|
||||
to_topo: topology.ProcessTopology,
|
||||
model_config: ReaLModelConfig,
|
||||
bw: float, # Gbps
|
||||
set_interval_cost: float,
|
||||
) -> int:
|
||||
from realhf.impl.model.comm.param_realloc import (
|
||||
ParamReallocInfo,
|
||||
ReparallelizeReceiverStep,
|
||||
ReparallelizeSenderStep,
|
||||
_create_param_realloc_groups,
|
||||
_derive_reparallelize_comm_plan,
|
||||
)
|
||||
|
||||
param_sync_groups = {}
|
||||
param_sync_src_ranks = {}
|
||||
param_sync_dst_ranks = {}
|
||||
msid2mwid = {}
|
||||
for i in range(from_topo.world_size()):
|
||||
msid2mwid[
|
||||
system_api.ModelShardID.from_parallelism_rank(from_model_name, from_topo, i)
|
||||
] = i
|
||||
for i in range(to_topo.world_size()):
|
||||
msid2mwid[
|
||||
system_api.ModelShardID.from_parallelism_rank(to_model_name, to_topo, i)
|
||||
] = (i + world_size - to_topo.world_size())
|
||||
_create_param_realloc_groups(
|
||||
from_topo,
|
||||
to_topo,
|
||||
from_model_name,
|
||||
to_model_name,
|
||||
msid2mwid,
|
||||
param_sync_groups,
|
||||
param_sync_src_ranks,
|
||||
param_sync_dst_ranks,
|
||||
)
|
||||
pg_info = ParamReallocInfo(
|
||||
param_sync_groups,
|
||||
param_sync_src_ranks,
|
||||
param_sync_dst_ranks,
|
||||
)
|
||||
comm_plan = _derive_reparallelize_comm_plan(
|
||||
from_model_name,
|
||||
to_model_name,
|
||||
from_topo,
|
||||
to_topo,
|
||||
model_config,
|
||||
model_config,
|
||||
pg_info,
|
||||
)
|
||||
|
||||
# Run boradcast!
|
||||
max_cost = max_comm_volume = max_bcast_cnt = 0
|
||||
for _rank in range(world_size):
|
||||
cost = comm_volume = bcast_cnt = 0
|
||||
for step in comm_plan:
|
||||
if isinstance(step, ReparallelizeReceiverStep) and step.rank == _rank:
|
||||
if step.rank != step.src:
|
||||
cost += bcast_cost(step.param_size, bw, step.src, step.dst_ranks)
|
||||
comm_volume += step.param_size
|
||||
bcast_cnt += 1
|
||||
cost += set_interval_cost
|
||||
if isinstance(step, ReparallelizeSenderStep) and step.rank == _rank:
|
||||
if step.group is not None:
|
||||
cost += bcast_cost(step.param_size, bw, step.rank, step.dst_ranks)
|
||||
bcast_cnt += 1
|
||||
max_cost = max(max_cost, cost)
|
||||
max_comm_volume = max(max_comm_volume, comm_volume)
|
||||
max_bcast_cnt = max(max_bcast_cnt, bcast_cnt)
|
||||
|
||||
return max_cost
|
||||
|
||||
|
||||
def dump_table(
|
||||
n_nodes: int,
|
||||
model_family: ModelFamily,
|
||||
model_path: str,
|
||||
rank: int = 0,
|
||||
parallel: int = 1,
|
||||
):
|
||||
from_model_name = ModelName("actor", 0)
|
||||
to_model_name = ModelName("actor", 1)
|
||||
|
||||
def hash_tuple_into_str(t) -> str:
|
||||
return ",".join([str(i) for i in t])
|
||||
|
||||
import tqdm
|
||||
|
||||
res = {}
|
||||
device_mesh_sizes = [4] + [8 * i for i in range(1, n_nodes + 1)]
|
||||
space = list(itertools.product(device_mesh_sizes, device_mesh_sizes))
|
||||
sub_space = space[rank::parallel]
|
||||
# for a, b in set(small_spaces + large_spaces):
|
||||
for a, b in sub_space:
|
||||
mtik = time.perf_counter()
|
||||
all_configs = list(
|
||||
itertools.product(
|
||||
decompose_to_three_factors(a), decompose_to_three_factors(b)
|
||||
)
|
||||
)
|
||||
all_configs = list(filter(lambda x: x[0][1] <= 8 and x[1][1] <= 8, all_configs))
|
||||
all_configs = list(filter(lambda x: x[0][2] <= 8 and x[1][2] <= 8, all_configs))
|
||||
all_configs = list(
|
||||
filter(
|
||||
lambda x: x[0][1] in [1, 2, 4, 8] and x[1][1] in [1, 2, 4, 8],
|
||||
all_configs,
|
||||
)
|
||||
)
|
||||
all_configs = list(
|
||||
filter(lambda x: x[0][0] <= 16 and x[1][0] <= 16, all_configs)
|
||||
)
|
||||
all_configs = list(
|
||||
filter(
|
||||
lambda x: x[0][1] % x[1][1] == 0 or x[1][1] % x[0][1] == 0,
|
||||
all_configs,
|
||||
)
|
||||
)
|
||||
for config_id, (from_pp_mp_dp, to_pp_mp_dp) in tqdm.tqdm(
|
||||
enumerate(all_configs)
|
||||
):
|
||||
world_size = max(a, b)
|
||||
|
||||
from_topo = topology.PipeDataModelParallelTopology(
|
||||
*from_pp_mp_dp, False, False
|
||||
)
|
||||
to_topo = topology.PipeDataModelParallelTopology(*to_pp_mp_dp, False, False)
|
||||
assert world_size >= from_topo.world_size()
|
||||
assert world_size >= to_topo.world_size()
|
||||
|
||||
from realhf.search_engine.utils import load_model_config
|
||||
|
||||
mconfig = load_model_config(model_family._class, model_path)
|
||||
|
||||
cost = compute_cost(
|
||||
world_size,
|
||||
from_model_name,
|
||||
to_model_name,
|
||||
from_topo,
|
||||
to_topo,
|
||||
mconfig,
|
||||
bw=200.0,
|
||||
set_interval_cost=0.03,
|
||||
)
|
||||
res[
|
||||
hash_tuple_into_str((model_family.size, *from_pp_mp_dp, *to_pp_mp_dp))
|
||||
] = int(cost * 1000 * 1000)
|
||||
print(
|
||||
f"Time for model size {model_family.size} {a} -> {b} {rank}/{parallel}: "
|
||||
f"{time.perf_counter() - mtik:.4f}, num res entries {len(res)}"
|
||||
)
|
||||
|
||||
print(f"Rank {rank} of model {model_family} finished, res size {len(res)}.")
|
||||
|
||||
import os
|
||||
import pickle
|
||||
|
||||
dump_path = os.path.join(constants.PROFILER_CACHE_PATH, "param_realloc")
|
||||
fn = f"prtc_{model_family}_n{n_nodes}_{rank}_{parallel}.pkl"
|
||||
if not os.path.exists(dump_path):
|
||||
os.makedirs(dump_path, exist_ok=True)
|
||||
with open(os.path.join(dump_path, fn), "wb") as f:
|
||||
pickle.dump(res, f)
|
||||
print(f"dumped table with {len(res)} entries to {model_family}-{rank}-{parallel}.")
|
||||
|
||||
|
||||
def dump_table_parallel(
|
||||
n_nodes: int,
|
||||
model_family_to_path: Dict[ModelFamily, str],
|
||||
parallel: int = 4,
|
||||
):
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
rq = mp.Queue()
|
||||
ps = []
|
||||
for model_family, model_path in model_family_to_path.items():
|
||||
for rank in range(parallel):
|
||||
ps.append(
|
||||
mp.Process(
|
||||
target=dump_table,
|
||||
args=(n_nodes, model_family, model_path, rank, parallel),
|
||||
)
|
||||
)
|
||||
|
||||
for p in ps:
|
||||
p.start()
|
||||
|
||||
for p in ps:
|
||||
p.join()
|
||||
|
||||
|
||||
def merge_tables(
|
||||
n_nodes: int,
|
||||
model_family_to_path: Dict[ModelFamily, str],
|
||||
parallel: int = 4,
|
||||
):
|
||||
import os
|
||||
import pickle
|
||||
|
||||
res_path = os.path.join(constants.PROFILER_CACHE_PATH, "param_realloc")
|
||||
|
||||
for model_family in model_family_to_path.keys():
|
||||
prefix = f"prtc_{model_family}_n{n_nodes}"
|
||||
|
||||
r = {}
|
||||
counter = 0
|
||||
for fn in os.listdir(res_path):
|
||||
if fn.endswith(".pkl") and fn.startswith(prefix):
|
||||
counter += 1
|
||||
path = os.path.join(res_path, fn)
|
||||
with open(path, "rb") as f:
|
||||
r.update(pickle.load(f))
|
||||
os.remove(path)
|
||||
if counter < parallel:
|
||||
raise RuntimeError(
|
||||
"missing sub-tables, probably some sub-processes failed "
|
||||
"during param realloc time cost estimation."
|
||||
)
|
||||
with open(os.path.join(res_path, f"{prefix}.pkl"), "wb") as f:
|
||||
pickle.dump(r, f)
|
||||
print(f"merged parallel tables into {prefix}.pkl, total entries {len(r)}")
|
||||
|
||||
|
||||
def estimate_param_realloc_time_cost(
|
||||
n_nodes: int,
|
||||
model_family_to_path: Dict[ModelFamily, str],
|
||||
parallel: int = 4,
|
||||
):
|
||||
dump_table_parallel(n_nodes, model_family_to_path, parallel)
|
||||
merge_tables(n_nodes, model_family_to_path, parallel)
|
|
@ -1,232 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
import argparse
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import pprint
|
||||
import re
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import realhf._C.mdm_search as mdm_search
|
||||
except ModuleNotFoundError:
|
||||
mdm_search = None
|
||||
|
||||
import realhf.base.constants as constants
|
||||
from realhf.api.cli_args import ModelTrainEvalConfig, ParallelismConfig
|
||||
from realhf.api.core.config import ModelInterfaceType
|
||||
from realhf.api.core.dfg import MFCDef
|
||||
from realhf.api.quickstart.device_mesh import DeviceMesh, RPCAllocation
|
||||
from realhf.api.quickstart.search import RPCExecution
|
||||
|
||||
|
||||
def search_rpc_allocations(
|
||||
device_mesh: DeviceMesh,
|
||||
rpcs: List[MFCDef],
|
||||
num_gen_tokens: int = 256,
|
||||
n_ppo_minibatches: int = 1,
|
||||
seq_len: int = 256,
|
||||
gradient_checkpointing: bool = True,
|
||||
use_cache: bool = False,
|
||||
) -> List[RPCAllocation]:
|
||||
from realhf.search_engine.enumerate import build_graph
|
||||
from realhf.search_engine.estimate import get_param_realloc_stats
|
||||
|
||||
from_file = os.environ.get("REAL_IS_REMOTE", "0") == "1"
|
||||
dump_dir = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"device_mapping.pkl",
|
||||
)
|
||||
log_dir = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"device_mapping",
|
||||
)
|
||||
rs_dir = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"raw_search_result",
|
||||
)
|
||||
rpc_exe_dir = os.path.join(
|
||||
constants.LOG_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"rpc_exe_info",
|
||||
)
|
||||
|
||||
if from_file or (use_cache and os.path.exists(dump_dir)):
|
||||
with open(dump_dir, "r") as f:
|
||||
s = json.load(f)
|
||||
rpc_allocs = [RPCAllocation.from_dict(d) for d in s]
|
||||
return rpc_allocs
|
||||
else:
|
||||
os.makedirs(os.path.dirname(dump_dir), exist_ok=True)
|
||||
|
||||
n_nodes = device_mesh.n_nodes
|
||||
table = {}
|
||||
for rpc in rpcs:
|
||||
print(f"Getting param realloc stats for {rpc.model_type} at {rpc.model_path}")
|
||||
t = get_param_realloc_stats(rpc.model_type, rpc.model_path, n_nodes, True)
|
||||
table.update(t)
|
||||
|
||||
rpc_exe_list = make_rpc_exe_list(
|
||||
rpcs,
|
||||
device_mesh,
|
||||
num_gen_tokens=num_gen_tokens,
|
||||
n_ppo_minibatches=n_ppo_minibatches,
|
||||
seq_len=seq_len,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
log_dir=rpc_exe_dir,
|
||||
if_print=False,
|
||||
)
|
||||
graph = build_graph(rpcs, 5, 1, if_print=False)
|
||||
model_size_dict = make_model_size_dict(rpcs, if_print=False)
|
||||
|
||||
n_nodes = device_mesh.n_nodes
|
||||
search_time = 120
|
||||
|
||||
rs: List[Dict[str, List]] = mdm_search.multi_mcmc_search(
|
||||
rpcs,
|
||||
rpc_exe_list,
|
||||
graph,
|
||||
table,
|
||||
model_size_dict,
|
||||
0.001, # beta min
|
||||
0.002, # beta max
|
||||
0.001, # beta step
|
||||
search_time, # time limit for each search
|
||||
1, # repeat
|
||||
)
|
||||
if not from_file:
|
||||
with open(rs_dir, "w") as f:
|
||||
|
||||
pprint.pprint(rs, stream=f)
|
||||
|
||||
r: Dict[str, Dict[str, Any]] = rs[-1]
|
||||
pprint.pprint(r)
|
||||
|
||||
rpc_name_to_rpcs = {rpc.name: rpc for rpc in rpcs}
|
||||
rpc_allocs = []
|
||||
for rpc_name, alloc_info in r.items():
|
||||
if rpc_name in ["end_time", "mem_cost"]:
|
||||
continue
|
||||
# rpc = rpc_dict[rpc_name]
|
||||
rpc = rpc_name_to_rpcs[rpc_name]
|
||||
parallel = ParallelismConfig(
|
||||
pipeline_parallel_size=alloc_info["num_pp"],
|
||||
data_parallel_size=alloc_info["num_dp"],
|
||||
model_parallel_size=alloc_info["num_mp"],
|
||||
use_sequence_parallel=(
|
||||
alloc_info["num_mp"] > 1
|
||||
and rpc.interface_type == ModelInterfaceType.TRAIN_STEP
|
||||
),
|
||||
)
|
||||
sub_device_mesh = DeviceMesh(
|
||||
n_nodes=device_mesh.n_nodes,
|
||||
n_gpus_per_node=device_mesh.n_gpus_per_node,
|
||||
mapping=alloc_info["device_mesh_mapping"],
|
||||
name=alloc_info["device_mesh_name"],
|
||||
global_mesh_name=device_mesh.global_mesh_name,
|
||||
)
|
||||
rpc_alloc = RPCAllocation(
|
||||
rpc=rpc,
|
||||
device_mesh=sub_device_mesh,
|
||||
parallel=parallel,
|
||||
)
|
||||
rpc_allocs.append(rpc_alloc)
|
||||
|
||||
if not from_file:
|
||||
with open(dump_dir, "w") as f:
|
||||
json.dump([rpc_alloc.to_dict() for rpc_alloc in rpc_allocs], f, indent=4)
|
||||
with open(log_dir, "w") as f:
|
||||
|
||||
pprint.pprint(rpc_allocs, stream=f)
|
||||
|
||||
return rpc_allocs
|
||||
|
||||
|
||||
def make_rpc_exe_list(
|
||||
rpcs: List[MFCDef],
|
||||
device_mesh: DeviceMesh,
|
||||
num_gen_tokens: int,
|
||||
n_ppo_minibatches: int,
|
||||
seq_len: int,
|
||||
gradient_checkpointing: bool,
|
||||
if_print: bool = False,
|
||||
log_dir: Optional[str] = None,
|
||||
) -> List[RPCExecution]:
|
||||
from realhf.search_engine.enumerate import enumerate_rpc_executions
|
||||
|
||||
rpc_exe_list = []
|
||||
log_flag = False
|
||||
for rpc in rpcs:
|
||||
# real_model_config = load_model_config(rpc)
|
||||
feasible = enumerate_rpc_executions(
|
||||
rpc,
|
||||
device_mesh,
|
||||
seq_len=seq_len,
|
||||
num_gen_tokens=num_gen_tokens,
|
||||
n_ppo_minibatches=n_ppo_minibatches,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
rpc_exe_list.extend(feasible)
|
||||
|
||||
if log_dir is not None:
|
||||
mode = "w" if not log_flag else "a"
|
||||
with open(log_dir, mode) as f:
|
||||
f.write(f"{rpc.name} feasible: {len(feasible)}\n")
|
||||
feasible.sort(key=lambda x: x.time_cost)
|
||||
# feasible = feasible[:30]
|
||||
for i, rpc_exe in enumerate(feasible):
|
||||
f.write(
|
||||
f"{i}: time_cost: {rpc_exe.time_cost} ms, {rpc_exe.time_cost} "
|
||||
f"sub_device_mesh: {rpc_exe.device_mesh}, "
|
||||
f"parallel_strategy: {rpc_exe.parallel_strategy}, "
|
||||
f"mem_cost: {rpc_exe.mem/(1024*1024*1024):02f} GB, "
|
||||
f"static_mem_cost: {rpc_exe.static_mem/(1024*1024*1024):02f} GB\n"
|
||||
)
|
||||
f.write("\n")
|
||||
log_flag = True
|
||||
|
||||
if if_print:
|
||||
print(f"{rpc.name} feasible: {len(feasible)}")
|
||||
feasible.sort(key=lambda x: x.time_cost)
|
||||
# feasible = feasible[:10]
|
||||
for i, rpc_exe in enumerate(feasible):
|
||||
print(
|
||||
f"{i}: time_cost: {rpc_exe.time_cost} ms, "
|
||||
f"sub_device_mesh: {rpc_exe.device_mesh}, "
|
||||
f"parallel_strategy: {rpc_exe.parallel_strategy}, "
|
||||
f"mem_cost: {rpc_exe.mem/(1024*1024*1024):02f} GB, "
|
||||
f"static_mem_cost: {rpc_exe.static_mem/(1024*1024*1024):02f} GB"
|
||||
)
|
||||
|
||||
return rpc_exe_list
|
||||
|
||||
|
||||
def make_model_size_dict(rpcs: List[MFCDef], if_print: bool = False) -> Dict[str, int]:
|
||||
model_size_dict = {}
|
||||
|
||||
for rpc in rpcs:
|
||||
if rpc.model_name.role in model_size_dict:
|
||||
continue
|
||||
# model_configs = load_model_config(rpc)
|
||||
# model_size_dict[rpc.model_name.role] = estimate_model_size(real_model_config)
|
||||
model_size_dict[rpc.model_name.role] = rpc.model_type.size
|
||||
|
||||
if if_print:
|
||||
print(
|
||||
f"model_name: {rpc.model_name.role}, "
|
||||
f"model_size: {rpc.model_type.size}"
|
||||
)
|
||||
return model_size_dict
|
|
@ -1,28 +0,0 @@
|
|||
# Copyright 2025 Ant Group Inc.
|
||||
# Copyright 2024 Wei Fu & Zhiyu Mei
|
||||
# Licensed under the Apache License, Version 2.0 (the "License").
|
||||
|
||||
from realhf.api.core.model_api import ReaLModelConfig
|
||||
|
||||
|
||||
def find_factors(n):
|
||||
factors = []
|
||||
for i in range(1, n + 1):
|
||||
if n % i == 0:
|
||||
factors.append(i)
|
||||
return factors
|
||||
|
||||
|
||||
def make_stats_key(rpc_name, bs, seq_len):
|
||||
return f"{rpc_name}|{bs}|{seq_len}"
|
||||
|
||||
|
||||
def parse_stats_key(key):
|
||||
rpc_name, bs, seq_len = key.split("|")
|
||||
return rpc_name, int(bs), int(seq_len)
|
||||
|
||||
|
||||
def load_model_config(model_class: str, model_path: str) -> ReaLModelConfig:
|
||||
from realhf.impl.model.nn.real_llm_api import ReaLModel
|
||||
|
||||
return getattr(ReaLModel, f"config_from_{model_class}")(model_path=model_path)
|
|
@ -478,6 +478,8 @@ def run_ray_worker(
|
|||
|
||||
# NOTE: Importing these will initialize DeepSpeed/CUDA devices.
|
||||
# profiler.import_profiler_registers()
|
||||
if worker_type != "master_worker":
|
||||
# For master_worker, there could be errors while importing and it is not necessary.
|
||||
import realhf.impl.dataset
|
||||
import realhf.impl.model
|
||||
import realhf.system
|
||||
|
|
|
@ -24,6 +24,17 @@ SCATTER_GROUPS = {}
|
|||
logger = logging.getLogger("data_manager", "system")
|
||||
|
||||
|
||||
def find_minimal_superset(A: List[Set[int]], B: Set[int]) -> Set[int] | None:
|
||||
min_size = float("inf")
|
||||
result = None
|
||||
for S in A:
|
||||
if B.issubset(S):
|
||||
if len(S) < min_size:
|
||||
min_size = len(S)
|
||||
result = S
|
||||
return result
|
||||
|
||||
|
||||
class DataManager:
|
||||
|
||||
def __init__(
|
||||
|
@ -52,7 +63,7 @@ class DataManager:
|
|||
|
||||
mw_ranks: Dict[ModelName, List[int]] = {}
|
||||
|
||||
# Stores the dp_head (i.e., mp_rank=0, pp_rank=-1) ranks given a model_name.
|
||||
# Stores the dp_head (i.e., tp_rank=0, pp_rank=-1) ranks given a model_name.
|
||||
mw_dp_head_ranks: Dict[ModelName, List[int]] = defaultdict(list)
|
||||
|
||||
assert msid2mwid is not None
|
||||
|
@ -67,7 +78,7 @@ class DataManager:
|
|||
topo,
|
||||
msid2mwid,
|
||||
pipe=topo.get_dim("pipe") - 1,
|
||||
model=0,
|
||||
tensor=0,
|
||||
)
|
||||
dp_size = topo.get_dim("data")
|
||||
for dp_i in range(dp_size):
|
||||
|
@ -87,7 +98,8 @@ class DataManager:
|
|||
list(ranks), backend="nccl" if constants.use_cuda() else "gloo"
|
||||
)
|
||||
|
||||
scatter_ranks = tuple(sorted(set([ranks[0]] + mw_ranks[dst])))
|
||||
for rank in ranks:
|
||||
scatter_ranks = tuple(sorted(set([rank] + mw_ranks[dst])))
|
||||
SCATTER_GROUPS[scatter_ranks] = new_or_get_group(
|
||||
list(scatter_ranks),
|
||||
backend="nccl" if constants.use_cuda() else "gloo",
|
||||
|
@ -228,7 +240,20 @@ class DataManager:
|
|||
def _run_gather(
|
||||
self, step: RedistribStep, data_infos: Dict[Hashable, SequenceSample]
|
||||
):
|
||||
if dist.get_rank() not in step.srcs:
|
||||
# It's possible that some DP rank is not involved.
|
||||
# Create dummpy data to make the gather happy.
|
||||
gather_ranks = find_minimal_superset(
|
||||
[set(k) for k in GATHER_GROUPS.keys()], set(step.srcs)
|
||||
)
|
||||
assert gather_ranks is not None, (
|
||||
set(step.srcs),
|
||||
[set(k) for k in GATHER_GROUPS.keys()],
|
||||
)
|
||||
gather_ranks = sorted(list(gather_ranks))
|
||||
|
||||
pgroup = GATHER_GROUPS[tuple(gather_ranks)]
|
||||
|
||||
if dist.get_rank() not in gather_ranks:
|
||||
return
|
||||
|
||||
maxlen = 0
|
||||
|
@ -249,11 +274,13 @@ class DataManager:
|
|||
torch.empty(
|
||||
maxlen, device=constants.current_device(), dtype=torch.float32
|
||||
)
|
||||
for _ in range(len(step.srcs))
|
||||
for _ in gather_ranks
|
||||
]
|
||||
is_valid_gather = [i in step.srcs for i in gather_ranks]
|
||||
else:
|
||||
gather_list = None
|
||||
|
||||
if dist.get_rank() in step.srcs:
|
||||
local_gather_idx = step.srcs.index(dist.get_rank())
|
||||
ids = step.ids[local_gather_idx]
|
||||
for i in ids:
|
||||
|
@ -267,18 +294,27 @@ class DataManager:
|
|||
]
|
||||
)
|
||||
data = self._pad_data(data, maxlen)
|
||||
else:
|
||||
data = torch.empty(
|
||||
maxlen, device=constants.current_device(), dtype=torch.float32
|
||||
)
|
||||
|
||||
dist.gather(
|
||||
data,
|
||||
gather_list,
|
||||
dst=step.root,
|
||||
group=GATHER_GROUPS[tuple(sorted(step.srcs))],
|
||||
group=pgroup,
|
||||
)
|
||||
|
||||
if dist.get_rank() != step.root:
|
||||
del data
|
||||
return
|
||||
|
||||
for ids, buf in zip(step.ids, gather_list):
|
||||
cnt = 0
|
||||
for is_valid, buf in zip(is_valid_gather, gather_list):
|
||||
if not is_valid:
|
||||
continue
|
||||
ids = step.ids[cnt]
|
||||
offset = 0
|
||||
for i in ids:
|
||||
for key in step.keys:
|
||||
|
@ -302,6 +338,9 @@ class DataManager:
|
|||
self.storage[i].update_(s)
|
||||
else:
|
||||
self.storage[i] = s
|
||||
cnt += 1
|
||||
assert cnt == len(step.srcs) == len(step.ids)
|
||||
del data
|
||||
|
||||
def _run_scatter(
|
||||
self, step: RedistribStep, data_infos: Dict[Hashable, SequenceSample]
|
||||
|
|
|
@ -27,7 +27,7 @@ class FunctionExecutor:
|
|||
rpcs: List[MFCDef],
|
||||
msid2mwid: Dict[ModelShardID, int],
|
||||
stream: NameResolvingRequestClient,
|
||||
buffer: AsyncIOSequenceBuffer,
|
||||
buffers: List[AsyncIOSequenceBuffer],
|
||||
model_topos: Dict[str, ProcessTopology],
|
||||
model_configs: Dict[str, None | ReaLModelConfig],
|
||||
ctrl: RPCCorountineControl,
|
||||
|
@ -58,14 +58,15 @@ class FunctionExecutor:
|
|||
model_topos=model_topos,
|
||||
model_configs=model_configs,
|
||||
ctrl=ctrl,
|
||||
buffer=buffer,
|
||||
buffers=buffers,
|
||||
redistrib_planner=self.redistrib_planner,
|
||||
summary_writer=summary_writer,
|
||||
)
|
||||
self.func_calls[rpc.name] = func_call
|
||||
|
||||
self.stream = stream
|
||||
self.buffer = buffer
|
||||
self.buffers = buffers
|
||||
self.buffer_id = 0
|
||||
|
||||
self.data_loading_dp_idx = -1
|
||||
self.shuffle_dataset = shuffle_dataset
|
||||
|
@ -111,18 +112,17 @@ class FunctionExecutor:
|
|||
|
||||
self.ctrl.ids_to_clear.clear()
|
||||
|
||||
async def load_data(self):
|
||||
buffer = self.buffer
|
||||
async def load_data(self, buffer_id: int):
|
||||
buffer = self.buffers[buffer_id]
|
||||
ctrl = self.ctrl
|
||||
|
||||
received_ids = set()
|
||||
|
||||
while self.buffer.size < max(rpc.n_seqs for rpc in self.rpcs):
|
||||
|
||||
while buffer.size < max(rpc.n_seqs for rpc in self.rpcs):
|
||||
resps = await self.stream.call_async(
|
||||
handlers=[f"__data{dp_idx}__" for dp_idx in range(self.src_dp_size)],
|
||||
handle_type="fetch",
|
||||
datas=[None for _ in range(self.src_dp_size)],
|
||||
datas=[buffer_id for _ in range(self.src_dp_size)],
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
|
@ -182,10 +182,13 @@ class FunctionExecutor:
|
|||
logger.info("Waiting for the finish of the execution graph.")
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
tasks = [loop.create_task(fc.run()) for fc in self.func_calls.values()] + [
|
||||
tasks = [
|
||||
loop.create_task(fc.run(self.buffer_id)) for fc in self.func_calls.values()
|
||||
] + [
|
||||
loop.create_task(self.flush_calls()),
|
||||
loop.create_task(self.load_data()),
|
||||
loop.create_task(self.load_data(self.buffer_id)),
|
||||
loop.create_task(self.finish_traverse()),
|
||||
]
|
||||
|
||||
loop.run_until_complete(asyncio.gather(*tasks))
|
||||
self.buffer_id = (self.buffer_id + 1) % len(self.buffers)
|
||||
|
|
|
@ -1,14 +1,111 @@
|
|||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
|
||||
from realhf.api.cli_args import SGLangConfig
|
||||
from realhf.api.core.system_api import GenerationServer as GenerationServerConfig
|
||||
from realhf.base import gpu_utils, logging, name_resolve, names, network, seeding
|
||||
from realhf.base import (
|
||||
gpu_utils,
|
||||
logging,
|
||||
name_resolve,
|
||||
names,
|
||||
network,
|
||||
pkg_version,
|
||||
seeding,
|
||||
)
|
||||
from realhf.base.cluster import spec as cluster_spec
|
||||
from realhf.system.worker_base import PollResult, Worker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def execute_shell_command(command: str) -> subprocess.Popen:
|
||||
"""
|
||||
Execute a shell command and return its process handle.
|
||||
"""
|
||||
# Replace newline continuations and split the command string.
|
||||
command = command.replace("\\\n", " ").replace("\\", " ")
|
||||
parts = command.split()
|
||||
return subprocess.Popen(parts, text=True, stderr=subprocess.STDOUT)
|
||||
|
||||
|
||||
def launch_server_cmd(command: str, port: int = 30000):
|
||||
"""
|
||||
Launch the server using the given command.
|
||||
If no port is specified, a free port is reserved.
|
||||
"""
|
||||
p = Path(os.path.dirname(__file__))
|
||||
patch_path = str(
|
||||
p.parent.parent
|
||||
/ "patch"
|
||||
/ "sglang"
|
||||
/ f"v{pkg_version.get_version('sglang')}.patch"
|
||||
)
|
||||
|
||||
target_path = ""
|
||||
sglang_meta = subprocess.check_output(
|
||||
"python3 -m pip show sglang", shell=True
|
||||
).decode("ascii")
|
||||
for line in sglang_meta.split("\n"):
|
||||
line = line.strip()
|
||||
if line.startswith("Editable project location: "):
|
||||
target_path = str(Path(line.split(": ")[1]).parent)
|
||||
|
||||
if target_path:
|
||||
proc = subprocess.Popen(
|
||||
["git", "apply", patch_path],
|
||||
cwd=target_path,
|
||||
stderr=sys.stdout,
|
||||
stdout=sys.stdout,
|
||||
)
|
||||
proc.wait()
|
||||
logger.info(f"Applied SGLang patch at {target_path}")
|
||||
assert port is not None
|
||||
full_command = f"{command} --port {port}"
|
||||
process = execute_shell_command(full_command)
|
||||
return process, port
|
||||
|
||||
|
||||
def terminate_process(process, port=None):
|
||||
"""
|
||||
Terminate the process and, if a port was reserved, release it.
|
||||
"""
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
|
||||
kill_process_tree(process.pid)
|
||||
|
||||
|
||||
def wait_for_server(base_url: str, timeout: int = None) -> None:
|
||||
"""Wait for the server to be ready by polling the /v1/models endpoint.
|
||||
|
||||
Args:
|
||||
base_url: The base URL of the server
|
||||
timeout: Maximum time to wait in seconds. None means wait forever.
|
||||
"""
|
||||
start_time = time.time()
|
||||
while True:
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{base_url}/v1/models",
|
||||
headers={"Authorization": "Bearer None"},
|
||||
)
|
||||
if response.status_code == 200:
|
||||
time.sleep(5)
|
||||
break
|
||||
|
||||
if timeout and time.time() - start_time > timeout:
|
||||
raise TimeoutError("Server did not become ready within timeout period")
|
||||
except requests.exceptions.RequestException:
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
PORT_CLEARANCE_PERIOD = 90
|
||||
|
||||
|
||||
class GenerationServer(Worker):
|
||||
def _configure(self, config: GenerationServerConfig):
|
||||
self.config = config
|
||||
|
@ -36,20 +133,37 @@ class GenerationServer(Worker):
|
|||
config = self.config
|
||||
|
||||
assert config.backend_type == "sglang"
|
||||
|
||||
host_ip = network.gethostip()
|
||||
host = "localhost" if not config.backend_args.enable_metrics else host_ip
|
||||
|
||||
# NOTE: Ports returned by `find_multiple_free_ports` are unique,
|
||||
# but SGLang servers still encounter conflicts.
|
||||
# Use a clearance period to hack over this issue.
|
||||
servers_per_node = cluster_spec.n_gpus_per_node // self.config.tp_size
|
||||
idx_on_this_node = self.worker_index % servers_per_node
|
||||
time.sleep(idx_on_this_node * PORT_CLEARANCE_PERIOD / servers_per_node)
|
||||
|
||||
ports = network.find_multiple_free_ports(
|
||||
2,
|
||||
low=10000,
|
||||
high=60000,
|
||||
experiment_name=self.experiment_name,
|
||||
trial_name=self.trial_name,
|
||||
)
|
||||
server_port = ports[0]
|
||||
nccl_port = ports[1]
|
||||
|
||||
cmd = SGLangConfig.build_cmd(
|
||||
config.backend_args,
|
||||
config.model_path,
|
||||
tp_size=config.tp_size,
|
||||
server_index=self.worker_index,
|
||||
base_gpu_id=self.base_gpu_id,
|
||||
dist_init_addr=f"{host}:{nccl_port}",
|
||||
)
|
||||
from sglang.utils import launch_server_cmd, wait_for_server
|
||||
|
||||
host_ip = network.gethostip()
|
||||
host = "localhost" if not config.backend_args.enable_metrics else host_ip
|
||||
|
||||
# TODO: handle launching error and retry
|
||||
self.server_process, self.server_port = launch_server_cmd(cmd)
|
||||
self.server_process, self.server_port = launch_server_cmd(cmd, port=server_port)
|
||||
self.server_addr = f"http://{host}:{self.server_port}"
|
||||
|
||||
wait_for_server(self.server_addr)
|
||||
|
@ -80,6 +194,5 @@ class GenerationServer(Worker):
|
|||
|
||||
def _exit_hook(self, exit_status):
|
||||
if self.server_process is not None and self.config.backend_type == "sglang":
|
||||
from sglang.utils import terminate_process
|
||||
|
||||
terminate_process(self.server_process)
|
||||
|
|
|
@ -6,20 +6,46 @@ import shutil
|
|||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
|
||||
from realhf.api.core.model_api import GenReqMeta, GenRespMeta, ModelVersionReq
|
||||
from realhf.api.core.system_api import GserverManager as GserverManagerConfig
|
||||
from realhf.base import constants, logging, name_resolve, names, network, recover
|
||||
from realhf.system.worker_base import AsyncWorker, PollResult, Worker
|
||||
from realhf.system.worker_base import PollResult, Worker
|
||||
|
||||
logger = logging.getLogger("Generation Manager", "colored")
|
||||
logger = logging.getLogger("Generation Manager", "system")
|
||||
|
||||
STALENESS_WARNED = defaultdict(lambda: False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RolloutStat:
|
||||
submit: int = 0
|
||||
accepted: int = 0
|
||||
running: int = 0
|
||||
|
||||
def inc(self):
|
||||
self.submit += 1
|
||||
self.accepted += 1
|
||||
self.running += 1
|
||||
|
||||
def accept(self):
|
||||
self.running -= 1
|
||||
|
||||
def reject(self):
|
||||
self.running -= 1
|
||||
self.accepted -= 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class AllocateRolloutInput:
|
||||
qid: str
|
||||
|
||||
|
||||
class GserverManager(Worker):
|
||||
"""This worker has the following functionalities:
|
||||
1. As a router, it schedules generation requests and returns the
|
||||
|
@ -37,23 +63,30 @@ class GserverManager(Worker):
|
|||
|
||||
assert self.config.worker_info.worker_count == 1
|
||||
|
||||
self.async_lock = asyncio.Lock()
|
||||
self.threading_lock = threading.Lock()
|
||||
self.n_total_rollouts = 0
|
||||
self.n_running_rollouts = 0
|
||||
self.accepted_rollouts = 0
|
||||
self.rollout_stat = RolloutStat()
|
||||
|
||||
self.schedule_policy = config.schedule_policy
|
||||
|
||||
self._last_param_realloc_step = 0
|
||||
|
||||
self._qid_to_server_url = {}
|
||||
|
||||
self._server_token_usage = defaultdict(float)
|
||||
self._server_request_counts = defaultdict(int)
|
||||
|
||||
self._last_thpt_output_time = time.time()
|
||||
self._gen_tokens = 0
|
||||
|
||||
self.experiment_name = config.worker_info.experiment_name
|
||||
self.trial_name = config.worker_info.trial_name
|
||||
|
||||
# manager server
|
||||
self.server = None
|
||||
self.manager_http_server = None
|
||||
self.thread = None
|
||||
|
||||
self.server_urls = []
|
||||
|
||||
# recover info
|
||||
self.__recover_run, self.__recover_info = recover.load_recover_info()
|
||||
if self.__recover_run:
|
||||
|
@ -67,10 +100,12 @@ class GserverManager(Worker):
|
|||
name_resolve.add(name, self.__recover_info.last_step_info.global_step)
|
||||
|
||||
self._loaded_recover_weights = False
|
||||
self.n_total_rollouts = self.accepted_rollouts = (
|
||||
hist_rollouts = (
|
||||
self.config.train_batch_size
|
||||
* self.__recover_info.last_step_info.global_step
|
||||
)
|
||||
self.rollout_stat.submit = hist_rollouts
|
||||
self.rollout_stat.accepted = hist_rollouts
|
||||
|
||||
return config.worker_info
|
||||
|
||||
|
@ -84,7 +119,7 @@ class GserverManager(Worker):
|
|||
if cnt >= timeout:
|
||||
raise TimeoutError("Waiting generation servers timeout.")
|
||||
urls = name_resolve.get_subtree(name)
|
||||
assert len(set(urls)) == len(urls), urls
|
||||
assert len(set(urls)) == len(urls), (len(urls), len(set(urls)), urls)
|
||||
return urls
|
||||
|
||||
def _get_recover_ckpt_path(self, role: str):
|
||||
|
@ -146,42 +181,27 @@ class GserverManager(Worker):
|
|||
async def flush_requests_and_update_weights(
|
||||
self, server_url, new_param_path, update_weights_retries=5
|
||||
):
|
||||
# HACK: urls are designed for SGLang
|
||||
server_index = self.server_urls.index(server_url)
|
||||
async with aiohttp.ClientSession(server_url) as session:
|
||||
running_requests = None
|
||||
tik = time.perf_counter()
|
||||
while running_requests is None or running_requests > 0:
|
||||
if time.perf_counter() - tik > self.config.flush_request_timeout:
|
||||
raise RuntimeError(
|
||||
f"Waiting for flush requests failed. {running_requests} requests "
|
||||
f"remain after {self.config.flush_request_timeout} secs waiting. "
|
||||
f"Please try to reduce `new_tokens_per_chunk`."
|
||||
)
|
||||
if running_requests is not None and running_requests > 0:
|
||||
logger.info(
|
||||
f"Waiting for {running_requests} requests on gen server {server_index}... "
|
||||
f"Time taken so far: {time.perf_counter() - tik:.4f}s"
|
||||
)
|
||||
await asyncio.sleep(0.5)
|
||||
async with session.get(f"/metrics") as resp:
|
||||
resp.raise_for_status()
|
||||
text = await resp.text()
|
||||
for line in text.split("\n"):
|
||||
if line.startswith("sglang:num_running_reqs"):
|
||||
running_requests = float(line.split(" ")[1])
|
||||
break
|
||||
|
||||
success = False
|
||||
for _ in range(update_weights_retries):
|
||||
async with aiohttp.ClientSession(
|
||||
server_url,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.config.flush_request_timeout, sock_connect=30
|
||||
),
|
||||
) as session:
|
||||
async with session.post(
|
||||
f"/update_weights_from_disk",
|
||||
json=dict(model_path=new_param_path),
|
||||
json=dict(model_path=new_param_path, allow_interrupt=True),
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
res = await resp.json()
|
||||
success = res["success"]
|
||||
if success:
|
||||
logger.info(
|
||||
f"{res['num_paused_requests']} requests are interrupted "
|
||||
f"during updateing weights for server {server_index}: {server_url}"
|
||||
)
|
||||
return
|
||||
logger.warning(
|
||||
f"Update weights failed: {res['message']}. Retrying."
|
||||
|
@ -198,6 +218,16 @@ class GserverManager(Worker):
|
|||
self.round_robin_idx %= self.config.n_servers
|
||||
return r
|
||||
|
||||
def _least_requests_schedule(self, req_meta: GenReqMeta) -> int:
|
||||
counts = [
|
||||
self._server_request_counts[server_url] for server_url in self.server_urls
|
||||
]
|
||||
return int(np.argmin(counts))
|
||||
|
||||
def _least_token_usage_schedule(self, req_meta: GenReqMeta) -> int:
|
||||
url = min(self.server_urls, key=lambda k: self._server_token_usage[k])
|
||||
return self.server_urls.index(url)
|
||||
|
||||
def _poll(self):
|
||||
if not self.thread:
|
||||
# Find addresses of generation servers
|
||||
|
@ -228,6 +258,23 @@ class GserverManager(Worker):
|
|||
loop.run_until_complete(asyncio.gather(*tasks))
|
||||
logger.info(f"Generaion server updated weights from: {new_param_path}")
|
||||
|
||||
tasks = [
|
||||
self._get_server_token_usage(server_url) for server_url in self.server_urls
|
||||
]
|
||||
loop = asyncio.get_event_loop()
|
||||
token_usages = loop.run_until_complete(asyncio.gather(*tasks))
|
||||
with self.threading_lock:
|
||||
for server_url, token_usage in zip(self.server_urls, token_usages):
|
||||
self._server_token_usage[server_url] = token_usage
|
||||
|
||||
if time.time() - self._last_thpt_output_time > 30:
|
||||
interval = time.time() - self._last_thpt_output_time
|
||||
logger.info(
|
||||
f"Generation throughput: {self._gen_tokens / interval:.2f} tokens/s"
|
||||
)
|
||||
self._last_thpt_output_time = time.time()
|
||||
self._gen_tokens = 0
|
||||
|
||||
# clear old weights
|
||||
realloc_root = os.path.join(
|
||||
constants.PARAM_REALLOC_PATH,
|
||||
|
@ -237,9 +284,11 @@ class GserverManager(Worker):
|
|||
)
|
||||
if os.path.exists(realloc_root):
|
||||
for realloc_version in os.listdir(realloc_root):
|
||||
# Lock-free is safe here.
|
||||
# Remain one checkpoint for recover.
|
||||
if (
|
||||
os.path.isdir(os.path.join(realloc_root, realloc_version))
|
||||
and int(realloc_version) < self._last_param_realloc_step
|
||||
and int(realloc_version) < self._last_param_realloc_step - 1
|
||||
):
|
||||
shutil.rmtree(os.path.join(realloc_root, realloc_version))
|
||||
logger.info(
|
||||
|
@ -247,29 +296,56 @@ class GserverManager(Worker):
|
|||
f"checkpoint: {os.path.join(realloc_root, realloc_version)}"
|
||||
)
|
||||
|
||||
# TODO: we may want to update server status
|
||||
# in the main thread.
|
||||
|
||||
time.sleep(1)
|
||||
time.sleep(5)
|
||||
|
||||
return PollResult(0, 0)
|
||||
|
||||
async def is_staled(self):
|
||||
global_sample_cnt = self.n_total_rollouts
|
||||
expected_version = global_sample_cnt // self.config.train_batch_size
|
||||
staled = (
|
||||
expected_version
|
||||
> self.config.max_head_offpolicyness + self._last_param_realloc_step
|
||||
async def _get_server_token_usage(self, server_url):
|
||||
async with aiohttp.ClientSession(
|
||||
server_url,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.config.flush_request_timeout, sock_connect=30
|
||||
),
|
||||
) as session:
|
||||
async with session.get("/metrics") as resp:
|
||||
resp.raise_for_status()
|
||||
text = await resp.text()
|
||||
for l in text.split("\n"):
|
||||
if l.startswith("sglang:num_used_tokens"):
|
||||
return float(l.split(" ")[1])
|
||||
raise RuntimeError(f"Failed to get token usage metrics from {server_url}")
|
||||
|
||||
async def _get_server_num_running_requests(self, server_url):
|
||||
async with aiohttp.ClientSession(
|
||||
server_url,
|
||||
timeout=aiohttp.ClientTimeout(
|
||||
total=self.config.flush_request_timeout, sock_connect=30
|
||||
),
|
||||
) as session:
|
||||
async with session.get(f"/metrics") as resp:
|
||||
resp.raise_for_status()
|
||||
text = await resp.text()
|
||||
for line in text.split("\n"):
|
||||
if line.startswith("sglang:num_running_reqs"):
|
||||
return float(line.split(" ")[1])
|
||||
raise RuntimeError(
|
||||
f"Failed to get num running requests metrics from {server_url}"
|
||||
)
|
||||
|
||||
def is_staled(self):
|
||||
global_sample_cnt = self.rollout_stat.accepted
|
||||
expected_version = global_sample_cnt // self.config.train_batch_size
|
||||
version = self._last_param_realloc_step
|
||||
staled = expected_version > self.config.max_head_offpolicyness + version
|
||||
global STALENESS_WARNED
|
||||
if staled and not STALENESS_WARNED[self._last_param_realloc_step]:
|
||||
if staled and not STALENESS_WARNED[version]:
|
||||
logger.warning(
|
||||
f"expected version ({expected_version}) = "
|
||||
f"global sample cnt ({global_sample_cnt}) // batch size ({self.config.train_batch_size}), "
|
||||
f"current version {self._last_param_realloc_step}, "
|
||||
f"current latest version {version}, "
|
||||
f"offpolicyness {self.config.max_head_offpolicyness}. Staled? {staled}"
|
||||
)
|
||||
STALENESS_WARNED[self._last_param_realloc_step] = True
|
||||
STALENESS_WARNED[version] = True
|
||||
return staled
|
||||
|
||||
def _run_routing_service(self):
|
||||
|
@ -282,46 +358,69 @@ class GserverManager(Worker):
|
|||
@self.app.post("/schedule_request")
|
||||
async def schedule_request(req_meta: GenReqMeta):
|
||||
with self.threading_lock:
|
||||
async with self.async_lock:
|
||||
version = self._last_param_realloc_step
|
||||
# FIXME: We only implement a round-robin scheduler that
|
||||
# ignores server status and request metadata
|
||||
if (
|
||||
req_meta.previous_server_url
|
||||
and req_meta.previous_version == self._last_param_realloc_step
|
||||
):
|
||||
return dict(
|
||||
url=req_meta.previous_server_url,
|
||||
version=req_meta.previous_version,
|
||||
)
|
||||
|
||||
if self.schedule_policy == "round_robin":
|
||||
server_idx = self._round_robin_schedule(req_meta)
|
||||
return dict(url=self.server_urls[server_idx], version=max(0, version))
|
||||
elif self.schedule_policy == "least_token_usage":
|
||||
server_idx = self._least_token_usage_schedule(req_meta)
|
||||
elif self.schedule_policy == "least_requests":
|
||||
server_idx = self._least_requests_schedule(req_meta)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unknown schedule policy {self.schedule_policy}"
|
||||
)
|
||||
|
||||
server_url = self.server_urls[server_idx]
|
||||
# qid prompt (n samples) use the same dst server
|
||||
self._qid_to_server_url[req_meta.qid] = server_url
|
||||
self._server_request_counts[server_url] += 1
|
||||
self._server_token_usage[server_url] += (
|
||||
req_meta.prompt_len
|
||||
+ req_meta.new_token_budget * req_meta.group_size * 0.4
|
||||
)
|
||||
|
||||
version = self._last_param_realloc_step
|
||||
return dict(url=server_url, version=version)
|
||||
|
||||
@self.app.post("/get_model_version")
|
||||
async def get_model_version(req: ModelVersionReq):
|
||||
with self.threading_lock:
|
||||
async with self.async_lock:
|
||||
# FIXME: we may have different versions for different servers
|
||||
version = self._last_param_realloc_step
|
||||
return dict(version=version)
|
||||
|
||||
@self.app.get("/allocate_rollout")
|
||||
async def allocate_rollout():
|
||||
@self.app.post("/allocate_rollout")
|
||||
async def allocate_rollout(req: AllocateRolloutInput):
|
||||
with self.threading_lock:
|
||||
async with self.async_lock:
|
||||
has_capacity = (
|
||||
self.n_running_rollouts < self.config.max_concurrent_rollouts
|
||||
self.rollout_stat.running < self.config.max_concurrent_rollouts
|
||||
)
|
||||
is_staled = await self.is_staled()
|
||||
is_staled = self.is_staled()
|
||||
reason = ""
|
||||
if has_capacity and not is_staled:
|
||||
self.n_running_rollouts += 1
|
||||
self.n_total_rollouts += 1
|
||||
self.rollout_stat.inc()
|
||||
return dict(success=True, reason=reason)
|
||||
else:
|
||||
if not has_capacity:
|
||||
reason += f"capacity: {self.n_running_rollouts} >= {self.config.max_concurrent_rollouts}"
|
||||
reason += f"capacity: {self.rollout_stat.running} >= {self.config.max_concurrent_rollouts}"
|
||||
if is_staled:
|
||||
global_sample_cnt = self.n_total_rollouts
|
||||
global_sample_cnt = self.rollout_stat.accepted
|
||||
expected_version = (
|
||||
global_sample_cnt // self.config.train_batch_size
|
||||
)
|
||||
version = self._last_param_realloc_step
|
||||
reason += (
|
||||
f" and staled: expected version ({expected_version}) = "
|
||||
f"global sample cnt ({global_sample_cnt}) // batch size ({self.config.train_batch_size}), "
|
||||
f"current version {self._last_param_realloc_step}, "
|
||||
f"current latest version {version}, "
|
||||
f"offpolicyness {self.config.max_head_offpolicyness}."
|
||||
)
|
||||
return dict(success=False, reason=reason)
|
||||
|
@ -329,13 +428,24 @@ class GserverManager(Worker):
|
|||
@self.app.post("/finish_rollout")
|
||||
async def finish_rollout(resp_meta: GenRespMeta):
|
||||
with self.threading_lock:
|
||||
async with self.async_lock:
|
||||
self.n_running_rollouts -= 1
|
||||
server_url = self._qid_to_server_url[resp_meta.qid]
|
||||
self._server_request_counts[server_url] -= 1
|
||||
assert (
|
||||
self._server_request_counts[server_url] >= 0
|
||||
), "server request count < 0"
|
||||
self._qid_to_server_url.pop(resp_meta.qid)
|
||||
self._gen_tokens += resp_meta.n_tokens
|
||||
if resp_meta.accepted:
|
||||
self.accepted_rollouts += 1
|
||||
self.rollout_stat.accept()
|
||||
else:
|
||||
self.rollout_stat.reject()
|
||||
return dict(success=True)
|
||||
|
||||
self.manager_addr = f"{network.gethostip()}:{network.find_free_port()}"
|
||||
port = network.find_free_port(
|
||||
experiment_name=self.experiment_name,
|
||||
trial_name=self.trial_name,
|
||||
)
|
||||
self.manager_addr = f"{network.gethostip()}:{port}"
|
||||
|
||||
config = uvicorn.Config(
|
||||
self.app,
|
||||
|
@ -343,12 +453,12 @@ class GserverManager(Worker):
|
|||
port=int(self.manager_addr.split(":")[1]),
|
||||
log_level="warning",
|
||||
)
|
||||
self.server = uvicorn.Server(config)
|
||||
self.server.run()
|
||||
self.manager_http_server = uvicorn.Server(config)
|
||||
self.manager_http_server.run()
|
||||
|
||||
def _exit_hook(self, exit_status):
|
||||
if self.server:
|
||||
self.server.should_exit = True
|
||||
if self.manager_http_server:
|
||||
self.manager_http_server.should_exit = True
|
||||
if self.thread:
|
||||
self.thread.join(timeout=3)
|
||||
logger.info("Server stopped")
|
||||
|
|
|
@ -145,6 +145,7 @@ class MasterWorker(worker_base.Worker):
|
|||
# for benchmark
|
||||
self.e2e_time_history = []
|
||||
self.__benchmark_steps = config.exp_ctrl.benchmark_steps
|
||||
self.__benchmark_n_seqs = config.exp_ctrl.benchmark_n_seqs
|
||||
|
||||
return config.worker_info
|
||||
|
||||
|
@ -210,13 +211,14 @@ class MasterWorker(worker_base.Worker):
|
|||
src_rpc_dp_size = src_rpc_topo.get_dim("data")
|
||||
|
||||
# Request training specification from data workers.
|
||||
self._dataset_size = sum(
|
||||
self.__stream.call(
|
||||
specs = self.__stream.call(
|
||||
handlers=[f"__data{i}__" for i in range(src_rpc_dp_size)],
|
||||
datas=[None for i in range(src_rpc_dp_size)],
|
||||
handle_type="spec",
|
||||
),
|
||||
)
|
||||
assert all(x["n_datasets"] == specs[0]["n_datasets"] for x in specs), specs
|
||||
self._dataset_size = sum(x["dataset_size"] for x in specs)
|
||||
self._n_datasets = specs[0]["n_datasets"]
|
||||
|
||||
self._steps_per_epoch = self._dataset_size // src_rpc.n_seqs
|
||||
|
||||
|
@ -239,7 +241,7 @@ class MasterWorker(worker_base.Worker):
|
|||
src_rpc_dp_size = src_rpc_topo.get_dim("data")
|
||||
src_rpc_pp_size = src_rpc_topo.get_dim("pipe")
|
||||
for i in range(src_rpc_dp_size):
|
||||
rank = src_rpc_topo.get_rank(data=i, pipe=src_rpc_pp_size - 1, model=0)
|
||||
rank = src_rpc_topo.get_rank(data=i, pipe=src_rpc_pp_size - 1, tensor=0)
|
||||
handler_routing[f"__data{i}__"] = self.config.msid2mwid[
|
||||
config_pkg.ModelShardID.from_parallelism_rank(
|
||||
model_name=src_rpc.model_name,
|
||||
|
@ -263,10 +265,13 @@ class MasterWorker(worker_base.Worker):
|
|||
|
||||
self.initialize_models()
|
||||
|
||||
self.__seqbuffer = AsyncIOSequenceBuffer(
|
||||
self.__seqbuffers = [
|
||||
AsyncIOSequenceBuffer(
|
||||
self.__model_rpcs,
|
||||
max_size=int(os.getenv("REAL_MASTER_BUFFER_SIZE", str(int(1e7)))),
|
||||
)
|
||||
for _ in range(self._n_datasets)
|
||||
]
|
||||
|
||||
# wandb init, connect to remote wandb host
|
||||
wandb.login()
|
||||
|
@ -300,7 +305,7 @@ class MasterWorker(worker_base.Worker):
|
|||
rpcs=self.__model_rpcs,
|
||||
msid2mwid=self.config.msid2mwid,
|
||||
stream=self.__stream,
|
||||
buffer=self.__seqbuffer,
|
||||
buffers=self.__seqbuffers,
|
||||
model_topos=self.__model_topos,
|
||||
model_configs=self.__model_configs,
|
||||
ctrl=self.__rpc_ctrl,
|
||||
|
@ -395,20 +400,33 @@ class MasterWorker(worker_base.Worker):
|
|||
|
||||
# Pause the worker if experiment or system-wise benchmark completes.
|
||||
if (
|
||||
(
|
||||
self.__benchmark_steps is not None
|
||||
and self.__rpc_ctrl.step_info.global_step >= self.__benchmark_steps
|
||||
) or (
|
||||
)
|
||||
or (
|
||||
self.__rpc_ctrl.step_info.global_step * self.__src_rpc.n_seqs
|
||||
>= self.__total_train_epochs * self._dataset_size
|
||||
)
|
||||
or (
|
||||
self.__benchmark_n_seqs is not None
|
||||
and self.__rpc_ctrl.step_info.global_step
|
||||
* self._ft_spec.train_batch_size
|
||||
>= self.__benchmark_n_seqs
|
||||
)
|
||||
):
|
||||
# We don't know whether it is the last step of the current epoch,
|
||||
# so we exit at the first step of the next epoch.
|
||||
if self.__benchmark_steps is not None:
|
||||
if (
|
||||
self.__benchmark_steps is not None
|
||||
or self.__benchmark_n_seqs is not None
|
||||
):
|
||||
logger.info(
|
||||
f"Finished benchmark {self.__benchmark_steps}. "
|
||||
f"Time consumption of this setup: {time_since_configure:.3f}"
|
||||
)
|
||||
logger.info(f"avg #e2e# time *{np.mean(self.e2e_time_history):.3f}*")
|
||||
# TODO: inform generation workers to exit
|
||||
return self.experiment_complete_exit()
|
||||
|
||||
return worker_base.PollResult(sample_count=1, batch_count=1)
|
||||
|
@ -439,9 +457,7 @@ class MasterWorker(worker_base.Worker):
|
|||
s += f"(global step {global_step}) finishes. "
|
||||
s += f"#End to end# execution time: *{e2e_time:.3f}*s. "
|
||||
s += f"Total time consumption: {time_since_configure:.3f}s. "
|
||||
logging.log_wandb_tensorboard(
|
||||
{"timeperf/e2e": e2e_time}, step=self.__rpc_ctrl.step_info.global_step
|
||||
)
|
||||
logging.log_wandb_tensorboard({"timeperf/e2e": e2e_time})
|
||||
if len(self.e2e_time_history) > 2:
|
||||
remaining_steps = self._steps_per_epoch - epoch_step
|
||||
remaining_epochs = self.__total_train_epochs - epoch
|
||||
|
|
|
@ -63,7 +63,7 @@ class ModelFunctionCall:
|
|||
model_topos: Dict[str, topology.ProcessTopology],
|
||||
model_configs: Dict[str, None | ReaLModelConfig],
|
||||
ctrl: RPCCorountineControl,
|
||||
buffer: AsyncIOSequenceBuffer,
|
||||
buffers: List[AsyncIOSequenceBuffer],
|
||||
redistrib_planner: RedistribPlanner,
|
||||
summary_writer: SummaryWriter | None,
|
||||
):
|
||||
|
@ -89,7 +89,7 @@ class ModelFunctionCall:
|
|||
)
|
||||
|
||||
self.rpc_ctrl = ctrl
|
||||
self.buffer = buffer
|
||||
self.buffers = buffers
|
||||
self.redistrib_planner = redistrib_planner
|
||||
|
||||
self.summary_writer = summary_writer
|
||||
|
@ -306,7 +306,7 @@ class ModelFunctionCall:
|
|||
).partitions
|
||||
return buf_indices, sample, partitions
|
||||
|
||||
async def run_step(self, buf_indices, sample):
|
||||
async def run_step(self, buf_indices, sample, buffer_id: int):
|
||||
rpc = self.rpc
|
||||
topo = self.model_topos[rpc.model_name]
|
||||
ctrl = self.rpc_ctrl
|
||||
|
@ -317,7 +317,7 @@ class ModelFunctionCall:
|
|||
]
|
||||
|
||||
dp_head_indices = [
|
||||
topo.get_rank(data=i, pipe=topo.get_dim("pipe") - 1, model=0)
|
||||
topo.get_rank(data=i, pipe=topo.get_dim("pipe") - 1, tensor=0)
|
||||
for i in range(self.dp_size)
|
||||
]
|
||||
|
||||
|
@ -348,11 +348,6 @@ class ModelFunctionCall:
|
|||
if i not in dests:
|
||||
dests[i] = []
|
||||
|
||||
# NOTE: The data loaded from the dataset may be unevenly distributed across DP ranks.
|
||||
# Only bcast works in this case.
|
||||
if rpc.is_src:
|
||||
pattern = "bcast"
|
||||
else:
|
||||
pattern = "gather-scatter"
|
||||
data_transfer_plan = self.redistrib_planner.derive_plan(
|
||||
dests,
|
||||
|
@ -362,14 +357,14 @@ class ModelFunctionCall:
|
|||
blogger.info(f"Data tranfer plan for `{rpc.name}`: {data_transfer_plan}.")
|
||||
|
||||
# Update storage tracker for transferred data.
|
||||
if rpc.is_src:
|
||||
if pattern == "bcast":
|
||||
# NOTE: since the data we loaded may be unevenly distributed across DP ranks,
|
||||
# we should change the owner of the data to the src RPC.
|
||||
for i in range(topo.world_size()):
|
||||
h = ModelShardID.from_parallelism_rank(
|
||||
model_name=rpc.model_name, topo=topo, parallelism_rank=i
|
||||
)
|
||||
is_dp_head = h.mp_rank == 0 and h.pp_rank == topo.get_dim("pipe") - 1
|
||||
is_dp_head = h.tp_rank == 0 and h.pp_rank == topo.get_dim("pipe") - 1
|
||||
gpu_id = self.msid2mwid[h]
|
||||
for key in rpc.input_keys:
|
||||
await self.redistrib_planner.storage_tracker.add_data(
|
||||
|
@ -414,13 +409,13 @@ class ModelFunctionCall:
|
|||
responses, time_records = list(zip(*[responses[i] for i in dp_head_indices]))
|
||||
|
||||
# If the returned data is a SequenceSample, it is the data returned by
|
||||
# model function calls. The data shoulbe be amended into buffer.
|
||||
# model function calls. The data should be amended into buffer.
|
||||
# Otherwise, it's the train statistics and should be reduced and logged.
|
||||
if isinstance(responses[-1], data_api.SequenceSample):
|
||||
# Update storage tracker for generated data.
|
||||
for dp_rank, x in enumerate(responses):
|
||||
pp_size = topo.get_dim("pipe")
|
||||
ranks = topo.filter_match(data=dp_rank, pipe=pp_size - 1, model=0)
|
||||
ranks = topo.filter_match(data=dp_rank, pipe=pp_size - 1, tensor=0)
|
||||
for rank in ranks:
|
||||
h = config_pkg.ModelShardID.from_parallelism_rank(
|
||||
model_name=rpc.model_name, topo=topo, parallelism_rank=rank
|
||||
|
@ -434,8 +429,14 @@ class ModelFunctionCall:
|
|||
is_owner=True,
|
||||
)
|
||||
res = data_api.SequenceSample.gather(responses)
|
||||
else:
|
||||
elif isinstance(responses[0], dict):
|
||||
res = data_api.gather_stat(responses)
|
||||
else:
|
||||
assert isinstance(responses[0], list)
|
||||
res = [
|
||||
data_api.gather_stat([r[i] for r in responses])
|
||||
for i in range(len(responses[0]))
|
||||
]
|
||||
|
||||
if rpc.log_return_value:
|
||||
if isinstance(res, dict):
|
||||
|
@ -447,6 +448,17 @@ class ModelFunctionCall:
|
|||
step=ctrl.step_info.global_step,
|
||||
summary_writer=self.summary_writer,
|
||||
)
|
||||
elif isinstance(res, list):
|
||||
for j, r in enumerate(res):
|
||||
logger.info(
|
||||
f"RPC name {rpc.name} returns ({j}/{len(res)})\n{data_api.tabulate_stats(r)}"
|
||||
)
|
||||
offset = len(res) * ctrl.step_info.global_step
|
||||
logging.log_wandb_tensorboard(
|
||||
r,
|
||||
step=offset + j,
|
||||
summary_writer=self.summary_writer,
|
||||
)
|
||||
else:
|
||||
logger.info(f"RPC name {rpc.name} returns\n{res}")
|
||||
|
||||
|
@ -456,7 +468,6 @@ class ModelFunctionCall:
|
|||
time_stats = stats_tracker.export()
|
||||
logging.log_wandb_tensorboard(
|
||||
time_stats,
|
||||
step=ctrl.step_info.global_step,
|
||||
summary_writer=self.summary_writer,
|
||||
)
|
||||
|
||||
|
@ -475,7 +486,7 @@ class ModelFunctionCall:
|
|||
await ctrl.train_count.put(1)
|
||||
else:
|
||||
logger.info(f"Amending RPC {rpc.name} output keys: {res.keys}")
|
||||
await self.buffer.amend_batch(buf_indices, res.unpack())
|
||||
await self.buffers[buffer_id].amend_batch(buf_indices, res.unpack())
|
||||
|
||||
# Wait for all side-effect requests to finish.
|
||||
# Side-effect or empty requests are required for data transfer
|
||||
|
@ -483,20 +494,20 @@ class ModelFunctionCall:
|
|||
# Wait them after the main request to log the oorrect MFC time.
|
||||
await self.stream.gather_async(other_req_ids)
|
||||
|
||||
async def run(self):
|
||||
async def run(self, buffer_id: int):
|
||||
rpc = self.rpc
|
||||
topo = self.model_topos[rpc.model_name]
|
||||
|
||||
logger.info(
|
||||
f"Running Model RPC, interface_type=#{rpc.interface_type}# "
|
||||
f"(dp,mp,pp) = *({topo.get_dim('data')},{topo.get_dim('model')},{topo.get_dim('pipe')})*"
|
||||
f"(dp,tp,pp) = *({topo.get_dim('data')},{topo.get_dim('tensor')},{topo.get_dim('pipe')})*"
|
||||
)
|
||||
|
||||
consumed = 0
|
||||
while True:
|
||||
buf_indices, sample = await self.buffer.get_batch_for_rpc(rpc)
|
||||
buf_indices, sample = await self.buffers[buffer_id].get_batch_for_rpc(rpc)
|
||||
|
||||
await self.run_step(buf_indices, sample)
|
||||
await self.run_step(buf_indices, sample, buffer_id)
|
||||
consumed += sample.bs
|
||||
|
||||
# Ensure that parent RPCs will not be over-consumed.
|
||||
|
|
|
@ -153,7 +153,7 @@ class ModelWorker(worker_base.Worker):
|
|||
]
|
||||
for s in self.config.shards:
|
||||
_pp_size = s.id.topo.get_dim("pipe")
|
||||
if not (s.id.mp_rank == 0 and s.id.pp_rank == _pp_size - 1):
|
||||
if not (s.id.tp_rank == 0 and s.id.pp_rank == _pp_size - 1):
|
||||
continue
|
||||
if src_rpc.model_name == s.id.model_name:
|
||||
self.__has_dataset = True
|
||||
|
@ -195,8 +195,8 @@ class ModelWorker(worker_base.Worker):
|
|||
return None
|
||||
|
||||
@property
|
||||
def _mp_rank(self) -> int:
|
||||
return constants.model_parallel_rank()
|
||||
def _tp_rank(self) -> int:
|
||||
return constants.tensor_parallel_rank()
|
||||
|
||||
@property
|
||||
def _pp_rank(self) -> int:
|
||||
|
@ -211,8 +211,8 @@ class ModelWorker(worker_base.Worker):
|
|||
return constants.pipe_parallel_world_size()
|
||||
|
||||
@property
|
||||
def _mp_size(self) -> int:
|
||||
return constants.model_parallel_world_size()
|
||||
def _tp_size(self) -> int:
|
||||
return constants.tensor_parallel_world_size()
|
||||
|
||||
@property
|
||||
def _dp_size(self) -> int:
|
||||
|
@ -220,7 +220,7 @@ class ModelWorker(worker_base.Worker):
|
|||
|
||||
@property
|
||||
def _is_dp_head(self) -> bool:
|
||||
return self._mp_rank == 0 and self._pp_rank == self._pp_size - 1
|
||||
return self._tp_rank == 0 and self._pp_rank == self._pp_size - 1
|
||||
|
||||
@property
|
||||
def _model(self) -> model_api.Model:
|
||||
|
@ -302,6 +302,7 @@ class ModelWorker(worker_base.Worker):
|
|||
constants.set_grid(model_name_, grid)
|
||||
|
||||
# Set up training dataset for source RPCs.
|
||||
self.__datasets = []
|
||||
if self.__has_dataset:
|
||||
datasets = [
|
||||
data_api.make_dataset(
|
||||
|
@ -321,31 +322,34 @@ class ModelWorker(worker_base.Worker):
|
|||
)
|
||||
for d in self.config.datasets
|
||||
]
|
||||
if len(self.config.datasets) == 1:
|
||||
self.__dataset = datasets[0]
|
||||
else:
|
||||
self.__dataset = torch.utils.data.ConcatDataset(datasets)
|
||||
self.__datasets = datasets
|
||||
|
||||
self.__dataloaders: List[
|
||||
torch.utils.data.DataLoader[data_api.SequenceSample]
|
||||
] = []
|
||||
for i, d in enumerate(self.__datasets):
|
||||
g = torch.Generator()
|
||||
g.manual_seed(seeding.get_seed())
|
||||
g.manual_seed(
|
||||
self.config.base_seed + seeding._seed_from_key(f"__dataloader{i}__")
|
||||
)
|
||||
dataloader_kwargs = dict(
|
||||
shuffle=self.config.shuffle_dataset,
|
||||
generator=g,
|
||||
)
|
||||
if not isinstance(self.__dataset, PullerStreamDataset):
|
||||
if not isinstance(d, PullerStreamDataset):
|
||||
dataloader_kwargs["collate_fn"] = data_api.SequenceSample.gather
|
||||
# NOTE: This is *NOT* the actual batch size for training.
|
||||
# It is just a proper size to load data to workers.
|
||||
dataloader_kwargs["batch_size"] = 10240
|
||||
else:
|
||||
dataloader_kwargs["batch_size"] = None
|
||||
self.__dataloader = torch.utils.data.DataLoader(
|
||||
self.__dataset, **dataloader_kwargs
|
||||
self.__dataloaders.append(
|
||||
torch.utils.data.DataLoader(d, **dataloader_kwargs)
|
||||
)
|
||||
|
||||
self.dataset_size = len(self.__dataset)
|
||||
self.dataset_size = sum(len(d) for d in self.__datasets)
|
||||
|
||||
self.__data_generator = enumerate(self.__dataloader)
|
||||
self.__data_generators = [enumerate(d) for d in self.__dataloaders]
|
||||
|
||||
self.__models: Dict[ModelName, model_api.Model] = dict()
|
||||
self.__model_is_handle: Dict[ModelName, bool] = dict()
|
||||
|
@ -377,25 +381,26 @@ class ModelWorker(worker_base.Worker):
|
|||
)
|
||||
|
||||
# Recover indices for dynamic dataset
|
||||
for i, d in enumerate(self.__datasets):
|
||||
if (
|
||||
s.id.model_name == self.src_rpc.model_name
|
||||
and self.__has_dataset
|
||||
and hasattr(self.__dataset, "filter")
|
||||
and hasattr(d, "filter")
|
||||
):
|
||||
dataset_indices_path = os.path.join(
|
||||
constants.MODEL_SAVE_ROOT,
|
||||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"dataset_indices",
|
||||
f"{self._dp_rank}.npy",
|
||||
f"{self._dp_rank}_{i}.npy",
|
||||
)
|
||||
if os.path.exists(dataset_indices_path):
|
||||
indices = np.load(dataset_indices_path).tolist()
|
||||
logger.info(
|
||||
f"DP rank {self._dp_rank} updating dataset indices upon recover, "
|
||||
f"size {len(self.__dataset.active_indices)} -> {len(indices)}"
|
||||
f"size {len(d.active_indices)} -> {len(indices)}"
|
||||
)
|
||||
self.__dataset.active_indices = indices
|
||||
d.active_indices = indices
|
||||
|
||||
if constants.parallelism_rank() == 0:
|
||||
self.logger.info(
|
||||
|
@ -537,9 +542,13 @@ class ModelWorker(worker_base.Worker):
|
|||
cache = []
|
||||
while True:
|
||||
try:
|
||||
request, data, handled, res, time_record = (
|
||||
self.__request_queue.get_nowait()
|
||||
)
|
||||
(
|
||||
request,
|
||||
data,
|
||||
handled,
|
||||
res,
|
||||
time_record,
|
||||
) = self.__request_queue.get_nowait()
|
||||
request: request_reply_stream.Payload
|
||||
if not handled:
|
||||
while len(request.pre_hooks) > 0:
|
||||
|
@ -582,9 +591,13 @@ class ModelWorker(worker_base.Worker):
|
|||
elif request.handle_name == "fetch":
|
||||
dp_rank = int(re.search(r"__data(\d+)__", request.handler).group(1))
|
||||
assert self.__has_dataset
|
||||
assert isinstance(request.data, int), request.data
|
||||
dataset_id = request.data
|
||||
# Fetch.
|
||||
try:
|
||||
self.__dataset_batch_counter, cur_sample = next(self.__data_generator)
|
||||
self.__dataset_batch_counter, cur_sample = next(
|
||||
self.__data_generators[dataset_id]
|
||||
)
|
||||
except StopIteration:
|
||||
# Upon the first fetch request, filter dataset and create dataloader.
|
||||
eval_scores_path = os.path.join(
|
||||
|
@ -598,39 +611,43 @@ class ModelWorker(worker_base.Worker):
|
|||
constants.experiment_name(),
|
||||
constants.trial_name(),
|
||||
"dataset_indices",
|
||||
f"{dp_rank}.npy",
|
||||
f"{dp_rank}_{dataset_id}.npy",
|
||||
)
|
||||
os.makedirs(os.path.dirname(dataset_indices_path), exist_ok=True)
|
||||
if hasattr(self.__dataset, "filter") and os.path.exists(
|
||||
if hasattr(self.__datasets[dataset_id], "filter") and os.path.exists(
|
||||
eval_scores_path
|
||||
):
|
||||
# Don't filter dataset on the first poll after recover.
|
||||
with open(eval_scores_path, "r", encoding="utf-8") as f:
|
||||
dataset_eval_scores = json.load(f)
|
||||
self.__dataset.filter(dataset_eval_scores)
|
||||
self.__datasets[dataset_id].filter(dataset_eval_scores)
|
||||
# Save the dataset indices after filtering
|
||||
np.save(
|
||||
dataset_indices_path,
|
||||
self.__dataset.active_indices,
|
||||
self.__datasets[dataset_id].active_indices,
|
||||
)
|
||||
g = torch.Generator()
|
||||
g = g.set_state(self.__dataloader.generator.get_state())
|
||||
g = g.set_state(self.__dataloaders[dataset_id].generator.get_state())
|
||||
dataloader_kwargs = dict(
|
||||
shuffle=self.config.shuffle_dataset,
|
||||
generator=g,
|
||||
)
|
||||
if not isinstance(self.__dataset, PullerStreamDataset):
|
||||
if not isinstance(self.__datasets[dataset_id], PullerStreamDataset):
|
||||
dataloader_kwargs["collate_fn"] = data_api.SequenceSample.gather
|
||||
# NOTE: This is *NOT* the actual batch size for training.
|
||||
# It is just a proper size to load data to workers.
|
||||
dataloader_kwargs["batch_size"] = 10240
|
||||
else:
|
||||
dataloader_kwargs["batch_size"] = None
|
||||
self.__dataloader = torch.utils.data.DataLoader(
|
||||
self.__dataset, **dataloader_kwargs
|
||||
self.__dataloaders[dataset_id] = torch.utils.data.DataLoader(
|
||||
self.__datasets[dataset_id], **dataloader_kwargs
|
||||
)
|
||||
self.__data_generators[dataset_id] = enumerate(
|
||||
self.__dataloaders[dataset_id]
|
||||
)
|
||||
self.__dataset_batch_counter, cur_sample = next(
|
||||
self.__data_generators[dataset_id]
|
||||
)
|
||||
self.__data_generator = enumerate(self.__dataloader)
|
||||
self.__dataset_batch_counter, cur_sample = next(self.__data_generator)
|
||||
|
||||
if isinstance(cur_sample, data_api.SequenceSample):
|
||||
samples = cur_sample.unpack()
|
||||
|
@ -663,7 +680,10 @@ class ModelWorker(worker_base.Worker):
|
|||
)
|
||||
elif request.handle_name == "spec":
|
||||
# Raw dataset without filtering.
|
||||
res = self.dataset_size
|
||||
res = {
|
||||
"n_datasets": len(self.__datasets),
|
||||
"dataset_size": self.dataset_size,
|
||||
}
|
||||
elif request.handle_name == "clear_data_cache":
|
||||
with cuda_tmarked("clear_data_cache", CUDATimeMarkType.misc):
|
||||
ids = request.data
|
||||
|
@ -772,8 +792,10 @@ class ModelWorker(worker_base.Worker):
|
|||
if hook == "evaluate":
|
||||
assert request.handle_name == "train_step", request.handle_name
|
||||
assert isinstance(ret, dict), ret
|
||||
assert isinstance(res, dict), res
|
||||
if isinstance(res, dict):
|
||||
res.update(ret)
|
||||
else:
|
||||
res[0].update(ret)
|
||||
time_record[
|
||||
f"timeperf/{request.handler.model_name.role}_{request.handle_name}/post-{hook}"
|
||||
] += (time.perf_counter() - tik)
|
||||
|
@ -803,13 +825,7 @@ class ModelWorker(worker_base.Worker):
|
|||
with constants.model_scope(model_name):
|
||||
dist.barrier(group=constants.cpu_parallelism_group())
|
||||
if constants.parallelism_rank() == 0:
|
||||
name_resolve.add(
|
||||
name,
|
||||
str(global_step),
|
||||
delete_on_exit=False,
|
||||
keepalive_ttl=30,
|
||||
replace=True,
|
||||
)
|
||||
name_resolve.add(name, str(global_step), replace=True)
|
||||
time_record[
|
||||
f"timeperf/{request.handler.model_name.role}_{request.handle_name}/param-sync-save"
|
||||
] += (time.perf_counter() - tik)
|
||||
|
@ -867,7 +883,7 @@ class ModelWorker(worker_base.Worker):
|
|||
if len(self.__performance_recorder) == 0:
|
||||
self.__performance_recorder["info"] = {
|
||||
"pipeline_size": self._pp_size,
|
||||
"model_size": self._mp_size,
|
||||
"model_size": self._tp_size,
|
||||
"data_size": self._dp_size,
|
||||
"rank": constants.parallelism_rank(),
|
||||
"sequence_parallel_enabled": constants.sequence_parallel(),
|
||||
|
@ -1374,9 +1390,13 @@ class ModelWorker(worker_base.Worker):
|
|||
rescheduled_requests = []
|
||||
other_requests = []
|
||||
for _ in range(self.__request_queue.qsize()):
|
||||
request, data, handled, res, time_record = (
|
||||
self.__request_queue.get_nowait()
|
||||
)
|
||||
(
|
||||
request,
|
||||
data,
|
||||
handled,
|
||||
res,
|
||||
time_record,
|
||||
) = self.__request_queue.get_nowait()
|
||||
if request.handle_name not in ["inference", "generate", "train_step"]:
|
||||
other_requests.append((request, data, handled, res, time_record))
|
||||
else:
|
||||
|
@ -1399,9 +1419,13 @@ class ModelWorker(worker_base.Worker):
|
|||
# we can correctly log the time consumption in the master worker.
|
||||
while True:
|
||||
try:
|
||||
request, data, handled, res, time_record = (
|
||||
self.__request_queue.get_nowait()
|
||||
)
|
||||
(
|
||||
request,
|
||||
data,
|
||||
handled,
|
||||
res,
|
||||
time_record,
|
||||
) = self.__request_queue.get_nowait()
|
||||
self.handle_blocking_request(
|
||||
request, data, handled, res, time_record
|
||||
)
|
||||
|
|
|
@ -103,9 +103,14 @@ class PartialRolloutManager:
|
|||
):
|
||||
from realhf.impl.model.backend.sglang import SGLangAPIClient
|
||||
|
||||
max_new_tokens = min(raw_gconfig.max_new_tokens, self.new_tokens_per_chunk)
|
||||
max_new_tokens = min(
|
||||
max_new_tokens,
|
||||
raw_gconfig.max_new_tokens - len(input_ids) + len(prompt_ids),
|
||||
)
|
||||
gconfig = raw_gconfig.new(
|
||||
n=1,
|
||||
max_new_tokens=min(raw_gconfig.max_new_tokens, self.new_tokens_per_chunk),
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
assert self.tokenizer.pad_token_id is not None
|
||||
assert self.tokenizer.eos_token_id is not None
|
||||
|
@ -130,6 +135,7 @@ class PartialRolloutManager:
|
|||
group_idx=group_idx,
|
||||
raw_gconfig=raw_gconfig,
|
||||
server_url=url,
|
||||
version=cur_server_version,
|
||||
),
|
||||
),
|
||||
stream=False,
|
||||
|
@ -190,6 +196,7 @@ class PartialRolloutManager:
|
|||
s: APIGenerateOutput = await task
|
||||
group_idx = s.metadata["group_idx"]
|
||||
raw_gconfig = s.metadata["raw_gconfig"]
|
||||
previous_version = s.metadata["version"]
|
||||
|
||||
assert s.group_size == 1
|
||||
no_eos = s.no_eos[0]
|
||||
|
@ -202,20 +209,27 @@ class PartialRolloutManager:
|
|||
if no_eos and gen_len < raw_gconfig.max_new_tokens:
|
||||
# Unfinished request due to chunked generation.
|
||||
# Send it back to continue.
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"http://{self.gserver_manager_addr}/get_model_version",
|
||||
json=dict(server_url=s.metadata["server_url"]),
|
||||
timeout=ClientTimeout(total=self.timeout, sock_connect=30),
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
cur_version = (await resp.json())["version"]
|
||||
req_meta = GenReqMeta(
|
||||
qid=s.qid,
|
||||
prompt_len=s.prompt_len,
|
||||
group_size=raw_gconfig.n,
|
||||
new_token_budget=raw_gconfig.max_new_tokens,
|
||||
predicted_new_tokens=None,
|
||||
previous_server_url=s.metadata["server_url"],
|
||||
previous_version=previous_version,
|
||||
)
|
||||
info = await self._schedule_request(req_meta)
|
||||
cur_version = info["version"]
|
||||
server_url = info["url"]
|
||||
|
||||
if len(s.output_logprobs) > 0:
|
||||
prev_logprobs = s.prev_logprobs + s.output_logprobs[0]
|
||||
else:
|
||||
prev_logprobs = s.prev_logprobs
|
||||
if prev_logprobs is None:
|
||||
prev_logprobs = []
|
||||
await self._issue_generation(
|
||||
s.metadata["server_url"],
|
||||
server_url,
|
||||
s.qid,
|
||||
group_idx,
|
||||
s.prompt_ids,
|
||||
|
@ -240,9 +254,10 @@ class PartialRolloutManager:
|
|||
try:
|
||||
qid, prompt_token_ids, gconfig = self.request_queue.get_nowait()
|
||||
req_meta = GenReqMeta(
|
||||
qid=qid,
|
||||
prompt_len=len(prompt_token_ids),
|
||||
group_size=gconfig.n,
|
||||
new_token_budget=self.new_tokens_per_chunk,
|
||||
new_token_budget=gconfig.max_new_tokens,
|
||||
predicted_new_tokens=None,
|
||||
)
|
||||
dst_server_info = await self._schedule_request(req_meta)
|
||||
|
|
|
@ -171,7 +171,9 @@ class NameResolvingZmqPuller(ZMQJsonPuller):
|
|||
name = names.push_pull_stream(
|
||||
experiment_name, trial_name, stream_name=f"puller{puller_index}"
|
||||
)
|
||||
host, port = network.gethostip(), network.find_free_port()
|
||||
host, port = network.gethostip(), network.find_free_port(
|
||||
experiment_name=experiment_name, trial_name=trial_name
|
||||
)
|
||||
addr = f"{host}:{port}"
|
||||
name_resolve.add(name, addr)
|
||||
super().__init__(host, port, **kwargs)
|
||||
|
|
|
@ -189,10 +189,11 @@ class RolloutWorker(AsyncWorker):
|
|||
assert data_id not in self.rollout_tasks
|
||||
return cur_sample
|
||||
|
||||
async def allocate_new_rollout(self) -> bool:
|
||||
async def allocate_new_rollout(self, qid) -> bool:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
async with session.post(
|
||||
f"http://{self.gserver_manager_addr}/allocate_rollout",
|
||||
json=dict(qid=qid),
|
||||
timeout=ClientTimeout(
|
||||
total=self.config.rollout_request_timeout, sock_connect=30
|
||||
),
|
||||
|
@ -231,10 +232,10 @@ class RolloutWorker(AsyncWorker):
|
|||
self._cur_data = self.load_next_data()
|
||||
|
||||
if self._cur_data is not None:
|
||||
can_rollout = await self.allocate_new_rollout()
|
||||
if can_rollout:
|
||||
data = self._cur_data
|
||||
qid = data.ids[0]
|
||||
can_rollout = await self.allocate_new_rollout(qid)
|
||||
if can_rollout:
|
||||
self.act_queues[qid] = asyncio.Queue(1024)
|
||||
|
||||
task = asyncio.create_task(self.rollout_task(qid, data))
|
||||
|
@ -265,7 +266,11 @@ class RolloutWorker(AsyncWorker):
|
|||
accepted = True
|
||||
self.push_stream.push([traj.as_json_compatible() for traj in trajs])
|
||||
|
||||
info = dict(qid=qid, accepted=accepted)
|
||||
n_tokens = 0
|
||||
for traj in trajs:
|
||||
seqlens = [sum(datapack.flat2d(ss)) for ss in traj.seqlens.values()]
|
||||
n_tokens += max(seqlens)
|
||||
info = dict(qid=qid, accepted=accepted, n_tokens=n_tokens)
|
||||
async with aiohttp.ClientSession(
|
||||
f"http://{self.gserver_manager_addr}"
|
||||
) as session:
|
||||
|
|
|
@ -36,7 +36,6 @@ ray
|
|||
redis
|
||||
scipy
|
||||
seaborn
|
||||
setuptools>=61.0
|
||||
tqdm
|
||||
networkx==3.3
|
||||
matplotlib
|
||||
|
@ -59,3 +58,5 @@ protobuf<3.21
|
|||
rich
|
||||
orjson>=3.10.16
|
||||
flask
|
||||
setuptools>=62.3.0,<75.9
|
||||
func_timeout
|
||||
|
|
24
setup.py
24
setup.py
|
@ -246,30 +246,6 @@ if not no_ext and _is_cuda():
|
|||
ext_modules.append(interval_op_cuda)
|
||||
|
||||
if not no_ext:
|
||||
search_extension = setuptools.Extension(
|
||||
name="realhf._C.mdm_search",
|
||||
sources=[
|
||||
"csrc/search/search.cpp",
|
||||
"csrc/search/rpc.cpp",
|
||||
"csrc/search/device_mesh.cpp",
|
||||
"csrc/search/simulate.cpp",
|
||||
],
|
||||
language="c++",
|
||||
extra_compile_args=[
|
||||
"-O3",
|
||||
"-Wall",
|
||||
"-shared",
|
||||
"-std=c++11",
|
||||
"-fPIC",
|
||||
"-std=c++17",
|
||||
],
|
||||
include_dirs=[
|
||||
os.path.join(os.path.abspath(os.path.dirname(__file__)), "csrc", "search"),
|
||||
get_pybind11_include_path(),
|
||||
],
|
||||
)
|
||||
ext_modules.append(search_extension)
|
||||
|
||||
interval_extension = setuptools.Extension(
|
||||
name="realhf._C.interval_op",
|
||||
sources=[
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue